egglog/ast/
mod.rs

1pub mod check_shadowing;
2pub mod desugar;
3mod expr;
4mod parse;
5pub mod remove_globals;
6
7use crate::core::{
8    GenericAtom, GenericAtomTerm, GenericExprExt, HeadOrEq, Query, ResolvedCall, ResolvedCoreRule,
9};
10use crate::util::sanitize_internal_name;
11use crate::*;
12pub use egglog_ast::generic_ast::{
13    Change, GenericAction, GenericActions, GenericExpr, GenericFact, GenericRule, Literal,
14};
15pub use egglog_ast::span::{RustSpan, Span};
16use egglog_ast::util::ListDisplay;
17pub use expr::*;
18pub use parse::*;
19
20#[derive(Clone, Debug)]
21/// The egglog internal representation of already compiled rules
22pub(crate) enum Ruleset {
23    /// Represents a ruleset with a set of rules.
24    Rules(IndexMap<String, (ResolvedCoreRule, egglog_bridge::RuleId)>),
25    /// A combined ruleset may contain other rulesets.
26    Combined(Vec<String>),
27}
28
29pub type NCommand = GenericNCommand<String, String>;
30/// [`ResolvedNCommand`] is another specialization of [`GenericNCommand`], which
31/// adds the type information to heads and leaves of commands.
32/// [`TypeInfo::typecheck_command`] turns an [`NCommand`] into a [`ResolvedNCommand`].
33pub(crate) type ResolvedNCommand = GenericNCommand<ResolvedCall, ResolvedVar>;
34
35/// A [`NCommand`] is a desugared [`Command`], where syntactic sugars
36/// like [`Command::Datatype`] and [`Command::Rewrite`]
37/// are eliminated.
38/// Most of the heavy lifting in egglog is done over [`NCommand`]s.
39///
40/// [`GenericNCommand`] is a generalization of [`NCommand`], like how [`GenericCommand`]
41/// is a generalization of [`Command`], allowing annotations over `Head` and `Leaf`.
42///
43/// TODO: The name "NCommand" used to denote normalized command, but this
44/// meaning is obsolete. A future PR should rename this type to something
45/// like "DCommand".
46#[derive(Debug, Clone, Eq, PartialEq, Hash)]
47pub enum GenericNCommand<Head, Leaf>
48where
49    Head: Clone + Display,
50    Leaf: Clone + PartialEq + Eq + Display + Hash,
51{
52    Sort(
53        Span,
54        String,
55        Option<(String, Vec<GenericExpr<String, String>>)>,
56    ),
57    Function(GenericFunctionDecl<Head, Leaf>),
58    AddRuleset(Span, String),
59    UnstableCombinedRuleset(Span, String, Vec<String>),
60    NormRule {
61        rule: GenericRule<Head, Leaf>,
62    },
63    CoreAction(GenericAction<Head, Leaf>),
64    Extract(Span, GenericExpr<Head, Leaf>, GenericExpr<Head, Leaf>),
65    RunSchedule(GenericSchedule<Head, Leaf>),
66    PrintOverallStatistics(Span, Option<String>),
67    Check(Span, Vec<GenericFact<Head, Leaf>>),
68    PrintFunction(
69        Span,
70        String,
71        Option<usize>,
72        Option<String>,
73        PrintFunctionMode,
74    ),
75    PrintSize(Span, Option<String>),
76    Output {
77        span: Span,
78        file: String,
79        exprs: Vec<GenericExpr<Head, Leaf>>,
80    },
81    Push(usize),
82    Pop(Span, usize),
83    Fail(Span, Box<GenericNCommand<Head, Leaf>>),
84    Input {
85        span: Span,
86        name: String,
87        file: String,
88    },
89    UserDefined(Span, String, Vec<Expr>),
90}
91
92impl<Head, Leaf> GenericNCommand<Head, Leaf>
93where
94    Head: Clone + Display,
95    Leaf: Clone + PartialEq + Eq + Display + Hash,
96{
97    pub fn to_command(&self) -> GenericCommand<Head, Leaf> {
98        match self {
99            GenericNCommand::Sort(span, name, params) => {
100                GenericCommand::Sort(span.clone(), name.clone(), params.clone())
101            }
102            GenericNCommand::Function(f) => match f.subtype {
103                FunctionSubtype::Constructor => GenericCommand::Constructor {
104                    span: f.span.clone(),
105                    name: f.name.clone(),
106                    schema: f.schema.clone(),
107                    cost: f.cost,
108                    unextractable: f.unextractable,
109                },
110                FunctionSubtype::Relation => GenericCommand::Relation {
111                    span: f.span.clone(),
112                    name: f.name.clone(),
113                    inputs: f.schema.input.clone(),
114                },
115                FunctionSubtype::Custom => GenericCommand::Function {
116                    span: f.span.clone(),
117                    schema: f.schema.clone(),
118                    name: f.name.clone(),
119                    merge: f.merge.clone(),
120                },
121            },
122            GenericNCommand::AddRuleset(span, name) => {
123                GenericCommand::AddRuleset(span.clone(), name.clone())
124            }
125            GenericNCommand::UnstableCombinedRuleset(span, name, others) => {
126                GenericCommand::UnstableCombinedRuleset(span.clone(), name.clone(), others.clone())
127            }
128            GenericNCommand::NormRule { rule } => GenericCommand::Rule { rule: rule.clone() },
129            GenericNCommand::RunSchedule(schedule) => GenericCommand::RunSchedule(schedule.clone()),
130            GenericNCommand::PrintOverallStatistics(span, file) => {
131                GenericCommand::PrintOverallStatistics(span.clone(), file.clone())
132            }
133            GenericNCommand::CoreAction(action) => GenericCommand::Action(action.clone()),
134            GenericNCommand::Extract(span, expr, variants) => {
135                GenericCommand::Extract(span.clone(), expr.clone(), variants.clone())
136            }
137            GenericNCommand::Check(span, facts) => {
138                GenericCommand::Check(span.clone(), facts.clone())
139            }
140            GenericNCommand::PrintFunction(span, name, n, file, mode) => {
141                GenericCommand::PrintFunction(span.clone(), name.clone(), *n, file.clone(), *mode)
142            }
143            GenericNCommand::PrintSize(span, name) => {
144                GenericCommand::PrintSize(span.clone(), name.clone())
145            }
146            GenericNCommand::Output { span, file, exprs } => GenericCommand::Output {
147                span: span.clone(),
148                file: file.to_string(),
149                exprs: exprs.clone(),
150            },
151            GenericNCommand::Push(n) => GenericCommand::Push(*n),
152            GenericNCommand::Pop(span, n) => GenericCommand::Pop(span.clone(), *n),
153            GenericNCommand::Fail(span, cmd) => {
154                GenericCommand::Fail(span.clone(), Box::new(cmd.to_command()))
155            }
156            GenericNCommand::Input { span, name, file } => GenericCommand::Input {
157                span: span.clone(),
158                name: name.clone(),
159                file: file.clone(),
160            },
161            GenericNCommand::UserDefined(span, name, exprs) => {
162                GenericCommand::UserDefined(span.clone(), name.clone(), exprs.clone())
163            }
164        }
165    }
166
167    pub fn visit_exprs(
168        self,
169        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
170    ) -> Self {
171        match self {
172            GenericNCommand::Sort(span, name, params) => GenericNCommand::Sort(span, name, params),
173            GenericNCommand::Function(func) => GenericNCommand::Function(func.visit_exprs(f)),
174            GenericNCommand::AddRuleset(span, name) => GenericNCommand::AddRuleset(span, name),
175            GenericNCommand::UnstableCombinedRuleset(span, name, rulesets) => {
176                GenericNCommand::UnstableCombinedRuleset(span, name, rulesets)
177            }
178            GenericNCommand::NormRule { rule } => GenericNCommand::NormRule {
179                rule: rule.visit_exprs(f),
180            },
181            GenericNCommand::RunSchedule(schedule) => {
182                GenericNCommand::RunSchedule(schedule.visit_exprs(f))
183            }
184            GenericNCommand::PrintOverallStatistics(span, file) => {
185                GenericNCommand::PrintOverallStatistics(span, file)
186            }
187            GenericNCommand::CoreAction(action) => {
188                GenericNCommand::CoreAction(action.visit_exprs(f))
189            }
190            GenericNCommand::Extract(span, expr, variants) => {
191                GenericNCommand::Extract(span, expr.visit_exprs(f), variants.visit_exprs(f))
192            }
193            GenericNCommand::Check(span, facts) => GenericNCommand::Check(
194                span,
195                facts.into_iter().map(|fact| fact.visit_exprs(f)).collect(),
196            ),
197            GenericNCommand::PrintFunction(span, name, n, file, mode) => {
198                GenericNCommand::PrintFunction(span, name, n, file, mode)
199            }
200            GenericNCommand::PrintSize(span, name) => GenericNCommand::PrintSize(span, name),
201            GenericNCommand::Output { span, file, exprs } => GenericNCommand::Output {
202                span,
203                file,
204                exprs: exprs.into_iter().map(f).collect(),
205            },
206            GenericNCommand::Push(n) => GenericNCommand::Push(n),
207            GenericNCommand::Pop(span, n) => GenericNCommand::Pop(span, n),
208            GenericNCommand::Fail(span, cmd) => {
209                GenericNCommand::Fail(span, Box::new(cmd.visit_exprs(f)))
210            }
211            GenericNCommand::Input { span, name, file } => {
212                GenericNCommand::Input { span, name, file }
213            }
214            GenericNCommand::UserDefined(span, name, exprs) => {
215                // We can't map `f` over UserDefined because UserDefined always assumes plain `Expr`s
216                GenericNCommand::UserDefined(span, name, exprs)
217            }
218        }
219    }
220}
221
222pub type Schedule = GenericSchedule<String, String>;
223pub(crate) type ResolvedSchedule = GenericSchedule<ResolvedCall, ResolvedVar>;
224
225#[derive(Debug, Clone, PartialEq, Eq, Hash)]
226pub enum GenericSchedule<Head, Leaf> {
227    Saturate(Span, Box<GenericSchedule<Head, Leaf>>),
228    Repeat(Span, usize, Box<GenericSchedule<Head, Leaf>>),
229    Run(Span, GenericRunConfig<Head, Leaf>),
230    Sequence(Span, Vec<GenericSchedule<Head, Leaf>>),
231}
232
233impl<Head, Leaf> GenericSchedule<Head, Leaf>
234where
235    Head: Clone + Display,
236    Leaf: Clone + PartialEq + Eq + Display + Hash,
237{
238    fn visit_exprs(
239        self,
240        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
241    ) -> Self {
242        match self {
243            GenericSchedule::Saturate(span, sched) => {
244                GenericSchedule::Saturate(span, Box::new(sched.visit_exprs(f)))
245            }
246            GenericSchedule::Repeat(span, size, sched) => {
247                GenericSchedule::Repeat(span, size, Box::new(sched.visit_exprs(f)))
248            }
249            GenericSchedule::Run(span, config) => GenericSchedule::Run(span, config.visit_exprs(f)),
250            GenericSchedule::Sequence(span, scheds) => GenericSchedule::Sequence(
251                span,
252                scheds.into_iter().map(|s| s.visit_exprs(f)).collect(),
253            ),
254        }
255    }
256}
257
258impl<Head: Display, Leaf: Display> Display for GenericSchedule<Head, Leaf> {
259    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
260        match self {
261            GenericSchedule::Saturate(_ann, sched) => write!(f, "(saturate {sched})"),
262            GenericSchedule::Repeat(_ann, size, sched) => write!(f, "(repeat {size} {sched})"),
263            GenericSchedule::Run(_ann, config) => write!(f, "{config}"),
264            GenericSchedule::Sequence(_ann, scheds) => {
265                write!(f, "(seq {})", ListDisplay(scheds, " "))
266            }
267        }
268    }
269}
270
271pub type Command = GenericCommand<String, String>;
272
273pub type Subsume = bool;
274
275#[derive(Debug, Clone, PartialEq, Eq)]
276pub enum Subdatatypes {
277    Variants(Vec<Variant>),
278    NewSort(String, Vec<Expr>),
279}
280
281/// The mode of printing a function. The default mode prints the function in a user-friendly way and
282/// has an unreliable interface.
283/// The CSV mode prints the function in the CSV format.
284#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
285pub enum PrintFunctionMode {
286    Default,
287    CSV,
288}
289
290impl Display for PrintFunctionMode {
291    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
292        match self {
293            PrintFunctionMode::Default => write!(f, "default"),
294            PrintFunctionMode::CSV => write!(f, "csv"),
295        }
296    }
297}
298
299/// A [`Command`] is the top-level construct in egglog.
300/// It includes defining rules, declaring functions,
301/// adding to tables, and running rules (via a [`Schedule`]).
302///
303/// # Binding naming convention
304/// Bindings introduced by commands fall into two categories:
305/// - **Global bindings** must start with [`$`](crate::GLOBAL_NAME_PREFIX).
306/// - **Non-global bindings** must *not* start with [`$`](crate::GLOBAL_NAME_PREFIX).
307///
308/// When `--strict-mode` is enabled, violating these conventions is a type error;
309/// otherwise, egglog emits a single warning per program.
310#[derive(Debug, Clone)]
311pub enum GenericCommand<Head, Leaf>
312where
313    Head: Clone + Display,
314    Leaf: Clone + PartialEq + Eq + Display + Hash,
315{
316    /// Create a new user-defined sort, which can then
317    /// be used in new [`Command::Function`] declarations.
318    /// The [`Command::Datatype`] command desugars directly to this command, with one [`Command::Function`]
319    /// per constructor.
320    /// The main use of this command (as opposed to using [`Command::Datatype`]) is for forward-declaring a sort for mutually-recursive datatypes.
321    ///
322    /// It can also be used to create
323    /// a container sort.
324    /// For example, here's how to make a sort for vectors
325    /// of some user-defined sort `Math`:
326    /// ```text
327    /// (sort MathVec (Vec Math))
328    /// ```
329    ///
330    /// Now `MathVec` can be used as an input or output sort.
331    Sort(Span, String, Option<(String, Vec<Expr>)>),
332
333    /// Egglog supports three types of functions
334    ///
335    /// A constructor models an egg-style user-defined datatype
336    /// It can only be defined through the `datatype`/`datatype*` command
337    /// or the `constructor` command
338    ///
339    /// A relation models a datalog-style mathematical relation
340    /// It can only be defined through the `relation` command
341    ///
342    /// A custom function is a dictionary
343    /// It can only be defined through the `function` command
344    ///
345    /// The `datatype` command declares a user-defined datatype.
346    /// Datatypes can be unioned with [`Action::Union`] either
347    /// at the top level or in the actions of a rule.
348    /// This makes them equal in the implicit, global equality relation.
349    ///
350    /// Example:
351    /// ```text
352    /// (datatype Math
353    ///   (Num i64)
354    ///   (Var String)
355    ///   (Add Math Math)
356    ///   (Mul Math Math))
357    /// ```
358    ///
359    /// defines a simple `Math` datatype with variants for numbers, named variables, addition and multiplication.
360    ///
361    /// Datatypes desugar directly to a [`Command::Sort`] and a [`Command::Constructor`] for each constructor.
362    /// The code above becomes:
363    /// ```text
364    /// (sort Math)
365    /// (constructor Num (i64) Math)
366    /// (constructor Var (String) Math)
367    /// (constructor Add (Math Math) Math)
368    /// (constructor Mul (Math Math) Math)
369    ///
370    /// Datatypes are also known as algebraic data types, tagged unions and sum types.
371    Datatype {
372        span: Span,
373        name: String,
374        variants: Vec<Variant>,
375    },
376    Datatypes {
377        span: Span,
378        datatypes: Vec<(Span, String, Subdatatypes)>,
379    },
380
381    /// The `constructor` command defines a new constructor for a user-defined datatype
382    /// Example:
383    /// ```text
384    /// (constructor Add (i64 i64) Math)
385    /// ```
386    ///
387    Constructor {
388        span: Span,
389        name: String,
390        schema: Schema,
391        cost: Option<DefaultCost>,
392        unextractable: bool,
393    },
394
395    /// The `relation` command declares a named relation
396    /// Example:
397    /// ```text
398    /// (relation path (i64 i64))
399    /// (relation edge (i64 i64))
400    /// ```
401    Relation {
402        span: Span,
403        name: String,
404        inputs: Vec<String>,
405    },
406
407    /// The `function` command declare an egglog custom function, which is a database table with a
408    /// a functional dependency (also called a primary key) on its inputs to one output.
409    ///
410    /// ```text
411    /// (function <name:Ident> <schema:Schema> <cost:Cost>
412    ///        (:on_merge <List<Action>>)?
413    ///        (:merge <Expr>)?)
414    ///```
415    /// A function can have a `cost` for extraction.
416    ///
417    /// Finally, it can have a `merge` and `on_merge`, which are triggered when
418    /// the function dependency is violated.
419    /// In this case, the merge expression determines which of the two outputs
420    /// for the same input is used.
421    /// The `on_merge` actions are run after the merge expression is evaluated.
422    ///
423    /// Note that the `:merge` expression must be monotonic
424    /// for the behavior of the egglog program to be consistent and defined.
425    /// In other words, the merge function must define a lattice on the output of the function.
426    /// If values are merged in different orders, they should still result in the same output.
427    /// If the merge expression is not monotonic, the behavior can vary as
428    /// actions may be applied more than once with different results.
429    ///
430    /// ```text
431    /// (function LowerBound (Math) i64 :merge (max old new))
432    /// ```
433    ///
434    /// Specifically, a custom function can also have an EqSort output type:
435    ///
436    /// ```text
437    /// (function Add (i64 i64) Math)
438    /// ```
439    ///
440    /// All functions can be `set`
441    /// with [`Action::Set`].
442    ///
443    /// Output of a function, if being the EqSort type, can be unioned with [`Action::Union`]
444    /// with another datatype of the same `sort`.
445    ///
446    Function {
447        span: Span,
448        name: String,
449        schema: Schema,
450        merge: Option<GenericExpr<Head, Leaf>>,
451    },
452
453    /// Using the `ruleset` command, defines a new
454    /// ruleset that can be added to in [`Command::Rule`]s.
455    /// Rulesets are used to group rules together
456    /// so that they can be run together in a [`Schedule`].
457    ///
458    /// Example:
459    /// Ruleset allows users to define a ruleset- a set of rules
460    ///
461    /// ```text
462    /// (ruleset myrules)
463    /// (rule ((edge x y))
464    ///       ((path x y))
465    ///       :ruleset myrules)
466    /// (run myrules 2)
467    /// ```
468    AddRuleset(Span, String),
469    /// Using the `combined-ruleset` command, construct another ruleset
470    /// which runs all the rules in the given rulesets.
471    /// This is useful for running multiple rulesets together.
472    /// The combined ruleset also inherits any rules added to the individual rulesets
473    /// after the combined ruleset is declared.
474    ///
475    /// Example:
476    /// ```text
477    /// (ruleset myrules1)
478    /// (rule ((edge x y))
479    ///       ((path x y))
480    ///      :ruleset myrules1)
481    /// (ruleset myrules2)
482    /// (rule ((path x y) (edge y z))
483    ///       ((path x z))
484    ///       :ruleset myrules2)
485    /// (combined-ruleset myrules-combined myrules1 myrules2)
486    /// ```
487    UnstableCombinedRuleset(Span, String, Vec<String>),
488    /// ```text
489    /// (rule <body:List<Fact>> <head:List<Action>>)
490    /// ```
491    ///
492    /// defines an egglog rule.
493    /// The rule matches a list of facts with respect to
494    /// the global database, and runs the list of actions
495    /// for each match.
496    /// The matches are done *modulo equality*, meaning
497    /// equal datatypes in the database are considered
498    /// equal.
499    ///
500    /// Example:
501    /// ```text
502    /// (rule ((edge x y))
503    ///       ((path x y)))
504    ///
505    /// (rule ((path x y) (edge y z))
506    ///       ((path x z)))
507    /// ```
508    Rule { rule: GenericRule<Head, Leaf> },
509    /// `rewrite` is syntactic sugar for a specific form of `rule`
510    /// which simply unions the left and right hand sides.
511    ///
512    /// Example:
513    /// ```text
514    /// (rewrite (Add a b)
515    ///          (Add b a))
516    /// ```
517    ///
518    /// Desugars to:
519    /// ```text
520    /// (rule ((= lhs (Add a b)))
521    ///       ((union lhs (Add b a))))
522    /// ```
523    ///
524    /// Additionally, additional facts can be specified
525    /// using a `:when` clause.
526    /// For example, the same rule can be run only
527    /// when `a` is zero:
528    ///
529    /// ```text
530    /// (rewrite (Add a b)
531    ///          (Add b a)
532    ///          :when ((= a (Num 0)))
533    /// ```
534    ///
535    /// Add the `:subsume` flag to cause the left hand side to be subsumed after matching, which means it can
536    /// no longer be matched in a rule, but can still be checked against (See [`Change`] for more details.)
537    ///
538    /// ```text
539    /// (rewrite (Mul a 2) (bitshift-left a 1) :subsume)
540    /// ```
541    ///
542    /// Desugars to:
543    /// ```text
544    /// (rule ((= lhs (Mul a 2)))
545    ///       ((union lhs (bitshift-left a 1))
546    ///        (subsume (Mul a 2))))
547    /// ```
548    Rewrite(String, GenericRewrite<Head, Leaf>, Subsume),
549    /// Similar to [`Command::Rewrite`], but
550    /// generates two rules, one for each direction.
551    ///
552    /// Example:
553    /// ```text
554    /// (bi-rewrite (Mul (Var x) (Num 0))
555    ///             (Var x))
556    /// ```
557    ///
558    /// Becomes:
559    /// ```text
560    /// (rule ((= lhs (Mul (Var x) (Num 0))))
561    ///       ((union lhs (Var x))))
562    /// (rule ((= lhs (Var x)))
563    ///       ((union lhs (Mul (Var x) (Num 0)))))
564    /// ```
565    BiRewrite(String, GenericRewrite<Head, Leaf>),
566    /// Perform an [`Action`] on the global database
567    /// (see documentation for [`Action`] for more details).
568    /// Example:
569    /// ```text
570    /// (let xplusone (Add (Var "x") (Num 1)))
571    /// ```
572    Action(GenericAction<Head, Leaf>),
573    /// `extract` a datatype from the egraph, choosing
574    /// the smallest representative.
575    /// By default, each constructor costs 1 to extract
576    /// (common subexpressions are not shared in the cost
577    /// model).
578    Extract(Span, GenericExpr<Head, Leaf>, GenericExpr<Head, Leaf>),
579    /// Runs a [`Schedule`], which specifies
580    /// rulesets and the number of times to run them.
581    ///
582    /// Example:
583    /// ```text
584    /// (run-schedule
585    ///     (saturate my-ruleset-1)
586    ///     (run my-ruleset-2 4))
587    /// ```
588    ///
589    /// Runs `my-ruleset-1` until saturation,
590    /// then runs `my-ruleset-2` four times.
591    ///
592    /// See [`Schedule`] for more details.
593    RunSchedule(GenericSchedule<Head, Leaf>),
594    /// Print runtime statistics about rules
595    /// and rulesets so far.
596    PrintOverallStatistics(Span, Option<String>),
597    /// The `check` command checks that the given facts
598    /// match at least once in the current database.
599    /// The list of facts is matched in the same way a [`Command::Rule`] is matched.
600    ///
601    /// Example:
602    ///
603    /// ```text
604    /// (check (= (+ 1 2) 3))
605    /// (check (<= 0 3) (>= 3 0))
606    /// (fail (check (= 1 2)))
607    /// ```
608    ///
609    /// prints
610    ///
611    /// ```text
612    /// [INFO ] Checked.
613    /// [INFO ] Checked.
614    /// [ERROR] Check failed
615    /// [INFO ] Command failed as expected.
616    /// ```
617    Check(Span, Vec<GenericFact<Head, Leaf>>),
618    /// Print out rows of a given function, extracting each of the elements of the function.
619    /// Example:
620    ///
621    /// ```text
622    /// (print-function Add 20)
623    /// ```
624    /// prints the first 20 rows of the `Add` function.
625    ///
626    /// ```text
627    /// (print-function Add)
628    /// ```
629    /// prints all rows of the `Add` function.
630    ///
631    /// ```text
632    /// (print-function Add :file "add.csv")
633    /// ```
634    /// prints all rows of the `Add` function to a CSV file.
635    PrintFunction(
636        Span,
637        String,
638        Option<usize>,
639        Option<String>,
640        PrintFunctionMode,
641    ),
642    /// Print out the number of rows in a function or all functions.
643    PrintSize(Span, Option<String>),
644    /// Input a CSV file directly into a function.
645    Input {
646        span: Span,
647        name: String,
648        file: String,
649    },
650    /// Extract and output a set of expressions to a file.
651    Output {
652        span: Span,
653        file: String,
654        exprs: Vec<GenericExpr<Head, Leaf>>,
655    },
656    /// `push` the current egraph `n` times so that it is saved.
657    /// Later, the current database and rules can be restored using `pop`.
658    Push(usize),
659    /// `pop` the current egraph, restoring the previous one.
660    /// The argument specifies how many egraphs to pop.
661    Pop(Span, usize),
662    /// Assert that a command fails with an error.
663    Fail(Span, Box<GenericCommand<Head, Leaf>>),
664    /// Include another egglog file directly as text and run it.
665    Include(Span, String),
666    /// User-defined command.
667    UserDefined(Span, String, Vec<Expr>),
668}
669
670impl<Head, Leaf> Display for GenericCommand<Head, Leaf>
671where
672    Head: Clone + Display,
673    Leaf: Clone + PartialEq + Eq + Display + Hash,
674{
675    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
676        match self {
677            GenericCommand::Rewrite(name, rewrite, subsume) => {
678                rewrite.fmt_with_ruleset(f, name, false, *subsume)
679            }
680            GenericCommand::BiRewrite(name, rewrite) => {
681                rewrite.fmt_with_ruleset(f, name, true, false)
682            }
683            GenericCommand::Datatype {
684                span: _,
685                name,
686                variants,
687            } => {
688                let name = sanitize_internal_name(name);
689                write!(f, "(datatype {name} {})", ListDisplay(variants, " "))
690            }
691            GenericCommand::Action(a) => write!(f, "{a}"),
692            GenericCommand::Extract(_span, expr, variants) => {
693                write!(f, "(extract {expr} {variants})")
694            }
695            GenericCommand::Sort(_span, name, None) => {
696                let name = sanitize_internal_name(name);
697                write!(f, "(sort {name})")
698            }
699            GenericCommand::Sort(_span, name, Some((name2, args))) => {
700                let name = sanitize_internal_name(name);
701                write!(f, "(sort {name} ({name2} {}))", ListDisplay(args, " "))
702            }
703            GenericCommand::Function {
704                span: _,
705                name,
706                schema,
707                merge,
708            } => {
709                let name = sanitize_internal_name(name);
710                write!(f, "(function {name} {schema}")?;
711                if let Some(merge) = &merge {
712                    write!(f, " :merge {merge}")?;
713                } else {
714                    write!(f, " :no-merge")?;
715                }
716                write!(f, ")")
717            }
718            GenericCommand::Constructor {
719                span: _,
720                name,
721                schema,
722                cost,
723                unextractable,
724            } => {
725                let name = sanitize_internal_name(name);
726                write!(f, "(constructor {name} {schema}")?;
727                if let Some(cost) = cost {
728                    write!(f, " :cost {cost}")?;
729                }
730                if *unextractable {
731                    write!(f, " :unextractable")?;
732                }
733                write!(f, ")")
734            }
735            GenericCommand::Relation {
736                span: _,
737                name,
738                inputs,
739            } => {
740                let name = sanitize_internal_name(name);
741                write!(f, "(relation {name} ({}))", ListDisplay(inputs, " "))
742            }
743            GenericCommand::AddRuleset(_span, name) => {
744                let name = sanitize_internal_name(name);
745                write!(f, "(ruleset {name})")
746            }
747            GenericCommand::UnstableCombinedRuleset(_span, name, others) => {
748                let name = sanitize_internal_name(name);
749                let others: Vec<_> = others
750                    .iter()
751                    .map(|other| sanitize_internal_name(other).into_owned())
752                    .collect();
753                write!(
754                    f,
755                    "(unstable-combined-ruleset {name} {})",
756                    ListDisplay(&others, " ")
757                )
758            }
759            GenericCommand::Rule { rule } => rule.fmt(f),
760            GenericCommand::RunSchedule(sched) => write!(f, "(run-schedule {sched})"),
761            GenericCommand::PrintOverallStatistics(_span, file) => match file {
762                Some(file) => write!(f, "(print-stats :file {file})"),
763                None => write!(f, "(print-stats)"),
764            },
765            GenericCommand::Check(_ann, facts) => {
766                write!(f, "(check {})", ListDisplay(facts, "\n"))
767            }
768            GenericCommand::Push(n) => write!(f, "(push {n})"),
769            GenericCommand::Pop(_span, n) => write!(f, "(pop {n})"),
770            GenericCommand::PrintFunction(_span, name, n, file, mode) => {
771                let name = sanitize_internal_name(name);
772                write!(f, "(print-function {name}")?;
773                if let Some(n) = n {
774                    write!(f, " {n}")?;
775                }
776                if let Some(file) = file {
777                    write!(f, " :file {file:?}")?;
778                }
779                match mode {
780                    PrintFunctionMode::Default => {}
781                    PrintFunctionMode::CSV => write!(f, " :mode csv")?,
782                }
783                write!(f, ")")
784            }
785            GenericCommand::PrintSize(_span, name) => {
786                let name: Option<_> = name
787                    .as_ref()
788                    .map(|value| sanitize_internal_name(value).into_owned());
789                write!(f, "(print-size {})", ListDisplay(name, " "))
790            }
791            GenericCommand::Input {
792                span: _,
793                name,
794                file,
795            } => {
796                let name = sanitize_internal_name(name);
797                write!(f, "(input {name} {file:?})")
798            }
799            GenericCommand::Output {
800                span: _,
801                file,
802                exprs,
803            } => write!(f, "(output {file:?} {})", ListDisplay(exprs, " ")),
804            GenericCommand::Fail(_span, cmd) => write!(f, "(fail {cmd})"),
805            GenericCommand::Include(_span, file) => write!(f, "(include {file:?})"),
806            GenericCommand::Datatypes { span: _, datatypes } => {
807                let datatypes: Vec<_> = datatypes
808                    .iter()
809                    .map(|(_, name, variants)| {
810                        let name = sanitize_internal_name(name);
811                        match variants {
812                            Subdatatypes::Variants(variants) => {
813                                format!("({name} {})", ListDisplay(variants, " "))
814                            }
815                            Subdatatypes::NewSort(head, args) => {
816                                format!("(sort {name} ({head} {}))", ListDisplay(args, " "))
817                            }
818                        }
819                    })
820                    .collect();
821                write!(f, "(datatype* {})", ListDisplay(datatypes, " "))
822            }
823            GenericCommand::UserDefined(_span, name, exprs) => {
824                let name = sanitize_internal_name(name);
825                write!(f, "({name} {})", ListDisplay(exprs, " "))
826            }
827        }
828    }
829}
830
831#[derive(Clone, Debug, PartialEq, Eq, Hash)]
832pub struct IdentSort {
833    pub ident: String,
834    pub sort: String,
835}
836
837impl Display for IdentSort {
838    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
839        write!(f, "({} {})", self.ident, self.sort)
840    }
841}
842
843pub type RunConfig = GenericRunConfig<String, String>;
844pub(crate) type ResolvedRunConfig = GenericRunConfig<ResolvedCall, ResolvedVar>;
845
846#[derive(Clone, Debug, PartialEq, Eq, Hash)]
847pub struct GenericRunConfig<Head, Leaf> {
848    pub ruleset: String,
849    pub until: Option<Vec<GenericFact<Head, Leaf>>>,
850}
851
852impl<Head, Leaf> GenericRunConfig<Head, Leaf>
853where
854    Head: Clone + Display,
855    Leaf: Clone + PartialEq + Eq + Display + Hash,
856{
857    pub fn visit_exprs(
858        self,
859        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
860    ) -> Self {
861        Self {
862            ruleset: self.ruleset,
863            until: self
864                .until
865                .map(|until| until.into_iter().map(|fact| fact.visit_exprs(f)).collect()),
866        }
867    }
868}
869
870impl<Head: Display, Leaf: Display> Display for GenericRunConfig<Head, Leaf>
871where
872    Head: Display,
873    Leaf: Display,
874{
875    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
876        write!(f, "(run")?;
877        let ruleset = sanitize_internal_name(&self.ruleset);
878        if !ruleset.is_empty() {
879            write!(f, " {ruleset}")?;
880        }
881        if let Some(until) = &self.until {
882            write!(f, " :until {}", ListDisplay(until, " "))?;
883        }
884        write!(f, ")")
885    }
886}
887
888pub type FunctionDecl = GenericFunctionDecl<String, String>;
889pub(crate) type ResolvedFunctionDecl = GenericFunctionDecl<ResolvedCall, ResolvedVar>;
890
891#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
892pub enum FunctionSubtype {
893    Constructor,
894    Relation,
895    Custom,
896}
897
898impl Display for FunctionSubtype {
899    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
900        match self {
901            FunctionSubtype::Constructor => write!(f, "constructor"),
902            FunctionSubtype::Relation => write!(f, "relation"),
903            FunctionSubtype::Custom => write!(f, "function"),
904        }
905    }
906}
907
908/// Represents the declaration of a function
909/// directly parsed from source syntax.
910#[derive(Clone, Debug, PartialEq, Eq, Hash)]
911pub struct GenericFunctionDecl<Head, Leaf>
912where
913    Head: Clone + Display,
914    Leaf: Clone + PartialEq + Eq + Display + Hash,
915{
916    pub name: String,
917    pub subtype: FunctionSubtype,
918    pub schema: Schema,
919    pub merge: Option<GenericExpr<Head, Leaf>>,
920    pub cost: Option<DefaultCost>,
921    pub unextractable: bool,
922    /// Globals are desugared to functions, with this flag set to true.
923    /// This is used by visualization to handle globals differently.
924    pub let_binding: bool,
925    pub span: Span,
926}
927
928#[derive(Clone, Debug, PartialEq, Eq, Hash)]
929pub struct Variant {
930    pub span: Span,
931    pub name: String,
932    pub types: Vec<String>,
933    pub cost: Option<DefaultCost>,
934    pub unextractable: bool,
935}
936
937impl Display for Variant {
938    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
939        let name = sanitize_internal_name(&self.name);
940        write!(f, "({name}")?;
941        if !self.types.is_empty() {
942            write!(f, " {}", ListDisplay(&self.types, " "))?;
943        }
944        if let Some(cost) = self.cost {
945            write!(f, " :cost {cost}")?;
946        }
947        write!(f, ")")
948    }
949}
950
951#[derive(Clone, Debug, PartialEq, Eq, Hash)]
952pub struct Schema {
953    pub input: Vec<String>,
954    pub output: String,
955}
956
957impl Display for Schema {
958    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
959        write!(f, "({}) {}", ListDisplay(&self.input, " "), self.output)
960    }
961}
962
963impl Schema {
964    pub fn new(input: Vec<String>, output: String) -> Self {
965        Self { input, output }
966    }
967}
968
969impl FunctionDecl {
970    /// Constructs a `function`
971    pub fn function(
972        span: Span,
973        name: String,
974        schema: Schema,
975        merge: Option<GenericExpr<String, String>>,
976    ) -> Self {
977        Self {
978            name,
979            subtype: FunctionSubtype::Custom,
980            schema,
981            merge,
982            cost: None,
983            unextractable: true,
984            let_binding: false,
985            span,
986        }
987    }
988
989    /// Constructs a `constructor`
990    pub fn constructor(
991        span: Span,
992        name: String,
993        schema: Schema,
994        cost: Option<DefaultCost>,
995        unextractable: bool,
996    ) -> Self {
997        Self {
998            name,
999            subtype: FunctionSubtype::Constructor,
1000            schema,
1001            merge: None,
1002            cost,
1003            unextractable,
1004            let_binding: false,
1005            span,
1006        }
1007    }
1008
1009    /// Constructs a `relation`
1010    pub fn relation(span: Span, name: String, input: Vec<String>) -> Self {
1011        Self {
1012            name,
1013            subtype: FunctionSubtype::Relation,
1014            schema: Schema {
1015                input,
1016                output: String::from("Unit"),
1017            },
1018            merge: None,
1019            cost: None,
1020            unextractable: true,
1021            let_binding: false,
1022            span,
1023        }
1024    }
1025}
1026
1027impl<Head, Leaf> GenericFunctionDecl<Head, Leaf>
1028where
1029    Head: Clone + Display,
1030    Leaf: Clone + PartialEq + Eq + Display + Hash,
1031{
1032    pub fn visit_exprs(
1033        self,
1034        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
1035    ) -> GenericFunctionDecl<Head, Leaf> {
1036        GenericFunctionDecl {
1037            name: self.name,
1038            subtype: self.subtype,
1039            schema: self.schema,
1040            merge: self.merge.map(|expr| expr.visit_exprs(f)),
1041            cost: self.cost,
1042            unextractable: self.unextractable,
1043            let_binding: self.let_binding,
1044            span: self.span,
1045        }
1046    }
1047}
1048
1049pub type Fact = GenericFact<String, String>;
1050pub(crate) type ResolvedFact = GenericFact<ResolvedCall, ResolvedVar>;
1051pub(crate) type MappedFact<Head, Leaf> = GenericFact<CorrespondingVar<Head, Leaf>, Leaf>;
1052
1053pub struct Facts<Head, Leaf>(pub Vec<GenericFact<Head, Leaf>>);
1054
1055impl<Head, Leaf> Facts<Head, Leaf>
1056where
1057    Head: Clone + Display,
1058    Leaf: Clone + PartialEq + Eq + Display + Hash,
1059{
1060    /// Flattens a list of facts into a Query.
1061    /// For typechecking, we need the correspondence between the original ast
1062    /// and the flattened one, so that we can annotate the original with types.
1063    /// That's why this function produces a corresponding list of facts, annotated with
1064    /// the variable names in the flattened Query.
1065    /// (Typechecking preserves the original AST this way,
1066    /// and allows terms and proof instrumentation to do the same).
1067    pub(crate) fn to_query(
1068        &self,
1069        typeinfo: &TypeInfo,
1070        fresh_gen: &mut impl FreshGen<Head, Leaf>,
1071    ) -> (Query<HeadOrEq<Head>, Leaf>, Vec<MappedFact<Head, Leaf>>) {
1072        let mut atoms = vec![];
1073        let mut new_body = vec![];
1074
1075        for fact in self.0.iter() {
1076            match fact {
1077                GenericFact::Eq(span, e1, e2) => {
1078                    let mut to_equate = vec![];
1079                    let mut process = |expr: &GenericExpr<Head, Leaf>| {
1080                        let (child_atoms, expr) = expr.to_query(typeinfo, fresh_gen);
1081                        atoms.extend(child_atoms);
1082                        to_equate.push(expr.get_corresponding_var_or_lit(typeinfo));
1083                        expr
1084                    };
1085                    let e1 = process(e1);
1086                    let e2 = process(e2);
1087                    atoms.push(GenericAtom {
1088                        span: span.clone(),
1089                        head: HeadOrEq::Eq,
1090                        args: to_equate,
1091                    });
1092                    new_body.push(GenericFact::Eq(span.clone(), e1, e2));
1093                }
1094                GenericFact::Fact(expr) => {
1095                    let (child_atoms, expr) = expr.to_query(typeinfo, fresh_gen);
1096                    atoms.extend(child_atoms);
1097                    new_body.push(GenericFact::Fact(expr));
1098                }
1099            }
1100        }
1101        (Query { atoms }, new_body)
1102    }
1103}
1104
1105#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1106pub struct CorrespondingVar<Head, Leaf>
1107where
1108    Head: Clone + Display,
1109    Leaf: Clone + PartialEq + Eq + Display + Hash,
1110{
1111    pub head: Head,
1112    pub to: Leaf,
1113}
1114
1115impl<Head, Leaf> CorrespondingVar<Head, Leaf>
1116where
1117    Head: Clone + Display,
1118    Leaf: Clone + PartialEq + Eq + Display + Hash,
1119{
1120    pub fn new(head: Head, leaf: Leaf) -> Self {
1121        Self { head, to: leaf }
1122    }
1123}
1124
1125impl<Head, Leaf> Display for CorrespondingVar<Head, Leaf>
1126where
1127    Head: Clone + Display,
1128    Leaf: Clone + PartialEq + Eq + Display + Hash,
1129{
1130    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
1131        write!(f, "{} -> {}", self.head, self.to)
1132    }
1133}
1134pub type Action = GenericAction<String, String>;
1135pub(crate) type MappedAction = GenericAction<CorrespondingVar<String, String>, String>;
1136pub(crate) type ResolvedAction = GenericAction<ResolvedCall, ResolvedVar>;
1137
1138pub type Actions = GenericActions<String, String>;
1139pub(crate) type ResolvedActions = GenericActions<ResolvedCall, ResolvedVar>;
1140pub(crate) type MappedActions<Head, Leaf> = GenericActions<CorrespondingVar<Head, Leaf>, Leaf>;
1141
1142pub type Rule = GenericRule<String, String>;
1143pub(crate) type ResolvedRule = GenericRule<ResolvedCall, ResolvedVar>;
1144
1145pub type Rewrite = GenericRewrite<String, String>;
1146
1147#[derive(Clone, Debug)]
1148pub struct GenericRewrite<Head, Leaf> {
1149    pub span: Span,
1150    pub lhs: GenericExpr<Head, Leaf>,
1151    pub rhs: GenericExpr<Head, Leaf>,
1152    pub conditions: Vec<GenericFact<Head, Leaf>>,
1153}
1154
1155impl<Head: Display, Leaf: Display> GenericRewrite<Head, Leaf> {
1156    /// Converts the rewrite into an s-expression.
1157    pub fn fmt_with_ruleset(
1158        &self,
1159        f: &mut Formatter,
1160        ruleset: &str,
1161        is_bidirectional: bool,
1162        subsume: bool,
1163    ) -> std::fmt::Result {
1164        let direction = if is_bidirectional {
1165            "birewrite"
1166        } else {
1167            "rewrite"
1168        };
1169        write!(f, "({direction} {} {}", self.lhs, self.rhs)?;
1170        if subsume {
1171            write!(f, " :subsume")?;
1172        }
1173        if !self.conditions.is_empty() {
1174            write!(f, " :when ({})", ListDisplay(&self.conditions, " "))?;
1175        }
1176        if !ruleset.is_empty() {
1177            let ruleset = sanitize_internal_name(ruleset);
1178            write!(f, " :ruleset {ruleset}")?;
1179        }
1180        write!(f, ")")
1181    }
1182}
1183
1184pub(crate) trait MappedExprExt<Head, Leaf>
1185where
1186    Head: Clone + Display,
1187    Leaf: Clone + PartialEq + Eq + Display + Hash,
1188{
1189    fn get_corresponding_var_or_lit(&self, typeinfo: &TypeInfo) -> GenericAtomTerm<Leaf>;
1190}
1191
1192impl<Head, Leaf> MappedExprExt<Head, Leaf> for MappedExpr<Head, Leaf>
1193where
1194    Head: Clone + Display,
1195    Leaf: Clone + PartialEq + Eq + Display + Hash,
1196{
1197    fn get_corresponding_var_or_lit(&self, typeinfo: &TypeInfo) -> GenericAtomTerm<Leaf> {
1198        // Note: need typeinfo to resolve whether a symbol is a global or not
1199        // This is error-prone and the complexities can be avoided by treating globals
1200        // as nullary functions.
1201        match self {
1202            GenericExpr::Var(span, v) => {
1203                if typeinfo.is_global(&v.to_string()) {
1204                    GenericAtomTerm::Global(span.clone(), v.clone())
1205                } else {
1206                    GenericAtomTerm::Var(span.clone(), v.clone())
1207                }
1208            }
1209            GenericExpr::Lit(span, lit) => GenericAtomTerm::Literal(span.clone(), lit.clone()),
1210            GenericExpr::Call(span, head, _) => GenericAtomTerm::Var(span.clone(), head.to.clone()),
1211        }
1212    }
1213}