hax_adt_into/
lib.rs

1use quote::quote;
2use quote::quote_spanned;
3use syn::Token;
4use syn::parse::ParseStream;
5use syn::{Data, DeriveInput, Generics, parse_macro_input};
6use syn::{PathArguments, PathSegment, spanned::Spanned};
7
8fn strip_parenthesis(tokens: proc_macro::TokenStream) -> Option<proc_macro::TokenStream> {
9    match tokens.into_iter().collect::<Vec<_>>().as_slice() {
10        [proc_macro::TokenTree::Group(token)] => Some(token.stream()),
11        _ => None,
12    }
13}
14
15#[derive(Debug)]
16struct Options {
17    generics: Generics,
18    from: syn::TypePath,
19    state: syn::Ident,
20    state_type: syn::Type,
21    where_clause: Option<syn::WhereClause>,
22}
23mod option_parse {
24    use super::*;
25    mod kw {
26        syn::custom_keyword!(from);
27        syn::custom_keyword!(state);
28    }
29    impl syn::parse::Parse for Options {
30        fn parse(input: ParseStream) -> syn::Result<Self> {
31            let generics = input.parse()?;
32            input.parse::<Token![,]>()?;
33
34            input.parse::<kw::from>()?;
35            input.parse::<Token![:]>()?;
36            let from = input.parse()?;
37            input.parse::<Token![,]>()?;
38
39            input.parse::<kw::state>()?;
40            input.parse::<Token![:]>()?;
41            let state_type = input.parse()?;
42            input.parse::<Token![as]>()?;
43            let state = input.parse()?;
44
45            let mut where_clause = None;
46            if input.peek(Token![,]) && input.peek2(Token![where]) {
47                input.parse::<Token![,]>()?;
48                where_clause = Some(input.parse()?);
49            }
50
51            Ok(Options {
52                generics,
53                from,
54                state,
55                state_type,
56                where_clause,
57            })
58        }
59    }
60}
61
62/// Returns the token stream corresponding to an attribute (if it
63/// exists), stripping parenthesis already.
64fn tokens_of_attrs<'a>(
65    attr_name: &'a str,
66    attrs: &'a Vec<syn::Attribute>,
67) -> impl Iterator<Item = proc_macro2::TokenStream> + 'a {
68    attrs
69        .iter()
70        .filter(|attr| attr.path.is_ident(attr_name))
71        .map(|attr| attr.clone().tokens.into())
72        .flat_map(strip_parenthesis)
73        .map(|x| x.into())
74}
75
76fn parse_attrs<'a, T: syn::parse::Parse>(
77    attr_name: &'a str,
78    attrs: &'a Vec<syn::Attribute>,
79) -> impl Iterator<Item = T> + 'a {
80    tokens_of_attrs(attr_name, attrs).map(move |x| {
81        syn::parse::<T>(x.clone().into())
82            .expect(format!("expected attribtue {}", attr_name).as_str())
83    })
84}
85
86/// Parse an attribute as a T if it exists.
87fn parse_attr<T: syn::parse::Parse>(attr_name: &str, attrs: &Vec<syn::Attribute>) -> Option<T> {
88    parse_attrs(attr_name, attrs).next()
89}
90
91/*
92TODO: add `ensure_no_attr` calls to forbid meaningless attributes
93fn ensure_no_attr(context: &str, attr: &str, attrs: &Vec<syn::Attribute>) {
94    if attrs.iter().any(|a| a.path.is_ident(attr)) {
95        panic!("Illegal attribute {} {}", attr, context)
96    }
97}
98*/
99
100/// Create a match arm that corresponds to a given set of fields.
101/// This can be used for named fields as well as unnamed ones.
102fn fields_to_arm(
103    from_record_name: proc_macro2::TokenStream,
104    to_record_name: proc_macro2::TokenStream,
105    fields: Vec<syn::Field>,
106    full_span: proc_macro2::Span,
107    prepend: proc_macro2::TokenStream,
108    used_fields: Vec<syn::Ident>,
109    state: syn::Ident,
110) -> proc_macro2::TokenStream {
111    if fields.is_empty() {
112        return quote_spanned! {full_span=> #from_record_name => #to_record_name, };
113    }
114
115    let is_struct = fields.iter().any(|f| f.ident.is_some());
116    let is_tuple = fields.iter().any(|f| f.ident.is_none());
117    if is_tuple && is_struct {
118        panic!("Impossibe: variant with both named and unamed fields")
119    }
120
121    let data = fields.iter().enumerate().map(|(i, field)| {
122        let attrs = &field.attrs;
123        let name_destination = field.ident.clone().unwrap_or(syn::Ident::new(
124            format!("value_{}", i).as_str(),
125            field.span(),
126        ));
127        let span = field.span();
128        let field_name_span = field.clone().ident.map(|x| x.span()).unwrap_or(span);
129        let name_source =
130            parse_attr::<syn::Ident>("from", attrs).unwrap_or(name_destination.clone());
131        let value = parse_attr::<syn::Expr>("value", attrs);
132        let not_in_source =
133            value.is_some() ||
134            attrs.iter().any(|attr| attr.path.is_ident("not_in_source"));
135        let typ = &field.ty;
136        let point = syn::Ident::new("x", field_name_span);
137
138        let translation = parse_attr::<syn::Expr>("map", attrs).or(value).unwrap_or(
139            syn::parse::<syn::Expr>((quote_spanned! {typ.span()=> #point.sinto(#state)}).into())
140                .expect("Could not default [translation]")
141        );
142        let mapped_value = if not_in_source {
143            quote_spanned! {span=> {#translation}}
144        } else {
145            quote_spanned! {span=> {#[allow(unused_variables)] let #point = #name_source; #translation}}
146        };
147
148        let prefix = if is_struct {
149            quote_spanned! {field_name_span=> #name_destination:}
150        } else {
151            quote! {}
152        };
153        (
154            if not_in_source {
155                quote! {}
156            } else {
157                quote_spanned! {span=> #name_source, }
158            },
159            quote_spanned! {span=> #prefix #mapped_value, },
160        )
161    });
162
163    let bindings: proc_macro2::TokenStream = data
164        .clone()
165        .map(|(x, _)| x)
166        .chain(used_fields.iter().map(|f| quote! {#f,}))
167        .collect();
168    let fields: proc_macro2::TokenStream = data.clone().map(|(_, x)| x).collect();
169
170    if is_struct {
171        quote_spanned! {full_span=> #from_record_name { #bindings .. } => {#prepend #to_record_name { #fields }}, }
172    } else {
173        quote_spanned! {full_span=> #from_record_name ( #bindings ) => {#prepend #to_record_name ( #fields )}, }
174    }
175}
176
177/// Extracts a vector of Field out of a Fields.
178/// This function discard the Unnamed / Named variants.
179fn field_vec_of_fields(fields: syn::Fields) -> Vec<syn::Field> {
180    match fields {
181        syn::Fields::Unit => vec![],
182        syn::Fields::Named(syn::FieldsNamed { named: fields, .. })
183        | syn::Fields::Unnamed(syn::FieldsUnnamed {
184            unnamed: fields, ..
185        }) => fields.into_iter().collect(),
186    }
187}
188
189/// Given a variant, produce a match arm.
190fn variant_to_arm(
191    typ_from: proc_macro2::TokenStream,
192    typ_to: proc_macro2::TokenStream,
193    variant: syn::Variant,
194    state: syn::Ident,
195) -> proc_macro2::TokenStream {
196    let attrs = &variant.attrs;
197    let to_variant = variant.clone().ident;
198    if attrs.iter().any(|attr| attr.path.is_ident("todo")) {
199        return quote!();
200    }
201
202    let disable_mapping = attrs
203        .iter()
204        .any(|attr| attr.path.is_ident("disable_mapping"));
205    let custom_arm = tokens_of_attrs("custom_arm", attrs).next();
206    // TODO: either complete map or drop it
207    let map = parse_attr::<syn::Expr>("map", attrs);
208    // ensure_no_attr(
209    //     format!("on the variant {}::{}", typ_to, to_variant).as_str(),
210    //     "map",
211    //     attrs,
212    // );
213    let from_variant = parse_attr::<syn::Ident>("from", attrs);
214
215    if disable_mapping && (map.is_some() || custom_arm.is_some() || from_variant.is_some()) {
216        println!("Warning: `disable_mapping` makes `map`, `custom_arm` and `from_variant` inert")
217    }
218    if custom_arm.is_some() && (map.is_some() || from_variant.is_some()) {
219        println!("Warning: `custom_arm` makes `map` and `from` inert")
220    }
221
222    if disable_mapping {
223        return quote! {};
224    }
225    if let Some(custom_arm) = custom_arm {
226        return custom_arm.into();
227    }
228
229    let from_variant = from_variant.unwrap_or(to_variant.clone());
230
231    let to_variant = quote! { #typ_to::#to_variant };
232    let from_variant = quote! { #typ_from::#from_variant };
233
234    let fields = field_vec_of_fields(variant.clone().fields);
235
236    if let Some(map) = map {
237        let names: proc_macro2::TokenStream = fields
238            .iter()
239            .filter(|f| {
240                let attrs = &f.attrs;
241                !(parse_attr::<syn::Expr>("value", attrs).is_some()
242                    || attrs.iter().any(|attr| attr.path.is_ident("not_in_source")))
243            })
244            .enumerate()
245            .map(|(nth, f)| {
246                f.clone()
247                    .ident
248                    .unwrap_or(syn::Ident::new(format!("x{}", nth).as_str(), f.span()))
249            })
250            .map(|name| quote! {#name, })
251            .collect();
252        if fields.iter().any(|f| f.ident.is_some()) {
253            quote_spanned!(variant.span()=> #from_variant {#names ..} => #map,)
254        } else {
255            quote_spanned!(variant.span()=> #from_variant (#names) => #map,)
256        }
257    } else {
258        fields_to_arm(
259            from_variant,
260            to_variant,
261            fields,
262            variant.span(),
263            tokens_of_attrs("prepend", attrs).collect(),
264            parse_attrs("use_field", attrs).collect(),
265            state,
266        )
267    }
268}
269
270/// [`AdtInto`] derives a
271/// [`SInto`](../hax_frontend_exporter/trait.SInto.html)
272/// instance. This helps at transporting a algebraic data type `A` to
273/// another ADT `B` when `A` and `B` shares a lot of structure.
274#[proc_macro_derive(
275    AdtInto,
276    attributes(
277        map,
278        from,
279        custom_arm,
280        disable_mapping,
281        use_field,
282        prepend,
283        append,
284        args,
285        todo,
286        not_in_source,
287        value,
288    )
289)]
290pub fn adt_into(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
291    let dinput = {
292        let input = input.clone();
293        parse_macro_input!(input as DeriveInput)
294    };
295    let attrs = &dinput.attrs;
296    let span = dinput.clone().span().clone();
297    let to = dinput.ident;
298    let to_generics = dinput.generics;
299
300    let Options {
301        generics,
302        from: from_with_generics,
303        state,
304        state_type,
305        where_clause,
306    } = parse_attr("args", attrs).expect("An [args] attribute was expected");
307
308    let generics = {
309        let mut generics = generics;
310        generics.params = merge_generic_params(
311            to_generics.params.clone().into_iter(),
312            generics.params.into_iter(),
313        )
314        .collect();
315        generics
316    };
317
318    trait DropBounds {
319        fn drop_bounds(&mut self);
320    }
321
322    impl DropBounds for syn::GenericParam {
323        fn drop_bounds(&mut self) {
324            use syn::GenericParam::*;
325            match self {
326                Lifetime(lf) => {
327                    lf.colon_token = None;
328                    lf.bounds.clear()
329                }
330                Type(t) => {
331                    t.colon_token = None;
332                    t.bounds.clear();
333                    t.eq_token = None;
334                    t.default = None;
335                }
336                Const(c) => {
337                    c.eq_token = None;
338                    c.default = None;
339                }
340            }
341        }
342    }
343    impl DropBounds for syn::Generics {
344        fn drop_bounds(&mut self) {
345            self.params.iter_mut().for_each(DropBounds::drop_bounds);
346        }
347    }
348    let to_generics = {
349        let mut to_generics = to_generics;
350        to_generics.drop_bounds();
351        to_generics
352    };
353
354    let from = drop_generics(from_with_generics.clone());
355
356    let append: proc_macro2::TokenStream = tokens_of_attrs("append", &dinput.attrs)
357        .next()
358        .unwrap_or((quote! {}).into())
359        .into();
360
361    let body = match &dinput.data {
362        Data::Union(..) => panic!("Union types are not supported"),
363        Data::Struct(syn::DataStruct { fields, .. }) => {
364            let arm = fields_to_arm(
365                quote! {#from},
366                quote! {#to},
367                field_vec_of_fields(fields.clone()),
368                span,
369                tokens_of_attrs("prepend", attrs).collect(),
370                parse_attrs("use_field", attrs).collect(),
371                state.clone(),
372            );
373            quote! { match self { #arm #append } }
374        }
375        Data::Enum(syn::DataEnum { variants, .. }) => {
376            let arms: proc_macro2::TokenStream = variants
377                .iter()
378                .cloned()
379                .map(|variant| variant_to_arm(quote! {#from}, quote! {#to}, variant, state.clone()))
380                .collect();
381            let todo = variants.iter().find_map(|variant| {
382                let attrs = &variant.attrs;
383                let to_variant = variant.clone().ident;
384                if attrs.iter().any(|attr| attr.path.is_ident("todo")) {
385                    Some (quote_spanned! {variant.span()=> x => TO_TYPE::#to_variant(format!("{:?}", x)),})
386                } else {
387                    None
388                }
389            }).unwrap_or(quote!{});
390            let append = quote! {
391                #append
392                #todo
393            };
394            quote! { match self { #arms #append } }
395        }
396    };
397
398    quote! {
399        #[cfg(feature = "rustc")]
400        const _ : () = {
401            use #from as FROM_TYPE;
402            use #to as TO_TYPE;
403            impl #generics SInto<#state_type, #to #to_generics> for #from_with_generics #where_clause {
404                #[tracing::instrument(level = "trace", skip(#state))]
405                fn sinto(&self, #state: &#state_type) -> #to #to_generics {
406                    tracing::trace!("Enters sinto ({})", stringify!(#from_with_generics));
407                    #body
408                }
409            }
410        };
411    }
412    .into()
413}
414
415/// Merge two collections of generic params, with params from [a]
416/// before the ones from [b]. This function ensures lifetimes
417/// appear before anything else.
418fn merge_generic_params(
419    a: impl Iterator<Item = syn::GenericParam>,
420    b: impl Iterator<Item = syn::GenericParam>,
421) -> impl Iterator<Item = syn::GenericParam> {
422    fn partition(
423        a: impl Iterator<Item = syn::GenericParam>,
424    ) -> (Vec<syn::GenericParam>, Vec<syn::GenericParam>) {
425        a.partition(|g| matches!(g, syn::GenericParam::Lifetime(_)))
426    }
427    let (a_lt, a_others) = partition(a);
428    let (b_lt, b_others) = partition(b);
429    let h = |x: Vec<_>, y: Vec<_>| x.into_iter().chain(y.into_iter());
430    h(a_lt, b_lt).chain(h(a_others, b_others))
431}
432
433fn drop_generics(type_path: syn::TypePath) -> syn::TypePath {
434    syn::TypePath {
435        path: syn::Path {
436            segments: type_path
437                .path
438                .segments
439                .into_iter()
440                .map(|s| PathSegment {
441                    ident: s.ident,
442                    arguments: match s.arguments {
443                        PathArguments::AngleBracketed(_) => PathArguments::None,
444                        _ => s.arguments,
445                    },
446                })
447                .collect(),
448            ..type_path.path
449        },
450        ..type_path
451    }
452}
453
454/// A proc macro unrelated to `adt-into`: it is useful in hax
455/// and we don't want a whole crate only for that helper.
456///
457/// This proc macro defines some groups of derive clauses that
458/// we reuse all the time.
459#[proc_macro_attribute]
460pub fn derive_group(
461    attr: proc_macro::TokenStream,
462    item: proc_macro::TokenStream,
463) -> proc_macro::TokenStream {
464    let item: proc_macro2::TokenStream = item.into();
465    let groups = format!("{attr}");
466    let groups = groups.split(",").map(|s| s.trim());
467    let mut errors = vec![];
468    let result: proc_macro2::TokenStream = groups
469        .map(|group| match group {
470            "Serializers" => quote! {
471                #[derive(::serde::Serialize, ::serde::Deserialize)]
472            },
473            _ => {
474                errors.push(quote! {
475                    const _: () = compile_error!(concat!(
476                        "derive_group: `",
477                        stringify!(#group),
478                        "` is not a recognized group name"
479                    ));
480                });
481                quote! {}
482            }
483        })
484        .collect();
485    quote! {#(#errors)* #result #item}.into()
486}