egglog/sort/
vec.rs

1use super::*;
2
3#[derive(Clone, Debug, PartialEq, Eq, Hash)]
4pub struct VecContainer {
5    pub do_rebuild: bool,
6    pub data: Vec<Value>,
7}
8
9impl ContainerValue for VecContainer {
10    fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool {
11        if self.do_rebuild {
12            rebuilder.rebuild_slice(&mut self.data)
13        } else {
14            false
15        }
16    }
17    fn iter(&self) -> impl Iterator<Item = Value> + '_ {
18        self.data.iter().copied()
19    }
20}
21
22#[derive(Clone, Debug)]
23pub struct VecSort {
24    name: String,
25    element: ArcSort,
26}
27
28impl VecSort {
29    pub fn element(&self) -> ArcSort {
30        self.element.clone()
31    }
32}
33
34impl Presort for VecSort {
35    fn presort_name() -> &'static str {
36        "Vec"
37    }
38
39    fn reserved_primitives() -> Vec<&'static str> {
40        vec![
41            "vec-of",
42            "vec-append",
43            "vec-empty",
44            "vec-push",
45            "vec-pop",
46            "vec-not-contains",
47            "vec-contains",
48            "vec-length",
49            "vec-get",
50            "vec-set",
51            "vec-remove",
52        ]
53    }
54
55    fn make_sort(
56        typeinfo: &mut TypeInfo,
57        name: String,
58        args: &[Expr],
59    ) -> Result<ArcSort, TypeError> {
60        if let [Expr::Var(span, e)] = args {
61            let e = typeinfo
62                .get_sort_by_name(e)
63                .ok_or(TypeError::UndefinedSort(e.clone(), span.clone()))?;
64
65            if e.is_eq_container_sort() {
66                return Err(TypeError::DisallowedSort(
67                    name,
68                    "Vec nested with other EqSort containers are not allowed".into(),
69                    span.clone(),
70                ));
71            }
72
73            let out = Self {
74                name,
75                element: e.clone(),
76            };
77            Ok(out.to_arcsort())
78        } else {
79            panic!("Vec sort must have sort as argument. Got {:?}", args)
80        }
81    }
82}
83
84impl ContainerSort for VecSort {
85    type Container = VecContainer;
86
87    fn name(&self) -> &str {
88        &self.name
89    }
90
91    fn inner_sorts(&self) -> Vec<ArcSort> {
92        vec![self.element.clone()]
93    }
94
95    fn is_eq_container_sort(&self) -> bool {
96        self.element.is_eq_sort()
97    }
98
99    fn inner_values(
100        &self,
101        container_values: &ContainerValues,
102        value: Value,
103    ) -> Vec<(ArcSort, Value)> {
104        let val = container_values
105            .get_val::<VecContainer>(value)
106            .unwrap()
107            .clone();
108        val.data
109            .iter()
110            .map(|e| (self.element.clone(), *e))
111            .collect()
112    }
113
114    fn register_primitives(&self, eg: &mut EGraph) {
115        let arc = self.clone().to_arcsort();
116
117        add_primitive!(eg, "vec-empty"  = {self.clone(): VecSort} |                                | -> @VecContainer (arc) { VecContainer {
118            do_rebuild: self.ctx.element.is_eq_sort(),
119            data: Vec::new()
120        } });
121        add_primitive!(eg, "vec-of"     = {self.clone(): VecSort} [xs: # (self.element())          ] -> @VecContainer (arc) { VecContainer {
122            do_rebuild: self.ctx.element.is_eq_sort(),
123            data: xs                     .collect()
124        } });
125        add_primitive!(eg, "vec-append" = {self.clone(): VecSort} [xs: @VecContainer (arc)] -> @VecContainer (arc) { VecContainer {
126            do_rebuild: self.ctx.element.is_eq_sort(),
127            data: xs.flat_map(|x| x.data).collect()
128        } });
129
130        add_primitive!(eg, "vec-push" = |mut xs: @VecContainer (arc), x: # (self.element())| -> @VecContainer (arc) {{ xs.data.push(x); xs }});
131        add_primitive!(eg, "vec-pop"  = |mut xs: @VecContainer (arc)                       | -> @VecContainer (arc) {{ xs.data.pop();   xs }});
132
133        add_primitive!(eg, "vec-length"       = |xs: @VecContainer (arc)| -> i64 { xs.data.len() as i64 });
134        add_primitive!(eg, "vec-contains"     = |xs: @VecContainer (arc), x: # (self.element())| -?> () { ( xs.data.contains(&x)).then_some(()) });
135        add_primitive!(eg, "vec-not-contains" = |xs: @VecContainer (arc), x: # (self.element())| -?> () { (!xs.data.contains(&x)).then_some(()) });
136
137        add_primitive!(eg, "vec-get"    = |    xs: @VecContainer (arc), i: i64                       | -?> # (self.element()) { xs.data.get(i as usize).copied() });
138        add_primitive!(eg, "vec-set"    = |mut xs: @VecContainer (arc), i: i64, x: # (self.element())| -> @VecContainer (arc) {{ xs.data[i as usize] = x;    xs }});
139        add_primitive!(eg, "vec-remove" = |mut xs: @VecContainer (arc), i: i64                       | -> @VecContainer (arc) {{ xs.data.remove(i as usize); xs }});
140    }
141
142    fn reconstruct_termdag(
143        &self,
144        _container_values: &ContainerValues,
145        _value: Value,
146        termdag: &mut TermDag,
147        element_terms: Vec<Term>,
148    ) -> Term {
149        if element_terms.is_empty() {
150            termdag.app("vec-empty".into(), vec![])
151        } else {
152            termdag.app("vec-of".into(), element_terms)
153        }
154    }
155
156    fn serialized_name(&self, _container_values: &ContainerValues, _: Value) -> String {
157        "vec-of".to_owned()
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_vec_make_expr() {
167        let mut egraph = EGraph::default();
168        let outputs = egraph
169            .parse_and_run_program(
170                None,
171                r#"
172            (sort IVec (Vec i64))
173            (let v0 (vec-empty))
174            (let v1 (vec-of 1 2 3 4))
175            (extract v0)
176            (extract v1)
177            "#,
178            )
179            .unwrap();
180
181        // Check extracted expr is parsed as an original expr
182        egraph
183            .parse_and_run_program(
184                None,
185                &format!(
186                    r#"
187                (check (= v0 {}))
188                (check (= v1 {}))
189                "#,
190                    outputs[0], outputs[1],
191                ),
192            )
193            .unwrap();
194    }
195}