egglog_ast/
generic_ast_helpers.rs

1use std::fmt::{Display, Formatter};
2use std::hash::Hash;
3
4use ordered_float::OrderedFloat;
5
6use super::util::ListDisplay;
7use crate::generic_ast::*;
8use crate::span::Span;
9
10// Macro to implement From conversions for Literal types
11macro_rules! impl_from {
12    ($ctor:ident($t:ty)) => {
13        impl From<Literal> for $t {
14            fn from(literal: Literal) -> Self {
15                match literal {
16                    Literal::$ctor(t) => t,
17                    #[allow(unreachable_patterns)]
18                    _ => panic!("Expected {}, got {literal}", stringify!($ctor)),
19                }
20            }
21        }
22
23        impl From<$t> for Literal {
24            fn from(t: $t) -> Self {
25                Literal::$ctor(t)
26            }
27        }
28    };
29}
30
31pub const INTERNAL_SYMBOL_PREFIX: &str = "@";
32
33impl<Head: Display, Leaf: Display> Display for GenericRule<Head, Leaf>
34where
35    Head: Clone + Display,
36    Leaf: Clone + PartialEq + Eq + Display + Hash,
37{
38    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
39        let indent = " ".repeat(7);
40        write!(f, "(rule (")?;
41        for (i, fact) in self.body.iter().enumerate() {
42            if i > 0 {
43                write!(f, "{indent}")?;
44            }
45
46            if i != self.body.len() - 1 {
47                writeln!(f, "{fact}")?;
48            } else {
49                write!(f, "{fact}")?;
50            }
51        }
52        write!(f, ")\n      (")?;
53        for (i, action) in self.head.0.iter().enumerate() {
54            if i > 0 {
55                write!(f, "{indent}")?;
56            }
57            if i != self.head.0.len() - 1 {
58                writeln!(f, "{action}")?;
59            } else {
60                write!(f, "{action}")?;
61            }
62        }
63        let ruleset = if !self.ruleset.is_empty() {
64            format!(":ruleset {}", &self.ruleset)
65        } else {
66            "".into()
67        };
68        let name = if !self.name.is_empty() {
69            format!(":name \"{}\"", &self.name)
70        } else {
71            "".into()
72        };
73        let naive = if self.naive { " :naive" } else { "" };
74        let no_decomp = if self.no_decomp { " :no-decomp" } else { "" };
75        let include_subsumed = if self.include_subsumed {
76            " :internal-include-subsumed"
77        } else {
78            ""
79        };
80        write!(
81            f,
82            ")\n{indent} {ruleset} {name}{naive}{no_decomp}{include_subsumed})"
83        )
84    }
85}
86
87// Use the macro for Int, Float, and String conversions
88impl_from!(Int(i64));
89impl_from!(Float(OrderedFloat<f64>));
90impl_from!(String(String));
91
92impl<Head: Display, Leaf: Display> Display for GenericFact<Head, Leaf> {
93    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
94        match self {
95            GenericFact::Eq(_, e1, e2) => write!(f, "(= {e1} {e2})"),
96            GenericFact::Fact(expr) => write!(f, "{expr}"),
97        }
98    }
99}
100
101// Implement Display for GenericAction
102impl<Head: Display, Leaf: Display> Display for GenericAction<Head, Leaf>
103where
104    Head: Clone + Display,
105    Leaf: Clone + PartialEq + Eq + Display + Hash,
106{
107    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
108        match self {
109            GenericAction::Let(_, lhs, rhs) => write!(f, "(let {lhs} {rhs})"),
110            GenericAction::Set(_, lhs, args, rhs) => {
111                if args.is_empty() {
112                    write!(f, "(set ({lhs}) {rhs})")
113                } else {
114                    write!(
115                        f,
116                        "(set ({} {}) {})",
117                        lhs,
118                        args.iter()
119                            .map(|a| format!("{a}"))
120                            .collect::<Vec<_>>()
121                            .join(" "),
122                        rhs
123                    )
124                }
125            }
126            GenericAction::Union(_, lhs, rhs) => write!(f, "(union {lhs} {rhs})"),
127            GenericAction::Change(_, change, lhs, args) => {
128                let change_str = match change {
129                    Change::Delete => "delete",
130                    Change::Subsume => "subsume",
131                };
132                if args.is_empty() {
133                    write!(f, "({change_str} ({lhs}))")
134                } else {
135                    write!(
136                        f,
137                        "({} ({} {}))",
138                        change_str,
139                        lhs,
140                        args.iter()
141                            .map(|a| format!("{a}"))
142                            .collect::<Vec<_>>()
143                            .join(" ")
144                    )
145                }
146            }
147            GenericAction::Panic(_, msg) => write!(f, "(panic \"{msg}\")"),
148            GenericAction::Expr(_, e) => write!(f, "{e}"),
149        }
150    }
151}
152
153impl<Head, Leaf> Display for GenericExpr<Head, Leaf>
154where
155    Head: Display,
156    Leaf: Display,
157{
158    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
159        match self {
160            GenericExpr::Lit(_ann, lit) => write!(f, "{lit}"),
161            GenericExpr::Var(_ann, var) => write!(f, "{var}"),
162            GenericExpr::Call(_ann, op, children) => match children.is_empty() {
163                true => write!(f, "({op})"),
164                false => write!(f, "({} {})", op, ListDisplay(children, " ")),
165            },
166        }
167    }
168}
169
170impl<Head, Leaf> Default for GenericActions<Head, Leaf>
171where
172    Head: Clone + Display,
173    Leaf: Clone + PartialEq + Eq + Display + Hash,
174{
175    fn default() -> Self {
176        Self(vec![])
177    }
178}
179
180impl<Head, Leaf> GenericRule<Head, Leaf>
181where
182    Head: Clone + Display,
183    Leaf: Clone + PartialEq + Eq + Display + Hash,
184{
185    /// Applies `f` to every expression that appears in the rule body or head.
186    pub fn visit_exprs(
187        self,
188        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
189    ) -> Self {
190        Self {
191            span: self.span,
192            head: self.head.visit_exprs(f),
193            body: self
194                .body
195                .into_iter()
196                .map(|bexpr| bexpr.visit_exprs(f))
197                .collect(),
198            name: self.name.clone(),
199            ruleset: self.ruleset.clone(),
200            naive: self.naive,
201            no_decomp: self.no_decomp,
202            include_subsumed: self.include_subsumed,
203        }
204    }
205
206    /// Applies `f` to each action in the rule head, leaving the body unchanged.
207    pub fn visit_actions(
208        self,
209        f: &mut impl FnMut(GenericAction<Head, Leaf>) -> GenericAction<Head, Leaf>,
210    ) -> Self {
211        Self {
212            span: self.span,
213            head: self.head.visit_actions(f),
214            body: self.body,
215            name: self.name,
216            ruleset: self.ruleset,
217            naive: self.naive,
218            no_decomp: self.no_decomp,
219            include_subsumed: self.include_subsumed,
220        }
221    }
222
223    /// Applies the provided `head` and `leaf` mappings to every symbol that appears in the rule.
224    pub fn map_symbols<Head2, Leaf2>(
225        self,
226        head: &mut impl FnMut(Head) -> Head2,
227        leaf: &mut impl FnMut(Leaf) -> Leaf2,
228    ) -> GenericRule<Head2, Leaf2>
229    where
230        Head2: Clone + Display,
231        Leaf2: Clone + PartialEq + Eq + Display + Hash,
232    {
233        GenericRule {
234            span: self.span,
235            head: self.head.map_symbols(head, leaf),
236            body: self
237                .body
238                .into_iter()
239                .map(|fact| fact.map_symbols(head, leaf))
240                .collect(),
241            name: self.name,
242            ruleset: self.ruleset,
243            naive: self.naive,
244            no_decomp: self.no_decomp,
245            include_subsumed: self.include_subsumed,
246        }
247    }
248
249    /// Converts the rule into its unresolved representation by formatting heads and leaves.
250    pub fn make_unresolved(self) -> GenericRule<String, String> {
251        let mut map_head = |h: Head| h.to_string();
252        let mut map_leaf = |l: Leaf| l.to_string();
253        self.map_symbols(&mut map_head, &mut map_leaf)
254    }
255}
256
257impl<Head, Leaf> GenericActions<Head, Leaf>
258where
259    Head: Clone + Display,
260    Leaf: Clone + PartialEq + Eq + Display + Hash,
261{
262    pub fn len(&self) -> usize {
263        self.0.len()
264    }
265
266    pub fn is_empty(&self) -> bool {
267        self.0.is_empty()
268    }
269
270    pub fn iter(&self) -> impl Iterator<Item = &GenericAction<Head, Leaf>> {
271        self.0.iter()
272    }
273
274    pub fn visit_vars(&self, f: &mut impl FnMut(&Span, &Leaf)) {
275        for action in &self.0 {
276            action.visit_vars(f);
277        }
278    }
279
280    /// Transforms every expression appearing in the action list using `f`.
281    pub fn visit_exprs(
282        self,
283        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
284    ) -> Self {
285        Self(self.0.into_iter().map(|a| a.visit_exprs(f)).collect())
286    }
287
288    /// Rewrites each action in the collection with the provided closure.
289    pub fn visit_actions(
290        self,
291        f: &mut impl FnMut(GenericAction<Head, Leaf>) -> GenericAction<Head, Leaf>,
292    ) -> Self {
293        Self(self.0.into_iter().map(f).collect())
294    }
295
296    pub fn new(actions: Vec<GenericAction<Head, Leaf>>) -> Self {
297        Self(actions)
298    }
299
300    pub fn singleton(action: GenericAction<Head, Leaf>) -> Self {
301        Self(vec![action])
302    }
303
304    /// Applies the provided `head` and `leaf` mappings to each action.
305    pub fn map_symbols<Head2, Leaf2>(
306        self,
307        head: &mut impl FnMut(Head) -> Head2,
308        leaf: &mut impl FnMut(Leaf) -> Leaf2,
309    ) -> GenericActions<Head2, Leaf2>
310    where
311        Head2: Clone + Display,
312        Leaf2: Clone + PartialEq + Eq + Display + Hash,
313    {
314        GenericActions(
315            self.0
316                .into_iter()
317                .map(|action| action.map_symbols(head, leaf))
318                .collect(),
319        )
320    }
321
322    /// Converts the actions into their unresolved representation by formatting heads and leaves.
323    pub fn make_unresolved(self) -> GenericActions<String, String> {
324        let mut map_head = |h: Head| h.to_string();
325        let mut map_leaf = |l: Leaf| l.to_string();
326        self.map_symbols(&mut map_head, &mut map_leaf)
327    }
328}
329
330impl<Head, Leaf> GenericAction<Head, Leaf>
331where
332    Head: Clone + Display,
333    Leaf: Clone + Eq + Display + Hash,
334{
335    pub fn visit_vars(&self, f: &mut impl FnMut(&Span, &Leaf)) {
336        if let GenericAction::Let(span, lhs, _) = self {
337            f(span, lhs);
338        }
339        let mut visit = |expr: GenericExpr<Head, Leaf>| match expr {
340            GenericExpr::Var(span, var) => {
341                f(&span, &var);
342                GenericExpr::Var(span, var)
343            }
344            other => other,
345        };
346        let _ = self.clone().visit_exprs(&mut visit);
347    }
348
349    // Applys `f` to all expressions in the action.
350    pub fn map_exprs(
351        &self,
352        f: &mut impl FnMut(&GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
353    ) -> Self {
354        match self {
355            GenericAction::Let(span, lhs, rhs) => {
356                GenericAction::Let(span.clone(), lhs.clone(), f(rhs))
357            }
358            GenericAction::Set(span, lhs, args, rhs) => {
359                let right = f(rhs);
360                GenericAction::Set(
361                    span.clone(),
362                    lhs.clone(),
363                    args.iter().map(f).collect(),
364                    right,
365                )
366            }
367            GenericAction::Change(span, change, lhs, args) => GenericAction::Change(
368                span.clone(),
369                *change,
370                lhs.clone(),
371                args.iter().map(f).collect(),
372            ),
373            GenericAction::Union(span, lhs, rhs) => {
374                GenericAction::Union(span.clone(), f(lhs), f(rhs))
375            }
376            GenericAction::Panic(span, msg) => GenericAction::Panic(span.clone(), msg.clone()),
377            GenericAction::Expr(span, e) => GenericAction::Expr(span.clone(), f(e)),
378        }
379    }
380
381    /// Applys `f` to all sub-expressions (including `self`)
382    /// bottom-up, collecting the results.
383    pub fn visit_exprs(
384        self,
385        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
386    ) -> Self {
387        match self {
388            GenericAction::Let(span, lhs, rhs) => {
389                GenericAction::Let(span, lhs.clone(), rhs.visit_exprs(f))
390            }
391            // TODO should we refactor `Set` so that we can map over Expr::Call(lhs, args)?
392            // This seems more natural to oflatt
393            // Currently, visit_exprs does not apply f to the first argument of Set.
394            GenericAction::Set(span, lhs, args, rhs) => {
395                let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
396                GenericAction::Set(span, lhs.clone(), args, rhs.visit_exprs(f))
397            }
398            GenericAction::Change(span, change, lhs, args) => {
399                let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
400                GenericAction::Change(span, change, lhs.clone(), args)
401            }
402            GenericAction::Union(span, lhs, rhs) => {
403                GenericAction::Union(span, lhs.visit_exprs(f), rhs.visit_exprs(f))
404            }
405            GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg.clone()),
406            GenericAction::Expr(span, e) => GenericAction::Expr(span, e.visit_exprs(f)),
407        }
408    }
409
410    pub fn subst(&self, subst: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head, Leaf>) -> Self {
411        self.map_exprs(&mut |e| e.subst_leaf(subst))
412    }
413
414    pub fn map_def_use(self, fvar: &mut impl FnMut(Leaf, bool) -> Leaf) -> Self {
415        macro_rules! fvar_expr {
416            () => {
417                |span, s: _| GenericExpr::Var(span.clone(), fvar(s.clone(), false))
418            };
419        }
420        match self {
421            GenericAction::Let(span, lhs, rhs) => {
422                let lhs = fvar(lhs, true);
423                let rhs = rhs.subst_leaf(&mut fvar_expr!());
424                GenericAction::Let(span, lhs, rhs)
425            }
426            GenericAction::Set(span, lhs, args, rhs) => {
427                let args = args
428                    .into_iter()
429                    .map(|e| e.subst_leaf(&mut fvar_expr!()))
430                    .collect();
431                let rhs = rhs.subst_leaf(&mut fvar_expr!());
432                GenericAction::Set(span, lhs.clone(), args, rhs)
433            }
434            GenericAction::Change(span, change, lhs, args) => {
435                let args = args
436                    .into_iter()
437                    .map(|e| e.subst_leaf(&mut fvar_expr!()))
438                    .collect();
439                GenericAction::Change(span, change, lhs.clone(), args)
440            }
441            GenericAction::Union(span, lhs, rhs) => {
442                let lhs = lhs.subst_leaf(&mut fvar_expr!());
443                let rhs = rhs.subst_leaf(&mut fvar_expr!());
444                GenericAction::Union(span, lhs, rhs)
445            }
446            GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg.clone()),
447            GenericAction::Expr(span, e) => {
448                GenericAction::Expr(span, e.subst_leaf(&mut fvar_expr!()))
449            }
450        }
451    }
452
453    /// Applies the provided `head` and `leaf` mappings to the action and all nested expressions.
454    pub fn map_symbols<Head2, Leaf2>(
455        self,
456        head: &mut impl FnMut(Head) -> Head2,
457        leaf: &mut impl FnMut(Leaf) -> Leaf2,
458    ) -> GenericAction<Head2, Leaf2>
459    where
460        Head2: Clone + Display,
461        Leaf2: Clone + Eq + Display + Hash,
462    {
463        match self {
464            GenericAction::Let(span, lhs, rhs) => {
465                GenericAction::Let(span, leaf(lhs), rhs.map_symbols(head, leaf))
466            }
467            GenericAction::Set(span, head_sym, args, rhs) => {
468                let mut mapped_args = Vec::with_capacity(args.len());
469                for arg in args {
470                    mapped_args.push(arg.map_symbols(head, leaf));
471                }
472                GenericAction::Set(
473                    span,
474                    head(head_sym),
475                    mapped_args,
476                    rhs.map_symbols(head, leaf),
477                )
478            }
479            GenericAction::Change(span, change, head_sym, args) => {
480                let mut mapped_args = Vec::with_capacity(args.len());
481                for arg in args {
482                    mapped_args.push(arg.map_symbols(head, leaf));
483                }
484                GenericAction::Change(span, change, head(head_sym), mapped_args)
485            }
486            GenericAction::Union(span, lhs, rhs) => GenericAction::Union(
487                span,
488                lhs.map_symbols(head, leaf),
489                rhs.map_symbols(head, leaf),
490            ),
491            GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg),
492            GenericAction::Expr(span, expr) => {
493                GenericAction::Expr(span, expr.map_symbols(head, leaf))
494            }
495        }
496    }
497
498    /// Converts the action into its unresolved representation using String by
499    /// formatting heads and leaves.
500    pub fn make_unresolved(self) -> GenericAction<String, String> {
501        let mut map_head = |h: Head| h.to_string();
502        let mut map_leaf = |l: Leaf| l.to_string();
503        self.map_symbols(&mut map_head, &mut map_leaf)
504    }
505}
506
507impl<Head, Leaf> GenericFact<Head, Leaf>
508where
509    Head: Clone + Display,
510    Leaf: Clone + PartialEq + Eq + Display + Hash,
511{
512    pub fn visit_vars(&self, f: &mut impl FnMut(&Span, &Leaf)) {
513        let mut visit = |expr: GenericExpr<Head, Leaf>| match expr {
514            GenericExpr::Var(span, var) => {
515                f(&span, &var);
516                GenericExpr::Var(span, var)
517            }
518            other => other,
519        };
520        let _ = self.clone().visit_exprs(&mut visit);
521    }
522
523    pub fn visit_exprs(
524        self,
525        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
526    ) -> GenericFact<Head, Leaf> {
527        match self {
528            GenericFact::Eq(span, e1, e2) => {
529                GenericFact::Eq(span, e1.visit_exprs(f), e2.visit_exprs(f))
530            }
531            GenericFact::Fact(expr) => GenericFact::Fact(expr.visit_exprs(f)),
532        }
533    }
534
535    pub fn map_exprs<Head2, Leaf2>(
536        &self,
537        f: &mut impl FnMut(&GenericExpr<Head, Leaf>) -> GenericExpr<Head2, Leaf2>,
538    ) -> GenericFact<Head2, Leaf2> {
539        match self {
540            GenericFact::Eq(span, e1, e2) => GenericFact::Eq(span.clone(), f(e1), f(e2)),
541            GenericFact::Fact(expr) => GenericFact::Fact(f(expr)),
542        }
543    }
544
545    pub fn subst<Leaf2, Head2>(
546        &self,
547        subst_leaf: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head2, Leaf2>,
548        subst_head: &mut impl FnMut(&Head) -> Head2,
549    ) -> GenericFact<Head2, Leaf2> {
550        self.map_exprs(&mut |e| e.subst(subst_leaf, subst_head))
551    }
552}
553
554impl<Head, Leaf> GenericFact<Head, Leaf>
555where
556    Leaf: Clone + PartialEq + Eq + Display + Hash,
557    Head: Clone + Display,
558{
559    /// Applies the provided `head` and `leaf` mappings to the fact.
560    pub fn map_symbols<Head2, Leaf2>(
561        self,
562        head: &mut impl FnMut(Head) -> Head2,
563        leaf: &mut impl FnMut(Leaf) -> Leaf2,
564    ) -> GenericFact<Head2, Leaf2>
565    where
566        Head2: Clone + Display,
567        Leaf2: Clone + PartialEq + Eq + Display + Hash,
568    {
569        match self {
570            GenericFact::Eq(span, e1, e2) => {
571                GenericFact::Eq(span, e1.map_symbols(head, leaf), e2.map_symbols(head, leaf))
572            }
573            GenericFact::Fact(expr) => GenericFact::Fact(expr.map_symbols(head, leaf)),
574        }
575    }
576
577    /// Converts all heads and leaves to strings.
578    pub fn make_unresolved(self) -> GenericFact<String, String> {
579        let mut map_head = |h: Head| h.to_string();
580        let mut map_leaf = |l: Leaf| l.to_string();
581        self.map_symbols(&mut map_head, &mut map_leaf)
582    }
583}
584
585impl<Head: Clone + Display, Leaf: Hash + Clone + Display + Eq> GenericExpr<Head, Leaf> {
586    pub fn visit_vars(&self, f: &mut impl FnMut(&Span, &Leaf)) {
587        let mut visit = |expr: GenericExpr<Head, Leaf>| match expr {
588            GenericExpr::Var(span, var) => {
589                f(&span, &var);
590                GenericExpr::Var(span, var)
591            }
592            other => other,
593        };
594        let _ = self.clone().visit_exprs(&mut visit);
595    }
596
597    pub fn span(&self) -> Span {
598        match self {
599            GenericExpr::Lit(span, _) => span.clone(),
600            GenericExpr::Var(span, _) => span.clone(),
601            GenericExpr::Call(span, _, _) => span.clone(),
602        }
603    }
604
605    pub fn is_var(&self) -> bool {
606        matches!(self, GenericExpr::Var(_, _))
607    }
608
609    pub fn get_var(&self) -> Option<Leaf> {
610        match self {
611            GenericExpr::Var(_ann, v) => Some(v.clone()),
612            _ => None,
613        }
614    }
615
616    fn children(&self) -> &[Self] {
617        match self {
618            GenericExpr::Var(_, _) | GenericExpr::Lit(_, _) => &[],
619            GenericExpr::Call(_, _, children) => children,
620        }
621    }
622
623    pub fn ast_size(&self) -> usize {
624        let mut size = 0;
625        self.walk(&mut |_e| size += 1, &mut |_| {});
626        size
627    }
628
629    /// Traverse the expression tree, calling `pre` before visiting children
630    /// and `post` after visiting children. Visits all nodes in the tree.
631    pub fn walk(&self, pre: &mut impl FnMut(&Self), post: &mut impl FnMut(&Self)) {
632        pre(self);
633        self.children()
634            .iter()
635            .for_each(|child| child.walk(pre, post));
636        post(self);
637    }
638
639    /// Fold over the expression tree bottom-up, collecting results from children.
640    /// The function `f` is called on each node with the node itself and the results
641    /// from folding over its children. Results are computed from leaves to root.
642    pub fn fold<Out>(&self, f: &mut impl FnMut(&Self, Vec<Out>) -> Out) -> Out {
643        let ts = self.children().iter().map(|child| child.fold(f)).collect();
644        f(self, ts)
645    }
646
647    /// Search for the first node matching a predicate, returning early once found.
648    /// Traverses the tree in pre-order (top-down).
649    pub fn find<Out>(&self, f: &mut impl FnMut(&Self) -> Option<Out>) -> Option<Out> {
650        // Check current node first
651        if let Some(result) = f(self) {
652            return Some(result);
653        }
654
655        // Then check children
656        for child in self.children().iter() {
657            if let Some(result) = child.find(f) {
658                return Some(result);
659            }
660        }
661
662        None
663    }
664
665    /// Applys `f` to all sub-expressions (including `self`)
666    /// bottom-up, collecting the results.
667    pub fn visit_exprs(self, f: &mut impl FnMut(Self) -> Self) -> Self {
668        match self {
669            GenericExpr::Lit(..) => f(self),
670            GenericExpr::Var(..) => f(self),
671            GenericExpr::Call(span, op, children) => {
672                let children = children.into_iter().map(|c| c.visit_exprs(f)).collect();
673                f(GenericExpr::Call(span, op.clone(), children))
674            }
675        }
676    }
677
678    /// `subst` replaces occurrences of variables and head symbols in the expression.
679    pub fn subst<Head2, Leaf2>(
680        &self,
681        subst_leaf: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head2, Leaf2>,
682        subst_head: &mut impl FnMut(&Head) -> Head2,
683    ) -> GenericExpr<Head2, Leaf2> {
684        match self {
685            GenericExpr::Lit(span, lit) => GenericExpr::Lit(span.clone(), lit.clone()),
686            GenericExpr::Var(span, v) => subst_leaf(span, v),
687            GenericExpr::Call(span, op, children) => {
688                let children = children
689                    .iter()
690                    .map(|c| c.subst(subst_leaf, subst_head))
691                    .collect();
692                GenericExpr::Call(span.clone(), subst_head(op), children)
693            }
694        }
695    }
696
697    pub fn subst_leaf<Leaf2>(
698        &self,
699        subst_leaf: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head, Leaf2>,
700    ) -> GenericExpr<Head, Leaf2> {
701        self.subst(subst_leaf, &mut |x| x.clone())
702    }
703
704    /// Applies the provided `head` and `leaf` mappings to every symbol within the expression.
705    pub fn map_symbols<Head2, Leaf2>(
706        self,
707        head: &mut impl FnMut(Head) -> Head2,
708        leaf: &mut impl FnMut(Leaf) -> Leaf2,
709    ) -> GenericExpr<Head2, Leaf2> {
710        match self {
711            GenericExpr::Lit(span, lit) => GenericExpr::Lit(span, lit),
712            GenericExpr::Var(span, var) => GenericExpr::Var(span, leaf(var)),
713            GenericExpr::Call(span, op, children) => {
714                let mut mapped_children = Vec::with_capacity(children.len());
715                for child in children {
716                    mapped_children.push(child.map_symbols(head, leaf));
717                }
718                GenericExpr::Call(span, head(op), mapped_children)
719            }
720        }
721    }
722
723    /// Converts all heads and leaves to strings.
724    pub fn make_unresolved(self) -> GenericExpr<String, String> {
725        let mut map_head = |h: Head| h.to_string();
726        let mut map_leaf = |l: Leaf| l.to_string();
727        self.map_symbols(&mut map_head, &mut map_leaf)
728    }
729
730    pub fn vars(&self) -> impl Iterator<Item = Leaf> + '_ {
731        let iterator: Box<dyn Iterator<Item = Leaf>> = match self {
732            GenericExpr::Lit(_ann, _l) => Box::new(std::iter::empty()),
733            GenericExpr::Var(_ann, v) => Box::new(std::iter::once(v.clone())),
734            GenericExpr::Call(_ann, _head, exprs) => Box::new(exprs.iter().flat_map(|e| e.vars())),
735        };
736        iterator
737    }
738}
739
740impl Display for Literal {
741    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
742        match &self {
743            Literal::Int(i) => Display::fmt(i, f),
744            Literal::Float(n) => {
745                // need to display with decimal if there is none
746                let str = n.to_string();
747                if let Ok(_num) = str.parse::<i64>() {
748                    write!(f, "{str}.0")
749                } else {
750                    write!(f, "{str}")
751                }
752            }
753            Literal::Bool(b) => Display::fmt(b, f),
754            Literal::String(s) => write!(f, "\"{s}\""),
755            Literal::Unit => write!(f, "()"),
756        }
757    }
758}
759
760#[cfg(test)]
761mod tests {
762    use super::*;
763
764    #[test]
765    fn display_nullary_call_without_trailing_space() {
766        let expr = GenericExpr::<String, String>::Call(Span::Panic, "foo".into(), vec![]);
767
768        assert_eq!(expr.to_string(), "(foo)");
769    }
770
771    #[test]
772    fn display_nullary_change_without_trailing_space() {
773        let delete = GenericAction::<String, String>::Change(
774            Span::Panic,
775            Change::Delete,
776            "foo".into(),
777            vec![],
778        );
779        let subsume = GenericAction::<String, String>::Change(
780            Span::Panic,
781            Change::Subsume,
782            "foo".into(),
783            vec![],
784        );
785
786        assert_eq!(delete.to_string(), "(delete (foo))");
787        assert_eq!(subsume.to_string(), "(subsume (foo))");
788    }
789}