egglog_ast/
generic_ast_helpers.rs

1use std::fmt::{Display, Formatter};
2use std::hash::Hash;
3
4use ordered_float::OrderedFloat;
5
6use super::util::ListDisplay;
7use crate::generic_ast::*;
8use crate::span::Span;
9
10// Macro to implement From conversions for Literal types
11macro_rules! impl_from {
12    ($ctor:ident($t:ty)) => {
13        impl From<Literal> for $t {
14            fn from(literal: Literal) -> Self {
15                match literal {
16                    Literal::$ctor(t) => t,
17                    #[allow(unreachable_patterns)]
18                    _ => panic!("Expected {}, got {literal}", stringify!($ctor)),
19                }
20            }
21        }
22
23        impl From<$t> for Literal {
24            fn from(t: $t) -> Self {
25                Literal::$ctor(t)
26            }
27        }
28    };
29}
30
31pub const INTERNAL_SYMBOL_PREFIX: &str = "@";
32
33impl<Head: Display, Leaf: Display> Display for GenericRule<Head, Leaf>
34where
35    Head: Clone + Display,
36    Leaf: Clone + PartialEq + Eq + Display + Hash,
37{
38    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
39        let indent = " ".repeat(7);
40        write!(f, "(rule (")?;
41        for (i, fact) in self.body.iter().enumerate() {
42            if i > 0 {
43                write!(f, "{indent}")?;
44            }
45
46            if i != self.body.len() - 1 {
47                writeln!(f, "{fact}")?;
48            } else {
49                write!(f, "{fact}")?;
50            }
51        }
52        write!(f, ")\n      (")?;
53        for (i, action) in self.head.0.iter().enumerate() {
54            if i > 0 {
55                write!(f, "{indent}")?;
56            }
57            if i != self.head.0.len() - 1 {
58                writeln!(f, "{action}")?;
59            } else {
60                write!(f, "{action}")?;
61            }
62        }
63        let ruleset = if !self.ruleset.is_empty() {
64            format!(":ruleset {}", &self.ruleset)
65        } else {
66            "".into()
67        };
68        let name = if !self.name.is_empty() {
69            format!(":name \"{}\"", &self.name)
70        } else {
71            "".into()
72        };
73        let naive = if self.naive { " :naive" } else { "" };
74        write!(f, ")\n{indent} {ruleset} {name}{naive})")
75    }
76}
77
78// Use the macro for Int, Float, and String conversions
79impl_from!(Int(i64));
80impl_from!(Float(OrderedFloat<f64>));
81impl_from!(String(String));
82
83impl<Head: Display, Leaf: Display> Display for GenericFact<Head, Leaf> {
84    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
85        match self {
86            GenericFact::Eq(_, e1, e2) => write!(f, "(= {e1} {e2})"),
87            GenericFact::Fact(expr) => write!(f, "{expr}"),
88        }
89    }
90}
91
92// Implement Display for GenericAction
93impl<Head: Display, Leaf: Display> Display for GenericAction<Head, Leaf>
94where
95    Head: Clone + Display,
96    Leaf: Clone + PartialEq + Eq + Display + Hash,
97{
98    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
99        match self {
100            GenericAction::Let(_, lhs, rhs) => write!(f, "(let {lhs} {rhs})"),
101            GenericAction::Set(_, lhs, args, rhs) => {
102                if args.is_empty() {
103                    write!(f, "(set ({lhs}) {rhs})")
104                } else {
105                    write!(
106                        f,
107                        "(set ({} {}) {})",
108                        lhs,
109                        args.iter()
110                            .map(|a| format!("{a}"))
111                            .collect::<Vec<_>>()
112                            .join(" "),
113                        rhs
114                    )
115                }
116            }
117            GenericAction::Union(_, lhs, rhs) => write!(f, "(union {lhs} {rhs})"),
118            GenericAction::Change(_, change, lhs, args) => {
119                let change_str = match change {
120                    Change::Delete => "delete",
121                    Change::Subsume => "subsume",
122                };
123                write!(
124                    f,
125                    "({} ({} {}))",
126                    change_str,
127                    lhs,
128                    args.iter()
129                        .map(|a| format!("{a}"))
130                        .collect::<Vec<_>>()
131                        .join(" ")
132                )
133            }
134            GenericAction::Panic(_, msg) => write!(f, "(panic \"{msg}\")"),
135            GenericAction::Expr(_, e) => write!(f, "{e}"),
136        }
137    }
138}
139
140impl<Head, Leaf> Display for GenericExpr<Head, Leaf>
141where
142    Head: Display,
143    Leaf: Display,
144{
145    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
146        match self {
147            GenericExpr::Lit(_ann, lit) => write!(f, "{lit}"),
148            GenericExpr::Var(_ann, var) => write!(f, "{var}"),
149            GenericExpr::Call(_ann, op, children) => {
150                write!(f, "({} {})", op, ListDisplay(children, " "))
151            }
152        }
153    }
154}
155
156impl<Head, Leaf> Default for GenericActions<Head, Leaf>
157where
158    Head: Clone + Display,
159    Leaf: Clone + PartialEq + Eq + Display + Hash,
160{
161    fn default() -> Self {
162        Self(vec![])
163    }
164}
165
166impl<Head, Leaf> GenericRule<Head, Leaf>
167where
168    Head: Clone + Display,
169    Leaf: Clone + PartialEq + Eq + Display + Hash,
170{
171    /// Applies `f` to every expression that appears in the rule body or head.
172    pub fn visit_exprs(
173        self,
174        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
175    ) -> Self {
176        Self {
177            span: self.span,
178            head: self.head.visit_exprs(f),
179            body: self
180                .body
181                .into_iter()
182                .map(|bexpr| bexpr.visit_exprs(f))
183                .collect(),
184            name: self.name.clone(),
185            ruleset: self.ruleset.clone(),
186            naive: self.naive,
187        }
188    }
189
190    /// Applies `f` to each action in the rule head, leaving the body unchanged.
191    pub fn visit_actions(
192        self,
193        f: &mut impl FnMut(GenericAction<Head, Leaf>) -> GenericAction<Head, Leaf>,
194    ) -> Self {
195        Self {
196            span: self.span,
197            head: self.head.visit_actions(f),
198            body: self.body,
199            name: self.name,
200            ruleset: self.ruleset,
201            naive: self.naive,
202        }
203    }
204
205    /// Applies the provided `head` and `leaf` mappings to every symbol that appears in the rule.
206    pub fn map_symbols<Head2, Leaf2>(
207        self,
208        head: &mut impl FnMut(Head) -> Head2,
209        leaf: &mut impl FnMut(Leaf) -> Leaf2,
210    ) -> GenericRule<Head2, Leaf2>
211    where
212        Head2: Clone + Display,
213        Leaf2: Clone + PartialEq + Eq + Display + Hash,
214    {
215        GenericRule {
216            span: self.span,
217            head: self.head.map_symbols(head, leaf),
218            body: self
219                .body
220                .into_iter()
221                .map(|fact| fact.map_symbols(head, leaf))
222                .collect(),
223            name: self.name,
224            ruleset: self.ruleset,
225            naive: self.naive,
226        }
227    }
228
229    /// Converts the rule into its unresolved representation by formatting heads and leaves.
230    pub fn make_unresolved(self) -> GenericRule<String, String> {
231        let mut map_head = |h: Head| h.to_string();
232        let mut map_leaf = |l: Leaf| l.to_string();
233        self.map_symbols(&mut map_head, &mut map_leaf)
234    }
235}
236
237impl<Head, Leaf> GenericActions<Head, Leaf>
238where
239    Head: Clone + Display,
240    Leaf: Clone + PartialEq + Eq + Display + Hash,
241{
242    pub fn len(&self) -> usize {
243        self.0.len()
244    }
245
246    pub fn is_empty(&self) -> bool {
247        self.0.is_empty()
248    }
249
250    pub fn iter(&self) -> impl Iterator<Item = &GenericAction<Head, Leaf>> {
251        self.0.iter()
252    }
253
254    pub fn visit_vars(&self, f: &mut impl FnMut(&Span, &Leaf)) {
255        for action in &self.0 {
256            action.visit_vars(f);
257        }
258    }
259
260    /// Transforms every expression appearing in the action list using `f`.
261    pub fn visit_exprs(
262        self,
263        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
264    ) -> Self {
265        Self(self.0.into_iter().map(|a| a.visit_exprs(f)).collect())
266    }
267
268    /// Rewrites each action in the collection with the provided closure.
269    pub fn visit_actions(
270        self,
271        f: &mut impl FnMut(GenericAction<Head, Leaf>) -> GenericAction<Head, Leaf>,
272    ) -> Self {
273        Self(self.0.into_iter().map(f).collect())
274    }
275
276    pub fn new(actions: Vec<GenericAction<Head, Leaf>>) -> Self {
277        Self(actions)
278    }
279
280    pub fn singleton(action: GenericAction<Head, Leaf>) -> Self {
281        Self(vec![action])
282    }
283
284    /// Applies the provided `head` and `leaf` mappings to each action.
285    pub fn map_symbols<Head2, Leaf2>(
286        self,
287        head: &mut impl FnMut(Head) -> Head2,
288        leaf: &mut impl FnMut(Leaf) -> Leaf2,
289    ) -> GenericActions<Head2, Leaf2>
290    where
291        Head2: Clone + Display,
292        Leaf2: Clone + PartialEq + Eq + Display + Hash,
293    {
294        GenericActions(
295            self.0
296                .into_iter()
297                .map(|action| action.map_symbols(head, leaf))
298                .collect(),
299        )
300    }
301
302    /// Converts the actions into their unresolved representation by formatting heads and leaves.
303    pub fn make_unresolved(self) -> GenericActions<String, String> {
304        let mut map_head = |h: Head| h.to_string();
305        let mut map_leaf = |l: Leaf| l.to_string();
306        self.map_symbols(&mut map_head, &mut map_leaf)
307    }
308}
309
310impl<Head, Leaf> GenericAction<Head, Leaf>
311where
312    Head: Clone + Display,
313    Leaf: Clone + Eq + Display + Hash,
314{
315    pub fn visit_vars(&self, f: &mut impl FnMut(&Span, &Leaf)) {
316        if let GenericAction::Let(span, lhs, _) = self {
317            f(span, lhs);
318        }
319        let mut visit = |expr: GenericExpr<Head, Leaf>| match expr {
320            GenericExpr::Var(span, var) => {
321                f(&span, &var);
322                GenericExpr::Var(span, var)
323            }
324            other => other,
325        };
326        let _ = self.clone().visit_exprs(&mut visit);
327    }
328
329    // Applys `f` to all expressions in the action.
330    pub fn map_exprs(
331        &self,
332        f: &mut impl FnMut(&GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
333    ) -> Self {
334        match self {
335            GenericAction::Let(span, lhs, rhs) => {
336                GenericAction::Let(span.clone(), lhs.clone(), f(rhs))
337            }
338            GenericAction::Set(span, lhs, args, rhs) => {
339                let right = f(rhs);
340                GenericAction::Set(
341                    span.clone(),
342                    lhs.clone(),
343                    args.iter().map(f).collect(),
344                    right,
345                )
346            }
347            GenericAction::Change(span, change, lhs, args) => GenericAction::Change(
348                span.clone(),
349                *change,
350                lhs.clone(),
351                args.iter().map(f).collect(),
352            ),
353            GenericAction::Union(span, lhs, rhs) => {
354                GenericAction::Union(span.clone(), f(lhs), f(rhs))
355            }
356            GenericAction::Panic(span, msg) => GenericAction::Panic(span.clone(), msg.clone()),
357            GenericAction::Expr(span, e) => GenericAction::Expr(span.clone(), f(e)),
358        }
359    }
360
361    /// Applys `f` to all sub-expressions (including `self`)
362    /// bottom-up, collecting the results.
363    pub fn visit_exprs(
364        self,
365        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
366    ) -> Self {
367        match self {
368            GenericAction::Let(span, lhs, rhs) => {
369                GenericAction::Let(span, lhs.clone(), rhs.visit_exprs(f))
370            }
371            // TODO should we refactor `Set` so that we can map over Expr::Call(lhs, args)?
372            // This seems more natural to oflatt
373            // Currently, visit_exprs does not apply f to the first argument of Set.
374            GenericAction::Set(span, lhs, args, rhs) => {
375                let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
376                GenericAction::Set(span, lhs.clone(), args, rhs.visit_exprs(f))
377            }
378            GenericAction::Change(span, change, lhs, args) => {
379                let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
380                GenericAction::Change(span, change, lhs.clone(), args)
381            }
382            GenericAction::Union(span, lhs, rhs) => {
383                GenericAction::Union(span, lhs.visit_exprs(f), rhs.visit_exprs(f))
384            }
385            GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg.clone()),
386            GenericAction::Expr(span, e) => GenericAction::Expr(span, e.visit_exprs(f)),
387        }
388    }
389
390    pub fn subst(&self, subst: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head, Leaf>) -> Self {
391        self.map_exprs(&mut |e| e.subst_leaf(subst))
392    }
393
394    pub fn map_def_use(self, fvar: &mut impl FnMut(Leaf, bool) -> Leaf) -> Self {
395        macro_rules! fvar_expr {
396            () => {
397                |span, s: _| GenericExpr::Var(span.clone(), fvar(s.clone(), false))
398            };
399        }
400        match self {
401            GenericAction::Let(span, lhs, rhs) => {
402                let lhs = fvar(lhs, true);
403                let rhs = rhs.subst_leaf(&mut fvar_expr!());
404                GenericAction::Let(span, lhs, rhs)
405            }
406            GenericAction::Set(span, lhs, args, rhs) => {
407                let args = args
408                    .into_iter()
409                    .map(|e| e.subst_leaf(&mut fvar_expr!()))
410                    .collect();
411                let rhs = rhs.subst_leaf(&mut fvar_expr!());
412                GenericAction::Set(span, lhs.clone(), args, rhs)
413            }
414            GenericAction::Change(span, change, lhs, args) => {
415                let args = args
416                    .into_iter()
417                    .map(|e| e.subst_leaf(&mut fvar_expr!()))
418                    .collect();
419                GenericAction::Change(span, change, lhs.clone(), args)
420            }
421            GenericAction::Union(span, lhs, rhs) => {
422                let lhs = lhs.subst_leaf(&mut fvar_expr!());
423                let rhs = rhs.subst_leaf(&mut fvar_expr!());
424                GenericAction::Union(span, lhs, rhs)
425            }
426            GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg.clone()),
427            GenericAction::Expr(span, e) => {
428                GenericAction::Expr(span, e.subst_leaf(&mut fvar_expr!()))
429            }
430        }
431    }
432
433    /// Applies the provided `head` and `leaf` mappings to the action and all nested expressions.
434    pub fn map_symbols<Head2, Leaf2>(
435        self,
436        head: &mut impl FnMut(Head) -> Head2,
437        leaf: &mut impl FnMut(Leaf) -> Leaf2,
438    ) -> GenericAction<Head2, Leaf2>
439    where
440        Head2: Clone + Display,
441        Leaf2: Clone + Eq + Display + Hash,
442    {
443        match self {
444            GenericAction::Let(span, lhs, rhs) => {
445                GenericAction::Let(span, leaf(lhs), rhs.map_symbols(head, leaf))
446            }
447            GenericAction::Set(span, head_sym, args, rhs) => {
448                let mut mapped_args = Vec::with_capacity(args.len());
449                for arg in args {
450                    mapped_args.push(arg.map_symbols(head, leaf));
451                }
452                GenericAction::Set(
453                    span,
454                    head(head_sym),
455                    mapped_args,
456                    rhs.map_symbols(head, leaf),
457                )
458            }
459            GenericAction::Change(span, change, head_sym, args) => {
460                let mut mapped_args = Vec::with_capacity(args.len());
461                for arg in args {
462                    mapped_args.push(arg.map_symbols(head, leaf));
463                }
464                GenericAction::Change(span, change, head(head_sym), mapped_args)
465            }
466            GenericAction::Union(span, lhs, rhs) => GenericAction::Union(
467                span,
468                lhs.map_symbols(head, leaf),
469                rhs.map_symbols(head, leaf),
470            ),
471            GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg),
472            GenericAction::Expr(span, expr) => {
473                GenericAction::Expr(span, expr.map_symbols(head, leaf))
474            }
475        }
476    }
477
478    /// Converts the action into its unresolved representation using String by
479    /// formatting heads and leaves.
480    pub fn make_unresolved(self) -> GenericAction<String, String> {
481        let mut map_head = |h: Head| h.to_string();
482        let mut map_leaf = |l: Leaf| l.to_string();
483        self.map_symbols(&mut map_head, &mut map_leaf)
484    }
485}
486
487impl<Head, Leaf> GenericFact<Head, Leaf>
488where
489    Head: Clone + Display,
490    Leaf: Clone + PartialEq + Eq + Display + Hash,
491{
492    pub fn visit_vars(&self, f: &mut impl FnMut(&Span, &Leaf)) {
493        let mut visit = |expr: GenericExpr<Head, Leaf>| match expr {
494            GenericExpr::Var(span, var) => {
495                f(&span, &var);
496                GenericExpr::Var(span, var)
497            }
498            other => other,
499        };
500        let _ = self.clone().visit_exprs(&mut visit);
501    }
502
503    pub fn visit_exprs(
504        self,
505        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
506    ) -> GenericFact<Head, Leaf> {
507        match self {
508            GenericFact::Eq(span, e1, e2) => {
509                GenericFact::Eq(span, e1.visit_exprs(f), e2.visit_exprs(f))
510            }
511            GenericFact::Fact(expr) => GenericFact::Fact(expr.visit_exprs(f)),
512        }
513    }
514
515    pub fn map_exprs<Head2, Leaf2>(
516        &self,
517        f: &mut impl FnMut(&GenericExpr<Head, Leaf>) -> GenericExpr<Head2, Leaf2>,
518    ) -> GenericFact<Head2, Leaf2> {
519        match self {
520            GenericFact::Eq(span, e1, e2) => GenericFact::Eq(span.clone(), f(e1), f(e2)),
521            GenericFact::Fact(expr) => GenericFact::Fact(f(expr)),
522        }
523    }
524
525    pub fn subst<Leaf2, Head2>(
526        &self,
527        subst_leaf: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head2, Leaf2>,
528        subst_head: &mut impl FnMut(&Head) -> Head2,
529    ) -> GenericFact<Head2, Leaf2> {
530        self.map_exprs(&mut |e| e.subst(subst_leaf, subst_head))
531    }
532}
533
534impl<Head, Leaf> GenericFact<Head, Leaf>
535where
536    Leaf: Clone + PartialEq + Eq + Display + Hash,
537    Head: Clone + Display,
538{
539    /// Applies the provided `head` and `leaf` mappings to the fact.
540    pub fn map_symbols<Head2, Leaf2>(
541        self,
542        head: &mut impl FnMut(Head) -> Head2,
543        leaf: &mut impl FnMut(Leaf) -> Leaf2,
544    ) -> GenericFact<Head2, Leaf2>
545    where
546        Head2: Clone + Display,
547        Leaf2: Clone + PartialEq + Eq + Display + Hash,
548    {
549        match self {
550            GenericFact::Eq(span, e1, e2) => {
551                GenericFact::Eq(span, e1.map_symbols(head, leaf), e2.map_symbols(head, leaf))
552            }
553            GenericFact::Fact(expr) => GenericFact::Fact(expr.map_symbols(head, leaf)),
554        }
555    }
556
557    /// Converts all heads and leaves to strings.
558    pub fn make_unresolved(self) -> GenericFact<String, String> {
559        let mut map_head = |h: Head| h.to_string();
560        let mut map_leaf = |l: Leaf| l.to_string();
561        self.map_symbols(&mut map_head, &mut map_leaf)
562    }
563}
564
565impl<Head: Clone + Display, Leaf: Hash + Clone + Display + Eq> GenericExpr<Head, Leaf> {
566    pub fn visit_vars(&self, f: &mut impl FnMut(&Span, &Leaf)) {
567        let mut visit = |expr: GenericExpr<Head, Leaf>| match expr {
568            GenericExpr::Var(span, var) => {
569                f(&span, &var);
570                GenericExpr::Var(span, var)
571            }
572            other => other,
573        };
574        let _ = self.clone().visit_exprs(&mut visit);
575    }
576
577    pub fn span(&self) -> Span {
578        match self {
579            GenericExpr::Lit(span, _) => span.clone(),
580            GenericExpr::Var(span, _) => span.clone(),
581            GenericExpr::Call(span, _, _) => span.clone(),
582        }
583    }
584
585    pub fn is_var(&self) -> bool {
586        matches!(self, GenericExpr::Var(_, _))
587    }
588
589    pub fn get_var(&self) -> Option<Leaf> {
590        match self {
591            GenericExpr::Var(_ann, v) => Some(v.clone()),
592            _ => None,
593        }
594    }
595
596    fn children(&self) -> &[Self] {
597        match self {
598            GenericExpr::Var(_, _) | GenericExpr::Lit(_, _) => &[],
599            GenericExpr::Call(_, _, children) => children,
600        }
601    }
602
603    pub fn ast_size(&self) -> usize {
604        let mut size = 0;
605        self.walk(&mut |_e| size += 1, &mut |_| {});
606        size
607    }
608
609    /// Traverse the expression tree, calling `pre` before visiting children
610    /// and `post` after visiting children. Visits all nodes in the tree.
611    pub fn walk(&self, pre: &mut impl FnMut(&Self), post: &mut impl FnMut(&Self)) {
612        pre(self);
613        self.children()
614            .iter()
615            .for_each(|child| child.walk(pre, post));
616        post(self);
617    }
618
619    /// Fold over the expression tree bottom-up, collecting results from children.
620    /// The function `f` is called on each node with the node itself and the results
621    /// from folding over its children. Results are computed from leaves to root.
622    pub fn fold<Out>(&self, f: &mut impl FnMut(&Self, Vec<Out>) -> Out) -> Out {
623        let ts = self.children().iter().map(|child| child.fold(f)).collect();
624        f(self, ts)
625    }
626
627    /// Search for the first node matching a predicate, returning early once found.
628    /// Traverses the tree in pre-order (top-down).
629    pub fn find<Out>(&self, f: &mut impl FnMut(&Self) -> Option<Out>) -> Option<Out> {
630        // Check current node first
631        if let Some(result) = f(self) {
632            return Some(result);
633        }
634
635        // Then check children
636        for child in self.children().iter() {
637            if let Some(result) = child.find(f) {
638                return Some(result);
639            }
640        }
641
642        None
643    }
644
645    /// Applys `f` to all sub-expressions (including `self`)
646    /// bottom-up, collecting the results.
647    pub fn visit_exprs(self, f: &mut impl FnMut(Self) -> Self) -> Self {
648        match self {
649            GenericExpr::Lit(..) => f(self),
650            GenericExpr::Var(..) => f(self),
651            GenericExpr::Call(span, op, children) => {
652                let children = children.into_iter().map(|c| c.visit_exprs(f)).collect();
653                f(GenericExpr::Call(span, op.clone(), children))
654            }
655        }
656    }
657
658    /// `subst` replaces occurrences of variables and head symbols in the expression.
659    pub fn subst<Head2, Leaf2>(
660        &self,
661        subst_leaf: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head2, Leaf2>,
662        subst_head: &mut impl FnMut(&Head) -> Head2,
663    ) -> GenericExpr<Head2, Leaf2> {
664        match self {
665            GenericExpr::Lit(span, lit) => GenericExpr::Lit(span.clone(), lit.clone()),
666            GenericExpr::Var(span, v) => subst_leaf(span, v),
667            GenericExpr::Call(span, op, children) => {
668                let children = children
669                    .iter()
670                    .map(|c| c.subst(subst_leaf, subst_head))
671                    .collect();
672                GenericExpr::Call(span.clone(), subst_head(op), children)
673            }
674        }
675    }
676
677    pub fn subst_leaf<Leaf2>(
678        &self,
679        subst_leaf: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head, Leaf2>,
680    ) -> GenericExpr<Head, Leaf2> {
681        self.subst(subst_leaf, &mut |x| x.clone())
682    }
683
684    /// Applies the provided `head` and `leaf` mappings to every symbol within the expression.
685    pub fn map_symbols<Head2, Leaf2>(
686        self,
687        head: &mut impl FnMut(Head) -> Head2,
688        leaf: &mut impl FnMut(Leaf) -> Leaf2,
689    ) -> GenericExpr<Head2, Leaf2> {
690        match self {
691            GenericExpr::Lit(span, lit) => GenericExpr::Lit(span, lit),
692            GenericExpr::Var(span, var) => GenericExpr::Var(span, leaf(var)),
693            GenericExpr::Call(span, op, children) => {
694                let mut mapped_children = Vec::with_capacity(children.len());
695                for child in children {
696                    mapped_children.push(child.map_symbols(head, leaf));
697                }
698                GenericExpr::Call(span, head(op), mapped_children)
699            }
700        }
701    }
702
703    /// Converts all heads and leaves to strings.
704    pub fn make_unresolved(self) -> GenericExpr<String, String> {
705        let mut map_head = |h: Head| h.to_string();
706        let mut map_leaf = |l: Leaf| l.to_string();
707        self.map_symbols(&mut map_head, &mut map_leaf)
708    }
709
710    pub fn vars(&self) -> impl Iterator<Item = Leaf> + '_ {
711        let iterator: Box<dyn Iterator<Item = Leaf>> = match self {
712            GenericExpr::Lit(_ann, _l) => Box::new(std::iter::empty()),
713            GenericExpr::Var(_ann, v) => Box::new(std::iter::once(v.clone())),
714            GenericExpr::Call(_ann, _head, exprs) => Box::new(exprs.iter().flat_map(|e| e.vars())),
715        };
716        iterator
717    }
718}
719
720impl Display for Literal {
721    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
722        match &self {
723            Literal::Int(i) => Display::fmt(i, f),
724            Literal::Float(n) => {
725                // need to display with decimal if there is none
726                let str = n.to_string();
727                if let Ok(_num) = str.parse::<i64>() {
728                    write!(f, "{str}.0")
729                } else {
730                    write!(f, "{str}")
731                }
732            }
733            Literal::Bool(b) => Display::fmt(b, f),
734            Literal::String(s) => write!(f, "\"{s}\""),
735            Literal::Unit => write!(f, "()"),
736        }
737    }
738}