hax_frontend_exporter/
id_table.rs

1/// This module provides a notion of table, identifiers and nodes. A
2/// `Node<T>` is a `Arc<T>` bundled with a unique identifier such that
3/// there exists an entry in a table for that identifier.
4///
5/// The type `WithTable<T>` bundles a table with a value of type
6/// `T`. That value of type `T` may hold an arbitrary number of
7/// `Node<_>`s. In the context of a `WithTable<T>`, the type `Node<_>`
8/// serializes and deserializes using a table as a state. In this
9/// case, serializing a `Node<U>` produces only an identifier, without
10/// any data of type `U`. Deserializing a `Node<U>` under a
11/// `WithTable<T>` will recover `U` data from the table held by
12/// `WithTable`.
13///
14/// Serde is not designed for stateful (de)serialization. There is no
15/// way of deriving `serde::de::DeserializeSeed` systematically. This
16/// module thus makes use of global state to achieve serialization and
17/// deserialization. This modules provides an API that hides this
18/// global state.
19use crate::prelude::*;
20use std::{
21    hash::{Hash, Hasher},
22    sync::{Arc, LazyLock, Mutex, MutexGuard, atomic::Ordering},
23};
24
25/// Unique IDs in a ID table.
26#[derive_group(Serializers)]
27#[derive(Default, Clone, Copy, Debug, JsonSchema, Hash, PartialEq, Eq, PartialOrd, Ord)]
28#[serde(transparent)]
29pub struct Id {
30    id: u32,
31}
32
33/// A session providing fresh IDs for ID table.
34#[derive(Default, Debug)]
35pub struct Session {
36    next_id: Id,
37    table: Table,
38}
39
40impl Session {
41    pub fn table(&self) -> &Table {
42        &self.table
43    }
44}
45
46/// The different types of values one can store in an ID table.
47#[derive(Debug, Clone, Deserialize, Serialize)]
48pub enum Value {
49    Ty(Arc<TyKind>),
50    DefId(Arc<DefIdContents>),
51    ItemRef(Arc<ItemRefContents>),
52}
53
54impl SupportedType<Value> for TyKind {
55    fn to_types(value: Arc<Self>) -> Value {
56        Value::Ty(value)
57    }
58    fn from_types(t: &Value) -> Option<Arc<Self>> {
59        match t {
60            Value::Ty(value) => Some(value.clone()),
61            _ => None,
62        }
63    }
64}
65
66impl SupportedType<Value> for DefIdContents {
67    fn to_types(value: Arc<Self>) -> Value {
68        Value::DefId(value)
69    }
70    fn from_types(t: &Value) -> Option<Arc<Self>> {
71        match t {
72            Value::DefId(value) => Some(value.clone()),
73            _ => None,
74        }
75    }
76}
77
78impl SupportedType<Value> for ItemRefContents {
79    fn to_types(value: Arc<Self>) -> Value {
80        Value::ItemRef(value)
81    }
82    fn from_types(t: &Value) -> Option<Arc<Self>> {
83        match t {
84            Value::ItemRef(value) => Some(value.clone()),
85            _ => None,
86        }
87    }
88}
89
90/// A node is a bundle of an ID with a value.
91#[derive(Deserialize, Serialize, Debug, JsonSchema, PartialOrd, Ord)]
92#[serde(into = "serde_repr::NodeRepr<T>")]
93#[serde(try_from = "serde_repr::NodeRepr<T>")]
94pub struct Node<T: 'static + SupportedType<Value>> {
95    id: Id,
96    value: Arc<T>,
97}
98
99impl<T: SupportedType<Value>> std::ops::Deref for Node<T> {
100    type Target = T;
101    fn deref(&self) -> &Self::Target {
102        self.value.as_ref()
103    }
104}
105
106/// Hax relies on hashes being deterministic for predicates
107/// ids. Identifiers are not deterministic: we implement hash for
108/// `Node` manually, discarding the field `id`.
109impl<T: SupportedType<Value> + Hash> Hash for Node<T> {
110    fn hash<H: Hasher>(&self, state: &mut H) {
111        self.value.as_ref().hash(state);
112    }
113}
114impl<T: SupportedType<Value> + Eq> Eq for Node<T> {}
115impl<T: SupportedType<Value> + PartialEq> PartialEq for Node<T> {
116    fn eq(&self, other: &Self) -> bool {
117        self.value == other.value
118    }
119}
120
121/// Manual implementation of `Clone` that doesn't require a `Clone`
122/// bound on `T`.
123impl<T: SupportedType<Value>> Clone for Node<T> {
124    fn clone(&self) -> Self {
125        Self {
126            id: self.id.clone(),
127            value: self.value.clone(),
128        }
129    }
130}
131
132/// A table is a map from IDs to `Value`s. When serialized, we
133/// represent a table as a *sorted* vector. Indeed, the values stored
134/// in the table might reference each other, without cycle, so the
135/// order matters.
136#[derive(Default, Debug, Clone, Deserialize, Serialize)]
137#[serde(into = "serde_repr::SortedIdValuePairs")]
138#[serde(from = "serde_repr::SortedIdValuePairs")]
139pub struct Table(HeterogeneousMap<Id, Value>);
140
141mod heterogeneous_map {
142    //! This module provides an heterogenous map that can store types
143    //! that implement the trait `SupportedType`.
144
145    use std::collections::HashMap;
146    use std::hash::Hash;
147    use std::sync::Arc;
148    #[derive(Clone, Debug)]
149    /// An heterogenous map is a map from `Key` to `Value`. It provide
150    /// the methods `insert` and `get` for any type `T` that
151    /// implements `SupportedType<Value>`.
152    pub struct HeterogeneousMap<Key, Value>(HashMap<Key, Value>);
153
154    impl<Id, Value> Default for HeterogeneousMap<Id, Value> {
155        fn default() -> Self {
156            Self(HashMap::default())
157        }
158    }
159
160    impl<Key: Hash + Eq + PartialEq, Value> HeterogeneousMap<Key, Value> {
161        pub(super) fn insert<T>(&mut self, key: Key, value: Arc<T>)
162        where
163            T: SupportedType<Value>,
164        {
165            self.insert_raw_value(key, T::to_types(value));
166        }
167        pub(super) fn insert_raw_value(&mut self, key: Key, value: Value) {
168            self.0.insert(key, value);
169        }
170        pub(super) fn from_iter(it: impl Iterator<Item = (Key, Value)>) -> Self {
171            Self(HashMap::from_iter(it))
172        }
173        pub(super) fn into_iter(self) -> impl Iterator<Item = (Key, Value)> {
174            self.0.into_iter()
175        }
176        pub(super) fn get<T>(&self, key: &Key) -> Option<Option<Arc<T>>>
177        where
178            T: SupportedType<Value>,
179        {
180            self.0.get(key).map(T::from_types)
181        }
182    }
183
184    /// A type that can be mapped to `Value` and optionally
185    /// reconstructed back.
186    pub trait SupportedType<Value>: std::fmt::Debug {
187        fn to_types(value: Arc<Self>) -> Value;
188        fn from_types(t: &Value) -> Option<Arc<Self>>;
189    }
190}
191use heterogeneous_map::*;
192
193impl Session {
194    fn fresh_id(&mut self) -> Id {
195        let id = self.next_id.id;
196        self.next_id.id += 1;
197        Id { id }
198    }
199}
200
201impl<T: Sync + Send + 'static + SupportedType<Value>> Node<T> {
202    pub fn new(value: T, session: &mut Session) -> Self {
203        let id = session.fresh_id();
204        let value = Arc::new(value);
205        session.table.0.insert(id.clone(), value.clone());
206        Self { id, value }
207    }
208
209    pub fn inner(&self) -> &Arc<T> {
210        &self.value
211    }
212
213    pub fn id(&self) -> Id {
214        self.id
215    }
216}
217
218/// Wrapper for a type `T` that creates a bundle containing both a ID
219/// table and a value `T`. That value may contains `Node` values
220/// inside it. Serializing `WithTable<T>` will serialize IDs only,
221/// skipping values. Deserialization of a `WithTable<T>` will
222/// automatically use the table and IDs to reconstruct skipped values.
223#[derive(Debug)]
224pub struct WithTable<T> {
225    table: Table,
226    value: T,
227}
228
229/// The state used for deserialization: a table.
230static DESERIALIZATION_STATE: LazyLock<Mutex<Table>> =
231    LazyLock::new(|| Mutex::new(Table::default()));
232static DESERIALIZATION_STATE_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
233
234/// The mode of serialization: should `Node<T>` ship values of type `T` or not?
235static SERIALIZATION_MODE_USE_IDS: std::sync::atomic::AtomicBool =
236    std::sync::atomic::AtomicBool::new(false);
237
238fn serialize_use_id() -> bool {
239    SERIALIZATION_MODE_USE_IDS.load(Ordering::Relaxed)
240}
241
242impl<T> WithTable<T> {
243    /// Runs `f` with a `WithTable<T>` created out of `map` and
244    /// `value`. Any serialization of values of type `Node<_>` will
245    /// skip the field `value`.
246    pub fn run<R>(map: Table, value: T, f: impl FnOnce(&Self) -> R) -> R {
247        if serialize_use_id() {
248            panic!(
249                "CACHE_MAP_LOCK: only one WithTable serialization can occur at a time (nesting is forbidden)"
250            )
251        }
252        SERIALIZATION_MODE_USE_IDS.store(true, Ordering::Relaxed);
253        let result = f(&Self { table: map, value });
254        SERIALIZATION_MODE_USE_IDS.store(false, Ordering::Relaxed);
255        result
256    }
257    pub fn destruct(self) -> (T, Table) {
258        let Self { value, table: map } = self;
259        (value, map)
260    }
261}
262
263impl<T: Serialize> Serialize for WithTable<T> {
264    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
265        let mut ts = serializer.serialize_tuple_struct("WithTable", 2)?;
266        use serde::ser::SerializeTupleStruct;
267        ts.serialize_field(&self.table)?;
268        ts.serialize_field(&self.value)?;
269        ts.end()
270    }
271}
272
273/// The deserializer of `WithTable<T>` is special. We first decode the
274/// table in order: each `(Id, Value)` pair of the table populates the
275/// global table state found in `DESERIALIZATION_STATE`. Only then we
276/// can decode the value itself, knowing `DESERIALIZATION_STATE` is
277/// complete.
278impl<'de, T: Deserialize<'de>> serde::Deserialize<'de> for WithTable<T> {
279    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
280    where
281        D: serde::Deserializer<'de>,
282    {
283        let _lock: MutexGuard<_> = DESERIALIZATION_STATE_LOCK.try_lock().expect("CACHE_MAP_LOCK: only one WithTable deserialization can occur at a time (nesting is forbidden)");
284        use serde_repr::WithTableRepr;
285        let previous = std::mem::take(&mut *DESERIALIZATION_STATE.lock().unwrap());
286        let with_table_repr = WithTableRepr::deserialize(deserializer);
287        *DESERIALIZATION_STATE.lock().unwrap() = previous;
288        let WithTableRepr(table, value) = with_table_repr?;
289        Ok(Self { table, value })
290    }
291}
292
293/// Defines representations for various types when serializing or/and
294/// deserializing via serde
295mod serde_repr {
296    use super::*;
297
298    #[derive(Serialize, Deserialize, JsonSchema, Debug)]
299    pub(super) struct NodeRepr<T> {
300        id: Id,
301        value: Option<Arc<T>>,
302    }
303
304    #[derive(Serialize)]
305    pub(super) struct Pair(Id, Value);
306    pub(super) type SortedIdValuePairs = Vec<Pair>;
307
308    #[derive(Serialize, Deserialize)]
309    pub(super) struct WithTableRepr<T>(pub(super) Table, pub(super) T);
310
311    impl<T: SupportedType<Value>> Into<NodeRepr<T>> for Node<T> {
312        fn into(self) -> NodeRepr<T> {
313            let value = if serialize_use_id() {
314                None
315            } else {
316                Some(self.value.clone())
317            };
318            let id = self.id;
319            NodeRepr { value, id }
320        }
321    }
322
323    impl<T: 'static + SupportedType<Value>> TryFrom<NodeRepr<T>> for Node<T> {
324        type Error = serde::de::value::Error;
325
326        fn try_from(cached: NodeRepr<T>) -> Result<Self, Self::Error> {
327            use serde::de::Error;
328            let table = DESERIALIZATION_STATE.lock().unwrap();
329            let id = cached.id;
330            let kind = if let Some(kind) = cached.value {
331                kind
332            } else {
333                table
334                    .0
335                    .get(&id)
336                    .ok_or_else(|| {
337                        Self::Error::custom(&format!(
338                            "Stateful deserialization failed for id {:?}: not found in cache",
339                            id
340                        ))
341                    })?
342                    .ok_or_else(|| {
343                        Self::Error::custom(&format!(
344                            "Stateful deserialization failed for id {:?}: wrong type",
345                            id
346                        ))
347                    })?
348            };
349            Ok(Self { value: kind, id })
350        }
351    }
352
353    impl<'de> serde::Deserialize<'de> for Pair {
354        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
355        where
356            D: serde::Deserializer<'de>,
357        {
358            let (id, v) = <(Id, Value)>::deserialize(deserializer)?;
359            DESERIALIZATION_STATE
360                .lock()
361                .unwrap()
362                .0
363                .insert_raw_value(id.clone(), v.clone());
364            Ok(Pair(id, v))
365        }
366    }
367
368    impl Into<SortedIdValuePairs> for Table {
369        fn into(self) -> SortedIdValuePairs {
370            let mut vec: Vec<_> = self.0.into_iter().map(|(x, y)| Pair(x, y)).collect();
371            vec.sort_by_key(|o| o.0.clone());
372            vec
373        }
374    }
375
376    impl From<SortedIdValuePairs> for Table {
377        fn from(t: SortedIdValuePairs) -> Self {
378            Self(HeterogeneousMap::from_iter(
379                t.into_iter().map(|Pair(x, y)| (x, y)),
380            ))
381        }
382    }
383}