egglog/sort/
set.rs

1use super::*;
2use std::collections::BTreeSet;
3
4#[derive(Clone, Debug, PartialEq, Eq, Hash)]
5pub struct SetContainer {
6    pub do_rebuild: bool,
7    pub data: BTreeSet<Value>,
8}
9
10impl ContainerValue for SetContainer {
11    fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool {
12        if self.do_rebuild {
13            let mut xs: Vec<_> = self.data.iter().copied().collect();
14            let changed = rebuilder.rebuild_slice(&mut xs);
15            self.data = xs.into_iter().collect();
16            changed
17        } else {
18            false
19        }
20    }
21    fn iter(&self) -> impl Iterator<Item = Value> + '_ {
22        self.data.iter().copied()
23    }
24}
25
26#[derive(Clone, Debug)]
27pub struct SetSort {
28    name: String,
29    element: ArcSort,
30}
31
32impl SetSort {
33    pub fn element(&self) -> ArcSort {
34        self.element.clone()
35    }
36}
37
38impl Presort for SetSort {
39    fn presort_name() -> &'static str {
40        "Set"
41    }
42
43    fn reserved_primitives() -> Vec<&'static str> {
44        vec![
45            "set-of",
46            "set-empty",
47            "set-insert",
48            "set-not-contains",
49            "set-contains",
50            "set-remove",
51            "set-union",
52            "set-diff",
53            "set-intersect",
54            "set-get",
55            "set-length",
56        ]
57    }
58
59    fn make_sort(
60        typeinfo: &mut TypeInfo,
61        name: String,
62        args: &[Expr],
63    ) -> Result<ArcSort, TypeError> {
64        if let [Expr::Var(span, e)] = args {
65            let e = typeinfo
66                .get_sort_by_name(e)
67                .ok_or(TypeError::UndefinedSort(e.clone(), span.clone()))?;
68
69            if e.is_eq_container_sort() {
70                return Err(TypeError::DisallowedSort(
71                    name,
72                    "Sets nested with other EqSort containers are not allowed".into(),
73                    span.clone(),
74                ));
75            }
76
77            let out = Self {
78                name,
79                element: e.clone(),
80            };
81            Ok(out.to_arcsort())
82        } else {
83            panic!()
84        }
85    }
86}
87
88impl ContainerSort for SetSort {
89    type Container = SetContainer;
90
91    fn name(&self) -> &str {
92        &self.name
93    }
94
95    fn inner_sorts(&self) -> Vec<ArcSort> {
96        vec![self.element.clone()]
97    }
98
99    fn is_eq_container_sort(&self) -> bool {
100        self.element.is_eq_sort()
101    }
102
103    fn inner_values(
104        &self,
105        container_values: &ContainerValues,
106        value: Value,
107    ) -> Vec<(ArcSort, Value)> {
108        let val = container_values
109            .get_val::<SetContainer>(value)
110            .unwrap()
111            .clone();
112        val.data
113            .iter()
114            .map(|e| (self.element.clone(), *e))
115            .collect()
116    }
117
118    fn register_primitives(&self, eg: &mut EGraph) {
119        let arc = self.clone().to_arcsort();
120
121        add_primitive!(eg, "set-empty" = {self.clone(): SetSort} |                      | -> @SetContainer (arc) { SetContainer {
122            do_rebuild: self.ctx.element.is_eq_sort(),
123            data: BTreeSet::new()
124        } });
125        add_primitive!(eg, "set-of"    = {self.clone(): SetSort} [xs: # (self.element())] -> @SetContainer (arc) { SetContainer {
126            do_rebuild: self.ctx.element.is_eq_sort(),
127            data: xs.collect()
128        } });
129
130        add_primitive!(eg, "set-get" = |xs: @SetContainer (arc), i: i64| -?> # (self.element()) { xs.data.iter().nth(i as usize).copied() });
131        add_primitive!(eg, "set-insert" = |mut xs: @SetContainer (arc), x: # (self.element())| -> @SetContainer (arc) {{ xs.data.insert( x); xs }});
132        add_primitive!(eg, "set-remove" = |mut xs: @SetContainer (arc), x: # (self.element())| -> @SetContainer (arc) {{ xs.data.remove(&x); xs }});
133
134        add_primitive!(eg, "set-length"       = |xs: @SetContainer (arc)| -> i64 { xs.data.len() as i64 });
135        add_primitive!(eg, "set-contains"     = |xs: @SetContainer (arc), x: # (self.element())| -?> () { ( xs.data.contains(&x)).then_some(()) });
136        add_primitive!(eg, "set-not-contains" = |xs: @SetContainer (arc), x: # (self.element())| -?> () { (!xs.data.contains(&x)).then_some(()) });
137
138        add_primitive!(eg, "set-union"      = |mut xs: @SetContainer (arc), ys: @SetContainer (arc)| -> @SetContainer (arc) {{ xs.data.extend(ys.data);                  xs }});
139        add_primitive!(eg, "set-diff"       = |mut xs: @SetContainer (arc), ys: @SetContainer (arc)| -> @SetContainer (arc) {{ xs.data.retain(|k| !ys.data.contains(k)); xs }});
140        add_primitive!(eg, "set-intersect"  = |mut xs: @SetContainer (arc), ys: @SetContainer (arc)| -> @SetContainer (arc) {{ xs.data.retain(|k|  ys.data.contains(k)); xs }});
141    }
142
143    fn reconstruct_termdag(
144        &self,
145        _container_values: &ContainerValues,
146        _value: Value,
147        termdag: &mut TermDag,
148        element_terms: Vec<Term>,
149    ) -> Term {
150        termdag.app("set-of".into(), element_terms)
151    }
152
153    fn serialized_name(&self, _container_values: &ContainerValues, _: Value) -> String {
154        "set-of".to_owned()
155    }
156}