egglog/sort/
multiset.rs

1use super::*;
2use inner::MultiSet;
3
4#[derive(Clone, Debug, PartialEq, Eq, Hash)]
5pub struct MultiSetContainer {
6    pub do_rebuild: bool,
7    pub data: MultiSet<Value>,
8}
9
10impl ContainerValue for MultiSetContainer {
11    fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool {
12        if self.do_rebuild {
13            let mut xs: Vec<_> = self.data.iter().copied().collect();
14            let changed = rebuilder.rebuild_slice(&mut xs);
15            self.data = xs.into_iter().collect();
16            changed
17        } else {
18            false
19        }
20    }
21    fn iter(&self) -> impl Iterator<Item = Value> + '_ {
22        self.data.iter().copied()
23    }
24}
25
26#[derive(Clone, Debug)]
27pub struct MultiSetSort {
28    name: String,
29    element: ArcSort,
30}
31
32impl MultiSetSort {
33    pub fn element(&self) -> ArcSort {
34        self.element.clone()
35    }
36}
37
38impl Presort for MultiSetSort {
39    fn presort_name() -> &'static str {
40        "MultiSet"
41    }
42
43    fn reserved_primitives() -> Vec<&'static str> {
44        vec![
45            "multiset-of",
46            "multiset-insert",
47            "multiset-contains",
48            "multiset-not-contains",
49            "multiset-remove",
50            "multiset-length",
51            "multiset-sum",
52            "unstable-multiset-map",
53        ]
54    }
55
56    fn make_sort(
57        typeinfo: &mut TypeInfo,
58        name: String,
59        args: &[Expr],
60    ) -> Result<ArcSort, TypeError> {
61        if let [Expr::Var(span, e)] = args {
62            let e = typeinfo
63                .get_sort_by_name(e)
64                .ok_or(TypeError::UndefinedSort(e.clone(), span.clone()))?;
65
66            if e.is_eq_container_sort() {
67                return Err(TypeError::DisallowedSort(
68                    name,
69                    "Multisets nested with other EqSort containers are not allowed".into(),
70                    span.clone(),
71                ));
72            }
73
74            let out = Self {
75                name,
76                element: e.clone(),
77            };
78            Ok(out.to_arcsort())
79        } else {
80            panic!()
81        }
82    }
83}
84
85impl ContainerSort for MultiSetSort {
86    type Container = MultiSetContainer;
87
88    fn name(&self) -> &str {
89        &self.name
90    }
91
92    fn inner_sorts(&self) -> Vec<ArcSort> {
93        vec![self.element.clone()]
94    }
95
96    fn is_eq_container_sort(&self) -> bool {
97        self.element.is_eq_sort()
98    }
99
100    fn inner_values(
101        &self,
102        container_values: &ContainerValues,
103        value: Value,
104    ) -> Vec<(ArcSort, Value)> {
105        let val = container_values
106            .get_val::<MultiSetContainer>(value)
107            .unwrap()
108            .clone();
109        val.data
110            .iter()
111            .map(|k| (self.element.clone(), *k))
112            .collect()
113    }
114
115    fn register_primitives(&self, eg: &mut EGraph) {
116        let arc = self.clone().to_arcsort();
117
118        add_primitive!(eg, "multiset-of" = {self.clone(): MultiSetSort} [xs: # (self.element())] -> @MultiSetContainer (arc) { MultiSetContainer {
119            do_rebuild: self.ctx.element.is_eq_sort(),
120            data: xs.collect()
121        } });
122
123        add_primitive!(eg, "multiset-pick" = |xs: @MultiSetContainer (arc)| -> # (self.element()) { *xs.data.pick().expect("Cannot pick from an empty multiset") });
124        add_primitive!(eg, "multiset-insert" = |mut xs: @MultiSetContainer (arc), x: # (self.element())| -> @MultiSetContainer (arc) { MultiSetContainer { data: xs.data.insert( x) , ..xs } });
125        add_primitive!(eg, "multiset-remove" = |mut xs: @MultiSetContainer (arc), x: # (self.element())| -> @MultiSetContainer (arc) { MultiSetContainer { data: xs.data.remove(&x)?, ..xs } });
126
127        add_primitive!(eg, "multiset-length"       = |xs: @MultiSetContainer (arc)| -> i64 { xs.data.len() as i64 });
128        add_primitive!(eg, "multiset-contains"     = |xs: @MultiSetContainer (arc), x: # (self.element())| -?> () { ( xs.data.contains(&x)).then_some(()) });
129        add_primitive!(eg, "multiset-not-contains" = |xs: @MultiSetContainer (arc), x: # (self.element())| -?> () { (!xs.data.contains(&x)).then_some(()) });
130
131        add_primitive!(eg, "multiset-sum" = |xs: @MultiSetContainer (arc), ys: @MultiSetContainer (arc)| -> @MultiSetContainer (arc) { MultiSetContainer { data: xs.data.sum(ys.data), ..xs } });
132
133        // Only include map function if we already declared a function sort with the correct signature
134        let fn_sorts = eg.type_info.get_sorts_by(|s: &Arc<FunctionSort>| {
135            (s.inputs().len() == 1)
136                && (s.inputs()[0].name() == self.element.name())
137                && (s.output().name() == self.element.name())
138        });
139        match fn_sorts.len() {
140            0 => {}
141            1 => eg.add_primitive(Map {
142                name: "unstable-multiset-map".into(),
143                multiset: arc,
144                fn_: fn_sorts.into_iter().next().unwrap(),
145            }),
146            _ => panic!("too many applicable function sorts"),
147        }
148    }
149
150    fn reconstruct_termdag(
151        &self,
152        _container_values: &ContainerValues,
153        _value: Value,
154        termdag: &mut TermDag,
155        element_terms: Vec<Term>,
156    ) -> Term {
157        termdag.app("multiset-of".into(), element_terms)
158    }
159
160    fn serialized_name(&self, _container_values: &ContainerValues, _: Value) -> String {
161        "multiset-of".to_owned()
162    }
163}
164
165#[derive(Clone)]
166struct Map {
167    name: String,
168    multiset: ArcSort,
169    fn_: Arc<FunctionSort>,
170}
171
172impl Primitive for Map {
173    fn name(&self) -> &str {
174        &self.name
175    }
176
177    fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
178        SimpleTypeConstraint::new(
179            self.name(),
180            vec![
181                self.fn_.clone(),
182                self.multiset.clone(),
183                self.multiset.clone(),
184            ],
185            span.clone(),
186        )
187        .into_box()
188    }
189
190    fn apply(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
191        let fc = exec_state
192            .container_values()
193            .get_val::<FunctionContainer>(args[0])
194            .unwrap()
195            .clone();
196        let multiset = exec_state
197            .container_values()
198            .get_val::<MultiSetContainer>(args[1])
199            .unwrap()
200            .clone();
201        let multiset = MultiSetContainer {
202            data: multiset
203                .data
204                .iter()
205                .map(|e| fc.apply(exec_state, &[*e]))
206                .collect::<Option<_>>()?,
207            ..multiset
208        };
209        Some(
210            exec_state
211                .clone()
212                .container_values()
213                .register_val(multiset, exec_state),
214        )
215    }
216}
217
218// Place multiset in its own module to keep implementation details private from sort
219mod inner {
220    use std::collections::BTreeMap;
221    use std::hash::Hash;
222    /// Immutable multiset implementation, which is threadsafe and hash stable, regardless of insertion order.
223    ///
224    /// All methods that return a new multiset take ownership of the old multiset.
225    #[derive(Debug, Default, Hash, Eq, PartialEq, Clone)]
226    pub struct MultiSet<T: Clone + Hash + Ord>(
227        /// All values should be > 0
228        BTreeMap<T, usize>,
229        /// cached length
230        usize,
231    );
232
233    impl<T: Clone + Hash + Ord> MultiSet<T> {
234        /// Create a new empty multiset.
235        pub fn new() -> Self {
236            MultiSet(BTreeMap::new(), 0)
237        }
238
239        /// Check if the multiset contains a key.
240        pub fn contains(&self, value: &T) -> bool {
241            self.0.contains_key(value)
242        }
243
244        /// Return the total number of elements in the multiset.
245        pub fn len(&self) -> usize {
246            self.1
247        }
248
249        /// Return an iterator over all elements in the multiset.
250        pub fn iter(&self) -> impl Iterator<Item = &T> {
251            self.0.iter().flat_map(|(k, v)| std::iter::repeat_n(k, *v))
252        }
253
254        /// Return an arbitrary element from the multiset.
255        pub fn pick(&self) -> Option<&T> {
256            self.0.keys().next()
257        }
258
259        /// Insert a value into the multiset, taking ownership of it and returning a new multiset.
260        pub fn insert(mut self, value: T) -> MultiSet<T> {
261            self.insert_multiple_mut(value, 1);
262            self
263        }
264
265        /// Remove a value from the multiset, taking ownership of it and returning a new multiset.
266        pub fn remove(mut self, value: &T) -> Option<MultiSet<T>> {
267            if let Some(v) = self.0.get(value) {
268                self.1 -= 1;
269                if *v == 1 {
270                    self.0.remove(value);
271                } else {
272                    self.0.insert(value.clone(), v - 1);
273                }
274                Some(self)
275            } else {
276                None
277            }
278        }
279
280        fn insert_multiple_mut(&mut self, value: T, n: usize) {
281            self.1 += n;
282            if let Some(v) = self.0.get(&value) {
283                self.0.insert(value, v + n);
284            } else {
285                self.0.insert(value, n);
286            }
287        }
288
289        /// Compute the sum of two multisets.
290        pub fn sum(mut self, MultiSet(other_map, other_count): Self) -> Self {
291            let target_count = self.1 + other_count;
292            for (k, v) in other_map {
293                self.insert_multiple_mut(k, v);
294            }
295            assert_eq!(self.1, target_count);
296            self
297        }
298    }
299
300    impl<T: Clone + Hash + Ord> FromIterator<T> for MultiSet<T> {
301        fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
302            let mut multiset = MultiSet::new();
303            for value in iter {
304                multiset.insert_multiple_mut(value, 1);
305            }
306            multiset
307        }
308    }
309}