egglog/
termdag.rs

1use crate::*;
2use std::fmt::Write;
3
4pub type TermId = usize;
5
6#[allow(rustdoc::private_intra_doc_links)]
7/// Like [`Expr`]s but with sharing and deduplication.
8///
9/// Terms refer to their children indirectly via opaque [TermId]s (internally
10/// these are just `usize`s) that map into an ambient [`TermDag`].
11#[derive(Clone, PartialEq, Eq, Hash, Debug)]
12pub enum Term {
13    Lit(Literal),
14    Var(String),
15    App(String, Vec<TermId>),
16}
17
18/// A hashconsing arena for [`Term`]s.
19#[derive(Clone, PartialEq, Eq, Debug, Default)]
20pub struct TermDag {
21    /// A bidirectional map between deduplicated `Term`s and indices.
22    nodes: IndexSet<Term>,
23}
24
25#[macro_export]
26macro_rules! match_term_app {
27    ($e:expr; $body:tt) => {
28        match $e {
29            Term::App(head, args) => {
30                match (head.as_str(), args.as_slice())
31                    $body
32            }
33            _ => panic!("not an app")
34        }
35    }
36}
37
38impl TermDag {
39    /// Returns the number of nodes in this DAG.
40    pub fn size(&self) -> usize {
41        self.nodes.len()
42    }
43
44    /// Convert the given term to its id.
45    ///
46    /// Panics if the term does not already exist in this [TermDag].
47    pub fn lookup(&self, node: &Term) -> TermId {
48        self.nodes.get_index_of(node).unwrap()
49    }
50
51    /// Convert the given id to the corresponding term.
52    ///
53    /// Panics if the id is not valid.
54    pub fn get(&self, id: TermId) -> &Term {
55        self.nodes.get_index(id).unwrap()
56    }
57
58    /// Make and return a [`Term::App`] with the given head symbol and children,
59    /// and insert into the DAG if it is not already present.
60    ///
61    /// Panics if any of the children are not already in the DAG.
62    pub fn app(&mut self, sym: String, children: Vec<Term>) -> Term {
63        let node = Term::App(sym, children.iter().map(|c| self.lookup(c)).collect());
64
65        self.add_node(&node);
66
67        node
68    }
69
70    /// Make and return a [`Term::Lit`] with the given literal, and insert into
71    /// the DAG if it is not already present.
72    pub fn lit(&mut self, lit: Literal) -> Term {
73        let node = Term::Lit(lit);
74
75        self.add_node(&node);
76
77        node
78    }
79
80    /// Make and return a [`Term::Var`] with the given symbol, and insert into
81    /// the DAG if it is not already present.
82    pub fn var(&mut self, sym: String) -> Term {
83        let node = Term::Var(sym);
84
85        self.add_node(&node);
86
87        node
88    }
89
90    fn add_node(&mut self, node: &Term) {
91        if self.nodes.get(node).is_none() {
92            self.nodes.insert(node.clone());
93        }
94    }
95
96    /// Recursively converts the given expression to a term.
97    ///
98    /// This involves inserting every subexpression into this DAG. Because
99    /// TermDags are hashconsed, the resulting term is guaranteed to maximally
100    /// share subterms.
101    pub fn expr_to_term(&mut self, expr: &GenericExpr<String, String>) -> Term {
102        let res = match expr {
103            GenericExpr::Lit(_, lit) => Term::Lit(lit.clone()),
104            GenericExpr::Var(_, v) => Term::Var(v.to_owned()),
105            GenericExpr::Call(_, op, args) => {
106                let args = args
107                    .iter()
108                    .map(|a| {
109                        let term = self.expr_to_term(a);
110                        self.lookup(&term)
111                    })
112                    .collect();
113                Term::App(op.clone(), args)
114            }
115        };
116        self.add_node(&res);
117        res
118    }
119
120    /// Recursively converts the given term to an expression.
121    ///
122    /// Panics if the term contains subterms that are not in the DAG.
123    pub fn term_to_expr(&self, term: &Term, span: Span) -> Expr {
124        match term {
125            Term::Lit(lit) => Expr::Lit(span, lit.clone()),
126            Term::Var(v) => Expr::Var(span, v.clone()),
127            Term::App(op, args) => {
128                let args: Vec<_> = args
129                    .iter()
130                    .map(|a| self.term_to_expr(self.get(*a), span.clone()))
131                    .collect();
132                Expr::Call(span, op.clone(), args)
133            }
134        }
135    }
136
137    /// Converts the given term to a string.
138    ///
139    /// Panics if the term or any of its subterms are not in the DAG.
140    pub fn to_string(&self, term: &Term) -> String {
141        let mut result = String::new();
142        // subranges of the `result` string containing already stringified subterms
143        let mut ranges = HashMap::<TermId, (usize, usize)>::default();
144        let id = self.lookup(term);
145        // use a stack to avoid stack overflow
146
147        let mut stack = vec![(id, false, None)];
148        while let Some((id, space_before, mut start_index)) = stack.pop() {
149            if space_before {
150                result.push(' ');
151            }
152
153            if let Some((start, end)) = ranges.get(&id) {
154                result.extend_from_within(*start..*end);
155                continue;
156            }
157
158            match self.nodes[id].clone() {
159                Term::App(name, children) => {
160                    if start_index.is_some() {
161                        result.push(')');
162                    } else {
163                        stack.push((id, false, Some(result.len())));
164                        write!(&mut result, "({}", name).unwrap();
165                        for c in children.iter().rev() {
166                            stack.push((*c, true, None));
167                        }
168                    }
169                }
170                Term::Lit(lit) => {
171                    start_index = Some(result.len());
172                    write!(&mut result, "{lit}").unwrap();
173                }
174                Term::Var(v) => {
175                    start_index = Some(result.len());
176                    write!(&mut result, "{v}").unwrap();
177                }
178            }
179
180            if let Some(start_index) = start_index {
181                ranges.insert(id, (start_index, result.len()));
182            }
183        }
184
185        result
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use crate::{ast::*, span};
193
194    fn parse_term(s: &str) -> (TermDag, Term) {
195        let e = Parser::default().get_expr_from_string(None, s).unwrap();
196        let mut td = TermDag::default();
197        let t = td.expr_to_term(&e);
198        (td, t)
199    }
200
201    #[test]
202    fn test_to_from_expr() {
203        let s = r#"(f (g x y) x y (g x y))"#;
204        let e = Parser::default().get_expr_from_string(None, s).unwrap();
205        let mut td = TermDag::default();
206        assert_eq!(td.size(), 0);
207        let t = td.expr_to_term(&e);
208        assert_eq!(td.size(), 4);
209        // the expression above has 4 distinct subterms.
210        // in left-to-right, depth-first order, they are:
211        //     x, y, (g x y), and the root call to f
212        // so we can compute expected answer by hand:
213        assert_eq!(
214            td.nodes.as_slice().iter().cloned().collect::<Vec<_>>(),
215            vec![
216                Term::Var("x".into()),
217                Term::Var("y".into()),
218                Term::App("g".into(), vec![0, 1]),
219                Term::App("f".into(), vec![2, 0, 1, 2]),
220            ]
221        );
222        // This is tested using string equality because e1 and e2 have different
223        let e2 = td.term_to_expr(&t, span!());
224        // annotations. A better way to test this would be to implement a map_ann
225        // function for GenericExpr.
226        assert_eq!(format!("{e}"), format!("{e2}")); // roundtrip
227    }
228
229    #[test]
230    fn test_match_term_app() {
231        let s = r#"(f (g x y) x y (g x y))"#;
232        let (td, t) = parse_term(s);
233        match_term_app!(t; {
234            ("f", [_, x, _, _]) => {
235                let span = span!();
236                assert_eq!(
237                    td.term_to_expr(td.get(*x), span.clone()),
238                    crate::ast::GenericExpr::Var(span, "x".to_owned())
239                )
240            }
241            (head, _) => panic!("unexpected head {}, in {}:{}:{}", head, file!(), line!(), column!())
242        })
243    }
244
245    #[test]
246    fn test_to_string() {
247        let s = r#"(f (g x y) x y (g x y))"#;
248        let (td, t) = parse_term(s);
249        assert_eq!(td.to_string(&t), s);
250    }
251
252    #[test]
253    fn test_lookup() {
254        let s = r#"(f (g x y) x y (g x y))"#;
255        let (td, t) = parse_term(s);
256        assert_eq!(td.lookup(&t), td.size() - 1);
257    }
258
259    #[test]
260    fn test_app_var_lit() {
261        let s = r#"(f (g x y) x 7 (g x y))"#;
262        let (mut td, t) = parse_term(s);
263        let x = td.var("x".into());
264        let y = td.var("y".into());
265        let seven = td.lit(7.into());
266        let g = td.app("g".into(), vec![x.clone(), y.clone()]);
267        let t2 = td.app("f".into(), vec![g.clone(), x, seven, g]);
268        assert_eq!(t, t2);
269    }
270}