egglog_ast/
generic_ast_helpers.rs

1use std::fmt::{Display, Formatter};
2use std::hash::Hash;
3
4use ordered_float::OrderedFloat;
5
6use super::util::ListDisplay;
7use crate::generic_ast::*;
8use crate::span::Span;
9
10// Macro to implement From conversions for Literal types
11macro_rules! impl_from {
12    ($ctor:ident($t:ty)) => {
13        impl From<Literal> for $t {
14            fn from(literal: Literal) -> Self {
15                match literal {
16                    Literal::$ctor(t) => t,
17                    #[allow(unreachable_patterns)]
18                    _ => panic!("Expected {}, got {literal}", stringify!($ctor)),
19                }
20            }
21        }
22
23        impl From<$t> for Literal {
24            fn from(t: $t) -> Self {
25                Literal::$ctor(t)
26            }
27        }
28    };
29}
30
31pub const INTERNAL_SYMBOL_PREFIX: &str = "@";
32
33impl<Head: Display, Leaf: Display> Display for GenericRule<Head, Leaf>
34where
35    Head: Clone + Display,
36    Leaf: Clone + PartialEq + Eq + Display + Hash,
37{
38    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
39        let indent = " ".repeat(7);
40        write!(f, "(rule (")?;
41        for (i, fact) in self.body.iter().enumerate() {
42            if i > 0 {
43                write!(f, "{indent}")?;
44            }
45
46            if i != self.body.len() - 1 {
47                writeln!(f, "{fact}")?;
48            } else {
49                write!(f, "{fact}")?;
50            }
51        }
52        write!(f, ")\n      (")?;
53        for (i, action) in self.head.0.iter().enumerate() {
54            if i > 0 {
55                write!(f, "{indent}")?;
56            }
57            if i != self.head.0.len() - 1 {
58                writeln!(f, "{action}")?;
59            } else {
60                write!(f, "{action}")?;
61            }
62        }
63        let ruleset = if !self.ruleset.is_empty() {
64            format!(":ruleset {}", &self.ruleset)
65        } else {
66            "".into()
67        };
68        let name = if !self.name.is_empty() {
69            format!(":name \"{}\"", &self.name)
70        } else {
71            "".into()
72        };
73        write!(f, ")\n{indent} {ruleset} {name})")
74    }
75}
76
77// Use the macro for Int, Float, and String conversions
78impl_from!(Int(i64));
79impl_from!(Float(OrderedFloat<f64>));
80impl_from!(String(String));
81
82impl<Head: Display, Leaf: Display> Display for GenericFact<Head, Leaf> {
83    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
84        match self {
85            GenericFact::Eq(_, e1, e2) => write!(f, "(= {e1} {e2})"),
86            GenericFact::Fact(expr) => write!(f, "{expr}"),
87        }
88    }
89}
90
91// Implement Display for GenericAction
92impl<Head: Display, Leaf: Display> Display for GenericAction<Head, Leaf>
93where
94    Head: Clone + Display,
95    Leaf: Clone + PartialEq + Eq + Display + Hash,
96{
97    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
98        match self {
99            GenericAction::Let(_, lhs, rhs) => write!(f, "(let {lhs} {rhs})"),
100            GenericAction::Set(_, lhs, args, rhs) => {
101                if args.is_empty() {
102                    write!(f, "(set ({lhs}) {rhs})")
103                } else {
104                    write!(
105                        f,
106                        "(set ({} {}) {})",
107                        lhs,
108                        args.iter()
109                            .map(|a| format!("{a}"))
110                            .collect::<Vec<_>>()
111                            .join(" "),
112                        rhs
113                    )
114                }
115            }
116            GenericAction::Union(_, lhs, rhs) => write!(f, "(union {lhs} {rhs})"),
117            GenericAction::Change(_, change, lhs, args) => {
118                let change_str = match change {
119                    Change::Delete => "delete",
120                    Change::Subsume => "subsume",
121                };
122                write!(
123                    f,
124                    "({} ({} {}))",
125                    change_str,
126                    lhs,
127                    args.iter()
128                        .map(|a| format!("{a}"))
129                        .collect::<Vec<_>>()
130                        .join(" ")
131                )
132            }
133            GenericAction::Panic(_, msg) => write!(f, "(panic \"{msg}\")"),
134            GenericAction::Expr(_, e) => write!(f, "{e}"),
135        }
136    }
137}
138
139impl<Head, Leaf> Display for GenericExpr<Head, Leaf>
140where
141    Head: Display,
142    Leaf: Display,
143{
144    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
145        match self {
146            GenericExpr::Lit(_ann, lit) => write!(f, "{lit}"),
147            GenericExpr::Var(_ann, var) => write!(f, "{var}"),
148            GenericExpr::Call(_ann, op, children) => {
149                write!(f, "({} {})", op, ListDisplay(children, " "))
150            }
151        }
152    }
153}
154
155impl<Head, Leaf> Default for GenericActions<Head, Leaf>
156where
157    Head: Clone + Display,
158    Leaf: Clone + PartialEq + Eq + Display + Hash,
159{
160    fn default() -> Self {
161        Self(vec![])
162    }
163}
164
165impl<Head, Leaf> GenericRule<Head, Leaf>
166where
167    Head: Clone + Display,
168    Leaf: Clone + PartialEq + Eq + Display + Hash,
169{
170    /// Applies `f` to every expression that appears in the rule body or head.
171    pub fn visit_exprs(
172        self,
173        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
174    ) -> Self {
175        Self {
176            span: self.span,
177            head: self.head.visit_exprs(f),
178            body: self
179                .body
180                .into_iter()
181                .map(|bexpr| bexpr.visit_exprs(f))
182                .collect(),
183            name: self.name.clone(),
184            ruleset: self.ruleset.clone(),
185        }
186    }
187
188    /// Applies `f` to each action in the rule head, leaving the body unchanged.
189    pub fn visit_actions(
190        self,
191        f: &mut impl FnMut(GenericAction<Head, Leaf>) -> GenericAction<Head, Leaf>,
192    ) -> Self {
193        Self {
194            span: self.span,
195            head: self.head.visit_actions(f),
196            body: self.body,
197            name: self.name,
198            ruleset: self.ruleset,
199        }
200    }
201
202    /// Applies the provided `head` and `leaf` mappings to every symbol that appears in the rule.
203    pub fn map_symbols<Head2, Leaf2>(
204        self,
205        head: &mut impl FnMut(Head) -> Head2,
206        leaf: &mut impl FnMut(Leaf) -> Leaf2,
207    ) -> GenericRule<Head2, Leaf2>
208    where
209        Head2: Clone + Display,
210        Leaf2: Clone + PartialEq + Eq + Display + Hash,
211    {
212        GenericRule {
213            span: self.span,
214            head: self.head.map_symbols(head, leaf),
215            body: self
216                .body
217                .into_iter()
218                .map(|fact| fact.map_symbols(head, leaf))
219                .collect(),
220            name: self.name,
221            ruleset: self.ruleset,
222        }
223    }
224
225    /// Converts the rule into its unresolved representation by formatting heads and leaves.
226    pub fn make_unresolved(self) -> GenericRule<String, String> {
227        let mut map_head = |h: Head| h.to_string();
228        let mut map_leaf = |l: Leaf| l.to_string();
229        self.map_symbols(&mut map_head, &mut map_leaf)
230    }
231}
232
233impl<Head, Leaf> GenericActions<Head, Leaf>
234where
235    Head: Clone + Display,
236    Leaf: Clone + PartialEq + Eq + Display + Hash,
237{
238    pub fn len(&self) -> usize {
239        self.0.len()
240    }
241
242    pub fn is_empty(&self) -> bool {
243        self.0.is_empty()
244    }
245
246    pub fn iter(&self) -> impl Iterator<Item = &GenericAction<Head, Leaf>> {
247        self.0.iter()
248    }
249
250    pub fn visit_vars(&self, f: &mut impl FnMut(&Span, &Leaf)) {
251        for action in &self.0 {
252            action.visit_vars(f);
253        }
254    }
255
256    /// Transforms every expression appearing in the action list using `f`.
257    pub fn visit_exprs(
258        self,
259        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
260    ) -> Self {
261        Self(self.0.into_iter().map(|a| a.visit_exprs(f)).collect())
262    }
263
264    /// Rewrites each action in the collection with the provided closure.
265    pub fn visit_actions(
266        self,
267        f: &mut impl FnMut(GenericAction<Head, Leaf>) -> GenericAction<Head, Leaf>,
268    ) -> Self {
269        Self(self.0.into_iter().map(f).collect())
270    }
271
272    pub fn new(actions: Vec<GenericAction<Head, Leaf>>) -> Self {
273        Self(actions)
274    }
275
276    pub fn singleton(action: GenericAction<Head, Leaf>) -> Self {
277        Self(vec![action])
278    }
279
280    /// Applies the provided `head` and `leaf` mappings to each action.
281    pub fn map_symbols<Head2, Leaf2>(
282        self,
283        head: &mut impl FnMut(Head) -> Head2,
284        leaf: &mut impl FnMut(Leaf) -> Leaf2,
285    ) -> GenericActions<Head2, Leaf2>
286    where
287        Head2: Clone + Display,
288        Leaf2: Clone + PartialEq + Eq + Display + Hash,
289    {
290        GenericActions(
291            self.0
292                .into_iter()
293                .map(|action| action.map_symbols(head, leaf))
294                .collect(),
295        )
296    }
297
298    /// Converts the actions into their unresolved representation by formatting heads and leaves.
299    pub fn make_unresolved(self) -> GenericActions<String, String> {
300        let mut map_head = |h: Head| h.to_string();
301        let mut map_leaf = |l: Leaf| l.to_string();
302        self.map_symbols(&mut map_head, &mut map_leaf)
303    }
304}
305
306impl<Head, Leaf> GenericAction<Head, Leaf>
307where
308    Head: Clone + Display,
309    Leaf: Clone + Eq + Display + Hash,
310{
311    pub fn visit_vars(&self, f: &mut impl FnMut(&Span, &Leaf)) {
312        if let GenericAction::Let(span, lhs, _) = self {
313            f(span, lhs);
314        }
315        let mut visit = |expr: GenericExpr<Head, Leaf>| match expr {
316            GenericExpr::Var(span, var) => {
317                f(&span, &var);
318                GenericExpr::Var(span, var)
319            }
320            other => other,
321        };
322        let _ = self.clone().visit_exprs(&mut visit);
323    }
324
325    // Applys `f` to all expressions in the action.
326    pub fn map_exprs(
327        &self,
328        f: &mut impl FnMut(&GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
329    ) -> Self {
330        match self {
331            GenericAction::Let(span, lhs, rhs) => {
332                GenericAction::Let(span.clone(), lhs.clone(), f(rhs))
333            }
334            GenericAction::Set(span, lhs, args, rhs) => {
335                let right = f(rhs);
336                GenericAction::Set(
337                    span.clone(),
338                    lhs.clone(),
339                    args.iter().map(f).collect(),
340                    right,
341                )
342            }
343            GenericAction::Change(span, change, lhs, args) => GenericAction::Change(
344                span.clone(),
345                *change,
346                lhs.clone(),
347                args.iter().map(f).collect(),
348            ),
349            GenericAction::Union(span, lhs, rhs) => {
350                GenericAction::Union(span.clone(), f(lhs), f(rhs))
351            }
352            GenericAction::Panic(span, msg) => GenericAction::Panic(span.clone(), msg.clone()),
353            GenericAction::Expr(span, e) => GenericAction::Expr(span.clone(), f(e)),
354        }
355    }
356
357    /// Applys `f` to all sub-expressions (including `self`)
358    /// bottom-up, collecting the results.
359    pub fn visit_exprs(
360        self,
361        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
362    ) -> Self {
363        match self {
364            GenericAction::Let(span, lhs, rhs) => {
365                GenericAction::Let(span, lhs.clone(), rhs.visit_exprs(f))
366            }
367            // TODO should we refactor `Set` so that we can map over Expr::Call(lhs, args)?
368            // This seems more natural to oflatt
369            // Currently, visit_exprs does not apply f to the first argument of Set.
370            GenericAction::Set(span, lhs, args, rhs) => {
371                let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
372                GenericAction::Set(span, lhs.clone(), args, rhs.visit_exprs(f))
373            }
374            GenericAction::Change(span, change, lhs, args) => {
375                let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
376                GenericAction::Change(span, change, lhs.clone(), args)
377            }
378            GenericAction::Union(span, lhs, rhs) => {
379                GenericAction::Union(span, lhs.visit_exprs(f), rhs.visit_exprs(f))
380            }
381            GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg.clone()),
382            GenericAction::Expr(span, e) => GenericAction::Expr(span, e.visit_exprs(f)),
383        }
384    }
385
386    pub fn subst(&self, subst: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head, Leaf>) -> Self {
387        self.map_exprs(&mut |e| e.subst_leaf(subst))
388    }
389
390    pub fn map_def_use(self, fvar: &mut impl FnMut(Leaf, bool) -> Leaf) -> Self {
391        macro_rules! fvar_expr {
392            () => {
393                |span, s: _| GenericExpr::Var(span.clone(), fvar(s.clone(), false))
394            };
395        }
396        match self {
397            GenericAction::Let(span, lhs, rhs) => {
398                let lhs = fvar(lhs, true);
399                let rhs = rhs.subst_leaf(&mut fvar_expr!());
400                GenericAction::Let(span, lhs, rhs)
401            }
402            GenericAction::Set(span, lhs, args, rhs) => {
403                let args = args
404                    .into_iter()
405                    .map(|e| e.subst_leaf(&mut fvar_expr!()))
406                    .collect();
407                let rhs = rhs.subst_leaf(&mut fvar_expr!());
408                GenericAction::Set(span, lhs.clone(), args, rhs)
409            }
410            GenericAction::Change(span, change, lhs, args) => {
411                let args = args
412                    .into_iter()
413                    .map(|e| e.subst_leaf(&mut fvar_expr!()))
414                    .collect();
415                GenericAction::Change(span, change, lhs.clone(), args)
416            }
417            GenericAction::Union(span, lhs, rhs) => {
418                let lhs = lhs.subst_leaf(&mut fvar_expr!());
419                let rhs = rhs.subst_leaf(&mut fvar_expr!());
420                GenericAction::Union(span, lhs, rhs)
421            }
422            GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg.clone()),
423            GenericAction::Expr(span, e) => {
424                GenericAction::Expr(span, e.subst_leaf(&mut fvar_expr!()))
425            }
426        }
427    }
428
429    /// Applies the provided `head` and `leaf` mappings to the action and all nested expressions.
430    pub fn map_symbols<Head2, Leaf2>(
431        self,
432        head: &mut impl FnMut(Head) -> Head2,
433        leaf: &mut impl FnMut(Leaf) -> Leaf2,
434    ) -> GenericAction<Head2, Leaf2>
435    where
436        Head2: Clone + Display,
437        Leaf2: Clone + Eq + Display + Hash,
438    {
439        match self {
440            GenericAction::Let(span, lhs, rhs) => {
441                GenericAction::Let(span, leaf(lhs), rhs.map_symbols(head, leaf))
442            }
443            GenericAction::Set(span, head_sym, args, rhs) => {
444                let mut mapped_args = Vec::with_capacity(args.len());
445                for arg in args {
446                    mapped_args.push(arg.map_symbols(head, leaf));
447                }
448                GenericAction::Set(
449                    span,
450                    head(head_sym),
451                    mapped_args,
452                    rhs.map_symbols(head, leaf),
453                )
454            }
455            GenericAction::Change(span, change, head_sym, args) => {
456                let mut mapped_args = Vec::with_capacity(args.len());
457                for arg in args {
458                    mapped_args.push(arg.map_symbols(head, leaf));
459                }
460                GenericAction::Change(span, change, head(head_sym), mapped_args)
461            }
462            GenericAction::Union(span, lhs, rhs) => GenericAction::Union(
463                span,
464                lhs.map_symbols(head, leaf),
465                rhs.map_symbols(head, leaf),
466            ),
467            GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg),
468            GenericAction::Expr(span, expr) => {
469                GenericAction::Expr(span, expr.map_symbols(head, leaf))
470            }
471        }
472    }
473
474    /// Converts the action into its unresolved representation using String by
475    /// formatting heads and leaves.
476    pub fn make_unresolved(self) -> GenericAction<String, String> {
477        let mut map_head = |h: Head| h.to_string();
478        let mut map_leaf = |l: Leaf| l.to_string();
479        self.map_symbols(&mut map_head, &mut map_leaf)
480    }
481}
482
483impl<Head, Leaf> GenericFact<Head, Leaf>
484where
485    Head: Clone + Display,
486    Leaf: Clone + PartialEq + Eq + Display + Hash,
487{
488    pub fn visit_vars(&self, f: &mut impl FnMut(&Span, &Leaf)) {
489        let mut visit = |expr: GenericExpr<Head, Leaf>| match expr {
490            GenericExpr::Var(span, var) => {
491                f(&span, &var);
492                GenericExpr::Var(span, var)
493            }
494            other => other,
495        };
496        let _ = self.clone().visit_exprs(&mut visit);
497    }
498
499    pub fn visit_exprs(
500        self,
501        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
502    ) -> GenericFact<Head, Leaf> {
503        match self {
504            GenericFact::Eq(span, e1, e2) => {
505                GenericFact::Eq(span, e1.visit_exprs(f), e2.visit_exprs(f))
506            }
507            GenericFact::Fact(expr) => GenericFact::Fact(expr.visit_exprs(f)),
508        }
509    }
510
511    pub fn map_exprs<Head2, Leaf2>(
512        &self,
513        f: &mut impl FnMut(&GenericExpr<Head, Leaf>) -> GenericExpr<Head2, Leaf2>,
514    ) -> GenericFact<Head2, Leaf2> {
515        match self {
516            GenericFact::Eq(span, e1, e2) => GenericFact::Eq(span.clone(), f(e1), f(e2)),
517            GenericFact::Fact(expr) => GenericFact::Fact(f(expr)),
518        }
519    }
520
521    pub fn subst<Leaf2, Head2>(
522        &self,
523        subst_leaf: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head2, Leaf2>,
524        subst_head: &mut impl FnMut(&Head) -> Head2,
525    ) -> GenericFact<Head2, Leaf2> {
526        self.map_exprs(&mut |e| e.subst(subst_leaf, subst_head))
527    }
528}
529
530impl<Head, Leaf> GenericFact<Head, Leaf>
531where
532    Leaf: Clone + PartialEq + Eq + Display + Hash,
533    Head: Clone + Display,
534{
535    /// Applies the provided `head` and `leaf` mappings to the fact.
536    pub fn map_symbols<Head2, Leaf2>(
537        self,
538        head: &mut impl FnMut(Head) -> Head2,
539        leaf: &mut impl FnMut(Leaf) -> Leaf2,
540    ) -> GenericFact<Head2, Leaf2>
541    where
542        Head2: Clone + Display,
543        Leaf2: Clone + PartialEq + Eq + Display + Hash,
544    {
545        match self {
546            GenericFact::Eq(span, e1, e2) => {
547                GenericFact::Eq(span, e1.map_symbols(head, leaf), e2.map_symbols(head, leaf))
548            }
549            GenericFact::Fact(expr) => GenericFact::Fact(expr.map_symbols(head, leaf)),
550        }
551    }
552
553    /// Converts all heads and leaves to strings.
554    pub fn make_unresolved(self) -> GenericFact<String, String> {
555        let mut map_head = |h: Head| h.to_string();
556        let mut map_leaf = |l: Leaf| l.to_string();
557        self.map_symbols(&mut map_head, &mut map_leaf)
558    }
559}
560
561impl<Head: Clone + Display, Leaf: Hash + Clone + Display + Eq> GenericExpr<Head, Leaf> {
562    pub fn visit_vars(&self, f: &mut impl FnMut(&Span, &Leaf)) {
563        let mut visit = |expr: GenericExpr<Head, Leaf>| match expr {
564            GenericExpr::Var(span, var) => {
565                f(&span, &var);
566                GenericExpr::Var(span, var)
567            }
568            other => other,
569        };
570        let _ = self.clone().visit_exprs(&mut visit);
571    }
572
573    pub fn span(&self) -> Span {
574        match self {
575            GenericExpr::Lit(span, _) => span.clone(),
576            GenericExpr::Var(span, _) => span.clone(),
577            GenericExpr::Call(span, _, _) => span.clone(),
578        }
579    }
580
581    pub fn is_var(&self) -> bool {
582        matches!(self, GenericExpr::Var(_, _))
583    }
584
585    pub fn get_var(&self) -> Option<Leaf> {
586        match self {
587            GenericExpr::Var(_ann, v) => Some(v.clone()),
588            _ => None,
589        }
590    }
591
592    fn children(&self) -> &[Self] {
593        match self {
594            GenericExpr::Var(_, _) | GenericExpr::Lit(_, _) => &[],
595            GenericExpr::Call(_, _, children) => children,
596        }
597    }
598
599    pub fn ast_size(&self) -> usize {
600        let mut size = 0;
601        self.walk(&mut |_e| size += 1, &mut |_| {});
602        size
603    }
604
605    /// Traverse the expression tree, calling `pre` before visiting children
606    /// and `post` after visiting children. Visits all nodes in the tree.
607    pub fn walk(&self, pre: &mut impl FnMut(&Self), post: &mut impl FnMut(&Self)) {
608        pre(self);
609        self.children()
610            .iter()
611            .for_each(|child| child.walk(pre, post));
612        post(self);
613    }
614
615    /// Fold over the expression tree bottom-up, collecting results from children.
616    /// The function `f` is called on each node with the node itself and the results
617    /// from folding over its children. Results are computed from leaves to root.
618    pub fn fold<Out>(&self, f: &mut impl FnMut(&Self, Vec<Out>) -> Out) -> Out {
619        let ts = self.children().iter().map(|child| child.fold(f)).collect();
620        f(self, ts)
621    }
622
623    /// Search for the first node matching a predicate, returning early once found.
624    /// Traverses the tree in pre-order (top-down).
625    pub fn find<Out>(&self, f: &mut impl FnMut(&Self) -> Option<Out>) -> Option<Out> {
626        // Check current node first
627        if let Some(result) = f(self) {
628            return Some(result);
629        }
630
631        // Then check children
632        for child in self.children().iter() {
633            if let Some(result) = child.find(f) {
634                return Some(result);
635            }
636        }
637
638        None
639    }
640
641    /// Applys `f` to all sub-expressions (including `self`)
642    /// bottom-up, collecting the results.
643    pub fn visit_exprs(self, f: &mut impl FnMut(Self) -> Self) -> Self {
644        match self {
645            GenericExpr::Lit(..) => f(self),
646            GenericExpr::Var(..) => f(self),
647            GenericExpr::Call(span, op, children) => {
648                let children = children.into_iter().map(|c| c.visit_exprs(f)).collect();
649                f(GenericExpr::Call(span, op.clone(), children))
650            }
651        }
652    }
653
654    /// `subst` replaces occurrences of variables and head symbols in the expression.
655    pub fn subst<Head2, Leaf2>(
656        &self,
657        subst_leaf: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head2, Leaf2>,
658        subst_head: &mut impl FnMut(&Head) -> Head2,
659    ) -> GenericExpr<Head2, Leaf2> {
660        match self {
661            GenericExpr::Lit(span, lit) => GenericExpr::Lit(span.clone(), lit.clone()),
662            GenericExpr::Var(span, v) => subst_leaf(span, v),
663            GenericExpr::Call(span, op, children) => {
664                let children = children
665                    .iter()
666                    .map(|c| c.subst(subst_leaf, subst_head))
667                    .collect();
668                GenericExpr::Call(span.clone(), subst_head(op), children)
669            }
670        }
671    }
672
673    pub fn subst_leaf<Leaf2>(
674        &self,
675        subst_leaf: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head, Leaf2>,
676    ) -> GenericExpr<Head, Leaf2> {
677        self.subst(subst_leaf, &mut |x| x.clone())
678    }
679
680    /// Applies the provided `head` and `leaf` mappings to every symbol within the expression.
681    pub fn map_symbols<Head2, Leaf2>(
682        self,
683        head: &mut impl FnMut(Head) -> Head2,
684        leaf: &mut impl FnMut(Leaf) -> Leaf2,
685    ) -> GenericExpr<Head2, Leaf2> {
686        match self {
687            GenericExpr::Lit(span, lit) => GenericExpr::Lit(span, lit),
688            GenericExpr::Var(span, var) => GenericExpr::Var(span, leaf(var)),
689            GenericExpr::Call(span, op, children) => {
690                let mut mapped_children = Vec::with_capacity(children.len());
691                for child in children {
692                    mapped_children.push(child.map_symbols(head, leaf));
693                }
694                GenericExpr::Call(span, head(op), mapped_children)
695            }
696        }
697    }
698
699    /// Converts all heads and leaves to strings.
700    pub fn make_unresolved(self) -> GenericExpr<String, String> {
701        let mut map_head = |h: Head| h.to_string();
702        let mut map_leaf = |l: Leaf| l.to_string();
703        self.map_symbols(&mut map_head, &mut map_leaf)
704    }
705
706    pub fn vars(&self) -> impl Iterator<Item = Leaf> + '_ {
707        let iterator: Box<dyn Iterator<Item = Leaf>> = match self {
708            GenericExpr::Lit(_ann, _l) => Box::new(std::iter::empty()),
709            GenericExpr::Var(_ann, v) => Box::new(std::iter::once(v.clone())),
710            GenericExpr::Call(_ann, _head, exprs) => Box::new(exprs.iter().flat_map(|e| e.vars())),
711        };
712        iterator
713    }
714}
715
716impl Display for Literal {
717    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
718        match &self {
719            Literal::Int(i) => Display::fmt(i, f),
720            Literal::Float(n) => {
721                // need to display with decimal if there is none
722                let str = n.to_string();
723                if let Ok(_num) = str.parse::<i64>() {
724                    write!(f, "{str}.0")
725                } else {
726                    write!(f, "{str}")
727                }
728            }
729            Literal::Bool(b) => Display::fmt(b, f),
730            Literal::String(s) => write!(f, "\"{s}\""),
731            Literal::Unit => write!(f, "()"),
732        }
733    }
734}