egglog/sort/
map.rs

1use super::*;
2use std::collections::BTreeMap;
3
4#[derive(Clone, Debug, PartialEq, Eq, Hash)]
5pub struct MapContainer {
6    do_rebuild_keys: bool,
7    do_rebuild_vals: bool,
8    pub data: BTreeMap<Value, Value>,
9}
10
11impl ContainerValue for MapContainer {
12    fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool {
13        let mut changed = false;
14        if self.do_rebuild_keys {
15            self.data = self
16                .data
17                .iter()
18                .map(|(old, v)| {
19                    let new = rebuilder.rebuild_val(*old);
20                    changed |= *old != new;
21                    (new, *v)
22                })
23                .collect();
24        }
25        if self.do_rebuild_vals {
26            for old in self.data.values_mut() {
27                let new = rebuilder.rebuild_val(*old);
28                changed |= *old != new;
29                *old = new;
30            }
31        }
32        changed
33    }
34    fn iter(&self) -> impl Iterator<Item = Value> + '_ {
35        self.data.iter().flat_map(|(k, v)| [k, v]).copied()
36    }
37}
38
39/// A map from a key type to a value type supporting these primitives:
40/// - `map-empty`
41/// - `map-insert`
42/// - `map-get`
43/// - `map-contains`
44/// - `map-not-contains`
45/// - `map-remove`
46/// - `map-length`
47#[derive(Clone, Debug)]
48pub struct MapSort {
49    name: String,
50    key: ArcSort,
51    value: ArcSort,
52}
53
54impl MapSort {
55    pub fn key(&self) -> ArcSort {
56        self.key.clone()
57    }
58
59    pub fn value(&self) -> ArcSort {
60        self.value.clone()
61    }
62}
63
64impl Presort for MapSort {
65    fn presort_name() -> &'static str {
66        "Map"
67    }
68
69    fn reserved_primitives() -> Vec<&'static str> {
70        vec![
71            "map-empty",
72            "map-insert",
73            "map-get",
74            "map-not-contains",
75            "map-contains",
76            "map-remove",
77            "map-length",
78        ]
79    }
80
81    fn make_sort(
82        typeinfo: &mut TypeInfo,
83        name: String,
84        args: &[Expr],
85    ) -> Result<ArcSort, TypeError> {
86        if let [Expr::Var(k_span, k), Expr::Var(v_span, v)] = args {
87            let k = typeinfo
88                .get_sort_by_name(k)
89                .ok_or(TypeError::UndefinedSort(k.clone(), k_span.clone()))?;
90            let v = typeinfo
91                .get_sort_by_name(v)
92                .ok_or(TypeError::UndefinedSort(v.clone(), v_span.clone()))?;
93
94            // TODO: specialize the error message
95            if k.is_eq_container_sort() {
96                return Err(TypeError::DisallowedSort(
97                    name,
98                    "Maps nested with other EqSort containers are not allowed".into(),
99                    k_span.clone(),
100                ));
101            }
102            if v.is_container_sort() {
103                return Err(TypeError::DisallowedSort(
104                    name,
105                    "Maps nested with other containers are not allowed".into(),
106                    v_span.clone(),
107                ));
108            }
109
110            let out = Self {
111                name,
112                key: k.clone(),
113                value: v.clone(),
114            };
115            Ok(out.to_arcsort())
116        } else {
117            panic!()
118        }
119    }
120}
121
122impl ContainerSort for MapSort {
123    type Container = MapContainer;
124
125    fn name(&self) -> &str {
126        &self.name
127    }
128
129    fn inner_sorts(&self) -> Vec<ArcSort> {
130        vec![self.key.clone(), self.value.clone()]
131    }
132
133    fn is_eq_container_sort(&self) -> bool {
134        self.key.is_eq_sort() || self.value.is_eq_sort()
135    }
136
137    fn inner_values(
138        &self,
139        container_values: &ContainerValues,
140        value: Value,
141    ) -> Vec<(ArcSort, Value)> {
142        let val = container_values
143            .get_val::<MapContainer>(value)
144            .unwrap()
145            .clone();
146        val.data
147            .iter()
148            .flat_map(|(k, v)| [(self.key.clone(), *k), (self.value.clone(), *v)])
149            .collect()
150    }
151
152    fn register_primitives(&self, eg: &mut EGraph) {
153        let arc = self.clone().to_arcsort();
154
155        add_primitive!(eg, "map-empty" = {self.clone(): MapSort} || -> @MapContainer (arc) { MapContainer {
156            do_rebuild_keys: self.ctx.key.is_eq_sort(),
157            do_rebuild_vals: self.ctx.value.is_eq_sort(),
158            data: BTreeMap::new()
159        } });
160
161        add_primitive!(eg, "map-get"    = |    xs: @MapContainer (arc), x: # (self.key())                     | -?> # (self.value()) { xs.data.get(&x).copied() });
162        add_primitive!(eg, "map-insert" = |mut xs: @MapContainer (arc), x: # (self.key()), y: # (self.value())| -> @MapContainer (arc) {{ xs.data.insert(x, y); xs }});
163        add_primitive!(eg, "map-remove" = |mut xs: @MapContainer (arc), x: # (self.key())                     | -> @MapContainer (arc) {{ xs.data.remove(&x);   xs }});
164
165        add_primitive!(eg, "map-length"       = |xs: @MapContainer (arc)| -> i64 { xs.data.len() as i64 });
166        add_primitive!(eg, "map-contains"     = |xs: @MapContainer (arc), x: # (self.key())| -?> () { ( xs.data.contains_key(&x)).then_some(()) });
167        add_primitive!(eg, "map-not-contains" = |xs: @MapContainer (arc), x: # (self.key())| -?> () { (!xs.data.contains_key(&x)).then_some(()) });
168    }
169
170    fn reconstruct_termdag(
171        &self,
172        _container_values: &ContainerValues,
173        _value: Value,
174        termdag: &mut TermDag,
175        element_terms: Vec<Term>,
176    ) -> Term {
177        let mut term = termdag.app("map-empty".into(), vec![]);
178
179        for x in element_terms.chunks(2) {
180            term = termdag.app("map-insert".into(), vec![term, x[0].clone(), x[1].clone()])
181        }
182
183        term
184    }
185
186    fn serialized_name(&self, _container_values: &ContainerValues, _: Value) -> String {
187        self.name().to_owned()
188    }
189}