egglog/sort/
pair.rs

1use super::*;
2
3#[derive(Clone, Debug, PartialEq, Eq, Hash)]
4pub struct PairContainer {
5    do_rebuild_first: bool,
6    do_rebuild_second: bool,
7    pub first: Value,
8    pub second: Value,
9}
10
11impl ContainerValue for PairContainer {
12    fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool {
13        let mut changed = false;
14        if self.do_rebuild_first {
15            let new = rebuilder.rebuild_val(self.first);
16            changed |= self.first != new;
17            self.first = new;
18        }
19        if self.do_rebuild_second {
20            let new = rebuilder.rebuild_val(self.second);
21            changed |= self.second != new;
22            self.second = new;
23        }
24        changed
25    }
26    fn iter(&self) -> impl Iterator<Item = Value> + '_ {
27        [self.first, self.second].into_iter()
28    }
29}
30
31/// A pair of two values supporting these primitives:
32/// - `pair`
33/// - `pair-first`
34/// - `pair-second`
35#[derive(Clone, Debug)]
36pub struct PairSort {
37    name: String,
38    first: ArcSort,
39    second: ArcSort,
40}
41
42impl PairSort {
43    pub fn first(&self) -> ArcSort {
44        self.first.clone()
45    }
46
47    pub fn second(&self) -> ArcSort {
48        self.second.clone()
49    }
50}
51
52impl Presort for PairSort {
53    fn presort_name() -> &'static str {
54        "Pair"
55    }
56
57    fn reserved_primitives() -> Vec<&'static str> {
58        vec!["pair", "pair-first", "pair-second"]
59    }
60
61    fn make_sort(
62        typeinfo: &mut TypeInfo,
63        name: String,
64        args: &[Expr],
65    ) -> Result<ArcSort, TypeError> {
66        if let [Expr::Var(a_span, a), Expr::Var(b_span, b)] = args {
67            let a = typeinfo
68                .get_sort_by_name(a)
69                .ok_or(TypeError::UndefinedSort(a.clone(), a_span.clone()))?;
70            let b = typeinfo
71                .get_sort_by_name(b)
72                .ok_or(TypeError::UndefinedSort(b.clone(), b_span.clone()))?;
73
74            let out = Self {
75                name,
76                first: a.clone(),
77                second: b.clone(),
78            };
79            Ok(out.to_arcsort())
80        } else {
81            panic!("Pair sort requires exactly two arguments")
82        }
83    }
84}
85
86impl ContainerSort for PairSort {
87    type Container = PairContainer;
88
89    fn name(&self) -> &str {
90        &self.name
91    }
92
93    fn inner_sorts(&self) -> Vec<ArcSort> {
94        vec![self.first.clone(), self.second.clone()]
95    }
96
97    fn is_eq_container_sort(&self) -> bool {
98        self.first.is_eq_sort()
99            || self.second.is_eq_sort()
100            || self.first.is_eq_container_sort()
101            || self.second.is_eq_container_sort()
102    }
103
104    fn inner_values(
105        &self,
106        container_values: &ContainerValues,
107        value: Value,
108    ) -> Vec<(ArcSort, Value)> {
109        let val = container_values
110            .get_val::<PairContainer>(value)
111            .unwrap()
112            .clone();
113        vec![
114            (self.first.clone(), val.first),
115            (self.second.clone(), val.second),
116        ]
117    }
118
119    fn register_primitives(&self, eg: &mut EGraph) {
120        let arc = self.clone().to_arcsort();
121
122        add_primitive!(eg, "pair" = {self.clone(): PairSort} |x: # (self.first()), y: # (self.second())| -> @PairContainer (arc) {
123            PairContainer {
124                do_rebuild_first: self.ctx.first.is_eq_sort() || self.ctx.first.is_eq_container_sort(),
125                do_rebuild_second: self.ctx.second.is_eq_sort() || self.ctx.second.is_eq_container_sort(),
126                first: x,
127                second: y,
128            }
129        });
130
131        add_primitive!(eg, "pair-first"  = |xs: @PairContainer (arc)| -> # (self.first())  { xs.first  });
132        add_primitive!(eg, "pair-second" = |xs: @PairContainer (arc)| -> # (self.second()) { xs.second });
133    }
134
135    fn reconstruct_termdag(
136        &self,
137        _container_values: &ContainerValues,
138        _value: Value,
139        termdag: &mut TermDag,
140        element_terms: Vec<TermId>,
141    ) -> TermId {
142        assert_eq!(element_terms.len(), 2);
143        termdag.app("pair".into(), vec![element_terms[0], element_terms[1]])
144    }
145
146    fn serialized_name(&self, _container_values: &ContainerValues, _: Value) -> String {
147        self.name().to_owned()
148    }
149}