hax_frontend_exporter/
id_table.rs

1/// This module provides a notion of table, identifiers and nodes. A
2/// `Node<T>` is a `Arc<T>` bundled with a unique identifier such that
3/// there exists an entry in a table for that identifier.
4///
5/// The type `WithTable<T>` bundles a table with a value of type
6/// `T`. That value of type `T` may hold an arbitrary number of
7/// `Node<_>`s. In the context of a `WithTable<T>`, the type `Node<_>`
8/// serializes and deserializes using a table as a state. In this
9/// case, serializing a `Node<U>` produces only an identifier, without
10/// any data of type `U`. Deserializing a `Node<U>` under a
11/// `WithTable<T>` will recover `U` data from the table held by
12/// `WithTable`.
13///
14/// Serde is not designed for stateful (de)serialization. There is no
15/// way of deriving `serde::de::DeserializeSeed` systematically. This
16/// module thus makes use of global state to achieve serialization and
17/// deserialization. This modules provides an API that hides this
18/// global state.
19use crate::prelude::*;
20use std::{
21    hash::{Hash, Hasher},
22    sync::{Arc, LazyLock, Mutex, MutexGuard, atomic::Ordering},
23};
24
25/// Unique IDs in a ID table.
26#[derive_group(Serializers)]
27#[derive(Default, Clone, Debug, JsonSchema, Hash, PartialEq, Eq, PartialOrd, Ord)]
28#[serde(transparent)]
29pub struct Id {
30    id: u32,
31}
32
33/// A session providing fresh IDs for ID table.
34#[derive(Default, Debug)]
35pub struct Session {
36    next_id: Id,
37    table: Table,
38}
39
40impl Session {
41    pub fn table(&self) -> &Table {
42        &self.table
43    }
44}
45
46/// The different types of values one can store in an ID table.
47#[derive(Debug, Clone, Deserialize, Serialize)]
48pub enum Value {
49    Ty(Arc<TyKind>),
50    DefId(Arc<DefIdContents>),
51    ItemRef(Arc<ItemRefContents>),
52}
53
54impl SupportedType<Value> for TyKind {
55    fn to_types(value: Arc<Self>) -> Value {
56        Value::Ty(value)
57    }
58    fn from_types(t: &Value) -> Option<Arc<Self>> {
59        match t {
60            Value::Ty(value) => Some(value.clone()),
61            _ => None,
62        }
63    }
64}
65
66impl SupportedType<Value> for DefIdContents {
67    fn to_types(value: Arc<Self>) -> Value {
68        Value::DefId(value)
69    }
70    fn from_types(t: &Value) -> Option<Arc<Self>> {
71        match t {
72            Value::DefId(value) => Some(value.clone()),
73            _ => None,
74        }
75    }
76}
77
78impl SupportedType<Value> for ItemRefContents {
79    fn to_types(value: Arc<Self>) -> Value {
80        Value::ItemRef(value)
81    }
82    fn from_types(t: &Value) -> Option<Arc<Self>> {
83        match t {
84            Value::ItemRef(value) => Some(value.clone()),
85            _ => None,
86        }
87    }
88}
89
90/// A node is a bundle of an ID with a value.
91#[derive(Deserialize, Serialize, Debug, JsonSchema, PartialEq, Eq, PartialOrd, Ord)]
92#[serde(into = "serde_repr::NodeRepr<T>")]
93#[serde(try_from = "serde_repr::NodeRepr<T>")]
94pub struct Node<T: 'static + SupportedType<Value>> {
95    id: Id,
96    value: Arc<T>,
97}
98
99impl<T: SupportedType<Value>> std::ops::Deref for Node<T> {
100    type Target = T;
101    fn deref(&self) -> &Self::Target {
102        self.value.as_ref()
103    }
104}
105
106/// Hax relies on hashes being deterministic for predicates
107/// ids. Identifiers are not deterministic: we implement hash for
108/// `Node` manually, discarding the field `id`.
109impl<T: SupportedType<Value> + Hash> Hash for Node<T> {
110    fn hash<H: Hasher>(&self, state: &mut H) {
111        self.value.as_ref().hash(state);
112    }
113}
114
115/// Manual implementation of `Clone` that doesn't require a `Clone`
116/// bound on `T`.
117impl<T: SupportedType<Value>> Clone for Node<T> {
118    fn clone(&self) -> Self {
119        Self {
120            id: self.id.clone(),
121            value: self.value.clone(),
122        }
123    }
124}
125
126/// A table is a map from IDs to `Value`s. When serialized, we
127/// represent a table as a *sorted* vector. Indeed, the values stored
128/// in the table might reference each other, without cycle, so the
129/// order matters.
130#[derive(Default, Debug, Clone, Deserialize, Serialize)]
131#[serde(into = "serde_repr::SortedIdValuePairs")]
132#[serde(from = "serde_repr::SortedIdValuePairs")]
133pub struct Table(HeterogeneousMap<Id, Value>);
134
135mod heterogeneous_map {
136    //! This module provides an heterogenous map that can store types
137    //! that implement the trait `SupportedType`.
138
139    use std::collections::HashMap;
140    use std::hash::Hash;
141    use std::sync::Arc;
142    #[derive(Clone, Debug)]
143    /// An heterogenous map is a map from `Key` to `Value`. It provide
144    /// the methods `insert` and `get` for any type `T` that
145    /// implements `SupportedType<Value>`.
146    pub struct HeterogeneousMap<Key, Value>(HashMap<Key, Value>);
147
148    impl<Id, Value> Default for HeterogeneousMap<Id, Value> {
149        fn default() -> Self {
150            Self(HashMap::default())
151        }
152    }
153
154    impl<Key: Hash + Eq + PartialEq, Value> HeterogeneousMap<Key, Value> {
155        pub(super) fn insert<T>(&mut self, key: Key, value: Arc<T>)
156        where
157            T: SupportedType<Value>,
158        {
159            self.insert_raw_value(key, T::to_types(value));
160        }
161        pub(super) fn insert_raw_value(&mut self, key: Key, value: Value) {
162            self.0.insert(key, value);
163        }
164        pub(super) fn from_iter(it: impl Iterator<Item = (Key, Value)>) -> Self {
165            Self(HashMap::from_iter(it))
166        }
167        pub(super) fn into_iter(self) -> impl Iterator<Item = (Key, Value)> {
168            self.0.into_iter()
169        }
170        pub(super) fn get<T>(&self, key: &Key) -> Option<Option<Arc<T>>>
171        where
172            T: SupportedType<Value>,
173        {
174            self.0.get(key).map(T::from_types)
175        }
176    }
177
178    /// A type that can be mapped to `Value` and optionally
179    /// reconstructed back.
180    pub trait SupportedType<Value>: std::fmt::Debug {
181        fn to_types(value: Arc<Self>) -> Value;
182        fn from_types(t: &Value) -> Option<Arc<Self>>;
183    }
184}
185use heterogeneous_map::*;
186
187impl Session {
188    fn fresh_id(&mut self) -> Id {
189        let id = self.next_id.id;
190        self.next_id.id += 1;
191        Id { id }
192    }
193}
194
195impl<T: Sync + Send + 'static + SupportedType<Value>> Node<T> {
196    pub fn new(value: T, session: &mut Session) -> Self {
197        let id = session.fresh_id();
198        let value = Arc::new(value);
199        session.table.0.insert(id.clone(), value.clone());
200        Self { id, value }
201    }
202
203    pub fn inner(&self) -> &Arc<T> {
204        &self.value
205    }
206}
207
208/// Wrapper for a type `T` that creates a bundle containing both a ID
209/// table and a value `T`. That value may contains `Node` values
210/// inside it. Serializing `WithTable<T>` will serialize IDs only,
211/// skipping values. Deserialization of a `WithTable<T>` will
212/// automatically use the table and IDs to reconstruct skipped values.
213#[derive(Debug)]
214pub struct WithTable<T> {
215    table: Table,
216    value: T,
217}
218
219/// The state used for deserialization: a table.
220static DESERIALIZATION_STATE: LazyLock<Mutex<Table>> =
221    LazyLock::new(|| Mutex::new(Table::default()));
222static DESERIALIZATION_STATE_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
223
224/// The mode of serialization: should `Node<T>` ship values of type `T` or not?
225static SERIALIZATION_MODE_USE_IDS: std::sync::atomic::AtomicBool =
226    std::sync::atomic::AtomicBool::new(false);
227
228fn serialize_use_id() -> bool {
229    SERIALIZATION_MODE_USE_IDS.load(Ordering::Relaxed)
230}
231
232impl<T> WithTable<T> {
233    /// Runs `f` with a `WithTable<T>` created out of `map` and
234    /// `value`. Any serialization of values of type `Node<_>` will
235    /// skip the field `value`.
236    pub fn run<R>(map: Table, value: T, f: impl FnOnce(&Self) -> R) -> R {
237        if serialize_use_id() {
238            panic!(
239                "CACHE_MAP_LOCK: only one WithTable serialization can occur at a time (nesting is forbidden)"
240            )
241        }
242        SERIALIZATION_MODE_USE_IDS.store(true, Ordering::Relaxed);
243        let result = f(&Self { table: map, value });
244        SERIALIZATION_MODE_USE_IDS.store(false, Ordering::Relaxed);
245        result
246    }
247    pub fn destruct(self) -> (T, Table) {
248        let Self { value, table: map } = self;
249        (value, map)
250    }
251}
252
253impl<T: Serialize> Serialize for WithTable<T> {
254    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
255        let mut ts = serializer.serialize_tuple_struct("WithTable", 2)?;
256        use serde::ser::SerializeTupleStruct;
257        ts.serialize_field(&self.table)?;
258        ts.serialize_field(&self.value)?;
259        ts.end()
260    }
261}
262
263/// The deserializer of `WithTable<T>` is special. We first decode the
264/// table in order: each `(Id, Value)` pair of the table populates the
265/// global table state found in `DESERIALIZATION_STATE`. Only then we
266/// can decode the value itself, knowing `DESERIALIZATION_STATE` is
267/// complete.
268impl<'de, T: Deserialize<'de>> serde::Deserialize<'de> for WithTable<T> {
269    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
270    where
271        D: serde::Deserializer<'de>,
272    {
273        let _lock: MutexGuard<_> = DESERIALIZATION_STATE_LOCK.try_lock().expect("CACHE_MAP_LOCK: only one WithTable deserialization can occur at a time (nesting is forbidden)");
274        use serde_repr::WithTableRepr;
275        let previous = std::mem::take(&mut *DESERIALIZATION_STATE.lock().unwrap());
276        let with_table_repr = WithTableRepr::deserialize(deserializer);
277        *DESERIALIZATION_STATE.lock().unwrap() = previous;
278        let WithTableRepr(table, value) = with_table_repr?;
279        Ok(Self { table, value })
280    }
281}
282
283/// Defines representations for various types when serializing or/and
284/// deserializing via serde
285mod serde_repr {
286    use super::*;
287
288    #[derive(Serialize, Deserialize, JsonSchema, Debug)]
289    pub(super) struct NodeRepr<T> {
290        id: Id,
291        value: Option<Arc<T>>,
292    }
293
294    #[derive(Serialize)]
295    pub(super) struct Pair(Id, Value);
296    pub(super) type SortedIdValuePairs = Vec<Pair>;
297
298    #[derive(Serialize, Deserialize)]
299    pub(super) struct WithTableRepr<T>(pub(super) Table, pub(super) T);
300
301    impl<T: SupportedType<Value>> Into<NodeRepr<T>> for Node<T> {
302        fn into(self) -> NodeRepr<T> {
303            let value = if serialize_use_id() {
304                None
305            } else {
306                Some(self.value.clone())
307            };
308            let id = self.id;
309            NodeRepr { value, id }
310        }
311    }
312
313    impl<T: 'static + SupportedType<Value>> TryFrom<NodeRepr<T>> for Node<T> {
314        type Error = serde::de::value::Error;
315
316        fn try_from(cached: NodeRepr<T>) -> Result<Self, Self::Error> {
317            use serde::de::Error;
318            let table = DESERIALIZATION_STATE.lock().unwrap();
319            let id = cached.id;
320            let kind = if let Some(kind) = cached.value {
321                kind
322            } else {
323                table
324                    .0
325                    .get(&id)
326                    .ok_or_else(|| {
327                        Self::Error::custom(&format!(
328                            "Stateful deserialization failed for id {:?}: not found in cache",
329                            id
330                        ))
331                    })?
332                    .ok_or_else(|| {
333                        Self::Error::custom(&format!(
334                            "Stateful deserialization failed for id {:?}: wrong type",
335                            id
336                        ))
337                    })?
338            };
339            Ok(Self { value: kind, id })
340        }
341    }
342
343    impl<'de> serde::Deserialize<'de> for Pair {
344        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
345        where
346            D: serde::Deserializer<'de>,
347        {
348            let (id, v) = <(Id, Value)>::deserialize(deserializer)?;
349            DESERIALIZATION_STATE
350                .lock()
351                .unwrap()
352                .0
353                .insert_raw_value(id.clone(), v.clone());
354            Ok(Pair(id, v))
355        }
356    }
357
358    impl Into<SortedIdValuePairs> for Table {
359        fn into(self) -> SortedIdValuePairs {
360            let mut vec: Vec<_> = self.0.into_iter().map(|(x, y)| Pair(x, y)).collect();
361            vec.sort_by_key(|o| o.0.clone());
362            vec
363        }
364    }
365
366    impl From<SortedIdValuePairs> for Table {
367        fn from(t: SortedIdValuePairs) -> Self {
368            Self(HeterogeneousMap::from_iter(
369                t.into_iter().map(|Pair(x, y)| (x, y)),
370            ))
371        }
372    }
373}