1use crate::prelude::*;
20use std::{
21 hash::{Hash, Hasher},
22 sync::{Arc, LazyLock, Mutex, MutexGuard, atomic::Ordering},
23};
24
25#[derive_group(Serializers)]
27#[derive(Default, Clone, Debug, JsonSchema, Hash, PartialEq, Eq, PartialOrd, Ord)]
28#[serde(transparent)]
29pub struct Id {
30 id: u32,
31}
32
33#[derive(Default, Debug)]
35pub struct Session {
36 next_id: Id,
37 table: Table,
38}
39
40impl Session {
41 pub fn table(&self) -> &Table {
42 &self.table
43 }
44}
45
46#[derive(Debug, Clone, Deserialize, Serialize)]
48pub enum Value {
49 Ty(Arc<TyKind>),
50 DefId(Arc<DefIdContents>),
51 ItemRef(Arc<ItemRefContents>),
52}
53
54impl SupportedType<Value> for TyKind {
55 fn to_types(value: Arc<Self>) -> Value {
56 Value::Ty(value)
57 }
58 fn from_types(t: &Value) -> Option<Arc<Self>> {
59 match t {
60 Value::Ty(value) => Some(value.clone()),
61 _ => None,
62 }
63 }
64}
65
66impl SupportedType<Value> for DefIdContents {
67 fn to_types(value: Arc<Self>) -> Value {
68 Value::DefId(value)
69 }
70 fn from_types(t: &Value) -> Option<Arc<Self>> {
71 match t {
72 Value::DefId(value) => Some(value.clone()),
73 _ => None,
74 }
75 }
76}
77
78impl SupportedType<Value> for ItemRefContents {
79 fn to_types(value: Arc<Self>) -> Value {
80 Value::ItemRef(value)
81 }
82 fn from_types(t: &Value) -> Option<Arc<Self>> {
83 match t {
84 Value::ItemRef(value) => Some(value.clone()),
85 _ => None,
86 }
87 }
88}
89
90#[derive(Deserialize, Serialize, Debug, JsonSchema, PartialEq, Eq, PartialOrd, Ord)]
92#[serde(into = "serde_repr::NodeRepr<T>")]
93#[serde(try_from = "serde_repr::NodeRepr<T>")]
94pub struct Node<T: 'static + SupportedType<Value>> {
95 id: Id,
96 value: Arc<T>,
97}
98
99impl<T: SupportedType<Value>> std::ops::Deref for Node<T> {
100 type Target = T;
101 fn deref(&self) -> &Self::Target {
102 self.value.as_ref()
103 }
104}
105
106impl<T: SupportedType<Value> + Hash> Hash for Node<T> {
110 fn hash<H: Hasher>(&self, state: &mut H) {
111 self.value.as_ref().hash(state);
112 }
113}
114
115impl<T: SupportedType<Value>> Clone for Node<T> {
118 fn clone(&self) -> Self {
119 Self {
120 id: self.id.clone(),
121 value: self.value.clone(),
122 }
123 }
124}
125
126#[derive(Default, Debug, Clone, Deserialize, Serialize)]
131#[serde(into = "serde_repr::SortedIdValuePairs")]
132#[serde(from = "serde_repr::SortedIdValuePairs")]
133pub struct Table(HeterogeneousMap<Id, Value>);
134
135mod heterogeneous_map {
136 use std::collections::HashMap;
140 use std::hash::Hash;
141 use std::sync::Arc;
142 #[derive(Clone, Debug)]
143 pub struct HeterogeneousMap<Key, Value>(HashMap<Key, Value>);
147
148 impl<Id, Value> Default for HeterogeneousMap<Id, Value> {
149 fn default() -> Self {
150 Self(HashMap::default())
151 }
152 }
153
154 impl<Key: Hash + Eq + PartialEq, Value> HeterogeneousMap<Key, Value> {
155 pub(super) fn insert<T>(&mut self, key: Key, value: Arc<T>)
156 where
157 T: SupportedType<Value>,
158 {
159 self.insert_raw_value(key, T::to_types(value));
160 }
161 pub(super) fn insert_raw_value(&mut self, key: Key, value: Value) {
162 self.0.insert(key, value);
163 }
164 pub(super) fn from_iter(it: impl Iterator<Item = (Key, Value)>) -> Self {
165 Self(HashMap::from_iter(it))
166 }
167 pub(super) fn into_iter(self) -> impl Iterator<Item = (Key, Value)> {
168 self.0.into_iter()
169 }
170 pub(super) fn get<T>(&self, key: &Key) -> Option<Option<Arc<T>>>
171 where
172 T: SupportedType<Value>,
173 {
174 self.0.get(key).map(T::from_types)
175 }
176 }
177
178 pub trait SupportedType<Value>: std::fmt::Debug {
181 fn to_types(value: Arc<Self>) -> Value;
182 fn from_types(t: &Value) -> Option<Arc<Self>>;
183 }
184}
185use heterogeneous_map::*;
186
187impl Session {
188 fn fresh_id(&mut self) -> Id {
189 let id = self.next_id.id;
190 self.next_id.id += 1;
191 Id { id }
192 }
193}
194
195impl<T: Sync + Send + 'static + SupportedType<Value>> Node<T> {
196 pub fn new(value: T, session: &mut Session) -> Self {
197 let id = session.fresh_id();
198 let value = Arc::new(value);
199 session.table.0.insert(id.clone(), value.clone());
200 Self { id, value }
201 }
202
203 pub fn inner(&self) -> &Arc<T> {
204 &self.value
205 }
206}
207
208#[derive(Debug)]
214pub struct WithTable<T> {
215 table: Table,
216 value: T,
217}
218
219static DESERIALIZATION_STATE: LazyLock<Mutex<Table>> =
221 LazyLock::new(|| Mutex::new(Table::default()));
222static DESERIALIZATION_STATE_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
223
224static SERIALIZATION_MODE_USE_IDS: std::sync::atomic::AtomicBool =
226 std::sync::atomic::AtomicBool::new(false);
227
228fn serialize_use_id() -> bool {
229 SERIALIZATION_MODE_USE_IDS.load(Ordering::Relaxed)
230}
231
232impl<T> WithTable<T> {
233 pub fn run<R>(map: Table, value: T, f: impl FnOnce(&Self) -> R) -> R {
237 if serialize_use_id() {
238 panic!(
239 "CACHE_MAP_LOCK: only one WithTable serialization can occur at a time (nesting is forbidden)"
240 )
241 }
242 SERIALIZATION_MODE_USE_IDS.store(true, Ordering::Relaxed);
243 let result = f(&Self { table: map, value });
244 SERIALIZATION_MODE_USE_IDS.store(false, Ordering::Relaxed);
245 result
246 }
247 pub fn destruct(self) -> (T, Table) {
248 let Self { value, table: map } = self;
249 (value, map)
250 }
251}
252
253impl<T: Serialize> Serialize for WithTable<T> {
254 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
255 let mut ts = serializer.serialize_tuple_struct("WithTable", 2)?;
256 use serde::ser::SerializeTupleStruct;
257 ts.serialize_field(&self.table)?;
258 ts.serialize_field(&self.value)?;
259 ts.end()
260 }
261}
262
263impl<'de, T: Deserialize<'de>> serde::Deserialize<'de> for WithTable<T> {
269 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
270 where
271 D: serde::Deserializer<'de>,
272 {
273 let _lock: MutexGuard<_> = DESERIALIZATION_STATE_LOCK.try_lock().expect("CACHE_MAP_LOCK: only one WithTable deserialization can occur at a time (nesting is forbidden)");
274 use serde_repr::WithTableRepr;
275 let previous = std::mem::take(&mut *DESERIALIZATION_STATE.lock().unwrap());
276 let with_table_repr = WithTableRepr::deserialize(deserializer);
277 *DESERIALIZATION_STATE.lock().unwrap() = previous;
278 let WithTableRepr(table, value) = with_table_repr?;
279 Ok(Self { table, value })
280 }
281}
282
283mod serde_repr {
286 use super::*;
287
288 #[derive(Serialize, Deserialize, JsonSchema, Debug)]
289 pub(super) struct NodeRepr<T> {
290 id: Id,
291 value: Option<Arc<T>>,
292 }
293
294 #[derive(Serialize)]
295 pub(super) struct Pair(Id, Value);
296 pub(super) type SortedIdValuePairs = Vec<Pair>;
297
298 #[derive(Serialize, Deserialize)]
299 pub(super) struct WithTableRepr<T>(pub(super) Table, pub(super) T);
300
301 impl<T: SupportedType<Value>> Into<NodeRepr<T>> for Node<T> {
302 fn into(self) -> NodeRepr<T> {
303 let value = if serialize_use_id() {
304 None
305 } else {
306 Some(self.value.clone())
307 };
308 let id = self.id;
309 NodeRepr { value, id }
310 }
311 }
312
313 impl<T: 'static + SupportedType<Value>> TryFrom<NodeRepr<T>> for Node<T> {
314 type Error = serde::de::value::Error;
315
316 fn try_from(cached: NodeRepr<T>) -> Result<Self, Self::Error> {
317 use serde::de::Error;
318 let table = DESERIALIZATION_STATE.lock().unwrap();
319 let id = cached.id;
320 let kind = if let Some(kind) = cached.value {
321 kind
322 } else {
323 table
324 .0
325 .get(&id)
326 .ok_or_else(|| {
327 Self::Error::custom(&format!(
328 "Stateful deserialization failed for id {:?}: not found in cache",
329 id
330 ))
331 })?
332 .ok_or_else(|| {
333 Self::Error::custom(&format!(
334 "Stateful deserialization failed for id {:?}: wrong type",
335 id
336 ))
337 })?
338 };
339 Ok(Self { value: kind, id })
340 }
341 }
342
343 impl<'de> serde::Deserialize<'de> for Pair {
344 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
345 where
346 D: serde::Deserializer<'de>,
347 {
348 let (id, v) = <(Id, Value)>::deserialize(deserializer)?;
349 DESERIALIZATION_STATE
350 .lock()
351 .unwrap()
352 .0
353 .insert_raw_value(id.clone(), v.clone());
354 Ok(Pair(id, v))
355 }
356 }
357
358 impl Into<SortedIdValuePairs> for Table {
359 fn into(self) -> SortedIdValuePairs {
360 let mut vec: Vec<_> = self.0.into_iter().map(|(x, y)| Pair(x, y)).collect();
361 vec.sort_by_key(|o| o.0.clone());
362 vec
363 }
364 }
365
366 impl From<SortedIdValuePairs> for Table {
367 fn from(t: SortedIdValuePairs) -> Self {
368 Self(HeterogeneousMap::from_iter(
369 t.into_iter().map(|Pair(x, y)| (x, y)),
370 ))
371 }
372 }
373}