1use crate::prelude::*;
20use std::{
21 hash::{Hash, Hasher},
22 sync::{Arc, LazyLock, Mutex, MutexGuard, atomic::Ordering},
23};
24
25#[derive_group(Serializers)]
27#[derive(Default, Clone, Copy, Debug, JsonSchema, Hash, PartialEq, Eq, PartialOrd, Ord)]
28#[serde(transparent)]
29pub struct Id {
30 id: u32,
31}
32
33#[derive(Default, Debug)]
35pub struct Session {
36 next_id: Id,
37 table: Table,
38}
39
40impl Session {
41 pub fn table(&self) -> &Table {
42 &self.table
43 }
44}
45
46#[derive(Debug, Clone, Deserialize, Serialize)]
48pub enum Value {
49 Ty(Arc<TyKind>),
50 DefId(Arc<DefIdContents>),
51 ItemRef(Arc<ItemRefContents>),
52}
53
54impl SupportedType<Value> for TyKind {
55 fn to_types(value: Arc<Self>) -> Value {
56 Value::Ty(value)
57 }
58 fn from_types(t: &Value) -> Option<Arc<Self>> {
59 match t {
60 Value::Ty(value) => Some(value.clone()),
61 _ => None,
62 }
63 }
64}
65
66impl SupportedType<Value> for DefIdContents {
67 fn to_types(value: Arc<Self>) -> Value {
68 Value::DefId(value)
69 }
70 fn from_types(t: &Value) -> Option<Arc<Self>> {
71 match t {
72 Value::DefId(value) => Some(value.clone()),
73 _ => None,
74 }
75 }
76}
77
78impl SupportedType<Value> for ItemRefContents {
79 fn to_types(value: Arc<Self>) -> Value {
80 Value::ItemRef(value)
81 }
82 fn from_types(t: &Value) -> Option<Arc<Self>> {
83 match t {
84 Value::ItemRef(value) => Some(value.clone()),
85 _ => None,
86 }
87 }
88}
89
90#[derive(Deserialize, Serialize, Debug, JsonSchema, PartialOrd, Ord)]
92#[serde(into = "serde_repr::NodeRepr<T>")]
93#[serde(try_from = "serde_repr::NodeRepr<T>")]
94pub struct Node<T: 'static + SupportedType<Value>> {
95 id: Id,
96 value: Arc<T>,
97}
98
99impl<T: SupportedType<Value>> std::ops::Deref for Node<T> {
100 type Target = T;
101 fn deref(&self) -> &Self::Target {
102 self.value.as_ref()
103 }
104}
105
106impl<T: SupportedType<Value> + Hash> Hash for Node<T> {
110 fn hash<H: Hasher>(&self, state: &mut H) {
111 self.value.as_ref().hash(state);
112 }
113}
114impl<T: SupportedType<Value> + Eq> Eq for Node<T> {}
115impl<T: SupportedType<Value> + PartialEq> PartialEq for Node<T> {
116 fn eq(&self, other: &Self) -> bool {
117 self.value == other.value
118 }
119}
120
121impl<T: SupportedType<Value>> Clone for Node<T> {
124 fn clone(&self) -> Self {
125 Self {
126 id: self.id.clone(),
127 value: self.value.clone(),
128 }
129 }
130}
131
132#[derive(Default, Debug, Clone, Deserialize, Serialize)]
137#[serde(into = "serde_repr::SortedIdValuePairs")]
138#[serde(from = "serde_repr::SortedIdValuePairs")]
139pub struct Table(HeterogeneousMap<Id, Value>);
140
141mod heterogeneous_map {
142 use std::collections::HashMap;
146 use std::hash::Hash;
147 use std::sync::Arc;
148 #[derive(Clone, Debug)]
149 pub struct HeterogeneousMap<Key, Value>(HashMap<Key, Value>);
153
154 impl<Id, Value> Default for HeterogeneousMap<Id, Value> {
155 fn default() -> Self {
156 Self(HashMap::default())
157 }
158 }
159
160 impl<Key: Hash + Eq + PartialEq, Value> HeterogeneousMap<Key, Value> {
161 pub(super) fn insert<T>(&mut self, key: Key, value: Arc<T>)
162 where
163 T: SupportedType<Value>,
164 {
165 self.insert_raw_value(key, T::to_types(value));
166 }
167 pub(super) fn insert_raw_value(&mut self, key: Key, value: Value) {
168 self.0.insert(key, value);
169 }
170 pub(super) fn from_iter(it: impl Iterator<Item = (Key, Value)>) -> Self {
171 Self(HashMap::from_iter(it))
172 }
173 pub(super) fn into_iter(self) -> impl Iterator<Item = (Key, Value)> {
174 self.0.into_iter()
175 }
176 pub(super) fn get<T>(&self, key: &Key) -> Option<Option<Arc<T>>>
177 where
178 T: SupportedType<Value>,
179 {
180 self.0.get(key).map(T::from_types)
181 }
182 }
183
184 pub trait SupportedType<Value>: std::fmt::Debug {
187 fn to_types(value: Arc<Self>) -> Value;
188 fn from_types(t: &Value) -> Option<Arc<Self>>;
189 }
190}
191use heterogeneous_map::*;
192
193impl Session {
194 fn fresh_id(&mut self) -> Id {
195 let id = self.next_id.id;
196 self.next_id.id += 1;
197 Id { id }
198 }
199}
200
201impl<T: Sync + Send + 'static + SupportedType<Value>> Node<T> {
202 pub fn new(value: T, session: &mut Session) -> Self {
203 let id = session.fresh_id();
204 let value = Arc::new(value);
205 session.table.0.insert(id.clone(), value.clone());
206 Self { id, value }
207 }
208
209 pub fn inner(&self) -> &Arc<T> {
210 &self.value
211 }
212
213 pub fn id(&self) -> Id {
214 self.id
215 }
216}
217
218#[derive(Debug)]
224pub struct WithTable<T> {
225 table: Table,
226 value: T,
227}
228
229static DESERIALIZATION_STATE: LazyLock<Mutex<Table>> =
231 LazyLock::new(|| Mutex::new(Table::default()));
232static DESERIALIZATION_STATE_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
233
234static SERIALIZATION_MODE_USE_IDS: std::sync::atomic::AtomicBool =
236 std::sync::atomic::AtomicBool::new(false);
237
238fn serialize_use_id() -> bool {
239 SERIALIZATION_MODE_USE_IDS.load(Ordering::Relaxed)
240}
241
242impl<T> WithTable<T> {
243 pub fn run<R>(map: Table, value: T, f: impl FnOnce(&Self) -> R) -> R {
247 if serialize_use_id() {
248 panic!(
249 "CACHE_MAP_LOCK: only one WithTable serialization can occur at a time (nesting is forbidden)"
250 )
251 }
252 SERIALIZATION_MODE_USE_IDS.store(true, Ordering::Relaxed);
253 let result = f(&Self { table: map, value });
254 SERIALIZATION_MODE_USE_IDS.store(false, Ordering::Relaxed);
255 result
256 }
257 pub fn destruct(self) -> (T, Table) {
258 let Self { value, table: map } = self;
259 (value, map)
260 }
261}
262
263impl<T: Serialize> Serialize for WithTable<T> {
264 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
265 let mut ts = serializer.serialize_tuple_struct("WithTable", 2)?;
266 use serde::ser::SerializeTupleStruct;
267 ts.serialize_field(&self.table)?;
268 ts.serialize_field(&self.value)?;
269 ts.end()
270 }
271}
272
273impl<'de, T: Deserialize<'de>> serde::Deserialize<'de> for WithTable<T> {
279 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
280 where
281 D: serde::Deserializer<'de>,
282 {
283 let _lock: MutexGuard<_> = DESERIALIZATION_STATE_LOCK.try_lock().expect("CACHE_MAP_LOCK: only one WithTable deserialization can occur at a time (nesting is forbidden)");
284 use serde_repr::WithTableRepr;
285 let previous = std::mem::take(&mut *DESERIALIZATION_STATE.lock().unwrap());
286 let with_table_repr = WithTableRepr::deserialize(deserializer);
287 *DESERIALIZATION_STATE.lock().unwrap() = previous;
288 let WithTableRepr(table, value) = with_table_repr?;
289 Ok(Self { table, value })
290 }
291}
292
293mod serde_repr {
296 use super::*;
297
298 #[derive(Serialize, Deserialize, JsonSchema, Debug)]
299 pub(super) struct NodeRepr<T> {
300 id: Id,
301 value: Option<Arc<T>>,
302 }
303
304 #[derive(Serialize)]
305 pub(super) struct Pair(Id, Value);
306 pub(super) type SortedIdValuePairs = Vec<Pair>;
307
308 #[derive(Serialize, Deserialize)]
309 pub(super) struct WithTableRepr<T>(pub(super) Table, pub(super) T);
310
311 impl<T: SupportedType<Value>> Into<NodeRepr<T>> for Node<T> {
312 fn into(self) -> NodeRepr<T> {
313 let value = if serialize_use_id() {
314 None
315 } else {
316 Some(self.value.clone())
317 };
318 let id = self.id;
319 NodeRepr { value, id }
320 }
321 }
322
323 impl<T: 'static + SupportedType<Value>> TryFrom<NodeRepr<T>> for Node<T> {
324 type Error = serde::de::value::Error;
325
326 fn try_from(cached: NodeRepr<T>) -> Result<Self, Self::Error> {
327 use serde::de::Error;
328 let table = DESERIALIZATION_STATE.lock().unwrap();
329 let id = cached.id;
330 let kind = if let Some(kind) = cached.value {
331 kind
332 } else {
333 table
334 .0
335 .get(&id)
336 .ok_or_else(|| {
337 Self::Error::custom(&format!(
338 "Stateful deserialization failed for id {:?}: not found in cache",
339 id
340 ))
341 })?
342 .ok_or_else(|| {
343 Self::Error::custom(&format!(
344 "Stateful deserialization failed for id {:?}: wrong type",
345 id
346 ))
347 })?
348 };
349 Ok(Self { value: kind, id })
350 }
351 }
352
353 impl<'de> serde::Deserialize<'de> for Pair {
354 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
355 where
356 D: serde::Deserializer<'de>,
357 {
358 let (id, v) = <(Id, Value)>::deserialize(deserializer)?;
359 DESERIALIZATION_STATE
360 .lock()
361 .unwrap()
362 .0
363 .insert_raw_value(id.clone(), v.clone());
364 Ok(Pair(id, v))
365 }
366 }
367
368 impl Into<SortedIdValuePairs> for Table {
369 fn into(self) -> SortedIdValuePairs {
370 let mut vec: Vec<_> = self.0.into_iter().map(|(x, y)| Pair(x, y)).collect();
371 vec.sort_by_key(|o| o.0.clone());
372 vec
373 }
374 }
375
376 impl From<SortedIdValuePairs> for Table {
377 fn from(t: SortedIdValuePairs) -> Self {
378 Self(HeterogeneousMap::from_iter(
379 t.into_iter().map(|Pair(x, y)| (x, y)),
380 ))
381 }
382 }
383}