hax_lib_protocol_macros/
lib.rs

1use quote::quote;
2use syn::{parse, parse_macro_input};
3
4/// This macro takes an `fn` as the basis of an `InitialState` implementation
5/// for the state type that is returned by the `fn` (on success).
6///
7/// The `fn` is expected to build the state type specified as a `Path` attribute
8/// argument from a `Vec<u8>`, i.e. the signature should be compatible with
9/// `TryFrom<Vec<u8>>` for the state type given as argument to the macro.
10///
11/// Example:
12/// ```ignore
13/// pub struct A0 {
14///   data: u8,
15/// }
16///
17/// #[hax_lib_protocol_macros::init(A0)]
18/// fn init_a(prologue: Vec<u8>) -> ::hax_lib_protocol::ProtocolResult<A0> {
19///     if prologue.len() < 1 {
20///        return Err(::hax_lib_protocol::ProtocolError::InvalidPrologue);
21///     }
22///     Ok(A0 { data: prologue[0] })
23/// }
24///
25/// // The following is generated by the macro:
26/// #[hax_lib::exclude]
27/// impl TryFrom<Vec<u8>> for A0 {
28///     type Error = ::hax_lib_protocol::ProtocolError;
29///     fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
30///         init_a(value)
31///     }
32/// }
33/// #[hax_lib::exclude]
34/// impl InitialState for A0 {
35///     fn init(prologue: Option<Vec<u8>>) -> ::hax_lib_protocol::ProtocolResult<Self> {
36///         if let Some(prologue) = prologue {
37///             prologue.try_into()
38///         } else {
39///             Err(::hax_lib_protocol::ProtocolError::InvalidPrologue)
40///         }
41///     }
42/// }
43/// ```
44#[proc_macro_attribute]
45pub fn init(
46    attr: proc_macro::TokenStream,
47    item: proc_macro::TokenStream,
48) -> proc_macro::TokenStream {
49    let mut output = quote!(#[hax_lib::process_init]);
50    output.extend(proc_macro2::TokenStream::from(item.clone()));
51
52    let input: syn::ItemFn = parse_macro_input!(item);
53    let return_type: syn::Path = parse_macro_input!(attr);
54    let name = input.sig.ident;
55
56    let expanded = quote!(
57        #[hax_lib::exclude]
58        impl TryFrom<Vec<u8>> for #return_type {
59            type Error = ::hax_lib_protocol::ProtocolError;
60
61            fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
62                #name(value)
63            }
64        }
65
66        #[hax_lib::exclude]
67        impl InitialState for #return_type {
68            fn init(prologue: Option<Vec<u8>>) -> ::hax_lib_protocol::ProtocolResult<Self> {
69                if let Some(prologue) = prologue {
70                    prologue.try_into()
71                } else {
72                    Err(::hax_lib_protocol::ProtocolError::InvalidPrologue)
73                }
74            }
75        }
76    );
77    output.extend(expanded);
78
79    output.into()
80}
81
82/// This macro takes an `fn` as the basis of an `InitialState` implementation
83/// for the state type that is returned by the `fn` (on success).
84///
85/// The `fn` is expected to build the state type specified as a `Path` attribute
86/// argument without additional input.
87/// Example:
88/// ```ignore
89/// pub struct B0 {}
90///
91/// #[hax_lib_protocol_macros::init_empty(B0)]
92/// fn init_b() -> ::hax_lib_protocol::ProtocolResult<B0> {
93///    Ok(B0 {})
94/// }
95///
96/// // The following is generated by the macro:
97/// #[hax_lib::exclude]
98/// impl InitialState for B0 {
99///     fn init(prologue: Option<Vec<u8>>) -> ::hax_lib_protocol::ProtocolResult<Self> {
100///         if let Some(_) = prologue {
101///             Err(::hax_lib_protocol::ProtocolError::InvalidPrologue)
102///         } else {
103///             init_b()
104///         }
105///     }
106/// }
107/// ```
108#[proc_macro_error2::proc_macro_error]
109#[proc_macro_attribute]
110pub fn init_empty(
111    attr: proc_macro::TokenStream,
112    item: proc_macro::TokenStream,
113) -> proc_macro::TokenStream {
114    let mut output = quote!(#[hax_lib::process_init]);
115    output.extend(proc_macro2::TokenStream::from(item.clone()));
116
117    let input: syn::ItemFn = parse_macro_input!(item);
118    let return_type: syn::Path = parse_macro_input!(attr);
119    let name = input.sig.ident;
120
121    let expanded = quote!(
122        #[hax_lib::exclude]
123        impl InitialState for #return_type {
124            fn init(prologue: Option<Vec<u8>>) -> ::hax_lib_protocol::ProtocolResult<Self> {
125                if let Some(_) = prologue {
126                    Err(::hax_lib_protocol::ProtocolError::InvalidPrologue)
127                } else {
128                    #name()
129                }
130            }
131        }
132    );
133    output.extend(expanded);
134
135    return output.into();
136}
137
138/// A structure to parse transition tuples from `read` and `write` macros.
139struct Transition {
140    /// `Path` to the current state type of the transition.
141    pub current_state: syn::Path,
142    /// `Path` to the destination state type of the transition.
143    pub next_state: syn::Path,
144    /// `Path` to the message type this transition is based on.
145    pub message_type: syn::Path,
146}
147
148impl syn::parse::Parse for Transition {
149    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
150        use syn::spanned::Spanned;
151        let punctuated =
152            syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated(input)?;
153        if punctuated.len() != 3 {
154            Err(syn::Error::new(
155                punctuated.span(),
156                "Insufficient number of arguments",
157            ))
158        } else {
159            let mut args = punctuated.into_iter();
160            Ok(Self {
161                current_state: args.next().unwrap(),
162                next_state: args.next().unwrap(),
163                message_type: args.next().unwrap(),
164            })
165        }
166    }
167}
168
169/// Macro deriving a `WriteState` implementation for the origin state type,
170/// generating a message of `message_type` and a new state, as indicated by the
171/// transition tuple.
172///
173/// Example:
174/// ```ignore
175/// #[hax_lib_protocol_macros::write(A0, A1, Message)]
176/// fn write_ping(state: A0) -> ::hax_lib_protocol::ProtocolResult<(A1, Message)> {
177///    Ok((A1 {}, Message::Ping(state.data)))
178/// }
179///
180/// // The following is generated by the macro:
181/// #[hax_lib::exclude]
182/// impl TryFrom<A0> for (A1, Message) {
183///    type Error = ::hax_lib_protocol::ProtocolError;
184///
185///    fn try_from(value: A0) -> Result<Self, Self::Error> {
186///       write_ping(value)
187///    }
188/// }
189///
190/// #[hax_lib::exclude]
191/// impl WriteState for A0 {
192///    type NextState = A1;
193///    type Message = Message;
194///
195///    fn write(self) -> ::hax_lib_protocol::ProtocolResult<(Self::NextState, Message)> {
196///        self.try_into()
197///    }
198/// }
199/// ```
200#[proc_macro_attribute]
201pub fn write(
202    attr: proc_macro::TokenStream,
203    item: proc_macro::TokenStream,
204) -> proc_macro::TokenStream {
205    let mut output = quote!(#[hax_lib::process_write]);
206    output.extend(proc_macro2::TokenStream::from(item.clone()));
207
208    let input: syn::ItemFn = parse_macro_input!(item);
209    let Transition {
210        current_state,
211        next_state,
212        message_type,
213    } = parse_macro_input!(attr);
214
215    let name = input.sig.ident;
216
217    let expanded = quote!(
218        #[hax_lib::exclude]
219        impl TryFrom<#current_state> for (#next_state, #message_type) {
220            type Error = ::hax_lib_protocol::ProtocolError;
221
222            fn try_from(value: #current_state) -> Result<Self, Self::Error> {
223                #name(value)
224            }
225        }
226
227        #[hax_lib::exclude]
228        impl WriteState for #current_state {
229            type NextState = #next_state;
230            type Message = #message_type;
231
232            fn write(self) -> ::hax_lib_protocol::ProtocolResult<(Self::NextState, Self::Message)> {
233                self.try_into()
234            }
235        }
236    );
237    output.extend(expanded);
238
239    output.into()
240}
241
242/// Macro deriving a `ReadState` implementation for the destination state type,
243/// consuming a message of `message_type` and the current state, as indicated by
244/// the transition tuple.
245///
246/// Example:
247/// ```ignore
248/// #[hax_lib_protocol_macros::read(A1, A2, Message)]
249/// fn read_pong(_state: A1, msg: Message) -> ::hax_lib_protocol::ProtocolResult<A2> {
250///     match msg {
251///         Message::Ping(_) => Err(::hax_lib_protocol::ProtocolError::InvalidMessage),
252///         Message::Pong(received) => Ok(A2 { received }),
253///     }
254/// }
255/// // The following is generated by the macro:
256/// #[hax_lib::exclude]
257/// impl TryFrom<(A1, Message)> for A2 {
258///     type Error = ::hax_lib_protocol::ProtocolError;
259///     fn try_from((state, msg): (A1, Message)) -> Result<Self, Self::Error> {
260///         read_pong(state, msg)
261///     }
262/// }
263/// #[hax_lib::exclude]
264/// impl ReadState<A2> for A1 {
265///     type Message = Message;
266///     fn read(self, msg: Message) -> ::hax_lib_protocol::ProtocolResult<A2> {
267///         A2::try_from((self, msg))
268///     }
269/// }
270/// ```
271#[proc_macro_attribute]
272pub fn read(
273    attr: proc_macro::TokenStream,
274    item: proc_macro::TokenStream,
275) -> proc_macro::TokenStream {
276    let mut output = quote!(#[hax_lib::process_read]);
277    output.extend(proc_macro2::TokenStream::from(item.clone()));
278
279    let input: syn::ItemFn = parse_macro_input!(item);
280    let Transition {
281        current_state,
282        next_state,
283        message_type,
284    } = parse_macro_input!(attr);
285
286    let name = input.sig.ident;
287
288    let expanded = quote!(
289        #[hax_lib::exclude]
290        impl TryFrom<(#current_state, #message_type)> for #next_state {
291            type Error = ::hax_lib_protocol::ProtocolError;
292
293            fn try_from((state, msg): (#current_state, #message_type)) -> Result<Self, Self::Error> {
294                #name(state, msg)
295            }
296        }
297
298        #[hax_lib::exclude]
299        impl ReadState<#next_state> for #current_state {
300            type Message = #message_type;
301            fn read(self, msg: Self::Message) -> ::hax_lib_protocol::ProtocolResult<#next_state> {
302                #next_state::try_from((self, msg))
303            }
304        }
305    );
306    output.extend(expanded);
307
308    output.into()
309}