egglog/
lib.rs

1//! # egglog
2//! egglog is a language specialized for writing equality saturation
3//! applications. It is the successor to the rust library [egg](https://github.com/egraphs-good/egg).
4//! egglog is faster and more general than egg.
5//!
6//! # Documentation
7//! Documentation for the egglog language can be found here: [`Command`].
8//!
9//! # Tutorial
10//! We have a [text tutorial](https://egraphs-good.github.io/egglog-tutorial/01-basics.html) on egglog and how to use it.
11//! We also have a slightly outdated [video tutorial](https://www.youtube.com/watch?v=N2RDQGRBrSY).
12//!
13//!
14//!
15pub mod ast;
16#[cfg(feature = "bin")]
17mod cli;
18mod command_macro;
19pub mod constraint;
20mod core;
21pub mod extract;
22pub mod prelude;
23mod proofs;
24
25pub mod scheduler;
26mod serialize;
27pub mod sort;
28mod termdag;
29mod typechecking;
30pub mod util;
31pub use command_macro::{CommandMacro, CommandMacroRegistry};
32
33// This is used to allow the `add_primitive` macro to work in
34// both this crate and other crates by referring to `::egglog`.
35extern crate self as egglog;
36pub use ast::{ResolvedExpr, ResolvedFact, ResolvedVar};
37#[cfg(feature = "bin")]
38pub use cli::*;
39use constraint::{Constraint, Problem, SimpleTypeConstraint, TypeConstraint};
40pub use core::{Atom, AtomTerm};
41use core::{CoreActionContext, ResolvedAtomTerm};
42pub use core::{ResolvedCall, SpecializedPrimitive};
43pub use core_relations::{BaseValue, ContainerValue, ExecutionState, Value};
44use core_relations::{ExternalFunctionId, make_external_func};
45use csv::Writer;
46pub use egglog_add_primitive::add_literal_prim;
47pub use egglog_add_primitive::add_primitive;
48pub use egglog_add_primitive::add_primitive_with_validator;
49use egglog_ast::generic_ast::{Change, GenericExpr, Literal};
50use egglog_ast::span::Span;
51use egglog_ast::util::ListDisplay;
52pub use egglog_bridge::FunctionRow;
53use egglog_bridge::{ColumnTy, QueryEntry};
54use egglog_core_relations as core_relations;
55use egglog_numeric_id as numeric_id;
56use egglog_reports::{ReportLevel, RunReport};
57use extract::{CostModel, DefaultCost, Extractor, TreeAdditiveCostModel};
58use indexmap::map::Entry;
59use log::{Level, log_enabled};
60use numeric_id::DenseIdMap;
61use prelude::*;
62pub use proofs::proof_encoding_helpers::{file_supports_proofs, program_supports_proofs};
63use scheduler::{SchedulerId, SchedulerRecord};
64pub use serialize::{SerializeConfig, SerializeOutput, SerializedNode};
65use sort::*;
66use std::fmt::{Debug, Display, Formatter};
67use std::fs::File;
68use std::hash::Hash;
69use std::io::{Read, Write as _};
70use std::iter::once;
71use std::ops::Deref;
72use std::path::PathBuf;
73use std::str::FromStr;
74use std::sync::Arc;
75pub use termdag::{Term, TermDag, TermId};
76use thiserror::Error;
77pub use typechecking::PrimitiveValidator;
78pub use typechecking::TypeError;
79pub use typechecking::TypeInfo;
80use util::*;
81
82use crate::ast::desugar::desugar_command;
83use crate::ast::*;
84use crate::core::{GenericActionsExt, ResolvedRuleExt};
85use crate::proofs::proof_encoding::{EncodingState, ProofInstrumentor};
86use crate::proofs::proof_encoding_helpers::command_supports_proof_encoding;
87use crate::proofs::proof_extraction::ProveExistsError;
88use crate::proofs::proof_format::{ProofId, ProofStore};
89use crate::proofs::proof_normal_form::proof_form;
90
91pub const GLOBAL_NAME_PREFIX: &str = "$";
92
93pub type ArcSort = Arc<dyn Sort>;
94
95/// A trait for implementing custom primitive operations in egglog.
96///
97/// Primitives are built-in functions that can be called in both rule queries and actions.
98pub trait Primitive {
99    /// Returns the name of this primitive operation.
100    fn name(&self) -> &str;
101
102    /// Constructs a type constraint for the primitive that uses the span information
103    /// for error localization.
104    fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint>;
105
106    /// Applies the primitive operation to the given arguments.
107    ///
108    /// Returns `Some(value)` if the operation succeeds, or `None` if it fails.
109    fn apply(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value>;
110}
111
112/// A user-defined command output trait.
113pub trait UserDefinedCommandOutput: Debug + std::fmt::Display + Send + Sync {}
114impl<T> UserDefinedCommandOutput for T where T: Debug + std::fmt::Display + Send + Sync {}
115
116/// Output from a command.
117#[derive(Clone, Debug)]
118#[allow(clippy::large_enum_variant)]
119pub enum CommandOutput {
120    /// The size of a function
121    PrintFunctionSize(usize),
122    /// The name of all functions and their sizes
123    PrintAllFunctionsSize(Vec<(String, usize)>),
124    /// The best function found after extracting
125    ExtractBest(TermDag, DefaultCost, TermId),
126    /// The variants of a function found after extracting
127    ExtractVariants(TermDag, Vec<TermId>),
128    /// A high-level proof witnessing constructor existence
129    ProveExists {
130        proof_store: ProofStore,
131        proof_id: ProofId,
132    },
133    /// The report from all runs
134    OverallStatistics(RunReport),
135    /// A printed function and all its values
136    PrintFunction(Function, TermDag, Vec<(TermId, TermId)>, PrintFunctionMode),
137    /// The report from a single run
138    RunSchedule(RunReport),
139    /// A user defined output
140    UserDefined(Arc<dyn UserDefinedCommandOutput>),
141}
142
143impl std::fmt::Display for CommandOutput {
144    /// Format the command output for display, ending with a newline.
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        match self {
147            CommandOutput::PrintFunctionSize(size) => writeln!(f, "{}", size),
148            CommandOutput::PrintAllFunctionsSize(names_and_sizes) => {
149                for name in names_and_sizes {
150                    writeln!(f, "{}: {}", name.0, name.1)?;
151                }
152                Ok(())
153            }
154            CommandOutput::ExtractBest(termdag, _cost, term) => {
155                writeln!(f, "{}", termdag.to_string(*term))
156            }
157            CommandOutput::ExtractVariants(termdag, terms) => {
158                writeln!(f, "(")?;
159                for expr in terms {
160                    writeln!(f, "   {}", termdag.to_string(*expr))?;
161                }
162                writeln!(f, ")")
163            }
164            CommandOutput::ProveExists {
165                proof_store,
166                proof_id,
167            } => write!(f, "{}", proof_store.proof_to_string(*proof_id)),
168            CommandOutput::OverallStatistics(run_report) => {
169                write!(f, "Overall statistics:\n{}", run_report)
170            }
171            CommandOutput::PrintFunction(function, termdag, terms_and_outputs, mode) => {
172                let out_is_unit = function.schema.output.name() == UnitSort.name();
173                if *mode == PrintFunctionMode::CSV {
174                    let mut wtr = Writer::from_writer(vec![]);
175                    for (term_id, output) in terms_and_outputs {
176                        let term = termdag.get(*term_id);
177                        match term {
178                            Term::App(name, children) => {
179                                let mut values = vec![name.clone()];
180                                for child_id in children {
181                                    values.push(termdag.to_string(*child_id));
182                                }
183
184                                if !out_is_unit {
185                                    values.push(termdag.to_string(*output));
186                                }
187                                wtr.write_record(&values).map_err(|_| std::fmt::Error)?;
188                            }
189                            _ => panic!("Expect function_to_dag to return a list of apps."),
190                        }
191                    }
192                    let csv_bytes = wtr.into_inner().map_err(|_| std::fmt::Error)?;
193                    f.write_str(&String::from_utf8(csv_bytes).map_err(|_| std::fmt::Error)?)
194                } else {
195                    writeln!(f, "(")?;
196                    for (term, output) in terms_and_outputs.iter() {
197                        write!(f, "   {}", termdag.to_string(*term))?;
198                        if !out_is_unit {
199                            write!(f, " -> {}", termdag.to_string(*output))?;
200                        }
201                        writeln!(f)?;
202                    }
203                    writeln!(f, ")")
204                }
205            }
206            CommandOutput::RunSchedule(_report) => Ok(()),
207            CommandOutput::UserDefined(output) => {
208                write!(f, "{}", *output)
209            }
210        }
211    }
212}
213
214/// The main interface for an e-graph in egglog.
215///
216/// An [`EGraph`] maintains a collection of equivalence classes of terms and provides
217/// operations for adding facts, running rules, and extracting optimal terms.
218///
219/// # Examples
220///
221/// ```
222/// use egglog::*;
223///
224/// let mut egraph = EGraph::default();
225/// egraph.parse_and_run_program(None, "(datatype Math (Num i64) (Add Math Math))").unwrap();
226/// ```
227#[derive(Clone)]
228pub struct EGraph {
229    backend: egglog_bridge::EGraph,
230    pub parser: Parser,
231    names: check_shadowing::Names,
232    /// pushed_egraph forms a linked list of pushed egraphs.
233    /// Pop reverts the egraph to the last pushed egraph.
234    pushed_egraph: Option<Box<Self>>,
235    functions: IndexMap<String, Function>,
236    rulesets: IndexMap<String, Ruleset>,
237    pub fact_directory: Option<PathBuf>,
238    pub seminaive: bool,
239    type_info: TypeInfo,
240    /// The run report unioned over all runs so far.
241    overall_run_report: RunReport,
242    schedulers: DenseIdMap<SchedulerId, SchedulerRecord>,
243    commands: IndexMap<String, Arc<dyn UserDefinedCommand>>,
244    strict_mode: bool,
245    warned_about_global_prefix: bool,
246    /// Registry for command-level macros
247    command_macros: CommandMacroRegistry,
248    proof_state: EncodingState,
249    /// In proof mode, we keep track of all desugared commands so we can check proofs later.
250    desugared_commands: Vec<ResolvedNCommand>,
251}
252
253/// A user-defined command allows users to inject custom command that can be called
254/// in an egglog program.
255///
256/// Compared to an external function, a user-defined command is more powerful because
257/// it has an exclusive access to the e-graph.
258pub trait UserDefinedCommand: Send + Sync {
259    /// Run the command with the given arguments.
260    fn update(&self, egraph: &mut EGraph, args: &[Expr]) -> Result<Option<CommandOutput>, Error>;
261}
262
263/// A function in the e-graph.
264///
265/// This contains the schema information of the function and
266/// the backend id of the function in the e-graph.
267#[derive(Clone)]
268pub struct Function {
269    decl: ResolvedFunctionDecl,
270    schema: ResolvedSchema,
271    can_subsume: bool,
272    backend_id: egglog_bridge::FunctionId,
273}
274
275impl Function {
276    /// Get the name of the function.
277    pub fn name(&self) -> &str {
278        &self.decl.name
279    }
280
281    /// Get the schema of the function.
282    pub fn schema(&self) -> &ResolvedSchema {
283        &self.schema
284    }
285
286    /// Whether this function supports subsumption.
287    pub fn can_subsume(&self) -> bool {
288        self.can_subsume
289    }
290}
291
292#[derive(Clone, Debug)]
293pub struct ResolvedSchema {
294    pub input: Vec<ArcSort>,
295    pub output: ArcSort,
296}
297
298impl ResolvedSchema {
299    /// Get the type at position `index`, counting the `output` sort as at position `input.len()`.
300    pub fn get_by_pos(&self, index: usize) -> Option<&ArcSort> {
301        if self.input.len() == index {
302            Some(&self.output)
303        } else {
304            self.input.get(index)
305        }
306    }
307}
308
309impl Debug for Function {
310    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
311        f.debug_struct("Function")
312            .field("decl", &self.decl)
313            .field("schema", &self.schema)
314            .finish()
315    }
316}
317
318impl Default for EGraph {
319    fn default() -> Self {
320        let mut parser = Parser::default();
321        let proof_state = EncodingState::new(&mut parser.symbol_gen);
322        let mut eg = Self {
323            backend: Default::default(),
324            parser,
325            names: Default::default(),
326            pushed_egraph: Default::default(),
327            functions: Default::default(),
328            rulesets: Default::default(),
329            fact_directory: None,
330            seminaive: true,
331            overall_run_report: Default::default(),
332            type_info: Default::default(),
333            schedulers: Default::default(),
334            commands: Default::default(),
335            strict_mode: false,
336            warned_about_global_prefix: false,
337            command_macros: Default::default(),
338            proof_state,
339            desugared_commands: vec![],
340        };
341
342        add_base_sort(&mut eg, UnitSort, span!()).unwrap();
343        add_base_sort(&mut eg, StringSort, span!()).unwrap();
344        add_base_sort(&mut eg, BoolSort, span!()).unwrap();
345        add_base_sort(&mut eg, I64Sort, span!()).unwrap();
346        add_base_sort(&mut eg, F64Sort, span!()).unwrap();
347        add_base_sort(&mut eg, BigIntSort, span!()).unwrap();
348        add_base_sort(&mut eg, BigRatSort, span!()).unwrap();
349        eg.type_info.add_presort::<MapSort>(span!()).unwrap();
350        eg.type_info.add_presort::<SetSort>(span!()).unwrap();
351        eg.type_info.add_presort::<VecSort>(span!()).unwrap();
352        eg.type_info.add_presort::<FunctionSort>(span!()).unwrap();
353        eg.type_info.add_presort::<MultiSetSort>(span!()).unwrap();
354
355        // Add != with a validator that computes inequality result
356        let neq_validator = |termdag: &mut TermDag, args: &[TermId]| -> Option<TermId> {
357            if args.len() == 2 && args[0] != args[1] {
358                // Return unit literal for successful inequality
359                Some(termdag.lit(Literal::Unit))
360            } else {
361                None
362            }
363        };
364        add_primitive_with_validator!(
365            &mut eg,
366            "!=" = |a: #, b: #| -?> () {
367                (a != b).then_some(())
368            },
369            neq_validator
370        );
371
372        add_primitive!(&mut eg, "value-eq" = |a: #, b: #| -?> () {
373            (a == b).then_some(())
374        });
375        add_primitive!(&mut eg, "ordering-min" = |a: #, b: #| -> # {
376            if a < b { a } else { b }
377        });
378        add_primitive!(&mut eg, "ordering-max" = |a: #, b: #| -> # {
379            if a > b { a } else { b }
380        });
381
382        eg.rulesets
383            .insert("".into(), Ruleset::Rules(Default::default()));
384
385        eg
386    }
387}
388
389#[derive(Debug, Error)]
390#[error("Not found: {0}")]
391pub struct NotFoundError(String);
392
393impl EGraph {
394    /// Create a new e-graph with the term-encoding pipeline enabled.
395    ///
396    /// In term-encoding mode the e-graph eagerly instruments every constructor
397    /// and function with auxiliary term tables, view tables, and per-sort
398    /// union-finds so that canonical representatives and their justifications are
399    /// materialized explicitly.  This makes it possible to record and emit
400    /// equality proofs while preserving the observable behaviour of supported
401    /// commands.
402    pub fn new_with_term_encoding() -> Self {
403        let mut egraph = EGraph::default();
404        egraph.proof_state.original_typechecking = Some(Box::new(egraph.clone()));
405        egraph
406    }
407
408    /// Create a new e-graph with proof generation enabled.
409    pub fn new_with_proofs() -> Self {
410        let mut egraph = EGraph::new_with_term_encoding();
411        egraph.proof_state.proofs_enabled = true;
412        egraph
413    }
414
415    /// Enable the term-encoding pipeline on an existing `EGraph`.
416    ///
417    /// This method is to support the current CLI implementation with egglog-experimental (https://github.com/egraphs-good/egglog/issues/768)
418    pub(crate) fn with_term_encoding_enabled(mut self) -> Self {
419        self.proof_state.original_typechecking = Some(Box::new(self.clone()));
420        self
421    }
422
423    /// Enable proof generation on this e-graph.
424    /// TODO proofs should be turned on during creation of the e-graph, not afterwards.
425    /// This method is to support the current CLI implementation with egglog-experimental (https://github.com/egraphs-good/egglog/issues/768)
426    pub(crate) fn with_proofs_enabled(mut self) -> Self {
427        self = self.with_term_encoding_enabled();
428        self.proof_state.proofs_enabled = true;
429        self
430    }
431
432    /// Enable testing of getting proofs for all `check` commands.
433    pub fn with_proof_testing(mut self) -> Self {
434        self.proof_state.proof_testing = true;
435        self
436    }
437
438    /// Add a user-defined command to the e-graph
439    /// Get the type information for this e-graph
440    pub fn type_info(&mut self) -> &mut TypeInfo {
441        &mut self.type_info
442    }
443
444    /// Get read-only access to the command macro registry
445    pub fn command_macros(&self) -> &CommandMacroRegistry {
446        &self.command_macros
447    }
448
449    /// Get mutable access to the command macro registry
450    pub fn command_macros_mut(&mut self) -> &mut CommandMacroRegistry {
451        &mut self.command_macros
452    }
453
454    pub fn add_command(
455        &mut self,
456        name: String,
457        command: Arc<dyn UserDefinedCommand>,
458    ) -> Result<(), Error> {
459        if self.commands.contains_key(&name)
460            || self.functions.contains_key(&name)
461            || self.type_info.get_prims(&name).is_some()
462        {
463            return Err(Error::CommandAlreadyExists(name, span!()));
464        }
465        self.commands.insert(name.clone(), command);
466        self.parser.add_user_defined(name)?;
467        Ok(())
468    }
469
470    /// Configure whether globals missing the required `$` prefix are treated as errors.
471    pub fn set_strict_mode(&mut self, strict_mode: bool) {
472        self.strict_mode = strict_mode;
473    }
474
475    /// Returns `true` when missing `$` prefixes on globals are treated as errors.
476    pub fn strict_mode(&self) -> bool {
477        self.strict_mode
478    }
479
480    fn ensure_global_name_prefix(&mut self, span: &Span, name: &str) -> Result<(), TypeError> {
481        if name.starts_with(GLOBAL_NAME_PREFIX) {
482            return Ok(());
483        }
484        if self.strict_mode {
485            Err(TypeError::GlobalMissingPrefix {
486                name: name.to_owned(),
487                span: span.clone(),
488            })
489        } else {
490            self.warn_missing_global_prefix(span, name)?;
491            Ok(())
492        }
493    }
494
495    fn warn_missing_global_prefix(
496        &mut self,
497        span: &Span,
498        canonical_name: &str,
499    ) -> Result<(), TypeError> {
500        if self.strict_mode {
501            return Err(TypeError::GlobalMissingPrefix {
502                name: format!("{}{}", GLOBAL_NAME_PREFIX, canonical_name),
503                span: span.clone(),
504            });
505        }
506        if self.warned_about_global_prefix {
507            return Ok(());
508        }
509        self.warned_about_global_prefix = true;
510        log::warn!(
511            "{}\nGlobal `{}` should start with `{}`. Enable `--strict-mode` to turn this warning into an error. Suppressing additional warnings of this type.",
512            span,
513            canonical_name,
514            GLOBAL_NAME_PREFIX
515        );
516        Ok(())
517    }
518
519    fn warn_prefixed_non_globals(
520        &mut self,
521        span: &Span,
522        canonical_name: &str,
523    ) -> Result<(), TypeError> {
524        if self.strict_mode {
525            return Err(TypeError::NonGlobalPrefixed {
526                name: canonical_name.to_string(),
527                span: span.clone(),
528            });
529        }
530        if self.warned_about_global_prefix {
531            return Ok(());
532        }
533        self.warned_about_global_prefix = true;
534        log::warn!(
535            "{}\nNon-global `{}` should not start with `{}`. Enable `--strict-mode` to turn this warning into an error. Suppressing additional warnings of this type.",
536            span,
537            canonical_name,
538            GLOBAL_NAME_PREFIX
539        );
540        Ok(())
541    }
542
543    /// Push a snapshot of the e-graph into the stack.
544    ///
545    /// See [`EGraph::pop`].
546    pub fn push(&mut self) {
547        let prev_prev: Option<Box<Self>> = self.pushed_egraph.take();
548        let mut prev = self.clone();
549        prev.pushed_egraph = prev_prev;
550        self.pushed_egraph = Some(Box::new(prev));
551    }
552
553    /// Pop the current egraph off the stack, replacing
554    /// it with the previously pushed egraph.
555    /// It preserves the run report and messages from the popped
556    /// egraph.
557    pub fn pop(&mut self) -> Result<(), Error> {
558        match self.pushed_egraph.take() {
559            Some(e) => {
560                // Copy the overall report from the popped egraph
561                let overall_run_report = self.overall_run_report.clone();
562                *self = *e;
563                self.overall_run_report = overall_run_report;
564                Ok(())
565            }
566            None => Err(Error::Pop(span!())),
567        }
568    }
569
570    fn translate_expr_to_mergefn(
571        &self,
572        expr: &ResolvedExpr,
573    ) -> Result<egglog_bridge::MergeFn, Error> {
574        match expr {
575            GenericExpr::Lit(_, literal) => {
576                let val = literal_to_value(&self.backend, literal);
577                Ok(egglog_bridge::MergeFn::Const(val))
578            }
579            GenericExpr::Var(span, resolved_var) => match resolved_var.name.as_str() {
580                "old" => Ok(egglog_bridge::MergeFn::Old),
581                "new" => Ok(egglog_bridge::MergeFn::New),
582                // NB: type-checking should already catch unbound variables here.
583                _ => Err(TypeError::Unbound(resolved_var.name.clone(), span.clone()).into()),
584            },
585            GenericExpr::Call(_, ResolvedCall::Func(f), args) => {
586                let translated_args = args
587                    .iter()
588                    .map(|arg| self.translate_expr_to_mergefn(arg))
589                    .collect::<Result<Vec<_>, _>>()?;
590                Ok(egglog_bridge::MergeFn::Function(
591                    self.functions[&f.name].backend_id,
592                    translated_args,
593                ))
594            }
595            GenericExpr::Call(_, ResolvedCall::Primitive(p), args) => {
596                let translated_args = args
597                    .iter()
598                    .map(|arg| self.translate_expr_to_mergefn(arg))
599                    .collect::<Result<Vec<_>, _>>()?;
600                Ok(egglog_bridge::MergeFn::Primitive(
601                    p.external_id(),
602                    translated_args,
603                ))
604            }
605        }
606    }
607
608    fn declare_function(&mut self, decl: &ResolvedFunctionDecl) -> Result<(), Error> {
609        let get_sort = |name: &String| match self.type_info.get_sort_by_name(name) {
610            Some(sort) => Ok(sort.clone()),
611            None => Err(Error::TypeError(TypeError::UndefinedSort(
612                name.to_owned(),
613                decl.span.clone(),
614            ))),
615        };
616
617        let input = decl
618            .schema
619            .input
620            .iter()
621            .map(get_sort)
622            .collect::<Result<Vec<_>, _>>()?;
623        let output = get_sort(&decl.schema.output)?;
624
625        let can_subsume = match decl.subtype {
626            FunctionSubtype::Constructor => true,
627            FunctionSubtype::Custom => false,
628        };
629
630        use egglog_bridge::{DefaultVal, MergeFn};
631        let backend_id = self.backend.add_table(egglog_bridge::FunctionConfig {
632            schema: input
633                .iter()
634                .chain([&output])
635                .map(|sort| sort.column_ty(&self.backend))
636                .collect(),
637            default: match decl.subtype {
638                FunctionSubtype::Constructor => DefaultVal::FreshId,
639                FunctionSubtype::Custom => DefaultVal::Fail,
640            },
641            merge: match decl.subtype {
642                FunctionSubtype::Constructor => MergeFn::UnionId,
643                FunctionSubtype::Custom => match &decl.merge {
644                    None => MergeFn::AssertEq,
645                    Some(expr) => self.translate_expr_to_mergefn(expr)?,
646                },
647            },
648            name: decl.name.to_string(),
649            can_subsume,
650        });
651
652        let function = Function {
653            decl: decl.clone(),
654            schema: ResolvedSchema { input, output },
655            can_subsume,
656            backend_id,
657        };
658
659        let old = self.functions.insert(decl.name.clone(), function);
660        if old.is_some() {
661            panic!(
662                "Typechecking should have caught function already bound: {}",
663                decl.name
664            );
665        }
666
667        Ok(())
668    }
669
670    /// Extract rows of a table using the default cost model with name sym
671    /// The `include_output` parameter controls whether the output column is always extracted
672    /// For functions, the output column is usually useful
673    /// For constructors and relations, the output column can be ignored
674    pub fn function_to_dag(
675        &self,
676        sym: &str,
677        n: usize,
678        include_output: bool,
679    ) -> Result<(Vec<TermId>, Option<Vec<TermId>>, TermDag), Error> {
680        let func = self
681            .functions
682            .get(sym)
683            .ok_or(TypeError::UnboundFunction(sym.to_owned(), span!()))?;
684        let mut rootsorts = func.schema.input.clone();
685        if include_output {
686            rootsorts.push(func.schema.output.clone());
687        }
688        let extractor = Extractor::compute_costs_from_rootsorts(
689            Some(rootsorts),
690            self,
691            TreeAdditiveCostModel::default(),
692        );
693
694        let mut termdag = TermDag::default();
695        let mut inputs: Vec<TermId> = Vec::new();
696        let mut output: Option<Vec<TermId>> = if include_output {
697            Some(Vec::new())
698        } else {
699            None
700        };
701
702        let extract_row = |row: egglog_bridge::FunctionRow| {
703            if inputs.len() < n {
704                // include subsumed rows
705                let mut children: Vec<TermId> = Vec::new();
706                for (value, sort) in row.vals.iter().zip(&func.schema.input) {
707                    let (_, term_id) = extractor
708                        .extract_best_with_sort(self, &mut termdag, *value, sort.clone())
709                        .unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
710                    children.push(term_id);
711                }
712                inputs.push(termdag.app(sym.to_owned(), children));
713                if include_output {
714                    let value = row.vals[func.schema.input.len()];
715                    let sort = &func.schema.output;
716                    let (_, term) = extractor
717                        .extract_best_with_sort(self, &mut termdag, value, sort.clone())
718                        .unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
719                    output.as_mut().unwrap().push(term);
720                }
721                true
722            } else {
723                false
724            }
725        };
726
727        self.backend.for_each_while(func.backend_id, extract_row);
728
729        Ok((inputs, output, termdag))
730    }
731
732    /// Print up to `n` the tuples in a given function.
733    /// Print all tuples if `n` is not provided.
734    pub fn print_function(
735        &mut self,
736        sym: &str,
737        n: Option<usize>,
738        file: Option<File>,
739        mode: PrintFunctionMode,
740    ) -> Result<Option<CommandOutput>, Error> {
741        let n = match n {
742            Some(n) => {
743                log::info!("Printing up to {n} tuples of function {sym} as {mode}");
744                n
745            }
746            None => {
747                log::info!("Printing all tuples of function {sym} as {mode}");
748                usize::MAX
749            }
750        };
751
752        let (terms, outputs, termdag) = self.function_to_dag(sym, n, true)?;
753        let f = self
754            .functions
755            .get(sym)
756            // function_to_dag should have checked this
757            .unwrap();
758        let terms_and_outputs: Vec<_> = terms.into_iter().zip(outputs.unwrap()).collect();
759        let output = CommandOutput::PrintFunction(f.clone(), termdag, terms_and_outputs, mode);
760        match file {
761            Some(mut file) => {
762                log::info!("Writing output to file");
763                file.write_all(output.to_string().as_bytes())
764                    .expect("Error writing to file");
765                Ok(None)
766            }
767            None => Ok(Some(output)),
768        }
769    }
770
771    /// Print the size of a function. If no function name is provided,
772    /// print the size of all functions in "name: len" pairs.
773    pub fn print_size(&self, sym: Option<&str>) -> Result<CommandOutput, Error> {
774        if let Some(sym) = sym {
775            let f = self
776                .functions
777                .get(sym)
778                .ok_or(TypeError::UnboundFunction(sym.to_owned(), span!()))?;
779            let size = self.backend.table_size(f.backend_id);
780            log::info!("Function {} has size {}", sym, size);
781            Ok(CommandOutput::PrintFunctionSize(size))
782        } else {
783            // Print size of all functions
784            let mut lens = self
785                .functions
786                .iter()
787                .map(|(sym, f)| (sym.clone(), self.backend.table_size(f.backend_id)))
788                .collect::<Vec<_>>();
789
790            // Function name's alphabetical order
791            lens.sort_by_key(|(name, _)| name.clone());
792            if log_enabled!(Level::Info) {
793                for (sym, len) in &lens {
794                    log::info!("Function {} has size {}", sym, len);
795                }
796            }
797            Ok(CommandOutput::PrintAllFunctionsSize(lens))
798        }
799    }
800
801    // returns whether the egraph was updated
802    fn run_schedule(&mut self, sched: &ResolvedSchedule) -> Result<RunReport, Error> {
803        match sched {
804            ResolvedSchedule::Run(span, config) => self.run_rules(span, config),
805            ResolvedSchedule::Repeat(_span, limit, sched) => {
806                let mut report = RunReport::default();
807                for _i in 0..*limit {
808                    let rec = self.run_schedule(sched)?;
809                    let updated = rec.updated;
810                    report.union(rec);
811                    if !updated {
812                        break;
813                    }
814                }
815                Ok(report)
816            }
817            ResolvedSchedule::Saturate(_span, sched) => {
818                let mut report = RunReport::default();
819                loop {
820                    let rec = self.run_schedule(sched)?;
821                    let updated = rec.updated;
822                    report.union(rec);
823                    if !updated {
824                        break;
825                    }
826                }
827                Ok(report)
828            }
829            ResolvedSchedule::Sequence(_span, scheds) => {
830                let mut report = RunReport::default();
831                for sched in scheds {
832                    report.union(self.run_schedule(sched)?);
833                }
834                Ok(report)
835            }
836        }
837    }
838
839    /// Extract a value to a [`TermDag`] and [`TermId`] in the [`TermDag`] using the default cost model.
840    /// See also [`EGraph::extract_value_with_cost_model`] for more control.
841    pub fn extract_value(
842        &self,
843        sort: &ArcSort,
844        value: Value,
845    ) -> Result<(TermDag, TermId, DefaultCost), Error> {
846        self.extract_value_with_cost_model(sort, value, TreeAdditiveCostModel::default())
847    }
848
849    /// Extract a value to a [`TermDag`] and [`TermId`] in the [`TermDag`].
850    /// Note that the `TermDag` may contain a superset of the nodes referenced by the returned `TermId`.
851    /// See also [`EGraph::extract_value_to_string`] for convenience.
852    pub fn extract_value_with_cost_model<CM: CostModel<DefaultCost> + 'static>(
853        &self,
854        sort: &ArcSort,
855        value: Value,
856        cost_model: CM,
857    ) -> Result<(TermDag, TermId, DefaultCost), Error> {
858        let extractor =
859            Extractor::compute_costs_from_rootsorts(Some(vec![sort.clone()]), self, cost_model);
860        let mut termdag = TermDag::default();
861        let (cost, term) = extractor.extract_best(self, &mut termdag, value).unwrap();
862        Ok((termdag, term, cost))
863    }
864
865    /// Extract a value to a string for printing.
866    /// See also [`EGraph::extract_value`] for more control.
867    pub fn extract_value_to_string(
868        &self,
869        sort: &ArcSort,
870        value: Value,
871    ) -> Result<(String, DefaultCost), Error> {
872        let (termdag, term, cost) = self.extract_value(sort, value)?;
873        Ok((termdag.to_string(term), cost))
874    }
875
876    fn run_rules(&mut self, span: &Span, config: &ResolvedRunConfig) -> Result<RunReport, Error> {
877        log::debug!("Running ruleset: {}", config.ruleset);
878        let mut report: RunReport = Default::default();
879
880        let GenericRunConfig { ruleset, until } = config;
881
882        if let Some(facts) = until {
883            if self.check_facts(span, facts).is_ok() {
884                log::info!(
885                    "Breaking early because of facts:\n {}!",
886                    ListDisplay(facts, "\n")
887                );
888                return Ok(report);
889            }
890        }
891
892        let subreport = self.step_rules(ruleset)?;
893        report.union(subreport);
894
895        if log_enabled!(Level::Debug) {
896            log::debug!("database size: {}", self.num_tuples());
897        }
898
899        Ok(report)
900    }
901
902    /// Runs a ruleset for an iteration.
903    ///
904    /// This applies every match it finds (under semi-naive).
905    /// See [`EGraph::step_rules_with_scheduler`] for more fine-grained control.
906    ///
907    /// This will return an error if an egglog primitive returns None in an action.
908    pub fn step_rules(&mut self, ruleset: &str) -> Result<RunReport, Error> {
909        fn collect_rule_ids(
910            ruleset: &str,
911            rulesets: &IndexMap<String, Ruleset>,
912            ids: &mut Vec<egglog_bridge::RuleId>,
913        ) {
914            match &rulesets[ruleset] {
915                Ruleset::Rules(rules) => {
916                    for (_, id) in rules.values() {
917                        ids.push(*id);
918                    }
919                }
920                Ruleset::Combined(sub_rulesets) => {
921                    for sub_ruleset in sub_rulesets {
922                        collect_rule_ids(sub_ruleset, rulesets, ids);
923                    }
924                }
925            }
926        }
927
928        let mut rule_ids = Vec::new();
929        collect_rule_ids(ruleset, &self.rulesets, &mut rule_ids);
930
931        let iteration_report = self
932            .backend
933            .run_rules(&rule_ids)
934            .map_err(|e| Error::BackendError(e.to_string()))?;
935
936        Ok(RunReport::singleton(ruleset, iteration_report))
937    }
938
939    fn add_rule(&mut self, rule: ast::ResolvedRule) -> Result<String, Error> {
940        // Disable union_to_set optimization in proof or term encoding mode, since
941        // it expects only `union` on constructors (not set).
942        let core_rule = rule.to_canonicalized_core_rule(
943            &self.type_info,
944            &mut self.parser.symbol_gen,
945            self.proof_state.original_typechecking.is_none(),
946        )?;
947        let (query, actions) = (&core_rule.body, &core_rule.head);
948
949        let rule_id = {
950            let mut translator = BackendRule::new(
951                self.backend.new_rule(&rule.name, self.seminaive),
952                &self.functions,
953                &self.type_info,
954            );
955            translator.query(query, false);
956            translator.actions(actions)?;
957            translator.build()
958        };
959
960        if let Some(rules) = self.rulesets.get_mut(&rule.ruleset) {
961            match rules {
962                Ruleset::Rules(rules) => {
963                    match rules.entry(rule.name.clone()) {
964                        indexmap::map::Entry::Occupied(_) => {
965                            let name = rule.name;
966                            panic!("Rule '{name}' was already present")
967                        }
968                        indexmap::map::Entry::Vacant(e) => e.insert((core_rule, rule_id)),
969                    };
970                    Ok(rule.name)
971                }
972                Ruleset::Combined(_) => Err(Error::CombinedRulesetError(rule.ruleset, rule.span)),
973            }
974        } else {
975            Err(Error::NoSuchRuleset(rule.ruleset, rule.span))
976        }
977    }
978
979    fn eval_actions(&mut self, actions: &ResolvedActions) -> Result<(), Error> {
980        let mut binding = IndexSet::default();
981        let mut ctx = CoreActionContext::new(
982            &self.type_info,
983            &mut binding,
984            &mut self.parser.symbol_gen,
985            self.proof_state.original_typechecking.is_none(),
986        );
987        let (actions, _) = actions.to_core_actions(&mut ctx)?;
988
989        let mut translator = BackendRule::new(
990            self.backend.new_rule("eval_actions", false),
991            &self.functions,
992            &self.type_info,
993        );
994        translator.actions(&actions)?;
995        let id = translator.build();
996        let result = self.backend.run_rules(&[id]);
997        self.backend.free_rule(id);
998
999        match result {
1000            Ok(_) => Ok(()),
1001            Err(e) => Err(Error::BackendError(e.to_string())),
1002        }
1003    }
1004
1005    /// Evaluates an expression, returns the sort of the expression and the evaluation result.
1006    pub fn eval_expr(&mut self, expr: &Expr) -> Result<(ArcSort, Value), Error> {
1007        let span = expr.span();
1008        let command = Command::Action(Action::Expr(span.clone(), expr.clone()));
1009        let resolved_commands = self.resolve_command(command)?;
1010        assert_eq!(resolved_commands.len(), 1);
1011        let resolved_command = resolved_commands.into_iter().next().unwrap();
1012        let resolved_expr = match resolved_command {
1013            ResolvedNCommand::CoreAction(ResolvedAction::Expr(_, resolved_expr)) => resolved_expr,
1014            _ => unreachable!(),
1015        };
1016        let sort = resolved_expr.output_type();
1017        let value = self.eval_resolved_expr(span, &resolved_expr)?;
1018        Ok((sort, value))
1019    }
1020
1021    fn eval_resolved_expr(&mut self, span: Span, expr: &ResolvedExpr) -> Result<Value, Error> {
1022        let unit_id = self.backend.base_values().get_ty::<()>();
1023        let unit_val = self.backend.base_values().get(());
1024
1025        let result: egglog_bridge::SideChannel<Value> = Default::default();
1026        let result_ref = result.clone();
1027        let ext_id = self
1028            .backend
1029            .register_external_func(Box::new(make_external_func(move |_es, vals| {
1030                debug_assert!(vals.len() == 1);
1031                *result_ref.lock().unwrap() = Some(vals[0]);
1032                Some(unit_val)
1033            })));
1034
1035        let mut translator = BackendRule::new(
1036            self.backend.new_rule("eval_resolved_expr", false),
1037            &self.functions,
1038            &self.type_info,
1039        );
1040
1041        let result_var = ResolvedVar {
1042            name: self.parser.symbol_gen.fresh("eval_resolved_expr"),
1043            sort: expr.output_type(),
1044            is_global_ref: false,
1045        };
1046        let actions = ResolvedActions::singleton(ResolvedAction::Let(
1047            span.clone(),
1048            result_var.clone(),
1049            expr.clone(),
1050        ));
1051        let mut binding = IndexSet::default();
1052        let mut ctx = CoreActionContext::new(
1053            &self.type_info,
1054            &mut binding,
1055            &mut self.parser.symbol_gen,
1056            self.proof_state.original_typechecking.is_none(),
1057        );
1058        let actions = actions.to_core_actions(&mut ctx)?.0;
1059        translator.actions(&actions)?;
1060
1061        let arg = translator.entry(&ResolvedAtomTerm::Var(span.clone(), result_var));
1062        translator.rb.call_external_func(
1063            ext_id,
1064            &[arg],
1065            egglog_bridge::ColumnTy::Base(unit_id),
1066            || "this function will never panic".to_string(),
1067        );
1068
1069        let id = translator.build();
1070        let rule_result = self.backend.run_rules(&[id]);
1071        self.backend.free_rule(id);
1072        self.backend.free_external_func(ext_id);
1073        let _ = rule_result.map_err(|e| {
1074            Error::BackendError(format!("Failed to evaluate expression '{}': {}", expr, e))
1075        })?;
1076
1077        let result = result.lock().unwrap().unwrap();
1078        Ok(result)
1079    }
1080
1081    fn add_combined_ruleset(&mut self, name: String, rulesets: Vec<String>) {
1082        match self.rulesets.entry(name.clone()) {
1083            Entry::Occupied(_) => panic!("Ruleset '{name}' was already present"),
1084            Entry::Vacant(e) => e.insert(Ruleset::Combined(rulesets)),
1085        };
1086    }
1087
1088    fn add_ruleset(&mut self, name: String) {
1089        match self.rulesets.entry(name.clone()) {
1090            Entry::Occupied(_) => panic!("Ruleset '{name}' was already present"),
1091            Entry::Vacant(e) => e.insert(Ruleset::Rules(Default::default())),
1092        };
1093    }
1094
1095    fn check_facts(&mut self, span: &Span, facts: &[ResolvedFact]) -> Result<(), Error> {
1096        let fresh_name = self.parser.symbol_gen.fresh("check_facts");
1097        let fresh_ruleset = self.parser.symbol_gen.fresh("check_facts_ruleset");
1098        let rule = ast::ResolvedRule {
1099            span: span.clone(),
1100            head: ResolvedActions::default(),
1101            body: facts.to_vec(),
1102            name: fresh_name.clone(),
1103            ruleset: fresh_ruleset.clone(),
1104        };
1105        let core_rule = rule.to_canonicalized_core_rule(
1106            &self.type_info,
1107            &mut self.parser.symbol_gen,
1108            self.proof_state.original_typechecking.is_none(),
1109        )?;
1110        let query = core_rule.body;
1111
1112        let ext_sc = egglog_bridge::SideChannel::default();
1113        let ext_sc_ref = ext_sc.clone();
1114        let ext_id = self
1115            .backend
1116            .register_external_func(Box::new(make_external_func(move |_, _| {
1117                *ext_sc_ref.lock().unwrap() = Some(());
1118                Some(Value::new_const(0))
1119            })));
1120
1121        let mut translator = BackendRule::new(
1122            self.backend.new_rule("check_facts", false),
1123            &self.functions,
1124            &self.type_info,
1125        );
1126        translator.query(&query, true);
1127        translator
1128            .rb
1129            .call_external_func(ext_id, &[], egglog_bridge::ColumnTy::Id, || {
1130                "this function will never panic".to_string()
1131            });
1132        let id = translator.build();
1133        let _ = self.backend.run_rules(&[id]).unwrap();
1134        self.backend.free_rule(id);
1135
1136        self.backend.free_external_func(ext_id);
1137
1138        let ext_sc_val = ext_sc.lock().unwrap().take();
1139        let matched = matches!(ext_sc_val, Some(()));
1140
1141        if !matched {
1142            Err(Error::CheckError(
1143                facts.iter().map(|f| f.clone().make_unresolved()).collect(),
1144                span.clone(),
1145            ))
1146        } else {
1147            Ok(())
1148        }
1149    }
1150
1151    fn run_command(&mut self, command: ResolvedNCommand) -> Result<Option<CommandOutput>, Error> {
1152        match command {
1153            // Sorts are already declared during typechecking
1154            ResolvedNCommand::Sort(_span, name, _presort_and_args) => {
1155                log::info!("Declared sort {}.", name)
1156            }
1157            ResolvedNCommand::Function(fdecl) => {
1158                self.declare_function(&fdecl)?;
1159                log::info!("Declared {} {}.", fdecl.subtype, fdecl.name)
1160            }
1161            ResolvedNCommand::AddRuleset(_span, name) => {
1162                self.add_ruleset(name.clone());
1163                log::info!("Declared ruleset {name}.");
1164            }
1165            ResolvedNCommand::UnstableCombinedRuleset(_span, name, others) => {
1166                self.add_combined_ruleset(name.clone(), others);
1167                log::info!("Declared ruleset {name}.");
1168            }
1169            ResolvedNCommand::NormRule { rule } => {
1170                let name = rule.name.clone();
1171                self.add_rule(rule)?;
1172                log::info!("Declared rule {name}.")
1173            }
1174            ResolvedNCommand::RunSchedule(sched) => {
1175                let report = self.run_schedule(&sched)?;
1176                log::info!("Ran schedule {}.", sched);
1177                log::info!("Report: {}", report);
1178                self.overall_run_report.union(report.clone());
1179                return Ok(Some(CommandOutput::RunSchedule(report)));
1180            }
1181            ResolvedNCommand::PrintOverallStatistics(span, file) => match file {
1182                None => {
1183                    log::info!("Printed overall statistics");
1184                    return Ok(Some(CommandOutput::OverallStatistics(
1185                        self.overall_run_report.clone(),
1186                    )));
1187                }
1188                Some(path) => {
1189                    let mut file = std::fs::File::create(&path)
1190                        .map_err(|e| Error::IoError(path.clone().into(), e, span.clone()))?;
1191                    log::info!("Printed overall statistics to json file {}", path);
1192
1193                    serde_json::to_writer(&mut file, &self.overall_run_report)
1194                        .expect("error serializing to json");
1195                }
1196            },
1197            ResolvedNCommand::Check(span, facts) => {
1198                self.check_facts(&span, &facts)?;
1199                log::info!("Checked fact {:?}.", facts);
1200            }
1201            ResolvedNCommand::CoreAction(action) => match &action {
1202                ResolvedAction::Let(_, name, contents) => {
1203                    panic!("Globals should have been desugared away: {name} = {contents}")
1204                }
1205                _ => {
1206                    self.eval_actions(&ResolvedActions::new(vec![action.clone()]))?;
1207                }
1208            },
1209            ResolvedNCommand::Extract(span, expr, variants) => {
1210                let sort = expr.output_type();
1211
1212                let x = self.eval_resolved_expr(span.clone(), &expr)?;
1213                let n = self.eval_resolved_expr(span, &variants)?;
1214                let n: i64 = self.backend.base_values().unwrap(n);
1215
1216                let mut termdag = TermDag::default();
1217
1218                let extractor = Extractor::compute_costs_from_rootsorts(
1219                    Some(vec![sort]),
1220                    self,
1221                    TreeAdditiveCostModel::default(),
1222                );
1223                return if n == 0 {
1224                    if let Some((cost, term)) = extractor.extract_best(self, &mut termdag, x) {
1225                        // dont turn termdag into a string if we have messages disabled for performance reasons
1226                        if log_enabled!(Level::Info) {
1227                            log::info!("extracted with cost {cost}: {}", termdag.to_string(term));
1228                        }
1229                        Ok(Some(CommandOutput::ExtractBest(termdag, cost, term)))
1230                    } else {
1231                        Err(Error::ExtractError(
1232                            "Unable to find any valid extraction (likely due to subsume or delete)"
1233                                .to_string(),
1234                        ))
1235                    }
1236                } else {
1237                    if n < 0 {
1238                        panic!("Cannot extract negative number of variants");
1239                    }
1240                    let terms: Vec<TermId> = extractor
1241                        .extract_variants(self, &mut termdag, x, n as usize)
1242                        .iter()
1243                        .map(|e| e.1)
1244                        .collect();
1245                    if log_enabled!(Level::Info) {
1246                        let expr_str = expr.to_string();
1247                        log::info!("extracted {} variants for {expr_str}", terms.len());
1248                    }
1249                    Ok(Some(CommandOutput::ExtractVariants(termdag, terms)))
1250                };
1251            }
1252            ResolvedNCommand::Push(n) => {
1253                (0..n).for_each(|_| self.push());
1254                log::info!("Pushed {n} levels.")
1255            }
1256            ResolvedNCommand::Pop(span, n) => {
1257                for _ in 0..n {
1258                    self.pop().map_err(|err| {
1259                        if let Error::Pop(_) = err {
1260                            Error::Pop(span.clone())
1261                        } else {
1262                            err
1263                        }
1264                    })?;
1265                }
1266                log::info!("Popped {n} levels.")
1267            }
1268            ResolvedNCommand::PrintFunction(span, f, n, file, mode) => {
1269                let file = file
1270                    .map(|file| {
1271                        std::fs::File::create(&file)
1272                            .map_err(|e| Error::IoError(file.into(), e, span.clone()))
1273                    })
1274                    .transpose()?;
1275                return self.print_function(&f, n, file, mode).map_err(|e| match e {
1276                    Error::TypeError(TypeError::UnboundFunction(f, _)) => {
1277                        Error::TypeError(TypeError::UnboundFunction(f, span.clone()))
1278                    }
1279                    // This case is currently impossible
1280                    _ => e,
1281                });
1282            }
1283            ResolvedNCommand::PrintSize(span, f) => {
1284                let res = self.print_size(f.as_deref()).map_err(|e| match e {
1285                    Error::TypeError(TypeError::UnboundFunction(f, _)) => {
1286                        Error::TypeError(TypeError::UnboundFunction(f, span.clone()))
1287                    }
1288                    // This case is currently impossible
1289                    _ => e,
1290                })?;
1291                return Ok(Some(res));
1292            }
1293            ResolvedNCommand::Fail(span, c) => {
1294                let result = self.run_command(*c);
1295                if let Err(e) = result {
1296                    log::info!("Command failed as expected: {e}");
1297                } else {
1298                    return Err(Error::ExpectFail(span));
1299                }
1300            }
1301            ResolvedNCommand::Input {
1302                span: _,
1303                name,
1304                file,
1305            } => {
1306                self.input_file(&name, file)?;
1307            }
1308            ResolvedNCommand::Output { span, file, exprs } => {
1309                let mut filename = self.fact_directory.clone().unwrap_or_default();
1310                filename.push(file.as_str());
1311                // append to file
1312                let mut f = File::options()
1313                    .append(true)
1314                    .create(true)
1315                    .open(&filename)
1316                    .map_err(|e| Error::IoError(filename.clone(), e, span.clone()))?;
1317
1318                let extractor = Extractor::compute_costs_from_rootsorts(
1319                    None,
1320                    self,
1321                    TreeAdditiveCostModel::default(),
1322                );
1323                let mut termdag: TermDag = Default::default();
1324
1325                use std::io::Write;
1326                for expr in exprs {
1327                    let value = self.eval_resolved_expr(span.clone(), &expr)?;
1328                    let expr_type = expr.output_type();
1329
1330                    let term = extractor
1331                        .extract_best_with_sort(self, &mut termdag, value, expr_type)
1332                        .unwrap()
1333                        .1;
1334                    writeln!(f, "{}", termdag.to_string(term))
1335                        .map_err(|e| Error::IoError(filename.clone(), e, span.clone()))?;
1336                }
1337
1338                log::info!("Output to '{filename:?}'.")
1339            }
1340            ResolvedNCommand::UserDefined(_span, name, exprs) => {
1341                let command = self.commands.swap_remove(&name).unwrap_or_else(|| {
1342                    panic!("Unrecognized user-defined command: {}", name);
1343                });
1344                let res = command.update(self, &exprs);
1345                self.commands.insert(name, command);
1346                return res;
1347            }
1348            ResolvedNCommand::ProveExists(span, resolved_call) => {
1349                let mut instrument = ProofInstrumentor { egraph: self };
1350                let (proof_store, proof_id) =
1351                    instrument
1352                        .prove_exists(&resolved_call)
1353                        .map_err(|error| Error::ProofError {
1354                            span: span.clone(),
1355                            error,
1356                        })?;
1357                return Ok(Some(CommandOutput::ProveExists {
1358                    proof_store,
1359                    proof_id,
1360                }));
1361            }
1362        };
1363
1364        Ok(None)
1365    }
1366
1367    fn input_file(&mut self, func_name: &str, file: String) -> Result<(), Error> {
1368        let function_type = self
1369            .type_info
1370            .get_func_type(func_name)
1371            .unwrap_or_else(|| panic!("Unrecognized function name {}", func_name));
1372        let func = self.functions.get_mut(func_name).unwrap();
1373
1374        let mut filename = self.fact_directory.clone().unwrap_or_default();
1375        filename.push(file.as_str());
1376
1377        // check that the function uses supported types
1378
1379        for t in &func.schema.input {
1380            match t.name() {
1381                "i64" | "f64" | "String" => {}
1382                s => panic!("Unsupported type {} for input", s),
1383            }
1384        }
1385
1386        if function_type.subtype != FunctionSubtype::Constructor {
1387            match func.schema.output.name() {
1388                "i64" | "String" | "Unit" => {}
1389                s => panic!("Unsupported type {} for input", s),
1390            }
1391        }
1392
1393        log::info!("Opening file '{:?}'...", filename);
1394        let mut f = File::open(filename).unwrap();
1395        let mut contents = String::new();
1396        f.read_to_string(&mut contents).unwrap();
1397
1398        // Can also do a row-major Vec<Value>
1399        let mut parsed_contents: Vec<Vec<Value>> = Vec::with_capacity(contents.lines().count());
1400
1401        let mut row_schema = func.schema.input.clone();
1402        if function_type.subtype == FunctionSubtype::Custom {
1403            row_schema.push(func.schema.output.clone());
1404        }
1405
1406        log::debug!("{:?}", row_schema);
1407
1408        let unit_val = self.backend.base_values().get(());
1409
1410        for line in contents.lines() {
1411            let mut it = line.split('\t').map(|s| s.trim());
1412
1413            let mut row: Vec<Value> = Vec::with_capacity(row_schema.len());
1414
1415            for sort in row_schema.iter() {
1416                if let Some(raw) = it.next() {
1417                    let val = match sort.name() {
1418                        "i64" => {
1419                            if let Ok(i) = raw.parse::<i64>() {
1420                                self.backend.base_values().get(i)
1421                            } else {
1422                                return Err(Error::InputFileFormatError(file));
1423                            }
1424                        }
1425                        "f64" => {
1426                            if let Ok(f) = raw.parse::<f64>() {
1427                                self.backend
1428                                    .base_values()
1429                                    .get::<F>(core_relations::Boxed::new(f.into()))
1430                            } else {
1431                                return Err(Error::InputFileFormatError(file));
1432                            }
1433                        }
1434                        "String" => self.backend.base_values().get::<S>(raw.to_string().into()),
1435                        "Unit" => unit_val,
1436                        _ => panic!("Unreachable"),
1437                    };
1438                    row.push(val);
1439                } else {
1440                    break;
1441                }
1442            }
1443
1444            if row.is_empty() {
1445                continue;
1446            }
1447
1448            if row.len() != row_schema.len() || it.next().is_some() {
1449                return Err(Error::InputFileFormatError(file));
1450            }
1451
1452            parsed_contents.push(row);
1453        }
1454
1455        log::debug!("Successfully loaded file.");
1456
1457        let num_facts = parsed_contents.len();
1458
1459        let mut table_action = egglog_bridge::TableAction::new(&self.backend, func.backend_id);
1460
1461        if function_type.subtype != FunctionSubtype::Constructor {
1462            self.backend.with_execution_state(|es| {
1463                for row in parsed_contents.iter() {
1464                    table_action.insert(es, row.iter().copied());
1465                }
1466                Some(unit_val)
1467            });
1468        } else {
1469            self.backend.with_execution_state(|es| {
1470                for row in parsed_contents.iter() {
1471                    table_action.lookup(es, row);
1472                }
1473                Some(unit_val)
1474            });
1475        }
1476
1477        self.backend.flush_updates();
1478
1479        log::info!("Read {num_facts} facts into {func_name} from '{file}'.");
1480        Ok(())
1481    }
1482
1483    /// Desugars, typechecks, and removes globals from a single [`Command`].
1484    /// Leverages previous type information in the [`EGraph`] to do so, adding new type information.
1485    fn resolve_command(&mut self, command: Command) -> Result<Vec<ResolvedNCommand>, Error> {
1486        let desugared = desugar_command(command, &mut self.parser, self.proof_state.proof_testing)?;
1487
1488        // Add term encoding when it is enabled
1489        if let Some(original_typechecking) = self.proof_state.original_typechecking.as_mut() {
1490            // Typecheck using the original egraph
1491            // TODO this is ugly- we don't need an entire e-graph just for type information.
1492            let typechecked = original_typechecking.typecheck_program(&desugared)?;
1493
1494            for command in &typechecked {
1495                if !command_supports_proof_encoding(&command.to_command(), &self.type_info) {
1496                    let command_text = format!("{}", command.to_command());
1497                    return Err(Error::UnsupportedProofCommand {
1498                        command: command_text,
1499                    });
1500                }
1501            }
1502
1503            let normalized = proof_form(typechecked, &mut self.parser.symbol_gen);
1504
1505            // Desugared commands are in proof form but globals are still using let bindings.
1506            self.desugared_commands.extend_from_slice(&normalized);
1507
1508            // Now remove globals for actual execution (but NOT from desugared_commands)
1509            let typechecked_no_globals =
1510                proof_global_remover::remove_globals(normalized, &mut self.parser.symbol_gen);
1511            for command in &typechecked_no_globals {
1512                self.names.check_shadowing(command)?;
1513            }
1514
1515            let term_encoding_added =
1516                ProofInstrumentor::add_term_encoding(self, typechecked_no_globals);
1517            let mut new_typechecked = vec![];
1518            for new_cmd in term_encoding_added {
1519                let desugared =
1520                    desugar_command(new_cmd, &mut self.parser, self.proof_state.proof_testing)?;
1521                for cmd in &desugared {
1522                    log::debug!("Desugared term encoding: {}", cmd.to_command());
1523                }
1524
1525                // Now typecheck using self, adding term type information.
1526                let desugared_typechecked = self.typecheck_program(&desugared)?;
1527                // remove globals again, but this time allow primitive globals
1528                let desugared_typechecked = remove_globals::remove_globals(
1529                    desugared_typechecked,
1530                    &mut self.parser.symbol_gen,
1531                );
1532
1533                new_typechecked.extend(desugared_typechecked);
1534            }
1535            Ok(new_typechecked)
1536        } else {
1537            let mut typechecked = self.typecheck_program(&desugared)?;
1538
1539            typechecked = remove_globals::remove_globals(typechecked, &mut self.parser.symbol_gen);
1540            for command in &typechecked {
1541                self.names.check_shadowing(command)?;
1542            }
1543            Ok(typechecked)
1544        }
1545    }
1546
1547    /// Run a program, returning the desugared outputs as well as the CommandOutputs.
1548    /// Can optionally not run the commands, just adding type information.
1549    fn process_program_internal(
1550        &mut self,
1551        program: Vec<Command>,
1552        run_commands: bool,
1553    ) -> Result<(Vec<CommandOutput>, Vec<ResolvedCommand>), Error> {
1554        let mut outputs = Vec::new();
1555        let mut desugared_commands = Vec::new();
1556
1557        for before_expanded_command in program {
1558            // First do user-provided macro expansion for this command,
1559            // which may rely on type information from previous commands.
1560            let macro_expanded = self.command_macros.apply(
1561                before_expanded_command,
1562                &mut self.parser.symbol_gen,
1563                &self.type_info,
1564            )?;
1565
1566            for command in macro_expanded {
1567                // handle include specially- we keep them as-is for desugaring
1568                if let Command::Include(span, file) = &command {
1569                    let s = std::fs::read_to_string(file)
1570                        .unwrap_or_else(|_| panic!("{span} Failed to read file {file}"));
1571                    let included_program = self
1572                        .parser
1573                        .get_program_from_string(Some(file.clone()), &s)?;
1574                    // run program internal on these include commands
1575                    let (included_outputs, included_desugared) =
1576                        self.process_program_internal(included_program, run_commands)?;
1577                    outputs.extend(included_outputs);
1578                    desugared_commands.extend(included_desugared);
1579                } else {
1580                    for processed in self.resolve_command(command)? {
1581                        desugared_commands.push(processed.to_command());
1582
1583                        // even in desugar mode we still run push and pop
1584                        if run_commands
1585                            || matches!(
1586                                processed,
1587                                ResolvedNCommand::Push(_) | ResolvedNCommand::Pop(_, _)
1588                            )
1589                        {
1590                            let result = self.run_command(processed)?;
1591                            if let Some(output) = result {
1592                                outputs.push(output);
1593                            }
1594                        }
1595                    }
1596                }
1597            }
1598        }
1599
1600        Ok((outputs, desugared_commands))
1601    }
1602
1603    /// Run a program, represented as an AST.
1604    /// Return a list of messages.
1605    pub fn run_program(&mut self, program: Vec<Command>) -> Result<Vec<CommandOutput>, Error> {
1606        let (outputs, _desugared_commands) = self.process_program_internal(program, true)?;
1607        Ok(outputs)
1608    }
1609
1610    /// Desugars an egglog program by parsing and desugaring each command.
1611    /// Outputs a new egglog program without any syntactic sugar, either user provided ([`CommandMacro`]) or built-in (e.g., `rewrite` commands).
1612    pub fn desugar_program(
1613        &mut self,
1614        filename: Option<String>,
1615        input: &str,
1616    ) -> Result<Vec<ResolvedCommand>, Error> {
1617        let parsed = self.parser.get_program_from_string(filename, input)?;
1618        let (_outputs, desugared_commands) = self.process_program_internal(parsed, false)?;
1619        Ok(desugared_commands)
1620    }
1621
1622    /// Takes a source program `input` and parses it into a list of [`Command`]s.
1623    pub fn parse_program(
1624        &mut self,
1625        filename: Option<String>,
1626        input: &str,
1627    ) -> Result<Vec<Command>, Error> {
1628        let parsed = self.parser.get_program_from_string(filename, input)?;
1629        Ok(parsed)
1630    }
1631
1632    /// Takes a source program `input`, parses it, runs it, and returns a list of messages.
1633    ///
1634    /// `filename` is an optional argument to indicate the source of
1635    /// the program for error reporting. If `filename` is `None`,
1636    /// a default name will be used.
1637    pub fn parse_and_run_program(
1638        &mut self,
1639        filename: Option<String>,
1640        input: &str,
1641    ) -> Result<Vec<CommandOutput>, Error> {
1642        let parsed = self.parser.get_program_from_string(filename, input)?;
1643        self.run_program(parsed)
1644    }
1645
1646    /// Get the number of tuples in the database.
1647    ///
1648    pub fn num_tuples(&self) -> usize {
1649        self.functions
1650            .values()
1651            .map(|f| self.backend.table_size(f.backend_id))
1652            .sum()
1653    }
1654
1655    /// Returns a sort based on the type.
1656    pub fn get_sort<S: Sort>(&self) -> Arc<S> {
1657        self.type_info.get_sort()
1658    }
1659
1660    /// Returns a sort that satisfies the type and predicate.
1661    pub fn get_sort_by<S: Sort>(&self, f: impl Fn(&Arc<S>) -> bool) -> Arc<S> {
1662        self.type_info.get_sort_by(f)
1663    }
1664
1665    /// Returns all sorts based on the type.
1666    pub fn get_sorts<S: Sort>(&self) -> Vec<Arc<S>> {
1667        self.type_info.get_sorts()
1668    }
1669
1670    /// Returns all sorts that satisfy the type and predicate.
1671    pub fn get_sorts_by<S: Sort>(&self, f: impl Fn(&Arc<S>) -> bool) -> Vec<Arc<S>> {
1672        self.type_info.get_sorts_by(f)
1673    }
1674
1675    /// Returns a sort based on the predicate.
1676    pub fn get_arcsort_by(&self, f: impl Fn(&ArcSort) -> bool) -> ArcSort {
1677        self.type_info.get_arcsort_by(f)
1678    }
1679
1680    /// Returns all sorts that satisfy the predicate.
1681    pub fn get_arcsorts_by(&self, f: impl Fn(&ArcSort) -> bool) -> Vec<ArcSort> {
1682        self.type_info.get_arcsorts_by(f)
1683    }
1684
1685    /// Returns the sort with the given name if it exists.
1686    pub fn get_sort_by_name(&self, sym: &str) -> Option<&ArcSort> {
1687        self.type_info.get_sort_by_name(sym)
1688    }
1689
1690    /// Gets the overall run report and returns it.
1691    pub fn get_overall_run_report(&self) -> &RunReport {
1692        &self.overall_run_report
1693    }
1694
1695    /// Convert from an egglog value to a Rust type.
1696    pub fn value_to_base<T: BaseValue>(&self, x: Value) -> T {
1697        self.backend.base_values().unwrap::<T>(x)
1698    }
1699
1700    /// Convert from a Rust type to an egglog value.
1701    pub fn base_to_value<T: BaseValue>(&self, x: T) -> Value {
1702        self.backend.base_values().get::<T>(x)
1703    }
1704
1705    /// Convert from an egglog value to a reference of a Rust container type.
1706    ///
1707    /// Returns `None` if the value cannot be converted to the requested container type.
1708    ///
1709    /// Warning: The return type of this function may contain lock guards.
1710    /// Attempts to modify the contents of the containers database may deadlock if the given guard has not been dropped.
1711    pub fn value_to_container<T: ContainerValue>(
1712        &self,
1713        x: Value,
1714    ) -> Option<impl Deref<Target = T>> {
1715        self.backend.container_values().get_val::<T>(x)
1716    }
1717
1718    /// Convert from a Rust container type to an egglog value.
1719    pub fn container_to_value<T: ContainerValue>(&mut self, x: T) -> Value {
1720        self.backend.with_execution_state(|state| {
1721            self.backend.container_values().register_val::<T>(x, state)
1722        })
1723    }
1724
1725    /// Get the size of a function in the e-graph.
1726    ///
1727    /// `panics` if the function does not exist.
1728    pub fn get_size(&self, func: &str) -> usize {
1729        let function_id = self.functions.get(func).unwrap().backend_id;
1730        self.backend.table_size(function_id)
1731    }
1732
1733    /// Lookup a tuple in afunction in the e-graph.
1734    ///
1735    /// Returns `None` if the tuple does not exist.
1736    /// `panics` if the function does not exist.
1737    pub fn lookup_function(&self, name: &str, key: &[Value]) -> Option<Value> {
1738        let func = self.functions.get(name).unwrap().backend_id;
1739        self.backend.lookup_id(func, key)
1740    }
1741
1742    /// Get a function by name.
1743    ///
1744    /// Returns `None` if the function does not exist.
1745    pub fn get_function(&self, name: &str) -> Option<&Function> {
1746        self.functions.get(name)
1747    }
1748
1749    pub fn set_report_level(&mut self, level: ReportLevel) {
1750        self.backend.set_report_level(level);
1751    }
1752
1753    /// A basic method for dumping the state of the database to `log::info!`.
1754    ///
1755    /// For large tables, this is unlikely to give particularly useful output.
1756    pub fn dump_debug_info(&self) {
1757        self.backend.dump_debug_info();
1758    }
1759
1760    /// Get the canonical representation for `val` based on type.
1761    pub fn get_canonical_value(&self, val: Value, sort: &ArcSort) -> Value {
1762        self.backend
1763            .get_canon_repr(val, sort.column_ty(&self.backend))
1764    }
1765}
1766
1767struct BackendRule<'a> {
1768    rb: egglog_bridge::RuleBuilder<'a>,
1769    entries: HashMap<core::ResolvedAtomTerm, QueryEntry>,
1770    functions: &'a IndexMap<String, Function>,
1771    type_info: &'a TypeInfo,
1772}
1773
1774impl<'a> BackendRule<'a> {
1775    fn new(
1776        rb: egglog_bridge::RuleBuilder<'a>,
1777        functions: &'a IndexMap<String, Function>,
1778        type_info: &'a TypeInfo,
1779    ) -> BackendRule<'a> {
1780        BackendRule {
1781            rb,
1782            functions,
1783            type_info,
1784            entries: Default::default(),
1785        }
1786    }
1787
1788    fn entry(&mut self, x: &core::ResolvedAtomTerm) -> QueryEntry {
1789        self.entries
1790            .entry(x.clone())
1791            .or_insert_with(|| match x {
1792                core::GenericAtomTerm::Var(_, v) => self
1793                    .rb
1794                    .new_var_named(v.sort.column_ty(self.rb.egraph()), &v.name),
1795                core::GenericAtomTerm::Literal(_, l) => literal_to_entry(self.rb.egraph(), l),
1796                core::GenericAtomTerm::Global(..) => {
1797                    panic!("Globals should have been desugared")
1798                }
1799            })
1800            .clone()
1801    }
1802
1803    fn func(&self, f: &typechecking::FuncType) -> egglog_bridge::FunctionId {
1804        self.functions[&f.name].backend_id
1805    }
1806
1807    fn prim(
1808        &mut self,
1809        prim: &core::SpecializedPrimitive,
1810        args: &[core::ResolvedAtomTerm],
1811    ) -> (ExternalFunctionId, Vec<QueryEntry>, ColumnTy) {
1812        let mut qe_args = self.args(args);
1813
1814        if prim.name() == "unstable-fn" {
1815            let core::ResolvedAtomTerm::Literal(_, Literal::String(ref name)) = args[0] else {
1816                panic!("expected string literal after `unstable-fn`")
1817            };
1818            let id = if let Some(f) = self.type_info.get_func_type(name) {
1819                ResolvedFunctionId::Lookup(egglog_bridge::TableAction::new(
1820                    self.rb.egraph(),
1821                    self.func(f),
1822                ))
1823            } else if let Some(possible) = self.type_info.get_prims(name) {
1824                let mut ps: Vec<_> = possible.iter().collect();
1825                ps.retain(|p| {
1826                    self.type_info
1827                        .get_sorts::<FunctionSort>()
1828                        .into_iter()
1829                        .any(|f| {
1830                            let types: Vec<_> = prim
1831                                .input()
1832                                .iter()
1833                                .skip(1)
1834                                .chain(f.inputs())
1835                                .chain([&f.output()])
1836                                .cloned()
1837                                .collect();
1838                            p.accept(&types, self.type_info)
1839                        })
1840                });
1841                assert!(ps.len() == 1, "options for {name}: {ps:?}");
1842                ResolvedFunctionId::Prim(ps.into_iter().next().unwrap().id)
1843            } else {
1844                panic!("no callable for {name}");
1845            };
1846            let partial_arcsorts = prim.input().iter().skip(1).cloned().collect();
1847
1848            qe_args[0] = self.rb.egraph().base_value_constant(ResolvedFunction {
1849                id,
1850                partial_arcsorts,
1851                name: name.clone(),
1852            });
1853        }
1854
1855        (
1856            prim.external_id(),
1857            qe_args,
1858            prim.output().column_ty(self.rb.egraph()),
1859        )
1860    }
1861
1862    fn args<'b>(
1863        &mut self,
1864        args: impl IntoIterator<Item = &'b core::ResolvedAtomTerm>,
1865    ) -> Vec<QueryEntry> {
1866        args.into_iter().map(|x| self.entry(x)).collect()
1867    }
1868
1869    fn query(&mut self, query: &core::Query<ResolvedCall, ResolvedVar>, include_subsumed: bool) {
1870        for atom in &query.atoms {
1871            match &atom.head {
1872                ResolvedCall::Func(f) => {
1873                    let f = self.func(f);
1874                    let args = self.args(&atom.args);
1875                    let is_subsumed = match include_subsumed {
1876                        true => None,
1877                        false => Some(false),
1878                    };
1879                    self.rb.query_table(f, &args, is_subsumed).unwrap();
1880                }
1881                ResolvedCall::Primitive(p) => {
1882                    let (p, args, ty) = self.prim(p, &atom.args);
1883                    self.rb.query_prim(p, &args, ty).unwrap()
1884                }
1885            }
1886        }
1887    }
1888
1889    fn actions(&mut self, actions: &core::ResolvedCoreActions) -> Result<(), Error> {
1890        for action in &actions.0 {
1891            match action {
1892                core::GenericCoreAction::Let(span, v, f, args) => {
1893                    let v = core::GenericAtomTerm::Var(span.clone(), v.clone());
1894                    let y = match f {
1895                        ResolvedCall::Func(f) => {
1896                            let name = f.name.clone();
1897                            let f = self.func(f);
1898                            let args = self.args(args);
1899                            let span = span.clone();
1900                            self.rb.lookup(f, &args, move || {
1901                                format!("{span}: lookup of function {name} failed")
1902                            })
1903                        }
1904                        ResolvedCall::Primitive(p) => {
1905                            let name = p.name().to_owned();
1906                            let (p, args, ty) = self.prim(p, args);
1907                            let span = span.clone();
1908                            self.rb.call_external_func(p, &args, ty, move || {
1909                                format!("{span}: call of primitive {name} failed")
1910                            })
1911                        }
1912                    };
1913                    self.entries.insert(v, y.into());
1914                }
1915                core::GenericCoreAction::LetAtomTerm(span, v, x) => {
1916                    let v = core::GenericAtomTerm::Var(span.clone(), v.clone());
1917                    let x = self.entry(x);
1918                    self.entries.insert(v, x);
1919                }
1920                core::GenericCoreAction::Set(_, f, xs, y) => match f {
1921                    ResolvedCall::Primitive(..) => panic!("runtime primitive set!"),
1922                    ResolvedCall::Func(f) => {
1923                        let f = self.func(f);
1924                        let args = self.args(xs.iter().chain([y]));
1925                        self.rb.set(f, &args)
1926                    }
1927                },
1928                core::GenericCoreAction::Change(span, change, f, args) => match f {
1929                    ResolvedCall::Primitive(..) => panic!("runtime primitive change!"),
1930                    ResolvedCall::Func(f) => {
1931                        let name = f.name.clone();
1932                        let can_subsume = self.functions[&f.name].can_subsume;
1933                        let f = self.func(f);
1934                        let args = self.args(args);
1935                        match change {
1936                            Change::Delete => self.rb.remove(f, &args),
1937                            Change::Subsume if can_subsume => self.rb.subsume(f, &args),
1938                            Change::Subsume => {
1939                                return Err(Error::SubsumeMergeError(name, span.clone()));
1940                            }
1941                        }
1942                    }
1943                },
1944                core::GenericCoreAction::Union(_, x, y) => {
1945                    let x = self.entry(x);
1946                    let y = self.entry(y);
1947                    self.rb.union(x, y)
1948                }
1949                core::GenericCoreAction::Panic(_, message) => self.rb.panic(message.clone()),
1950            }
1951        }
1952        Ok(())
1953    }
1954
1955    fn build(self) -> egglog_bridge::RuleId {
1956        self.rb.build()
1957    }
1958}
1959
1960fn literal_to_entry(egraph: &egglog_bridge::EGraph, l: &Literal) -> QueryEntry {
1961    match l {
1962        Literal::Int(x) => egraph.base_value_constant::<i64>(*x),
1963        Literal::Float(x) => egraph.base_value_constant::<sort::F>(x.into()),
1964        Literal::String(x) => egraph.base_value_constant::<sort::S>(sort::S::new(x.clone())),
1965        Literal::Bool(x) => egraph.base_value_constant::<bool>(*x),
1966        Literal::Unit => egraph.base_value_constant::<()>(()),
1967    }
1968}
1969
1970fn literal_to_value(egraph: &egglog_bridge::EGraph, l: &Literal) -> Value {
1971    match l {
1972        Literal::Int(x) => egraph.base_values().get::<i64>(*x),
1973        Literal::Float(x) => egraph.base_values().get::<sort::F>(x.into()),
1974        Literal::String(x) => egraph.base_values().get::<sort::S>(sort::S::new(x.clone())),
1975        Literal::Bool(x) => egraph.base_values().get::<bool>(*x),
1976        Literal::Unit => egraph.base_values().get::<()>(()),
1977    }
1978}
1979
1980#[derive(Debug, Error)]
1981pub enum Error {
1982    #[error(transparent)]
1983    ParseError(#[from] ParseError),
1984    #[error(transparent)]
1985    NotFoundError(#[from] NotFoundError),
1986    #[error(transparent)]
1987    TypeError(#[from] TypeError),
1988    #[error("Errors:\n{}", ListDisplay(.0, "\n"))]
1989    TypeErrors(Vec<TypeError>),
1990    #[error("{}\nCheck failed: \n{}", .1, ListDisplay(.0, "\n"))]
1991    CheckError(Vec<Fact>, Span),
1992    #[error("{1}\nNo such ruleset: {0}")]
1993    NoSuchRuleset(String, Span),
1994    #[error(
1995        "{1}\nAttempted to add a rule to combined ruleset {0}. Combined rulesets may only depend on other rulesets."
1996    )]
1997    CombinedRulesetError(String, Span),
1998    #[error("{0}")]
1999    BackendError(String),
2000    #[error("{0}\nTried to pop too much")]
2001    Pop(Span),
2002    #[error("{0}\nCommand should have failed.")]
2003    ExpectFail(Span),
2004    #[error("{2}\nIO error: {0}: {1}")]
2005    IoError(PathBuf, std::io::Error, Span),
2006    #[error("{1}\nCannot subsume function with merge: {0}")]
2007    SubsumeMergeError(String, Span),
2008    #[error("extraction failure: {:?}", .0)]
2009    ExtractError(String),
2010    #[error("{span}\n{error}")]
2011    ProofError {
2012        span: Span,
2013        #[source]
2014        error: ProveExistsError,
2015    },
2016    #[error("{1}\n{2}\nShadowing is not allowed, but found {0}")]
2017    Shadowing(String, Span, Span),
2018    #[error("{1}\nCommand already exists: {0}")]
2019    CommandAlreadyExists(String, Span),
2020    #[error("Incorrect format in file '{0}'.")]
2021    InputFileFormatError(String),
2022    #[error(
2023        "Command is not supported by the current proof term encoding implementation.\n\
2024         This typically means the command uses constructs that cannot yet be represented as proof terms.\n\
2025         Consider disabling proof term encoding for this run or rewriting the command to avoid unsupported features.\n\
2026         Offending command: {command}"
2027    )]
2028    UnsupportedProofCommand { command: String },
2029}
2030
2031#[cfg(test)]
2032mod tests {
2033    use crate::constraint::SimpleTypeConstraint;
2034    use crate::sort::*;
2035    use crate::*;
2036
2037    #[derive(Clone)]
2038    struct InnerProduct {
2039        vec: ArcSort,
2040    }
2041
2042    impl Primitive for InnerProduct {
2043        fn name(&self) -> &str {
2044            "inner-product"
2045        }
2046
2047        fn get_type_constraints(&self, span: &Span) -> Box<dyn crate::constraint::TypeConstraint> {
2048            SimpleTypeConstraint::new(
2049                self.name(),
2050                vec![self.vec.clone(), self.vec.clone(), I64Sort.to_arcsort()],
2051                span.clone(),
2052            )
2053            .into_box()
2054        }
2055
2056        fn apply(&self, exec_state: &mut ExecutionState<'_>, args: &[Value]) -> Option<Value> {
2057            let mut sum = 0;
2058            let vec1 = exec_state
2059                .container_values()
2060                .get_val::<VecContainer>(args[0])
2061                .unwrap();
2062            let vec2 = exec_state
2063                .container_values()
2064                .get_val::<VecContainer>(args[1])
2065                .unwrap();
2066            assert_eq!(vec1.data.len(), vec2.data.len());
2067            for (a, b) in vec1.data.iter().zip(vec2.data.iter()) {
2068                let a = exec_state.base_values().unwrap::<i64>(*a);
2069                let b = exec_state.base_values().unwrap::<i64>(*b);
2070                sum += a * b;
2071            }
2072            Some(exec_state.base_values().get::<i64>(sum))
2073        }
2074    }
2075
2076    #[test]
2077    fn test_user_defined_primitive() {
2078        let mut egraph = EGraph::default();
2079        egraph
2080            .parse_and_run_program(None, "(sort IntVec (Vec i64))")
2081            .unwrap();
2082
2083        let int_vec_sort = egraph.get_arcsort_by(|s| {
2084            s.value_type() == Some(std::any::TypeId::of::<VecContainer>())
2085                && s.inner_sorts()[0].name() == I64Sort.name()
2086        });
2087
2088        egraph.add_primitive(InnerProduct { vec: int_vec_sort });
2089
2090        egraph
2091            .parse_and_run_program(
2092                None,
2093                "
2094                (let a (vec-of 1 2 3 4 5 6))
2095                (let b (vec-of 6 5 4 3 2 1))
2096                (check (= (inner-product a b) 56))
2097            ",
2098            )
2099            .unwrap();
2100    }
2101
2102    // Test that an `EGraph` is `Send` & `Sync`
2103    #[test]
2104    fn test_egraph_send_sync() {
2105        fn is_send<T: Send>(_t: &T) -> bool {
2106            true
2107        }
2108        fn is_sync<T: Sync>(_t: &T) -> bool {
2109            true
2110        }
2111        let egraph = EGraph::default();
2112        assert!(is_send(&egraph) && is_sync(&egraph));
2113    }
2114
2115    fn get_function(egraph: &EGraph, name: &str) -> Function {
2116        egraph.functions.get(name).unwrap().clone()
2117    }
2118
2119    fn get_value(egraph: &EGraph, name: &str) -> Value {
2120        let mut out = None;
2121        let id = get_function(egraph, name).backend_id;
2122        egraph.backend.for_each(id, |row| out = Some(row.vals[0]));
2123        out.unwrap()
2124    }
2125
2126    #[test]
2127    fn test_subsumed_unextractable_rebuild_arg() {
2128        // Tests that a term stays unextractable even after a rebuild after a union would change the value of one of its args
2129        let mut egraph = EGraph::default();
2130
2131        egraph
2132            .parse_and_run_program(
2133                None,
2134                r#"
2135                (datatype Math)
2136                (constructor container (Math) Math)
2137                (constructor exp () Math :cost 100)
2138                (constructor cheap () Math)
2139                (constructor cheap-1 () Math)
2140                ; we make the container cheap so that it will be extracted if possible, but then we mark it as subsumed
2141                ; so the (exp) expr should be extracted instead
2142                (let res (container (cheap)))
2143                (union res (exp))
2144                (cheap)
2145                (cheap-1)
2146                (subsume (container (cheap)))
2147                "#,
2148            ).unwrap();
2149        // At this point (cheap) and (cheap-1) should have different values, because they aren't unioned
2150        let orig_cheap_value = get_value(&egraph, "cheap");
2151        let orig_cheap_1_value = get_value(&egraph, "cheap-1");
2152        assert_ne!(orig_cheap_value, orig_cheap_1_value);
2153        // Then we can union them
2154        egraph
2155            .parse_and_run_program(
2156                None,
2157                r#"
2158                (union (cheap-1) (cheap))
2159                "#,
2160            )
2161            .unwrap();
2162        // And verify that their values are now the same and different from the original (cheap) value.
2163        let new_cheap_value = get_value(&egraph, "cheap");
2164        let new_cheap_1_value = get_value(&egraph, "cheap-1");
2165        assert_eq!(new_cheap_value, new_cheap_1_value);
2166        assert!(new_cheap_value != orig_cheap_value || new_cheap_1_value != orig_cheap_1_value);
2167        // Now verify that if we extract, it still respects the unextractable, even though it's a different values now
2168        let outputs = egraph
2169            .parse_and_run_program(
2170                None,
2171                r#"
2172                (extract res)
2173                "#,
2174            )
2175            .unwrap();
2176        assert_eq!(outputs[0].to_string(), "(exp)\n");
2177    }
2178
2179    #[test]
2180    fn test_subsumed_unextractable_rebuild_self() {
2181        // Tests that a term stays unextractable even after a rebuild after a union change its output value.
2182        let mut egraph = EGraph::default();
2183
2184        egraph
2185            .parse_and_run_program(
2186                None,
2187                r#"
2188                (datatype Math)
2189                (constructor container (Math) Math)
2190                (constructor exp () Math :cost 100)
2191                (constructor cheap () Math)
2192                (exp)
2193                (let x (cheap))
2194                (subsume (cheap))
2195                "#,
2196            )
2197            .unwrap();
2198
2199        let orig_cheap_value = get_value(&egraph, "cheap");
2200        // Then we can union them
2201        egraph
2202            .parse_and_run_program(
2203                None,
2204                r#"
2205                (union (exp) x)
2206                "#,
2207            )
2208            .unwrap();
2209        // And verify that the cheap value is now different
2210        let new_cheap_value = get_value(&egraph, "cheap");
2211        assert_ne!(new_cheap_value, orig_cheap_value);
2212
2213        // Now verify that if we extract, it still respects the subsumption, even though it's a different values now
2214        let res = egraph
2215            .parse_and_run_program(
2216                None,
2217                r#"
2218                (extract x)
2219                "#,
2220            )
2221            .unwrap();
2222        assert_eq!(res[0].to_string(), "(exp)\n");
2223    }
2224}