egglog/sort/
map.rs

1use super::*;
2use std::collections::BTreeMap;
3
4#[derive(Clone, Debug, PartialEq, Eq, Hash)]
5pub struct MapContainer {
6    do_rebuild_keys: bool,
7    do_rebuild_vals: bool,
8    pub data: BTreeMap<Value, Value>,
9}
10
11impl ContainerValue for MapContainer {
12    fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool {
13        let mut changed = false;
14        if self.do_rebuild_keys {
15            self.data = self
16                .data
17                .iter()
18                .map(|(old, v)| {
19                    let new = rebuilder.rebuild_val(*old);
20                    changed |= *old != new;
21                    (new, *v)
22                })
23                .collect();
24        }
25        if self.do_rebuild_vals {
26            for old in self.data.values_mut() {
27                let new = rebuilder.rebuild_val(*old);
28                changed |= *old != new;
29                *old = new;
30            }
31        }
32        changed
33    }
34    fn iter(&self) -> impl Iterator<Item = Value> + '_ {
35        self.data.iter().flat_map(|(k, v)| [k, v]).copied()
36    }
37}
38
39/// A map from a key type to a value type supporting these primitives:
40/// - `map-empty`
41/// - `map-insert`
42/// - `map-get`
43/// - `map-contains`
44/// - `map-not-contains`
45/// - `map-remove`
46/// - `map-length`
47#[derive(Clone, Debug)]
48pub struct MapSort {
49    name: String,
50    key: ArcSort,
51    value: ArcSort,
52}
53
54impl MapSort {
55    pub fn key(&self) -> ArcSort {
56        self.key.clone()
57    }
58
59    pub fn value(&self) -> ArcSort {
60        self.value.clone()
61    }
62}
63
64impl Presort for MapSort {
65    fn presort_name() -> &'static str {
66        "Map"
67    }
68
69    fn reserved_primitives() -> Vec<&'static str> {
70        vec![
71            "map-empty",
72            "map-insert",
73            "map-get",
74            "map-not-contains",
75            "map-contains",
76            "map-remove",
77            "map-length",
78        ]
79    }
80
81    fn make_sort(
82        typeinfo: &mut TypeInfo,
83        name: String,
84        args: &[Expr],
85    ) -> Result<ArcSort, TypeError> {
86        if let [Expr::Var(k_span, k), Expr::Var(v_span, v)] = args {
87            let k = typeinfo
88                .get_sort_by_name(k)
89                .ok_or(TypeError::UndefinedSort(k.clone(), k_span.clone()))?;
90            let v = typeinfo
91                .get_sort_by_name(v)
92                .ok_or(TypeError::UndefinedSort(v.clone(), v_span.clone()))?;
93
94            let out = Self {
95                name,
96                key: k.clone(),
97                value: v.clone(),
98            };
99            Ok(out.to_arcsort())
100        } else {
101            panic!()
102        }
103    }
104}
105
106impl ContainerSort for MapSort {
107    type Container = MapContainer;
108
109    fn name(&self) -> &str {
110        &self.name
111    }
112
113    fn inner_sorts(&self) -> Vec<ArcSort> {
114        vec![self.key.clone(), self.value.clone()]
115    }
116
117    fn is_eq_container_sort(&self) -> bool {
118        self.key.is_eq_sort()
119            || self.value.is_eq_sort()
120            || self.key.is_eq_container_sort()
121            || self.value.is_eq_container_sort()
122    }
123
124    fn inner_values(
125        &self,
126        container_values: &ContainerValues,
127        value: Value,
128    ) -> Vec<(ArcSort, Value)> {
129        let val = container_values
130            .get_val::<MapContainer>(value)
131            .unwrap()
132            .clone();
133        val.data
134            .iter()
135            .flat_map(|(k, v)| [(self.key.clone(), *k), (self.value.clone(), *v)])
136            .collect()
137    }
138
139    fn register_primitives(&self, eg: &mut EGraph) {
140        let arc = self.clone().to_arcsort();
141
142        add_primitive!(eg, "map-empty" = {self.clone(): MapSort} || -> @MapContainer (arc) { MapContainer {
143            do_rebuild_keys: self.ctx.key.is_eq_sort() || self.ctx.key.is_eq_container_sort(),
144            do_rebuild_vals: self.ctx.value.is_eq_sort() || self.ctx.value.is_eq_container_sort(),
145            data: BTreeMap::new()
146        } });
147
148        add_primitive!(eg, "map-get"    = |    xs: @MapContainer (arc), x: # (self.key())                     | -?> # (self.value()) { xs.data.get(&x).copied() });
149        add_primitive!(eg, "map-insert" = |mut xs: @MapContainer (arc), x: # (self.key()), y: # (self.value())| -> @MapContainer (arc) {{ xs.data.insert(x, y); xs }});
150        add_primitive!(eg, "map-remove" = |mut xs: @MapContainer (arc), x: # (self.key())                     | -> @MapContainer (arc) {{ xs.data.remove(&x);   xs }});
151
152        add_primitive!(eg, "map-length"       = |xs: @MapContainer (arc)| -> i64 { xs.data.len() as i64 });
153        add_primitive!(eg, "map-contains"     = |xs: @MapContainer (arc), x: # (self.key())| -?> () { ( xs.data.contains_key(&x)).then_some(()) });
154        add_primitive!(eg, "map-not-contains" = |xs: @MapContainer (arc), x: # (self.key())| -?> () { (!xs.data.contains_key(&x)).then_some(()) });
155    }
156
157    fn reconstruct_termdag(
158        &self,
159        _container_values: &ContainerValues,
160        _value: Value,
161        termdag: &mut TermDag,
162        element_terms: Vec<TermId>,
163    ) -> TermId {
164        let mut term = termdag.app("map-empty".into(), vec![]);
165
166        for x in element_terms.chunks(2) {
167            term = termdag.app("map-insert".into(), vec![term, x[0], x[1]])
168        }
169
170        term
171    }
172
173    fn serialized_name(&self, _container_values: &ContainerValues, _: Value) -> String {
174        self.name().to_owned()
175    }
176}