hax_lib_protocol_macros/lib.rs
1use quote::quote;
2use syn::{parse, parse_macro_input};
3
4/// This macro takes an `fn` as the basis of an `InitialState` implementation
5/// for the state type that is returned by the `fn` (on success).
6///
7/// The `fn` is expected to build the state type specified as a `Path` attribute
8/// argument from a `Vec<u8>`, i.e. the signature should be compatible with
9/// `TryFrom<Vec<u8>>` for the state type given as argument to the macro.
10///
11/// Example:
12/// ```ignore
13/// pub struct A0 {
14/// data: u8,
15/// }
16///
17/// #[hax_lib_protocol_macros::init(A0)]
18/// fn init_a(prologue: Vec<u8>) -> ::hax_lib_protocol::ProtocolResult<A0> {
19/// if prologue.len() < 1 {
20/// return Err(::hax_lib_protocol::ProtocolError::InvalidPrologue);
21/// }
22/// Ok(A0 { data: prologue[0] })
23/// }
24///
25/// // The following is generated by the macro:
26/// #[hax_lib::exclude]
27/// impl TryFrom<Vec<u8>> for A0 {
28/// type Error = ::hax_lib_protocol::ProtocolError;
29/// fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
30/// init_a(value)
31/// }
32/// }
33/// #[hax_lib::exclude]
34/// impl InitialState for A0 {
35/// fn init(prologue: Option<Vec<u8>>) -> ::hax_lib_protocol::ProtocolResult<Self> {
36/// if let Some(prologue) = prologue {
37/// prologue.try_into()
38/// } else {
39/// Err(::hax_lib_protocol::ProtocolError::InvalidPrologue)
40/// }
41/// }
42/// }
43/// ```
44#[proc_macro_attribute]
45pub fn init(
46 attr: proc_macro::TokenStream,
47 item: proc_macro::TokenStream,
48) -> proc_macro::TokenStream {
49 let mut output = quote!(#[hax_lib::process_init]);
50 output.extend(proc_macro2::TokenStream::from(item.clone()));
51
52 let input: syn::ItemFn = parse_macro_input!(item);
53 let return_type: syn::Path = parse_macro_input!(attr);
54 let name = input.sig.ident;
55
56 let expanded = quote!(
57 #[hax_lib::exclude]
58 impl TryFrom<Vec<u8>> for #return_type {
59 type Error = ::hax_lib_protocol::ProtocolError;
60
61 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
62 #name(value)
63 }
64 }
65
66 #[hax_lib::exclude]
67 impl InitialState for #return_type {
68 fn init(prologue: Option<Vec<u8>>) -> ::hax_lib_protocol::ProtocolResult<Self> {
69 if let Some(prologue) = prologue {
70 prologue.try_into()
71 } else {
72 Err(::hax_lib_protocol::ProtocolError::InvalidPrologue)
73 }
74 }
75 }
76 );
77 output.extend(expanded);
78
79 output.into()
80}
81
82/// This macro takes an `fn` as the basis of an `InitialState` implementation
83/// for the state type that is returned by the `fn` (on success).
84///
85/// The `fn` is expected to build the state type specified as a `Path` attribute
86/// argument without additional input.
87/// Example:
88/// ```ignore
89/// pub struct B0 {}
90///
91/// #[hax_lib_protocol_macros::init_empty(B0)]
92/// fn init_b() -> ::hax_lib_protocol::ProtocolResult<B0> {
93/// Ok(B0 {})
94/// }
95///
96/// // The following is generated by the macro:
97/// #[hax_lib::exclude]
98/// impl InitialState for B0 {
99/// fn init(prologue: Option<Vec<u8>>) -> ::hax_lib_protocol::ProtocolResult<Self> {
100/// if let Some(_) = prologue {
101/// Err(::hax_lib_protocol::ProtocolError::InvalidPrologue)
102/// } else {
103/// init_b()
104/// }
105/// }
106/// }
107/// ```
108#[proc_macro_error2::proc_macro_error]
109#[proc_macro_attribute]
110pub fn init_empty(
111 attr: proc_macro::TokenStream,
112 item: proc_macro::TokenStream,
113) -> proc_macro::TokenStream {
114 let mut output = quote!(#[hax_lib::process_init]);
115 output.extend(proc_macro2::TokenStream::from(item.clone()));
116
117 let input: syn::ItemFn = parse_macro_input!(item);
118 let return_type: syn::Path = parse_macro_input!(attr);
119 let name = input.sig.ident;
120
121 let expanded = quote!(
122 #[hax_lib::exclude]
123 impl InitialState for #return_type {
124 fn init(prologue: Option<Vec<u8>>) -> ::hax_lib_protocol::ProtocolResult<Self> {
125 if let Some(_) = prologue {
126 Err(::hax_lib_protocol::ProtocolError::InvalidPrologue)
127 } else {
128 #name()
129 }
130 }
131 }
132 );
133 output.extend(expanded);
134
135 return output.into();
136}
137
138/// A structure to parse transition tuples from `read` and `write` macros.
139struct Transition {
140 /// `Path` to the current state type of the transition.
141 pub current_state: syn::Path,
142 /// `Path` to the destination state type of the transition.
143 pub next_state: syn::Path,
144 /// `Path` to the message type this transition is based on.
145 pub message_type: syn::Path,
146}
147
148impl syn::parse::Parse for Transition {
149 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
150 use syn::spanned::Spanned;
151 let punctuated =
152 syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated(input)?;
153 if punctuated.len() != 3 {
154 Err(syn::Error::new(
155 punctuated.span(),
156 "Insufficient number of arguments",
157 ))
158 } else {
159 let mut args = punctuated.into_iter();
160 Ok(Self {
161 current_state: args.next().unwrap(),
162 next_state: args.next().unwrap(),
163 message_type: args.next().unwrap(),
164 })
165 }
166 }
167}
168
169/// Macro deriving a `WriteState` implementation for the origin state type,
170/// generating a message of `message_type` and a new state, as indicated by the
171/// transition tuple.
172///
173/// Example:
174/// ```ignore
175/// #[hax_lib_protocol_macros::write(A0, A1, Message)]
176/// fn write_ping(state: A0) -> ::hax_lib_protocol::ProtocolResult<(A1, Message)> {
177/// Ok((A1 {}, Message::Ping(state.data)))
178/// }
179///
180/// // The following is generated by the macro:
181/// #[hax_lib::exclude]
182/// impl TryFrom<A0> for (A1, Message) {
183/// type Error = ::hax_lib_protocol::ProtocolError;
184///
185/// fn try_from(value: A0) -> Result<Self, Self::Error> {
186/// write_ping(value)
187/// }
188/// }
189///
190/// #[hax_lib::exclude]
191/// impl WriteState for A0 {
192/// type NextState = A1;
193/// type Message = Message;
194///
195/// fn write(self) -> ::hax_lib_protocol::ProtocolResult<(Self::NextState, Message)> {
196/// self.try_into()
197/// }
198/// }
199/// ```
200#[proc_macro_attribute]
201pub fn write(
202 attr: proc_macro::TokenStream,
203 item: proc_macro::TokenStream,
204) -> proc_macro::TokenStream {
205 let mut output = quote!(#[hax_lib::process_write]);
206 output.extend(proc_macro2::TokenStream::from(item.clone()));
207
208 let input: syn::ItemFn = parse_macro_input!(item);
209 let Transition {
210 current_state,
211 next_state,
212 message_type,
213 } = parse_macro_input!(attr);
214
215 let name = input.sig.ident;
216
217 let expanded = quote!(
218 #[hax_lib::exclude]
219 impl TryFrom<#current_state> for (#next_state, #message_type) {
220 type Error = ::hax_lib_protocol::ProtocolError;
221
222 fn try_from(value: #current_state) -> Result<Self, Self::Error> {
223 #name(value)
224 }
225 }
226
227 #[hax_lib::exclude]
228 impl WriteState for #current_state {
229 type NextState = #next_state;
230 type Message = #message_type;
231
232 fn write(self) -> ::hax_lib_protocol::ProtocolResult<(Self::NextState, Self::Message)> {
233 self.try_into()
234 }
235 }
236 );
237 output.extend(expanded);
238
239 output.into()
240}
241
242/// Macro deriving a `ReadState` implementation for the destination state type,
243/// consuming a message of `message_type` and the current state, as indicated by
244/// the transition tuple.
245///
246/// Example:
247/// ```ignore
248/// #[hax_lib_protocol_macros::read(A1, A2, Message)]
249/// fn read_pong(_state: A1, msg: Message) -> ::hax_lib_protocol::ProtocolResult<A2> {
250/// match msg {
251/// Message::Ping(_) => Err(::hax_lib_protocol::ProtocolError::InvalidMessage),
252/// Message::Pong(received) => Ok(A2 { received }),
253/// }
254/// }
255/// // The following is generated by the macro:
256/// #[hax_lib::exclude]
257/// impl TryFrom<(A1, Message)> for A2 {
258/// type Error = ::hax_lib_protocol::ProtocolError;
259/// fn try_from((state, msg): (A1, Message)) -> Result<Self, Self::Error> {
260/// read_pong(state, msg)
261/// }
262/// }
263/// #[hax_lib::exclude]
264/// impl ReadState<A2> for A1 {
265/// type Message = Message;
266/// fn read(self, msg: Message) -> ::hax_lib_protocol::ProtocolResult<A2> {
267/// A2::try_from((self, msg))
268/// }
269/// }
270/// ```
271#[proc_macro_attribute]
272pub fn read(
273 attr: proc_macro::TokenStream,
274 item: proc_macro::TokenStream,
275) -> proc_macro::TokenStream {
276 let mut output = quote!(#[hax_lib::process_read]);
277 output.extend(proc_macro2::TokenStream::from(item.clone()));
278
279 let input: syn::ItemFn = parse_macro_input!(item);
280 let Transition {
281 current_state,
282 next_state,
283 message_type,
284 } = parse_macro_input!(attr);
285
286 let name = input.sig.ident;
287
288 let expanded = quote!(
289 #[hax_lib::exclude]
290 impl TryFrom<(#current_state, #message_type)> for #next_state {
291 type Error = ::hax_lib_protocol::ProtocolError;
292
293 fn try_from((state, msg): (#current_state, #message_type)) -> Result<Self, Self::Error> {
294 #name(state, msg)
295 }
296 }
297
298 #[hax_lib::exclude]
299 impl ReadState<#next_state> for #current_state {
300 type Message = #message_type;
301 fn read(self, msg: Self::Message) -> ::hax_lib_protocol::ProtocolResult<#next_state> {
302 #next_state::try_from((self, msg))
303 }
304 }
305 );
306 output.extend(expanded);
307
308 output.into()
309}