egglog/ast/
mod.rs

1pub mod check_shadowing;
2pub mod desugar;
3mod expr;
4mod parse;
5pub mod proof_global_remover;
6pub mod remove_globals;
7
8use std::cmp::max;
9
10use crate::core::{
11    GenericAtom, GenericAtomTerm, GenericExprExt, HeadOrEq, Query, ResolvedCall, ResolvedCoreRule,
12};
13use crate::*;
14pub use egglog_ast::generic_ast::{
15    Change, GenericAction, GenericActions, GenericExpr, GenericFact, GenericRule, Literal,
16};
17pub use egglog_ast::span::{RustSpan, Span};
18use egglog_ast::util::ListDisplay;
19pub use expr::*;
20pub use parse::*;
21
22#[derive(Clone, Debug)]
23/// The egglog internal representation of already compiled rules
24pub(crate) enum Ruleset {
25    /// Represents a ruleset with a set of rules.
26    Rules(IndexMap<String, (ResolvedCoreRule, egglog_bridge::RuleId)>),
27    /// A combined ruleset may contain other rulesets.
28    Combined(Vec<String>),
29}
30
31pub type NCommand = GenericNCommand<String, String>;
32/// [`ResolvedNCommand`] is another specialization of [`GenericNCommand`], which
33/// adds the type information to heads and leaves of commands.
34/// [`TypeInfo::typecheck_command`] turns an [`NCommand`] into a [`ResolvedNCommand`].
35pub(crate) type ResolvedNCommand = GenericNCommand<ResolvedCall, ResolvedVar>;
36
37/// A [`NCommand`] is a desugared [`Command`], where syntactic sugars
38/// like [`Command::Datatype`] and [`Command::Rewrite`]
39/// are eliminated.
40/// Most of the heavy lifting in egglog is done over [`NCommand`]s.
41///
42/// [`GenericNCommand`] is a generalization of [`NCommand`], like how [`GenericCommand`]
43/// is a generalization of [`Command`], allowing annotations over `Head` and `Leaf`.
44///
45/// TODO: The name "NCommand" used to denote normalized command, but this
46/// meaning is obsolete. A future PR should rename this type to something
47/// like "DCommand".
48#[derive(Debug, Clone, Eq, PartialEq, Hash)]
49pub enum GenericNCommand<Head, Leaf>
50where
51    Head: Clone + Display,
52    Leaf: Clone + PartialEq + Eq + Display + Hash,
53{
54    Sort {
55        span: Span,
56        name: String,
57        presort_and_args: Option<(String, Vec<GenericExpr<String, String>>)>,
58        /// The name of the union-find function for this sort.
59        /// Used in term encoding to canonicalize values during extraction.
60        uf: Option<String>,
61        /// The name of the proof function for this sort.
62        /// Set by proof desugaring to record where proofs are stored for this sort.
63        proof_func: Option<String>,
64        /// Whether values of this sort can be unioned.
65        /// Defaults to true for user-defined sorts.
66        /// Set to false for relations and term tables that should not allow union.
67        unionable: bool,
68    },
69    Function(GenericFunctionDecl<Head, Leaf>),
70    AddRuleset(Span, String),
71    UnstableCombinedRuleset(Span, String, Vec<String>),
72    NormRule {
73        rule: GenericRule<Head, Leaf>,
74    },
75    CoreAction(GenericAction<Head, Leaf>),
76    Extract(Span, GenericExpr<Head, Leaf>, GenericExpr<Head, Leaf>),
77    RunSchedule(GenericSchedule<Head, Leaf>),
78    PrintOverallStatistics(Span, Option<String>),
79    Check(Span, Vec<GenericFact<Head, Leaf>>),
80    PrintFunction(
81        Span,
82        String,
83        Option<usize>,
84        Option<String>,
85        PrintFunctionMode,
86    ),
87    ProveExists(Span, Head),
88    PrintSize(Span, Option<String>),
89    Output {
90        span: Span,
91        file: String,
92        exprs: Vec<GenericExpr<Head, Leaf>>,
93    },
94    Push(usize),
95    Pop(Span, usize),
96    Fail(Span, Box<GenericNCommand<Head, Leaf>>),
97    Input {
98        span: Span,
99        name: String,
100        file: String,
101    },
102    UserDefined(Span, String, Vec<Expr>),
103}
104
105impl<Head, Leaf> GenericNCommand<Head, Leaf>
106where
107    Head: Clone + Display,
108    Leaf: Clone + PartialEq + Eq + Display + Hash,
109{
110    pub fn to_command(&self) -> GenericCommand<Head, Leaf> {
111        match self {
112            GenericNCommand::Sort {
113                span,
114                name,
115                presort_and_args,
116                uf,
117                proof_func,
118                unionable,
119            } => GenericCommand::Sort {
120                span: span.clone(),
121                name: name.clone(),
122                presort_and_args: presort_and_args.clone(),
123                uf: uf.clone(),
124                proof_func: proof_func.clone(),
125                unionable: *unionable,
126            },
127            GenericNCommand::Function(f) => match f.subtype {
128                FunctionSubtype::Constructor => GenericCommand::Constructor {
129                    span: f.span.clone(),
130                    name: f.name.clone(),
131                    schema: f.schema.clone(),
132                    cost: f.cost,
133                    unextractable: f.unextractable,
134                    hidden: f.internal_hidden,
135                    let_binding: f.internal_let,
136                    term_constructor: f.term_constructor.clone(),
137                },
138                FunctionSubtype::Custom => GenericCommand::Function {
139                    span: f.span.clone(),
140                    schema: f.schema.clone(),
141                    name: f.name.clone(),
142                    merge: f.merge.clone(),
143                    hidden: f.internal_hidden,
144                    let_binding: f.internal_let,
145                    term_constructor: f.term_constructor.clone(),
146                    unextractable: f.unextractable,
147                },
148            },
149            GenericNCommand::AddRuleset(span, name) => {
150                GenericCommand::AddRuleset(span.clone(), name.clone())
151            }
152            GenericNCommand::UnstableCombinedRuleset(span, name, others) => {
153                GenericCommand::UnstableCombinedRuleset(span.clone(), name.clone(), others.clone())
154            }
155            GenericNCommand::NormRule { rule } => GenericCommand::Rule { rule: rule.clone() },
156            GenericNCommand::RunSchedule(schedule) => GenericCommand::RunSchedule(schedule.clone()),
157            GenericNCommand::PrintOverallStatistics(span, file) => {
158                GenericCommand::PrintOverallStatistics(span.clone(), file.clone())
159            }
160            GenericNCommand::CoreAction(action) => GenericCommand::Action(action.clone()),
161            GenericNCommand::Extract(span, expr, variants) => {
162                GenericCommand::Extract(span.clone(), expr.clone(), variants.clone())
163            }
164            GenericNCommand::Check(span, facts) => {
165                GenericCommand::Check(span.clone(), facts.clone())
166            }
167            GenericNCommand::PrintFunction(span, name, n, file, mode) => {
168                GenericCommand::PrintFunction(span.clone(), name.clone(), *n, file.clone(), *mode)
169            }
170            GenericNCommand::ProveExists(span, constructor) => {
171                GenericCommand::ProveExists(span.clone(), constructor.clone())
172            }
173            GenericNCommand::PrintSize(span, name) => {
174                GenericCommand::PrintSize(span.clone(), name.clone())
175            }
176            GenericNCommand::Output { span, file, exprs } => GenericCommand::Output {
177                span: span.clone(),
178                file: file.to_string(),
179                exprs: exprs.clone(),
180            },
181            GenericNCommand::Push(n) => GenericCommand::Push(*n),
182            GenericNCommand::Pop(span, n) => GenericCommand::Pop(span.clone(), *n),
183            GenericNCommand::Fail(span, cmd) => {
184                GenericCommand::Fail(span.clone(), Box::new(cmd.to_command()))
185            }
186            GenericNCommand::Input { span, name, file } => GenericCommand::Input {
187                span: span.clone(),
188                name: name.clone(),
189                file: file.clone(),
190            },
191            GenericNCommand::UserDefined(span, name, exprs) => {
192                GenericCommand::UserDefined(span.clone(), name.clone(), exprs.clone())
193            }
194        }
195    }
196
197    /// Applies `f` to
198    pub fn visit_queries(
199        self,
200        f: &mut impl FnMut(Vec<GenericFact<Head, Leaf>>) -> Vec<GenericFact<Head, Leaf>>,
201    ) -> Self {
202        match self {
203            GenericNCommand::Check(span, query) => GenericNCommand::Check(span, f(query)),
204            GenericNCommand::NormRule { mut rule } => {
205                rule.body = f(rule.body);
206                GenericNCommand::NormRule { rule }
207            }
208            GenericNCommand::RunSchedule(schedule) => {
209                GenericNCommand::RunSchedule(schedule.visit_queries(f))
210            }
211            GenericNCommand::Fail(span, cmd) => {
212                GenericNCommand::Fail(span, Box::new(cmd.visit_queries(f)))
213            }
214            GenericNCommand::Sort { .. }
215            | GenericNCommand::Function(..)
216            | GenericNCommand::AddRuleset(..)
217            | GenericNCommand::UnstableCombinedRuleset(..)
218            | GenericNCommand::CoreAction(..)
219            | GenericNCommand::Extract(..)
220            | GenericNCommand::PrintOverallStatistics(..)
221            | GenericNCommand::PrintFunction(..)
222            | GenericNCommand::PrintSize(..)
223            | GenericNCommand::Output { .. }
224            | GenericNCommand::Push(..)
225            | GenericNCommand::Pop(..)
226            | GenericNCommand::Input { .. }
227            | GenericNCommand::UserDefined(..)
228            | GenericNCommand::ProveExists(..) => self,
229        }
230    }
231
232    /// Applies `f` to all expressions in the command, bottom-up.
233    pub fn visit_exprs(
234        self,
235        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
236    ) -> Self {
237        match self {
238            GenericNCommand::Sort {
239                span,
240                name,
241                presort_and_args,
242                uf,
243                proof_func,
244                unionable,
245            } => GenericNCommand::Sort {
246                span,
247                name,
248                presort_and_args,
249                uf,
250                proof_func,
251                unionable,
252            },
253            GenericNCommand::Function(func) => GenericNCommand::Function(func.visit_exprs(f)),
254            GenericNCommand::AddRuleset(span, name) => GenericNCommand::AddRuleset(span, name),
255            GenericNCommand::UnstableCombinedRuleset(span, name, rulesets) => {
256                GenericNCommand::UnstableCombinedRuleset(span, name, rulesets)
257            }
258            GenericNCommand::NormRule { rule } => GenericNCommand::NormRule {
259                rule: rule.visit_exprs(f),
260            },
261            GenericNCommand::RunSchedule(schedule) => {
262                GenericNCommand::RunSchedule(schedule.visit_exprs(f))
263            }
264            GenericNCommand::PrintOverallStatistics(span, file) => {
265                GenericNCommand::PrintOverallStatistics(span, file)
266            }
267            GenericNCommand::CoreAction(action) => {
268                GenericNCommand::CoreAction(action.visit_exprs(f))
269            }
270            GenericNCommand::Extract(span, expr, variants) => {
271                GenericNCommand::Extract(span, expr.visit_exprs(f), variants.visit_exprs(f))
272            }
273            GenericNCommand::Check(span, facts) => GenericNCommand::Check(
274                span,
275                facts.into_iter().map(|fact| fact.visit_exprs(f)).collect(),
276            ),
277            GenericNCommand::PrintFunction(span, name, n, file, mode) => {
278                GenericNCommand::PrintFunction(span, name, n, file, mode)
279            }
280            GenericNCommand::ProveExists(span, constructor) => {
281                GenericNCommand::ProveExists(span, constructor)
282            }
283            GenericNCommand::PrintSize(span, name) => GenericNCommand::PrintSize(span, name),
284            GenericNCommand::Output { span, file, exprs } => GenericNCommand::Output {
285                span,
286                file,
287                exprs: exprs.into_iter().map(f).collect(),
288            },
289            GenericNCommand::Push(n) => GenericNCommand::Push(n),
290            GenericNCommand::Pop(span, n) => GenericNCommand::Pop(span, n),
291            GenericNCommand::Fail(span, cmd) => {
292                GenericNCommand::Fail(span, Box::new(cmd.visit_exprs(f)))
293            }
294            GenericNCommand::Input { span, name, file } => {
295                GenericNCommand::Input { span, name, file }
296            }
297            GenericNCommand::UserDefined(span, name, exprs) => {
298                // We can't map `f` over UserDefined because UserDefined always assumes plain `Expr`s
299                GenericNCommand::UserDefined(span, name, exprs)
300            }
301        }
302    }
303}
304
305impl<Head, Leaf> Display for GenericNCommand<Head, Leaf>
306where
307    Head: Clone + Display,
308    Leaf: Clone + PartialEq + Eq + Display + Hash,
309{
310    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
311        let command = self.to_command();
312        command.fmt(f)
313    }
314}
315
316pub type Schedule = GenericSchedule<String, String>;
317pub(crate) type ResolvedSchedule = GenericSchedule<ResolvedCall, ResolvedVar>;
318
319#[derive(Debug, Clone, PartialEq, Eq, Hash)]
320pub enum GenericSchedule<Head, Leaf> {
321    Saturate(Span, Box<GenericSchedule<Head, Leaf>>),
322    Repeat(Span, usize, Box<GenericSchedule<Head, Leaf>>),
323    Run(Span, GenericRunConfig<Head, Leaf>),
324    Sequence(Span, Vec<GenericSchedule<Head, Leaf>>),
325}
326
327impl<Head, Leaf> GenericSchedule<Head, Leaf>
328where
329    Head: Clone + Display,
330    Leaf: Clone + PartialEq + Eq + Display + Hash,
331{
332    /// Applies `f` to all the queries in the schedule.
333    pub fn visit_queries(
334        self,
335        f: &mut impl FnMut(Vec<GenericFact<Head, Leaf>>) -> Vec<GenericFact<Head, Leaf>>,
336    ) -> Self {
337        match self {
338            GenericSchedule::Saturate(span, generic_schedule) => {
339                GenericSchedule::Saturate(span, Box::new(generic_schedule.visit_queries(f)))
340            }
341            GenericSchedule::Repeat(span, iters, generic_schedule) => {
342                GenericSchedule::Repeat(span, iters, Box::new(generic_schedule.visit_queries(f)))
343            }
344            GenericSchedule::Run(span, run_config) => GenericSchedule::Run(
345                span,
346                GenericRunConfig {
347                    ruleset: run_config.ruleset,
348                    until: run_config.until.map(f),
349                },
350            ),
351            GenericSchedule::Sequence(span, generic_schedules) => GenericSchedule::Sequence(
352                span,
353                generic_schedules
354                    .into_iter()
355                    .map(|schedule| schedule.visit_queries(f))
356                    .collect(),
357            ),
358        }
359    }
360
361    /// Recursively flattens nested `Sequence` nodes into a single level.
362    /// For example, `(seq (seq a b) c)` becomes `(seq a b c)`.
363    /// Also unwraps single-element sequences into their inner schedule.
364    fn flatten_sequences(self) -> Self {
365        match self {
366            GenericSchedule::Saturate(span, sched) => {
367                GenericSchedule::Saturate(span, Box::new(sched.flatten_sequences()))
368            }
369            GenericSchedule::Repeat(span, size, sched) => {
370                GenericSchedule::Repeat(span, size, Box::new(sched.flatten_sequences()))
371            }
372            GenericSchedule::Run(span, config) => GenericSchedule::Run(span, config),
373            GenericSchedule::Sequence(span, scheds) => {
374                let mut flattened = Vec::new();
375                for sched in scheds.into_iter().map(Self::flatten_sequences) {
376                    match sched {
377                        GenericSchedule::Sequence(_, nested) => flattened.extend(nested),
378                        other => flattened.push(other),
379                    }
380                }
381
382                match flattened.len() {
383                    0 => GenericSchedule::Sequence(span, flattened),
384                    1 => flattened.into_iter().next().unwrap(),
385                    _ => GenericSchedule::Sequence(span, flattened),
386                }
387            }
388        }
389    }
390
391    fn visit_exprs(
392        self,
393        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
394    ) -> Self {
395        match self {
396            GenericSchedule::Saturate(span, sched) => {
397                GenericSchedule::Saturate(span, Box::new(sched.visit_exprs(f)))
398            }
399            GenericSchedule::Repeat(span, size, sched) => {
400                GenericSchedule::Repeat(span, size, Box::new(sched.visit_exprs(f)))
401            }
402            GenericSchedule::Run(span, config) => GenericSchedule::Run(span, config.visit_exprs(f)),
403            GenericSchedule::Sequence(span, scheds) => GenericSchedule::Sequence(
404                span,
405                scheds.into_iter().map(|s| s.visit_exprs(f)).collect(),
406            ),
407        }
408    }
409
410    /// Remaps every head and leaf symbol in the schedule using the supplied closures.
411    pub fn map_symbols<Head2, Leaf2>(
412        self,
413        head: &mut impl FnMut(Head) -> Head2,
414        leaf: &mut impl FnMut(Leaf) -> Leaf2,
415    ) -> GenericSchedule<Head2, Leaf2>
416    where
417        Head2: Clone + Display,
418        Leaf2: Clone + PartialEq + Eq + Display + Hash,
419    {
420        match self {
421            GenericSchedule::Saturate(span, sched) => {
422                GenericSchedule::Saturate(span, Box::new(sched.map_symbols(head, leaf)))
423            }
424            GenericSchedule::Repeat(span, size, sched) => {
425                GenericSchedule::Repeat(span, size, Box::new(sched.map_symbols(head, leaf)))
426            }
427            GenericSchedule::Run(span, config) => {
428                GenericSchedule::Run(span, config.map_symbols(head, leaf))
429            }
430            GenericSchedule::Sequence(span, scheds) => GenericSchedule::Sequence(
431                span,
432                scheds
433                    .into_iter()
434                    .map(|sched| sched.map_symbols(head, leaf))
435                    .collect(),
436            ),
437        }
438    }
439
440    /// Applies `fun` to every string-valued symbol contained in the schedule,
441    /// normalizes result with `flatten_sequences`.
442    pub fn map_string_symbols(
443        self,
444        fun: &mut impl FnMut(String) -> String,
445    ) -> GenericSchedule<Head, Leaf> {
446        let mapped = match self {
447            GenericSchedule::Saturate(span, sched) => {
448                GenericSchedule::Saturate(span, Box::new(sched.map_string_symbols(fun)))
449            }
450            GenericSchedule::Repeat(span, size, sched) => {
451                GenericSchedule::Repeat(span, size, Box::new(sched.map_string_symbols(fun)))
452            }
453            GenericSchedule::Run(span, config) => {
454                GenericSchedule::Run(span, config.map_string_symbols(fun))
455            }
456            GenericSchedule::Sequence(span, scheds) => GenericSchedule::Sequence(
457                span,
458                scheds
459                    .into_iter()
460                    .map(|sched| sched.map_string_symbols(fun))
461                    .collect(),
462            ),
463        };
464
465        mapped.flatten_sequences()
466    }
467
468    /// Converts all heads and leaves to strings.
469    pub fn make_unresolved(self) -> GenericSchedule<String, String> {
470        let mut map_head = |h: Head| h.to_string();
471        let mut map_leaf = |l: Leaf| l.to_string();
472        self.map_symbols(&mut map_head, &mut map_leaf)
473    }
474}
475
476impl<Head: Display, Leaf: Display> Display for GenericSchedule<Head, Leaf> {
477    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
478        match self {
479            GenericSchedule::Saturate(_ann, sched) => write!(f, "(saturate {sched})"),
480            GenericSchedule::Repeat(_ann, size, sched) => write!(f, "(repeat {size} {sched})"),
481            GenericSchedule::Run(_ann, config) => write!(f, "{config}"),
482            GenericSchedule::Sequence(_ann, scheds) => {
483                write!(f, "(seq {})", ListDisplay(scheds, " "))
484            }
485        }
486    }
487}
488
489pub type Command = GenericCommand<String, String>;
490pub type ResolvedCommand = GenericCommand<ResolvedCall, ResolvedVar>;
491
492pub type Subsume = bool;
493
494#[derive(Debug, Clone, PartialEq, Eq)]
495pub enum Subdatatypes {
496    Variants(Vec<Variant>),
497    NewSort(String, Vec<Expr>),
498}
499
500/// The mode of printing a function. The default mode prints the function in a user-friendly way and
501/// has an unreliable interface.
502/// The CSV mode prints the function in the CSV format.
503#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
504pub enum PrintFunctionMode {
505    Default,
506    CSV,
507}
508
509impl Display for PrintFunctionMode {
510    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
511        match self {
512            PrintFunctionMode::Default => write!(f, "default"),
513            PrintFunctionMode::CSV => write!(f, "csv"),
514        }
515    }
516}
517
518/// A [`Command`] is the top-level construct in egglog.
519/// It includes defining rules, declaring functions,
520/// adding to tables, and running rules (via a [`Schedule`]).
521///
522/// # Binding naming convention
523/// Bindings introduced by commands fall into two categories:
524/// - **Global bindings** must start with [`$`](crate::GLOBAL_NAME_PREFIX).
525/// - **Non-global bindings** must *not* start with [`$`](crate::GLOBAL_NAME_PREFIX).
526///
527/// When `--strict-mode` is enabled, violating these conventions is a type error;
528/// otherwise, egglog emits a single warning per program.
529#[derive(Debug, Clone)]
530pub enum GenericCommand<Head, Leaf>
531where
532    Head: Clone + Display,
533    Leaf: Clone + PartialEq + Eq + Display + Hash,
534{
535    /// Create a new user-defined sort, which can then
536    /// be used in new [`Command::Function`] declarations.
537    /// The [`Command::Datatype`] command desugars directly to this command, with one [`Command::Function`]
538    /// per constructor.
539    /// The main use of this command (as opposed to using [`Command::Datatype`]) is for forward-declaring a sort for mutually-recursive datatypes.
540    ///
541    /// It can also be used to create
542    /// a container sort.
543    /// For example, here's how to make a sort for vectors
544    /// of some user-defined sort `Math`:
545    /// ```text
546    /// (sort MathVec (Vec Math))
547    /// ```
548    ///
549    /// Now `MathVec` can be used as an input or output sort.
550    Sort {
551        span: Span,
552        name: String,
553        presort_and_args: Option<(String, Vec<Expr>)>,
554        /// The name of the union-find function for this sort.
555        /// Used in term encoding to canonicalize values during extraction.
556        uf: Option<String>,
557        /// The name of the proof function for this sort.
558        /// Set by proof desugaring to record where proofs are stored for this sort.
559        proof_func: Option<String>,
560        /// Whether values of this sort can be unioned.
561        /// Defaults to true for user-defined sorts.
562        /// Set to false for relations and term tables that should not allow union.
563        unionable: bool,
564    },
565
566    /// Egglog supports three types of functions
567    ///
568    /// A constructor models an egg-style user-defined datatype
569    /// It can only be defined through the `datatype`/`datatype*` command
570    /// or the `constructor` command
571    ///
572    /// A relation models a datalog-style mathematical relation
573    /// It can only be defined through the `relation` command
574    ///
575    /// A custom function is a dictionary
576    /// It can only be defined through the `function` command
577    ///
578    /// The `datatype` command declares a user-defined datatype.
579    /// Datatypes can be unioned with [`Action::Union`] either
580    /// at the top level or in the actions of a rule.
581    /// This makes them equal in the implicit, global equality relation.
582    ///
583    /// Example:
584    /// ```text
585    /// (datatype Math
586    ///   (Num i64)
587    ///   (Var String)
588    ///   (Add Math Math)
589    ///   (Mul Math Math))
590    /// ```
591    ///
592    /// defines a simple `Math` datatype with variants for numbers, named variables, addition and multiplication.
593    ///
594    /// Datatypes desugar directly to a [`Command::Sort`] and a [`Command::Constructor`] for each constructor.
595    /// The code above becomes:
596    /// ```text
597    /// (sort Math)
598    /// (constructor Num (i64) Math)
599    /// (constructor Var (String) Math)
600    /// (constructor Add (Math Math) Math)
601    /// (constructor Mul (Math Math) Math)
602    ///
603    /// Datatypes are also known as algebraic data types, tagged unions and sum types.
604    Datatype {
605        span: Span,
606        name: String,
607        variants: Vec<Variant>,
608    },
609    Datatypes {
610        span: Span,
611        datatypes: Vec<(Span, String, Subdatatypes)>,
612    },
613
614    /// The `constructor` command defines a new constructor for a user-defined datatype
615    /// Example:
616    /// ```text
617    /// (constructor Add (i64 i64) Math)
618    /// ```
619    ///
620    /// A constructor can be `:unextractable`, in which case extraction skip this constructor entirely.
621    /// ```text
622    /// (constructor UnextractableNode (i64) Math :unextractable)
623    /// ```
624    ///
625    /// A constructor can also have a `cost` for extraction, which is used in the cost model for extraction. It defaults to 1 if not specified.
626    /// ```text
627    /// (constructor ExpensiveNode (i64) Math :cost 10)
628    /// ```
629    Constructor {
630        span: Span,
631        name: String,
632        schema: Schema,
633        cost: Option<DefaultCost>,
634        unextractable: bool,
635        /// Internal-hidden constructors are excluded from print-size output and extraction.
636        /// Used for internal tables generated by proof production.
637        hidden: bool,
638        /// Internal-let constructors are let bindings, excluded from print-size output.
639        /// Used for global let bindings that are converted to constructors.
640        let_binding: bool,
641        /// Internal-only metadata for proof-encoding view tables.
642        /// Parsed user syntax only supports `:internal-term-constructor` on `function`.
643        term_constructor: Option<String>,
644    },
645
646    /// The `relation` command declares a named relation
647    /// Example:
648    /// ```text
649    /// (relation path (i64 i64))
650    /// (relation edge (i64 i64))
651    /// ```
652    Relation {
653        span: Span,
654        name: String,
655        inputs: Vec<String>,
656    },
657
658    /// The `function` command declare an egglog custom function, which is a database table with a
659    /// a functional dependency (also called a primary key) on its inputs to one output.
660    ///
661    /// ```text
662    /// (function <name:Ident> <schema:Schema> <cost:Cost>
663    ///        (:on_merge <List<Action>>)?
664    ///        (:merge <Expr>)?)
665    ///```
666    /// A function can have a `cost` for extraction.
667    ///
668    /// Finally, it can have a `merge` and `on_merge`, which are triggered when
669    /// the function dependency is violated.
670    /// In this case, the merge expression determines which of the two outputs
671    /// for the same input is used.
672    /// The `on_merge` actions are run after the merge expression is evaluated.
673    ///
674    /// Note that the `:merge` expression must be monotonic
675    /// for the behavior of the egglog program to be consistent and defined.
676    /// In other words, the merge function must define a lattice on the output of the function.
677    /// If values are merged in different orders, they should still result in the same output.
678    /// If the merge expression is not monotonic, the behavior can vary as
679    /// actions may be applied more than once with different results.
680    ///
681    /// ```text
682    /// (function LowerBound (Math) i64 :merge (max old new))
683    /// ```
684    ///
685    /// Specifically, a custom function can also have an EqSort output type:
686    ///
687    /// ```text
688    /// (function Add (i64 i64) Math)
689    /// ```
690    ///
691    /// All functions can be `set`
692    /// with [`Action::Set`].
693    ///
694    /// Output of a function, if being the EqSort type, can be unioned with [`Action::Union`]
695    /// with another datatype of the same `sort`.
696    ///
697    Function {
698        span: Span,
699        name: String,
700        schema: Schema,
701        merge: Option<GenericExpr<Head, Leaf>>,
702        hidden: bool,
703        let_binding: bool,
704        term_constructor: Option<String>,
705        unextractable: bool,
706    },
707
708    /// Using the `ruleset` command, defines a new
709    /// ruleset that can be added to in [`Command::Rule`]s.
710    /// Rulesets are used to group rules together
711    /// so that they can be run together in a [`Schedule`].
712    ///
713    /// Example:
714    /// Ruleset allows users to define a ruleset- a set of rules
715    ///
716    /// ```text
717    /// (ruleset myrules)
718    /// (rule ((edge x y))
719    ///       ((path x y))
720    ///       :ruleset myrules)
721    /// (run myrules 2)
722    /// ```
723    AddRuleset(Span, String),
724    /// Using the `combined-ruleset` command, construct another ruleset
725    /// which runs all the rules in the given rulesets.
726    /// This is useful for running multiple rulesets together.
727    /// The combined ruleset also inherits any rules added to the individual rulesets
728    /// after the combined ruleset is declared.
729    ///
730    /// Example:
731    /// ```text
732    /// (ruleset myrules1)
733    /// (rule ((edge x y))
734    ///       ((path x y))
735    ///      :ruleset myrules1)
736    /// (ruleset myrules2)
737    /// (rule ((path x y) (edge y z))
738    ///       ((path x z))
739    ///       :ruleset myrules2)
740    /// (combined-ruleset myrules-combined myrules1 myrules2)
741    /// ```
742    UnstableCombinedRuleset(Span, String, Vec<String>),
743    /// ```text
744    /// (rule <body:List<Fact>> <head:List<Action>>)
745    /// ```
746    ///
747    /// defines an egglog rule.
748    /// The rule matches a list of facts with respect to
749    /// the global database, and runs the list of actions
750    /// for each match.
751    /// The matches are done *modulo equality*, meaning
752    /// equal datatypes in the database are considered
753    /// equal.
754    ///
755    /// Example:
756    /// ```text
757    /// (rule ((edge x y))
758    ///       ((path x y)))
759    ///
760    /// (rule ((path x y) (edge y z))
761    ///       ((path x z)))
762    /// ```
763    Rule {
764        rule: GenericRule<Head, Leaf>,
765    },
766    /// `rewrite` is syntactic sugar for a specific form of `rule`
767    /// which simply unions the left and right hand sides.
768    ///
769    /// Example:
770    /// ```text
771    /// (rewrite (Add a b)
772    ///          (Add b a))
773    /// ```
774    ///
775    /// Desugars to:
776    /// ```text
777    /// (rule ((= lhs (Add a b)))
778    ///       ((union lhs (Add b a))))
779    /// ```
780    ///
781    /// Additionally, additional facts can be specified
782    /// using a `:when` clause.
783    /// For example, the same rule can be run only
784    /// when `a` is zero:
785    ///
786    /// ```text
787    /// (rewrite (Add a b)
788    ///          (Add b a)
789    ///          :when ((= a (Num 0)))
790    /// ```
791    ///
792    /// Add the `:subsume` flag to cause the left hand side to be subsumed after matching, which means it can
793    /// no longer be matched in a rule, but can still be checked against (See [`Change`] for more details.)
794    ///
795    /// ```text
796    /// (rewrite (Mul a 2) (bitshift-left a 1) :subsume)
797    /// ```
798    ///
799    /// Desugars to:
800    /// ```text
801    /// (rule ((= lhs (Mul a 2)))
802    ///       ((union lhs (bitshift-left a 1))
803    ///        (subsume (Mul a 2))))
804    /// ```
805    Rewrite(String, GenericRewrite<Head, Leaf>, Subsume),
806    /// Similar to [`Command::Rewrite`], but
807    /// generates two rules, one for each direction.
808    ///
809    /// Example:
810    /// ```text
811    /// (bi-rewrite (Mul (Var x) (Num 0))
812    ///             (Var x))
813    /// ```
814    ///
815    /// Becomes:
816    /// ```text
817    /// (rule ((= lhs (Mul (Var x) (Num 0))))
818    ///       ((union lhs (Var x))))
819    /// (rule ((= lhs (Var x)))
820    ///       ((union lhs (Mul (Var x) (Num 0)))))
821    /// ```
822    BiRewrite(String, GenericRewrite<Head, Leaf>),
823    /// Perform an [`Action`] on the global database
824    /// (see documentation for [`Action`] for more details).
825    /// Example:
826    /// ```text
827    /// (let xplusone (Add (Var "x") (Num 1)))
828    /// ```
829    Action(GenericAction<Head, Leaf>),
830    /// `extract` a datatype from the egraph, choosing
831    /// the smallest representative.
832    /// By default, each constructor costs 1 to extract
833    /// (common subexpressions are not shared in the cost
834    /// model).
835    Extract(Span, GenericExpr<Head, Leaf>, GenericExpr<Head, Leaf>),
836    /// Runs a [`Schedule`], which specifies
837    /// rulesets and the number of times to run them.
838    ///
839    /// Example:
840    /// ```text
841    /// (run-schedule
842    ///     (saturate my-ruleset-1)
843    ///     (run my-ruleset-2 4))
844    /// ```
845    ///
846    /// Runs `my-ruleset-1` until saturation,
847    /// then runs `my-ruleset-2` four times.
848    ///
849    /// See [`Schedule`] for more details.
850    RunSchedule(GenericSchedule<Head, Leaf>),
851    /// Print runtime statistics about rules
852    /// and rulesets so far.
853    PrintOverallStatistics(Span, Option<String>),
854    /// The `check` command checks that the given facts
855    /// match at least once in the current database.
856    /// The list of facts is matched in the same way a [`Command::Rule`] is matched.
857    ///
858    /// Example:
859    ///
860    /// ```text
861    /// (check (= (+ 1 2) 3))
862    /// (check (<= 0 3) (>= 3 0))
863    /// (fail (check (= 1 2)))
864    /// ```
865    ///
866    /// prints
867    ///
868    /// ```text
869    /// [INFO ] Checked.
870    /// [INFO ] Checked.
871    /// [ERROR] Check failed
872    /// [INFO ] Command failed as expected.
873    /// ```
874    Check(Span, Vec<GenericFact<Head, Leaf>>),
875    Prove(Span, Vec<GenericFact<Head, Leaf>>),
876    ProveExists(Span, Head),
877    /// Print out rows of a given function, extracting each of the elements of the function.
878    /// Example:
879    ///
880    /// ```text
881    /// (print-function Add 20)
882    /// ```
883    /// prints the first 20 rows of the `Add` function.
884    ///
885    /// ```text
886    /// (print-function Add)
887    /// ```
888    /// prints all rows of the `Add` function.
889    ///
890    /// ```text
891    /// (print-function Add :file "add.csv")
892    /// ```
893    /// prints all rows of the `Add` function to a CSV file.
894    PrintFunction(
895        Span,
896        String,
897        Option<usize>,
898        Option<String>,
899        PrintFunctionMode,
900    ),
901    /// Print out the number of rows in a function or all functions.
902    PrintSize(Span, Option<String>),
903    /// Input a CSV file directly into a function.
904    Input {
905        span: Span,
906        name: String,
907        file: String,
908    },
909    /// Extract and output a set of expressions to a file.
910    Output {
911        span: Span,
912        file: String,
913        exprs: Vec<GenericExpr<Head, Leaf>>,
914    },
915    /// `push` the current egraph `n` times so that it is saved.
916    /// Later, the current database and rules can be restored using `pop`.
917    Push(usize),
918    /// `pop` the current egraph, restoring the previous one.
919    /// The argument specifies how many egraphs to pop.
920    Pop(Span, usize),
921    /// Assert that a command fails with an error.
922    Fail(Span, Box<GenericCommand<Head, Leaf>>),
923    /// Include another egglog file directly as text and run it.
924    Include(Span, String),
925    /// User-defined command.
926    UserDefined(Span, String, Vec<Expr>),
927}
928
929impl<Head, Leaf> Display for GenericCommand<Head, Leaf>
930where
931    Head: Clone + Display,
932    Leaf: Clone + PartialEq + Eq + Display + Hash,
933{
934    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
935        match self {
936            GenericCommand::Rewrite(name, rewrite, subsume) => {
937                rewrite.fmt_with_ruleset(f, name, false, *subsume)
938            }
939            GenericCommand::BiRewrite(name, rewrite) => {
940                rewrite.fmt_with_ruleset(f, name, true, false)
941            }
942            GenericCommand::Datatype {
943                span: _,
944                name,
945                variants,
946            } => {
947                write!(f, "(datatype {name} {})", ListDisplay(variants, " "))
948            }
949            GenericCommand::Action(a) => write!(f, "{a}"),
950            GenericCommand::Extract(_span, expr, variants) => {
951                write!(f, "(extract {expr} {variants})")
952            }
953            GenericCommand::Sort {
954                name,
955                presort_and_args: None,
956                uf,
957                proof_func,
958                ..
959            } => {
960                write!(f, "(sort {name}")?;
961                if let Some(uf) = uf {
962                    write!(f, " :internal-uf {uf}")?;
963                }
964                if let Some(pf) = proof_func {
965                    write!(f, " :internal-proof-func {pf}")?;
966                }
967                write!(f, ")")
968            }
969            GenericCommand::Sort {
970                name,
971                presort_and_args: Some((name2, args)),
972                ..
973            } => {
974                write!(f, "(sort {name} ({name2} {}))", ListDisplay(args, " "))
975            }
976            GenericCommand::Function {
977                span: _,
978                name,
979                schema,
980                merge,
981                hidden,
982                let_binding,
983                term_constructor,
984                unextractable,
985            } => {
986                write!(f, "(function {name} {schema}")?;
987                if let Some(merge) = &merge {
988                    write!(f, " :merge {merge}")?;
989                } else {
990                    write!(f, " :no-merge")?;
991                }
992                if *unextractable {
993                    write!(f, " :unextractable")?;
994                }
995                if *hidden {
996                    write!(f, " :internal-hidden")?;
997                }
998                if *let_binding {
999                    write!(f, " :internal-let")?;
1000                }
1001                if let Some(tc) = term_constructor {
1002                    write!(f, " :internal-term-constructor {tc}")?;
1003                }
1004                write!(f, ")")
1005            }
1006            GenericCommand::Constructor {
1007                span: _,
1008                name,
1009                schema,
1010                cost,
1011                unextractable,
1012                hidden,
1013                let_binding,
1014                term_constructor,
1015            } => {
1016                write!(f, "(constructor {name} {schema}")?;
1017                if let Some(cost) = cost {
1018                    write!(f, " :cost {cost}")?;
1019                }
1020                if *unextractable {
1021                    write!(f, " :unextractable")?;
1022                }
1023                if *hidden {
1024                    write!(f, " :internal-hidden")?;
1025                }
1026                if *let_binding {
1027                    write!(f, " :internal-let")?;
1028                }
1029                if let Some(tc) = term_constructor {
1030                    write!(f, " :internal-term-constructor {tc}")?;
1031                }
1032                write!(f, ")")
1033            }
1034            GenericCommand::Relation {
1035                span: _,
1036                name,
1037                inputs,
1038            } => {
1039                write!(f, "(relation {name} ({}))", ListDisplay(inputs, " "))
1040            }
1041            GenericCommand::AddRuleset(_span, name) => {
1042                write!(f, "(ruleset {name})")
1043            }
1044            GenericCommand::UnstableCombinedRuleset(_span, name, others) => {
1045                write!(
1046                    f,
1047                    "(unstable-combined-ruleset {name} {})",
1048                    ListDisplay(others, " ")
1049                )
1050            }
1051            GenericCommand::Rule { rule } => rule.fmt(f),
1052            GenericCommand::RunSchedule(sched) => write!(f, "(run-schedule {sched})"),
1053            GenericCommand::PrintOverallStatistics(_span, file) => match file {
1054                Some(file) => write!(f, "(print-stats :file {file})"),
1055                None => write!(f, "(print-stats)"),
1056            },
1057            GenericCommand::Check(_ann, facts) => {
1058                write!(f, "(check {})", ListDisplay(facts, "\n"))
1059            }
1060            GenericCommand::Prove(_span, facts) => {
1061                if facts.is_empty() {
1062                    write!(f, "(prove)")
1063                } else {
1064                    write!(f, "(prove {})", ListDisplay(facts, " "))
1065                }
1066            }
1067            GenericCommand::ProveExists(_span, constructor) => {
1068                write!(f, "(prove-exists {constructor})")
1069            }
1070            GenericCommand::Push(n) => write!(f, "(push {n})"),
1071            GenericCommand::Pop(_span, n) => write!(f, "(pop {n})"),
1072            GenericCommand::PrintFunction(_span, name, n, file, mode) => {
1073                write!(f, "(print-function {name}")?;
1074                if let Some(n) = n {
1075                    write!(f, " {n}")?;
1076                }
1077                if let Some(file) = file {
1078                    write!(f, " :file {file:?}")?;
1079                }
1080                match mode {
1081                    PrintFunctionMode::Default => {}
1082                    PrintFunctionMode::CSV => write!(f, " :mode csv")?,
1083                }
1084                write!(f, ")")
1085            }
1086            GenericCommand::PrintSize(_span, name) => {
1087                write!(f, "(print-size {})", ListDisplay(name, " "))
1088            }
1089            GenericCommand::Input {
1090                span: _,
1091                name,
1092                file,
1093            } => {
1094                write!(f, "(input {name} {file:?})")
1095            }
1096            GenericCommand::Output {
1097                span: _,
1098                file,
1099                exprs,
1100            } => write!(f, "(output {file:?} {})", ListDisplay(exprs, " ")),
1101            GenericCommand::Fail(_span, cmd) => write!(f, "(fail {cmd})"),
1102            GenericCommand::Include(_span, file) => write!(f, "(include {file:?})"),
1103            GenericCommand::Datatypes { span: _, datatypes } => {
1104                let datatypes: Vec<_> = datatypes
1105                    .iter()
1106                    .map(|(_, name, variants)| match variants {
1107                        Subdatatypes::Variants(variants) => {
1108                            format!("({name} {})", ListDisplay(variants, " "))
1109                        }
1110                        Subdatatypes::NewSort(head, args) => {
1111                            format!("(sort {name} ({head} {}))", ListDisplay(args, " "))
1112                        }
1113                    })
1114                    .collect();
1115                write!(f, "(datatype* {})", ListDisplay(datatypes, " "))
1116            }
1117            GenericCommand::UserDefined(_span, name, exprs) => {
1118                write!(f, "({name} {})", ListDisplay(exprs, " "))
1119            }
1120        }
1121    }
1122}
1123
1124#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1125pub struct IdentSort {
1126    pub ident: String,
1127    pub sort: String,
1128}
1129
1130impl Display for IdentSort {
1131    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
1132        write!(f, "({} {})", self.ident, self.sort)
1133    }
1134}
1135
1136pub type RunConfig = GenericRunConfig<String, String>;
1137pub(crate) type ResolvedRunConfig = GenericRunConfig<ResolvedCall, ResolvedVar>;
1138
1139#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1140pub struct GenericRunConfig<Head, Leaf> {
1141    pub ruleset: String,
1142    pub until: Option<Vec<GenericFact<Head, Leaf>>>,
1143}
1144
1145impl<Head, Leaf> GenericRunConfig<Head, Leaf>
1146where
1147    Head: Clone + Display,
1148    Leaf: Clone + PartialEq + Eq + Display + Hash,
1149{
1150    pub fn visit_exprs(
1151        self,
1152        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
1153    ) -> Self {
1154        Self {
1155            ruleset: self.ruleset,
1156            until: self
1157                .until
1158                .map(|until| until.into_iter().map(|fact| fact.visit_exprs(f)).collect()),
1159        }
1160    }
1161
1162    /// Remaps every head and leaf symbol in the run configuration.
1163    pub fn map_symbols<Head2, Leaf2>(
1164        self,
1165        head: &mut impl FnMut(Head) -> Head2,
1166        leaf: &mut impl FnMut(Leaf) -> Leaf2,
1167    ) -> GenericRunConfig<Head2, Leaf2>
1168    where
1169        Head2: Clone + Display,
1170        Leaf2: Clone + PartialEq + Eq + Display + Hash,
1171    {
1172        GenericRunConfig {
1173            ruleset: self.ruleset,
1174            until: self.until.map(|facts| {
1175                facts
1176                    .into_iter()
1177                    .map(|fact| fact.map_symbols(head, leaf))
1178                    .collect()
1179            }),
1180        }
1181    }
1182
1183    /// Applies `fun` to string-valued symbols within the run configuration.
1184    pub fn map_string_symbols(
1185        self,
1186        fun: &mut impl FnMut(String) -> String,
1187    ) -> GenericRunConfig<Head, Leaf> {
1188        GenericRunConfig {
1189            ruleset: fun(self.ruleset),
1190            until: self.until,
1191        }
1192    }
1193
1194    pub fn make_unresolved(self) -> GenericRunConfig<String, String> {
1195        let mut map_head = |h: Head| h.to_string();
1196        let mut map_leaf = |l: Leaf| l.to_string();
1197        self.map_symbols(&mut map_head, &mut map_leaf)
1198    }
1199}
1200
1201impl<Head: Display, Leaf: Display> Display for GenericRunConfig<Head, Leaf>
1202where
1203    Head: Display,
1204    Leaf: Display,
1205{
1206    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
1207        write!(f, "(run")?;
1208        if !self.ruleset.is_empty() {
1209            write!(f, " {}", self.ruleset)?;
1210        }
1211        if let Some(until) = &self.until {
1212            write!(f, " :until {}", ListDisplay(until, " "))?;
1213        }
1214        write!(f, ")")
1215    }
1216}
1217
1218pub type FunctionDecl = GenericFunctionDecl<String, String>;
1219pub(crate) type ResolvedFunctionDecl = GenericFunctionDecl<ResolvedCall, ResolvedVar>;
1220
1221#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1222pub enum FunctionSubtype {
1223    Constructor,
1224    // TODO rename to [`Function`] to match surface syntax and terminology.
1225    Custom,
1226}
1227
1228impl Display for FunctionSubtype {
1229    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
1230        match self {
1231            FunctionSubtype::Constructor => write!(f, "constructor"),
1232            FunctionSubtype::Custom => write!(f, "function"),
1233        }
1234    }
1235}
1236
1237/// Represents the declaration of a function
1238/// directly parsed from source syntax.
1239#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1240pub struct GenericFunctionDecl<Head, Leaf>
1241where
1242    Head: Clone + Display,
1243    Leaf: Clone + PartialEq + Eq + Display + Hash,
1244{
1245    pub name: String,
1246    pub subtype: FunctionSubtype,
1247    /// Untyped schema
1248    pub schema: Schema,
1249    /// Resolved schema after typechecking is stored here, otherwise "".
1250    pub resolved_schema: Head,
1251    pub merge: Option<GenericExpr<Head, Leaf>>,
1252    pub cost: Option<DefaultCost>,
1253    pub unextractable: bool,
1254    /// Hidden functions are excluded from print-size output.
1255    /// Used for internal tables generated by proof production.
1256    pub internal_hidden: bool,
1257    /// Globals are desugared to functions, with this flag set to true.
1258    /// This is used by visualization to handle globals differently.
1259    pub internal_let: bool,
1260    pub span: Span,
1261    /// For view tables in proof encoding: the constructor to use for building
1262    /// terms from the first n-1 children during extraction.
1263    pub term_constructor: Option<String>,
1264}
1265
1266#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1267pub struct Variant {
1268    pub span: Span,
1269    pub name: String,
1270    pub types: Vec<String>,
1271    pub cost: Option<DefaultCost>,
1272    pub unextractable: bool,
1273}
1274
1275impl Display for Variant {
1276    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
1277        write!(f, "({}", self.name)?;
1278        if !self.types.is_empty() {
1279            write!(f, " {}", ListDisplay(&self.types, " "))?;
1280        }
1281        if let Some(cost) = self.cost {
1282            write!(f, " :cost {cost}")?;
1283        }
1284        write!(f, ")")
1285    }
1286}
1287
1288#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1289pub struct Schema {
1290    pub input: Vec<String>,
1291    pub output: String,
1292}
1293
1294impl Display for Schema {
1295    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
1296        write!(f, "({}) {}", ListDisplay(&self.input, " "), self.output)
1297    }
1298}
1299
1300impl Schema {
1301    pub fn new(input: Vec<String>, output: String) -> Self {
1302        Self { input, output }
1303    }
1304}
1305
1306impl FunctionDecl {
1307    /// Constructs a `function`
1308    pub fn function(
1309        span: Span,
1310        name: String,
1311        schema: Schema,
1312        merge: Option<GenericExpr<String, String>>,
1313    ) -> Self {
1314        Self {
1315            name,
1316            subtype: FunctionSubtype::Custom,
1317            schema,
1318            resolved_schema: String::new(),
1319            merge,
1320            cost: None,
1321            unextractable: true,
1322            internal_hidden: false,
1323            internal_let: false,
1324            span,
1325            term_constructor: None,
1326        }
1327    }
1328
1329    /// Constructs a `constructor`
1330    pub fn constructor(
1331        span: Span,
1332        name: String,
1333        schema: Schema,
1334        cost: Option<DefaultCost>,
1335        unextractable: bool,
1336        hidden: bool,
1337    ) -> Self {
1338        Self {
1339            name,
1340            subtype: FunctionSubtype::Constructor,
1341            resolved_schema: String::new(),
1342            schema,
1343            merge: None,
1344            cost,
1345            unextractable,
1346            internal_hidden: hidden,
1347            internal_let: false,
1348            span,
1349            term_constructor: None,
1350        }
1351    }
1352}
1353
1354impl<Head, Leaf> GenericFunctionDecl<Head, Leaf>
1355where
1356    Head: Clone + Display,
1357    Leaf: Clone + PartialEq + Eq + Display + Hash,
1358{
1359    pub fn visit_exprs(
1360        self,
1361        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
1362    ) -> GenericFunctionDecl<Head, Leaf> {
1363        GenericFunctionDecl {
1364            name: self.name,
1365            subtype: self.subtype,
1366            schema: self.schema,
1367            resolved_schema: self.resolved_schema,
1368            merge: self.merge.map(|expr| expr.visit_exprs(f)),
1369            cost: self.cost,
1370            unextractable: self.unextractable,
1371            internal_hidden: self.internal_hidden,
1372            internal_let: self.internal_let,
1373            span: self.span,
1374            term_constructor: self.term_constructor,
1375        }
1376    }
1377}
1378
1379pub type Fact = GenericFact<String, String>;
1380pub type ResolvedFact = GenericFact<ResolvedCall, ResolvedVar>;
1381pub(crate) type MappedFact<Head, Leaf> = GenericFact<CorrespondingVar<Head, Leaf>, Leaf>;
1382
1383pub struct Facts<Head, Leaf>(pub Vec<GenericFact<Head, Leaf>>);
1384
1385impl<Head, Leaf> Facts<Head, Leaf>
1386where
1387    Head: Clone + Display,
1388    Leaf: Clone + PartialEq + Eq + Display + Hash,
1389{
1390    /// Flattens a list of facts into a Query.
1391    /// For typechecking, we need the correspondence between the original ast
1392    /// and the flattened one, so that we can annotate the original with types.
1393    /// That's why this function produces a corresponding list of facts, annotated with
1394    /// the variable names in the flattened Query.
1395    /// (Typechecking preserves the original AST this way,
1396    /// and allows terms and proof instrumentation to do the same).
1397    pub(crate) fn to_query(
1398        &self,
1399        typeinfo: &TypeInfo,
1400        fresh_gen: &mut impl FreshGen<Head, Leaf>,
1401    ) -> (Query<HeadOrEq<Head>, Leaf>, Vec<MappedFact<Head, Leaf>>) {
1402        let mut atoms = vec![];
1403        let mut new_body = vec![];
1404
1405        for fact in self.0.iter() {
1406            match fact {
1407                GenericFact::Eq(span, e1, e2) => {
1408                    let mut to_equate = vec![];
1409                    let mut process = |expr: &GenericExpr<Head, Leaf>| {
1410                        let (child_atoms, expr) = expr.to_query(typeinfo, fresh_gen);
1411                        atoms.extend(child_atoms);
1412                        to_equate.push(expr.get_corresponding_var_or_lit(typeinfo));
1413                        expr
1414                    };
1415                    let e1 = process(e1);
1416                    let e2 = process(e2);
1417                    atoms.push(GenericAtom {
1418                        span: span.clone(),
1419                        head: HeadOrEq::Eq,
1420                        args: to_equate,
1421                    });
1422                    new_body.push(GenericFact::Eq(span.clone(), e1, e2));
1423                }
1424                GenericFact::Fact(expr) => {
1425                    let (child_atoms, expr) = expr.to_query(typeinfo, fresh_gen);
1426                    atoms.extend(child_atoms);
1427                    new_body.push(GenericFact::Fact(expr));
1428                }
1429            }
1430        }
1431        (Query { atoms }, new_body)
1432    }
1433}
1434
1435#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1436pub struct CorrespondingVar<Head, Leaf>
1437where
1438    Head: Clone + Display,
1439    Leaf: Clone + PartialEq + Eq + Display + Hash,
1440{
1441    pub head: Head,
1442    pub to: Leaf,
1443}
1444
1445impl<Head, Leaf> CorrespondingVar<Head, Leaf>
1446where
1447    Head: Clone + Display,
1448    Leaf: Clone + PartialEq + Eq + Display + Hash,
1449{
1450    pub fn new(head: Head, leaf: Leaf) -> Self {
1451        Self { head, to: leaf }
1452    }
1453}
1454
1455impl<Head, Leaf> Display for CorrespondingVar<Head, Leaf>
1456where
1457    Head: Clone + Display,
1458    Leaf: Clone + PartialEq + Eq + Display + Hash,
1459{
1460    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
1461        write!(f, "{} -> {}", self.head, self.to)
1462    }
1463}
1464pub type Action = GenericAction<String, String>;
1465pub(crate) type MappedAction = GenericAction<CorrespondingVar<String, String>, String>;
1466pub(crate) type ResolvedAction = GenericAction<ResolvedCall, ResolvedVar>;
1467
1468pub type Actions = GenericActions<String, String>;
1469pub(crate) type ResolvedActions = GenericActions<ResolvedCall, ResolvedVar>;
1470pub(crate) type MappedActions<Head, Leaf> = GenericActions<CorrespondingVar<Head, Leaf>, Leaf>;
1471
1472pub type Rule = GenericRule<String, String>;
1473pub(crate) type ResolvedRule = GenericRule<ResolvedCall, ResolvedVar>;
1474
1475pub type Rewrite = GenericRewrite<String, String>;
1476
1477#[derive(Clone, Debug)]
1478pub struct GenericRewrite<Head, Leaf> {
1479    pub span: Span,
1480    pub lhs: GenericExpr<Head, Leaf>,
1481    pub rhs: GenericExpr<Head, Leaf>,
1482    pub conditions: Vec<GenericFact<Head, Leaf>>,
1483    pub name: String,
1484}
1485
1486impl<Head, Leaf> GenericRewrite<Head, Leaf>
1487where
1488    Head: Clone + Display,
1489    Leaf: Clone + PartialEq + Eq + Display + Hash,
1490{
1491    /// Remaps every head and leaf symbol in the rewrite, including the optional conditions.
1492    pub fn map_symbols<Head2, Leaf2>(
1493        self,
1494        head: &mut impl FnMut(Head) -> Head2,
1495        leaf: &mut impl FnMut(Leaf) -> Leaf2,
1496    ) -> GenericRewrite<Head2, Leaf2>
1497    where
1498        Head2: Clone + Display,
1499        Leaf2: Clone + PartialEq + Eq + Display + Hash,
1500    {
1501        GenericRewrite {
1502            span: self.span,
1503            lhs: self.lhs.map_symbols(head, leaf),
1504            rhs: self.rhs.map_symbols(head, leaf),
1505            conditions: self
1506                .conditions
1507                .into_iter()
1508                .map(|fact| fact.map_symbols(head, leaf))
1509                .collect(),
1510            name: self.name,
1511        }
1512    }
1513
1514    pub fn make_unresolved(self) -> GenericRewrite<String, String> {
1515        let mut map_head = |h: Head| h.to_string();
1516        let mut map_leaf = |l: Leaf| l.to_string();
1517        self.map_symbols(&mut map_head, &mut map_leaf)
1518    }
1519}
1520
1521impl<Head: Display, Leaf: Display> GenericRewrite<Head, Leaf> {
1522    /// Converts the rewrite into an s-expression.
1523    pub fn fmt_with_ruleset(
1524        &self,
1525        f: &mut Formatter,
1526        ruleset: &str,
1527        is_bidirectional: bool,
1528        subsume: bool,
1529    ) -> std::fmt::Result {
1530        let direction = if is_bidirectional {
1531            "birewrite"
1532        } else {
1533            "rewrite"
1534        };
1535        write!(f, "({direction} {} {}", self.lhs, self.rhs)?;
1536        if subsume {
1537            write!(f, " :subsume")?;
1538        }
1539        if !self.conditions.is_empty() {
1540            write!(f, " :when ({})", ListDisplay(&self.conditions, " "))?;
1541        }
1542        if !ruleset.is_empty() {
1543            write!(f, " :ruleset {ruleset}")?;
1544        }
1545        write!(f, ")")
1546    }
1547}
1548
1549pub(crate) trait MappedExprExt<Head, Leaf>
1550where
1551    Head: Clone + Display,
1552    Leaf: Clone + PartialEq + Eq + Display + Hash,
1553{
1554    fn get_corresponding_var_or_lit(&self, typeinfo: &TypeInfo) -> GenericAtomTerm<Leaf>;
1555}
1556
1557impl<Head, Leaf> MappedExprExt<Head, Leaf> for MappedExpr<Head, Leaf>
1558where
1559    Head: Clone + Display,
1560    Leaf: Clone + PartialEq + Eq + Display + Hash,
1561{
1562    fn get_corresponding_var_or_lit(&self, typeinfo: &TypeInfo) -> GenericAtomTerm<Leaf> {
1563        // Note: need typeinfo to resolve whether a symbol is a global or not
1564        // This is error-prone and the complexities can be avoided by treating globals
1565        // as nullary functions.
1566        match self {
1567            GenericExpr::Var(span, v) => {
1568                if typeinfo.is_global(&v.to_string()) {
1569                    GenericAtomTerm::Global(span.clone(), v.clone())
1570                } else {
1571                    GenericAtomTerm::Var(span.clone(), v.clone())
1572                }
1573            }
1574            GenericExpr::Lit(span, lit) => GenericAtomTerm::Literal(span.clone(), lit.clone()),
1575            GenericExpr::Call(span, head, _) => GenericAtomTerm::Var(span.clone(), head.to.clone()),
1576        }
1577    }
1578}
1579
1580impl<Head, Leaf> GenericCommand<Head, Leaf>
1581where
1582    Head: Clone + Display,
1583    Leaf: Clone + PartialEq + Eq + Display + Hash,
1584{
1585    /// The current egglog AST has strings even when resolved.
1586    /// We map over those strings with this function, used by sanitize_internal_symbols.
1587    pub fn map_string_symbols(
1588        self,
1589        fun: &mut impl FnMut(String) -> String,
1590    ) -> GenericCommand<Head, Leaf>
1591    where
1592        Head: Clone + Display,
1593        Leaf: Clone + PartialEq + Eq + Display + Hash,
1594    {
1595        match self {
1596            GenericCommand::Sort {
1597                span,
1598                name,
1599                presort_and_args,
1600                uf,
1601                proof_func,
1602                unionable,
1603            } => GenericCommand::Sort {
1604                span,
1605                name: fun(name),
1606                presort_and_args,
1607                uf: uf.map(&mut *fun),
1608                proof_func: proof_func.map(&mut *fun),
1609                unionable,
1610            },
1611            GenericCommand::Datatype {
1612                span,
1613                name,
1614                variants,
1615            } => GenericCommand::Datatype {
1616                span,
1617                name: fun(name),
1618                variants: variants
1619                    .into_iter()
1620                    .map(|variant| Variant {
1621                        span: variant.span,
1622                        name: fun(variant.name),
1623                        types: variant.types.into_iter().map(&mut *fun).collect(),
1624                        cost: variant.cost,
1625                        unextractable: variant.unextractable,
1626                    })
1627                    .collect(),
1628            },
1629            GenericCommand::Datatypes { span, datatypes } => GenericCommand::Datatypes {
1630                span,
1631                datatypes: datatypes
1632                    .into_iter()
1633                    .map(|(span, name, variants)| {
1634                        let new_name = fun(name);
1635                        let new_variants = match variants {
1636                            Subdatatypes::Variants(variants) => Subdatatypes::Variants(
1637                                variants
1638                                    .into_iter()
1639                                    .map(|variant| Variant {
1640                                        span: variant.span,
1641                                        name: fun(variant.name),
1642                                        // Redundant closure helps with type inference here
1643                                        #[allow(clippy::redundant_closure)]
1644                                        types: variant
1645                                            .types
1646                                            .into_iter()
1647                                            .map(|ty| fun(ty))
1648                                            .collect(),
1649                                        cost: variant.cost,
1650                                        unextractable: variant.unextractable,
1651                                    })
1652                                    .collect(),
1653                            ),
1654                            Subdatatypes::NewSort(head, args) => {
1655                                Subdatatypes::NewSort(fun(head), args)
1656                            }
1657                        };
1658                        (span, new_name, new_variants)
1659                    })
1660                    .collect(),
1661            },
1662            GenericCommand::Constructor {
1663                span,
1664                name,
1665                schema,
1666                cost,
1667                unextractable,
1668                hidden,
1669                let_binding,
1670                term_constructor,
1671            } => GenericCommand::Constructor {
1672                span,
1673                name: fun(name),
1674                schema: Schema {
1675                    input: schema.input.into_iter().map(&mut *fun).collect(),
1676                    output: fun(schema.output),
1677                },
1678                cost,
1679                unextractable,
1680                hidden,
1681                let_binding,
1682                term_constructor: term_constructor.map(&mut *fun),
1683            },
1684            GenericCommand::Relation { span, name, inputs } => GenericCommand::Relation {
1685                span,
1686                name: fun(name),
1687                inputs: inputs.into_iter().map(&mut *fun).collect(),
1688            },
1689            GenericCommand::Function {
1690                span,
1691                name,
1692                schema,
1693                merge,
1694                hidden,
1695                let_binding,
1696                term_constructor,
1697                unextractable,
1698            } => GenericCommand::Function {
1699                span,
1700                name: fun(name),
1701                schema: Schema {
1702                    input: schema.input.into_iter().map(&mut *fun).collect(),
1703                    output: fun(schema.output),
1704                },
1705                merge,
1706                hidden,
1707                let_binding,
1708                term_constructor: term_constructor.map(&mut *fun),
1709                unextractable,
1710            },
1711            GenericCommand::AddRuleset(span, name) => GenericCommand::AddRuleset(span, fun(name)),
1712            GenericCommand::UnstableCombinedRuleset(span, name, others) => {
1713                GenericCommand::UnstableCombinedRuleset(
1714                    span,
1715                    fun(name),
1716                    others.into_iter().map(&mut *fun).collect(),
1717                )
1718            }
1719            GenericCommand::Rule { rule } => {
1720                let rule = GenericRule {
1721                    span: rule.span,
1722                    name: fun(rule.name),
1723                    ruleset: fun(rule.ruleset),
1724                    head: rule.head,
1725                    body: rule.body,
1726                };
1727                GenericCommand::Rule { rule }
1728            }
1729            GenericCommand::Rewrite(name, rewrite, subsume) => {
1730                GenericCommand::Rewrite(fun(name), rewrite, subsume)
1731            }
1732            GenericCommand::BiRewrite(name, rewrite) => {
1733                GenericCommand::BiRewrite(fun(name), rewrite)
1734            }
1735            GenericCommand::Action(action) => GenericCommand::Action(action),
1736            GenericCommand::Extract(span, expr, variants) => {
1737                GenericCommand::Extract(span, expr, variants)
1738            }
1739            GenericCommand::RunSchedule(schedule) => {
1740                GenericCommand::RunSchedule(schedule.map_string_symbols(fun))
1741            }
1742            GenericCommand::PrintOverallStatistics(span, file) => {
1743                GenericCommand::PrintOverallStatistics(span, file)
1744            }
1745            GenericCommand::Check(span, facts) => GenericCommand::Check(span, facts),
1746            GenericCommand::Prove(span, facts) => GenericCommand::Prove(span, facts),
1747            GenericCommand::ProveExists(span, constructor) => {
1748                GenericCommand::ProveExists(span, constructor)
1749            }
1750            GenericCommand::PrintFunction(span, name, n, file, mode) => {
1751                GenericCommand::PrintFunction(span, fun(name), n, file, mode)
1752            }
1753            GenericCommand::PrintSize(span, name) => GenericCommand::PrintSize(span, name.map(fun)),
1754            GenericCommand::Input { span, name, file } => GenericCommand::Input {
1755                span,
1756                name: fun(name),
1757                file,
1758            },
1759            GenericCommand::Output { span, file, exprs } => {
1760                GenericCommand::Output { span, file, exprs }
1761            }
1762            GenericCommand::Push(n) => GenericCommand::Push(n),
1763            GenericCommand::Pop(span, n) => GenericCommand::Pop(span, n),
1764            GenericCommand::Fail(span, cmd) => {
1765                GenericCommand::Fail(span, Box::new(cmd.map_string_symbols(fun)))
1766            }
1767            GenericCommand::Include(span, file) => GenericCommand::Include(span, file),
1768            GenericCommand::UserDefined(span, name, exprs) => {
1769                GenericCommand::UserDefined(span, name, exprs)
1770            }
1771        }
1772    }
1773
1774    /// Applies `f` to all expressions in the command, bottom-up.
1775    pub fn visit_exprs(
1776        self,
1777        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
1778    ) -> Self {
1779        match self {
1780            GenericCommand::Function {
1781                span,
1782                name,
1783                schema,
1784                merge,
1785                hidden,
1786                let_binding,
1787                term_constructor,
1788                unextractable,
1789            } => GenericCommand::Function {
1790                span,
1791                name,
1792                schema,
1793                merge: merge.map(|e| e.visit_exprs(f)),
1794                hidden,
1795                let_binding,
1796                term_constructor,
1797                unextractable,
1798            },
1799            GenericCommand::Rule { rule } => GenericCommand::Rule {
1800                rule: rule.visit_exprs(f),
1801            },
1802            GenericCommand::Rewrite(name, rewrite, subsume) => GenericCommand::Rewrite(
1803                name,
1804                GenericRewrite {
1805                    span: rewrite.span,
1806                    lhs: rewrite.lhs.visit_exprs(f),
1807                    rhs: rewrite.rhs.visit_exprs(f),
1808                    conditions: rewrite
1809                        .conditions
1810                        .into_iter()
1811                        .map(|fact| fact.visit_exprs(f))
1812                        .collect(),
1813                    name: rewrite.name,
1814                },
1815                subsume,
1816            ),
1817            GenericCommand::BiRewrite(name, rewrite) => GenericCommand::BiRewrite(
1818                name,
1819                GenericRewrite {
1820                    span: rewrite.span,
1821                    lhs: rewrite.lhs.visit_exprs(f),
1822                    rhs: rewrite.rhs.visit_exprs(f),
1823                    conditions: rewrite
1824                        .conditions
1825                        .into_iter()
1826                        .map(|fact| fact.visit_exprs(f))
1827                        .collect(),
1828                    name: rewrite.name,
1829                },
1830            ),
1831            GenericCommand::Action(action) => GenericCommand::Action(action.visit_exprs(f)),
1832            GenericCommand::Extract(span, expr1, expr2) => {
1833                GenericCommand::Extract(span, expr1.visit_exprs(f), expr2.visit_exprs(f))
1834            }
1835            GenericCommand::Check(span, facts) => GenericCommand::Check(
1836                span,
1837                facts.into_iter().map(|fact| fact.visit_exprs(f)).collect(),
1838            ),
1839            GenericCommand::Prove(span, facts) => GenericCommand::Prove(
1840                span,
1841                facts.into_iter().map(|fact| fact.visit_exprs(f)).collect(),
1842            ),
1843            GenericCommand::Output { span, file, exprs } => GenericCommand::Output {
1844                span,
1845                file,
1846                exprs: exprs.into_iter().map(|e| e.visit_exprs(f)).collect(),
1847            },
1848            GenericCommand::RunSchedule(schedule) => {
1849                GenericCommand::RunSchedule(schedule.visit_exprs(f))
1850            }
1851            GenericCommand::Fail(span, cmd) => {
1852                GenericCommand::Fail(span, Box::new(cmd.visit_exprs(f)))
1853            }
1854            // All other commands don't contain expressions
1855            cmd => cmd,
1856        }
1857    }
1858
1859    /// Remaps every head and leaf symbol contained in the command.
1860    pub fn map_symbols<Head2, Leaf2>(
1861        self,
1862        head: &mut impl FnMut(Head) -> Head2,
1863        leaf: &mut impl FnMut(Leaf) -> Leaf2,
1864    ) -> GenericCommand<Head2, Leaf2>
1865    where
1866        Head2: Clone + Display,
1867        Leaf2: Clone + PartialEq + Eq + Display + Hash,
1868    {
1869        match self {
1870            GenericCommand::Sort {
1871                span,
1872                name,
1873                presort_and_args,
1874                uf,
1875                proof_func,
1876                unionable,
1877            } => GenericCommand::Sort {
1878                span,
1879                name,
1880                presort_and_args,
1881                uf,
1882                proof_func,
1883                unionable,
1884            },
1885            GenericCommand::Datatype {
1886                span,
1887                name,
1888                variants,
1889            } => GenericCommand::Datatype {
1890                span,
1891                name,
1892                variants,
1893            },
1894            GenericCommand::Datatypes { span, datatypes } => {
1895                GenericCommand::Datatypes { span, datatypes }
1896            }
1897            GenericCommand::Constructor {
1898                span,
1899                name,
1900                schema,
1901                cost,
1902                unextractable,
1903                hidden,
1904                let_binding,
1905                term_constructor,
1906            } => GenericCommand::Constructor {
1907                span,
1908                name,
1909                schema,
1910                cost,
1911                unextractable,
1912                hidden,
1913                let_binding,
1914                term_constructor,
1915            },
1916            GenericCommand::Relation { span, name, inputs } => {
1917                GenericCommand::Relation { span, name, inputs }
1918            }
1919            GenericCommand::Function {
1920                span,
1921                name,
1922                schema,
1923                merge,
1924                hidden,
1925                let_binding,
1926                term_constructor,
1927                unextractable,
1928            } => GenericCommand::Function {
1929                span,
1930                name,
1931                schema,
1932                merge: merge.map(|expr| expr.map_symbols(head, leaf)),
1933                hidden,
1934                let_binding,
1935                term_constructor,
1936                unextractable,
1937            },
1938            GenericCommand::AddRuleset(span, name) => GenericCommand::AddRuleset(span, name),
1939            GenericCommand::UnstableCombinedRuleset(span, name, others) => {
1940                GenericCommand::UnstableCombinedRuleset(span, name, others)
1941            }
1942            GenericCommand::Rule { rule } => GenericCommand::Rule {
1943                rule: rule.map_symbols(head, leaf),
1944            },
1945            GenericCommand::Rewrite(name, rewrite, subsume) => {
1946                GenericCommand::Rewrite(name, rewrite.map_symbols(head, leaf), subsume)
1947            }
1948            GenericCommand::BiRewrite(name, rewrite) => {
1949                GenericCommand::BiRewrite(name, rewrite.map_symbols(head, leaf))
1950            }
1951            GenericCommand::Action(action) => {
1952                GenericCommand::Action(action.map_symbols(head, leaf))
1953            }
1954            GenericCommand::Extract(span, expr, variants) => GenericCommand::Extract(
1955                span,
1956                expr.map_symbols(head, leaf),
1957                variants.map_symbols(head, leaf),
1958            ),
1959            GenericCommand::RunSchedule(schedule) => {
1960                GenericCommand::RunSchedule(schedule.map_symbols(head, leaf))
1961            }
1962            GenericCommand::PrintOverallStatistics(span, file) => {
1963                GenericCommand::PrintOverallStatistics(span, file)
1964            }
1965            GenericCommand::Check(span, facts) => GenericCommand::Check(
1966                span,
1967                facts
1968                    .into_iter()
1969                    .map(|fact| fact.map_symbols(head, leaf))
1970                    .collect(),
1971            ),
1972            GenericCommand::Prove(span, facts) => GenericCommand::Prove(
1973                span,
1974                facts
1975                    .into_iter()
1976                    .map(|fact| fact.map_symbols(head, leaf))
1977                    .collect(),
1978            ),
1979            GenericCommand::ProveExists(span, constructor) => {
1980                GenericCommand::ProveExists(span, head(constructor))
1981            }
1982            GenericCommand::PrintFunction(span, name, n, file, mode) => {
1983                GenericCommand::PrintFunction(span, name, n, file, mode)
1984            }
1985            GenericCommand::PrintSize(span, name) => GenericCommand::PrintSize(span, name),
1986            GenericCommand::Input { span, name, file } => {
1987                GenericCommand::Input { span, name, file }
1988            }
1989            GenericCommand::Output { span, file, exprs } => GenericCommand::Output {
1990                span,
1991                file,
1992                exprs: exprs
1993                    .into_iter()
1994                    .map(|expr| expr.map_symbols(head, leaf))
1995                    .collect(),
1996            },
1997            GenericCommand::Push(n) => GenericCommand::Push(n),
1998            GenericCommand::Pop(span, n) => GenericCommand::Pop(span, n),
1999            GenericCommand::Fail(span, cmd) => {
2000                GenericCommand::Fail(span, Box::new(cmd.map_symbols(head, leaf)))
2001            }
2002            GenericCommand::Include(span, file) => GenericCommand::Include(span, file),
2003            GenericCommand::UserDefined(span, name, exprs) => {
2004                GenericCommand::UserDefined(span, name, exprs)
2005            }
2006        }
2007    }
2008
2009    /// Makes the command unresolved by converting all Head and Leaf types to String.
2010    pub fn make_unresolved(self) -> GenericCommand<String, String> {
2011        let mut map_head = |h: Head| h.to_string();
2012        let mut map_leaf = |l: Leaf| l.to_string();
2013        self.map_symbols(&mut map_head, &mut map_leaf)
2014    }
2015
2016    pub fn visit_actions(
2017        self,
2018        f: &mut impl FnMut(GenericAction<Head, Leaf>) -> GenericAction<Head, Leaf>,
2019    ) -> Self {
2020        match self {
2021            GenericCommand::Rule { rule } => GenericCommand::Rule {
2022                rule: rule.visit_actions(f),
2023            },
2024            GenericCommand::Action(action) => GenericCommand::Action(f(action)),
2025            GenericCommand::Fail(span, cmd) => {
2026                GenericCommand::Fail(span, Box::new(cmd.visit_actions(f)))
2027            }
2028            other => other,
2029        }
2030    }
2031}
2032
2033/// Computes the maximum number of underscores in any symbol name in the program.
2034pub fn get_max_underscores(program: &[GenericCommand<String, String>]) -> usize {
2035    // now count the max number of underscores in any name
2036    let mut max_underscores = 0;
2037    let mut max_underscores2 = 0;
2038    for cmd in program {
2039        cmd.clone().map_symbols(
2040            &mut |h: String| {
2041                let count = h.matches(INTERNAL_SYMBOL_PREFIX).count();
2042                if count > max_underscores {
2043                    max_underscores = count;
2044                }
2045                h
2046            },
2047            &mut |l: String| {
2048                let count = l.matches(INTERNAL_SYMBOL_PREFIX).count();
2049                if count > max_underscores2 {
2050                    max_underscores2 = count;
2051                }
2052                l
2053            },
2054        );
2055        cmd.clone().map_string_symbols(&mut |s: String| {
2056            let count = s.matches(INTERNAL_SYMBOL_PREFIX).count();
2057            if count > max_underscores {
2058                max_underscores = count;
2059            }
2060            s
2061        });
2062    }
2063    max(max_underscores, max_underscores2)
2064}
2065
2066/// Replaces all identifiers containing the internal symbol prefix with the given replacement string.
2067pub fn replace_internal_symbol_with(
2068    program: &[GenericCommand<String, String>],
2069    replacement: &str,
2070) -> Vec<GenericCommand<String, String>> {
2071    program
2072        .iter()
2073        .map(|cmd| {
2074            let cmd = cmd.clone().map_symbols(
2075                &mut |h: String| h.replace(INTERNAL_SYMBOL_PREFIX, replacement),
2076                &mut |l: String| l.replace(INTERNAL_SYMBOL_PREFIX, replacement),
2077            );
2078            cmd.map_string_symbols(&mut |s: String| s.replace(INTERNAL_SYMBOL_PREFIX, replacement))
2079        })
2080        .collect()
2081}
2082
2083/// Sanitizes internal names so they do not contain any internal characters.
2084/// This enables printing desugared egglog in a way that can be re-parsed.
2085pub fn sanitize_internal_names<Head, Leaf>(
2086    program: &[GenericCommand<Head, Leaf>],
2087) -> Vec<GenericCommand<String, String>>
2088where
2089    Head: Clone + Display,
2090    Leaf: Clone + PartialEq + Eq + Display + Hash,
2091{
2092    // first convert to unresolved
2093    let unresolved = program
2094        .iter()
2095        .map(|cmd| cmd.clone().make_unresolved())
2096        .collect::<Vec<_>>();
2097    // get the maximum number of underscores currently present, that way we can soundly add more than that to get fresh symbols
2098    let max_underscores = get_max_underscores(&unresolved);
2099    let replacement = "_".repeat(max_underscores + 1);
2100    // replace occurances of the internal symbol with replacement
2101    replace_internal_symbol_with(&unresolved, &replacement)
2102}