egglog/
lib.rs

1//! # egglog
2//! egglog is a language specialized for writing equality saturation
3//! applications. It is the successor to the rust library [egg](https://github.com/egraphs-good/egg).
4//! egglog is faster and more general than egg.
5//!
6//! # Documentation
7//! Documentation for the egglog language can be found
8//! here: [`Command`]
9//!
10//! # Tutorial
11//! [Here](https://www.youtube.com/watch?v=N2RDQGRBrSY) is the video tutorial on what egglog is and how to use it.
12//! We plan to have a text tutorial here soon, PRs welcome!
13//!
14pub mod ast;
15#[cfg(feature = "bin")]
16mod cli;
17pub mod constraint;
18mod core;
19pub mod extract;
20pub mod prelude;
21pub mod scheduler;
22mod serialize;
23pub mod sort;
24mod termdag;
25mod typechecking;
26pub mod util;
27
28// This is used to allow the `add_primitive` macro to work in
29// both this crate and other crates by referring to `::egglog`.
30extern crate self as egglog;
31use ast::*;
32#[cfg(feature = "bin")]
33pub use cli::*;
34use constraint::{Constraint, Problem, SimpleTypeConstraint, TypeConstraint};
35use core::{AtomTerm, ResolvedAtomTerm, ResolvedCall};
36pub use core_relations::{BaseValue, ContainerValue, ExecutionState, Value};
37use core_relations::{ExternalFunctionId, make_external_func};
38use csv::Writer;
39pub use egglog_add_primitive::add_primitive;
40use egglog_ast::generic_ast::{Change, GenericExpr, Literal};
41use egglog_ast::span::Span;
42use egglog_ast::util::ListDisplay;
43pub use egglog_bridge::FunctionRow;
44use egglog_bridge::{ColumnTy, QueryEntry};
45use egglog_core_relations as core_relations;
46use egglog_numeric_id as numeric_id;
47use egglog_reports::{ReportLevel, RunReport};
48use extract::{CostModel, DefaultCost, Extractor, TreeAdditiveCostModel};
49use indexmap::map::Entry;
50use log::{Level, log_enabled};
51use numeric_id::DenseIdMap;
52use prelude::*;
53use scheduler::{SchedulerId, SchedulerRecord};
54pub use serialize::{SerializeConfig, SerializeOutput, SerializedNode};
55use sort::*;
56use std::fmt::{Debug, Display, Formatter};
57use std::fs::File;
58use std::hash::Hash;
59use std::io::{Read, Write as _};
60use std::iter::once;
61use std::ops::Deref;
62use std::path::PathBuf;
63use std::str::FromStr;
64use std::sync::Arc;
65pub use termdag::{Term, TermDag, TermId};
66use thiserror::Error;
67pub use typechecking::TypeError;
68use typechecking::TypeInfo;
69use util::*;
70
71use crate::core::{GenericActionsExt, ResolvedRuleExt};
72
73pub type ArcSort = Arc<dyn Sort>;
74
75/// A trait for implementing custom primitive operations in egglog.
76///
77/// Primitives are built-in functions that can be called in both rule queries and actions.
78pub trait Primitive {
79    /// Returns the name of this primitive operation.
80    fn name(&self) -> &str;
81
82    /// Constructs a type constraint for the primitive that uses the span information
83    /// for error localization.
84    fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint>;
85
86    /// Applies the primitive operation to the given arguments.
87    ///
88    /// Returns `Some(value)` if the operation succeeds, or `None` if it fails.
89    fn apply(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value>;
90}
91
92/// A user-defined command output trait.
93pub trait UserDefinedCommandOutput: Debug + std::fmt::Display + Send + Sync {}
94impl<T> UserDefinedCommandOutput for T where T: Debug + std::fmt::Display + Send + Sync {}
95
96/// Output from a command.
97#[derive(Clone, Debug)]
98pub enum CommandOutput {
99    /// The size of a function
100    PrintFunctionSize(usize),
101    /// The name of all functions and their sizes
102    PrintAllFunctionsSize(Vec<(String, usize)>),
103    /// The best function found after extracting
104    ExtractBest(TermDag, DefaultCost, Term),
105    /// The variants of a function found after extracting
106    ExtractVariants(TermDag, Vec<Term>),
107    /// The report from all runs
108    OverallStatistics(RunReport),
109    /// A printed function and all its values
110    PrintFunction(Function, TermDag, Vec<(Term, Term)>, PrintFunctionMode),
111    /// The report from a single run
112    RunSchedule(RunReport),
113    /// A user defined output
114    UserDefined(Arc<dyn UserDefinedCommandOutput>),
115}
116
117impl std::fmt::Display for CommandOutput {
118    /// Format the command output for display, ending with a newline.
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        match self {
121            CommandOutput::PrintFunctionSize(size) => writeln!(f, "{}", size),
122            CommandOutput::PrintAllFunctionsSize(names_and_sizes) => {
123                for name in names_and_sizes {
124                    writeln!(f, "{}: {}", name.0, name.1)?;
125                }
126                Ok(())
127            }
128            CommandOutput::ExtractBest(termdag, _cost, term) => {
129                writeln!(f, "{}", termdag.to_string(term))
130            }
131            CommandOutput::ExtractVariants(termdag, terms) => {
132                writeln!(f, "(")?;
133                for expr in terms {
134                    writeln!(f, "   {}", termdag.to_string(expr))?;
135                }
136                writeln!(f, ")")
137            }
138            CommandOutput::OverallStatistics(run_report) => {
139                write!(f, "Overall statistics:\n{}", run_report)
140            }
141            CommandOutput::PrintFunction(function, termdag, terms_and_outputs, mode) => {
142                let out_is_unit = function.schema.output.name() == UnitSort.name();
143                if *mode == PrintFunctionMode::CSV {
144                    let mut wtr = Writer::from_writer(vec![]);
145                    for (term, output) in terms_and_outputs {
146                        match term {
147                            Term::App(name, children) => {
148                                let mut values = vec![name.clone()];
149                                for child_id in children {
150                                    values.push(termdag.to_string(termdag.get(*child_id)));
151                                }
152
153                                if !out_is_unit {
154                                    values.push(termdag.to_string(output));
155                                }
156                                wtr.write_record(&values).map_err(|_| std::fmt::Error)?;
157                            }
158                            _ => panic!("Expect function_to_dag to return a list of apps."),
159                        }
160                    }
161                    let csv_bytes = wtr.into_inner().map_err(|_| std::fmt::Error)?;
162                    f.write_str(&String::from_utf8(csv_bytes).map_err(|_| std::fmt::Error)?)
163                } else {
164                    writeln!(f, "(")?;
165                    for (term, output) in terms_and_outputs.iter() {
166                        write!(f, "   {}", termdag.to_string(term))?;
167                        if !out_is_unit {
168                            write!(f, " -> {}", termdag.to_string(output))?;
169                        }
170                        writeln!(f)?;
171                    }
172                    writeln!(f, ")")
173                }
174            }
175            CommandOutput::RunSchedule(_report) => Ok(()),
176            CommandOutput::UserDefined(output) => {
177                write!(f, "{}", *output)
178            }
179        }
180    }
181}
182
183/// The main interface for an e-graph in egglog.
184///
185/// An [`EGraph`] maintains a collection of equivalence classes of terms and provides
186/// operations for adding facts, running rules, and extracting optimal terms.
187///
188/// # Examples
189///
190/// ```
191/// use egglog::*;
192///
193/// let mut egraph = EGraph::default();
194/// egraph.parse_and_run_program(None, "(datatype Math (Num i64) (Add Math Math))").unwrap();
195/// ```
196#[derive(Clone)]
197pub struct EGraph {
198    backend: egglog_bridge::EGraph,
199    pub parser: Parser,
200    names: check_shadowing::Names,
201    /// pushed_egraph forms a linked list of pushed egraphs.
202    /// Pop reverts the egraph to the last pushed egraph.
203    pushed_egraph: Option<Box<Self>>,
204    functions: IndexMap<String, Function>,
205    rulesets: IndexMap<String, Ruleset>,
206    pub fact_directory: Option<PathBuf>,
207    pub seminaive: bool,
208    type_info: TypeInfo,
209    /// The run report unioned over all runs so far.
210    overall_run_report: RunReport,
211    schedulers: DenseIdMap<SchedulerId, SchedulerRecord>,
212    commands: IndexMap<String, Arc<dyn UserDefinedCommand>>,
213}
214
215/// A user-defined command allows users to inject custom command that can be called
216/// in an egglog program.
217///
218/// Compared to an external function, a user-defined command is more powerful because
219/// it has an exclusive access to the e-graph.
220pub trait UserDefinedCommand: Send + Sync {
221    /// Run the command with the given arguments.
222    fn update(&self, egraph: &mut EGraph, args: &[Expr]) -> Result<Option<CommandOutput>, Error>;
223}
224
225/// A function in the e-graph.
226///
227/// This contains the schema information of the function and
228/// the backend id of the function in the e-graph.
229#[derive(Clone)]
230pub struct Function {
231    decl: ResolvedFunctionDecl,
232    schema: ResolvedSchema,
233    can_subsume: bool,
234    backend_id: egglog_bridge::FunctionId,
235}
236
237impl Function {
238    /// Get the name of the function.
239    pub fn name(&self) -> &str {
240        &self.decl.name
241    }
242
243    /// Get the schema of the function.
244    pub fn schema(&self) -> &ResolvedSchema {
245        &self.schema
246    }
247
248    /// Whether this function supports subsumption.
249    pub fn can_subsume(&self) -> bool {
250        self.can_subsume
251    }
252}
253
254#[derive(Clone, Debug)]
255pub struct ResolvedSchema {
256    pub input: Vec<ArcSort>,
257    pub output: ArcSort,
258}
259
260impl ResolvedSchema {
261    /// Get the type at position `index`, counting the `output` sort as at position `input.len()`.
262    pub fn get_by_pos(&self, index: usize) -> Option<&ArcSort> {
263        if self.input.len() == index {
264            Some(&self.output)
265        } else {
266            self.input.get(index)
267        }
268    }
269}
270
271impl Debug for Function {
272    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
273        f.debug_struct("Function")
274            .field("decl", &self.decl)
275            .field("schema", &self.schema)
276            .finish()
277    }
278}
279
280impl Default for EGraph {
281    fn default() -> Self {
282        let mut eg = Self {
283            backend: Default::default(),
284            parser: Default::default(),
285            names: Default::default(),
286            pushed_egraph: Default::default(),
287            functions: Default::default(),
288            rulesets: Default::default(),
289            fact_directory: None,
290            seminaive: true,
291            overall_run_report: Default::default(),
292            type_info: Default::default(),
293            schedulers: Default::default(),
294            commands: Default::default(),
295        };
296
297        add_base_sort(&mut eg, UnitSort, span!()).unwrap();
298        add_base_sort(&mut eg, StringSort, span!()).unwrap();
299        add_base_sort(&mut eg, BoolSort, span!()).unwrap();
300        add_base_sort(&mut eg, I64Sort, span!()).unwrap();
301        add_base_sort(&mut eg, F64Sort, span!()).unwrap();
302        add_base_sort(&mut eg, BigIntSort, span!()).unwrap();
303        add_base_sort(&mut eg, BigRatSort, span!()).unwrap();
304        eg.type_info.add_presort::<MapSort>(span!()).unwrap();
305        eg.type_info.add_presort::<SetSort>(span!()).unwrap();
306        eg.type_info.add_presort::<VecSort>(span!()).unwrap();
307        eg.type_info.add_presort::<FunctionSort>(span!()).unwrap();
308        eg.type_info.add_presort::<MultiSetSort>(span!()).unwrap();
309
310        add_primitive!(&mut eg, "!=" = |a: #, b: #| -?> () {
311            (a != b).then_some(())
312        });
313        add_primitive!(&mut eg, "value-eq" = |a: #, b: #| -?> () {
314            (a == b).then_some(())
315        });
316        add_primitive!(&mut eg, "ordering-min" = |a: #, b: #| -> # {
317            if a < b { a } else { b }
318        });
319        add_primitive!(&mut eg, "ordering-max" = |a: #, b: #| -> # {
320            if a > b { a } else { b }
321        });
322
323        eg.rulesets
324            .insert("".into(), Ruleset::Rules(Default::default()));
325
326        eg
327    }
328}
329
330#[derive(Debug, Error)]
331#[error("Not found: {0}")]
332pub struct NotFoundError(String);
333
334impl EGraph {
335    /// Add a user-defined command to the e-graph
336    pub fn add_command(
337        &mut self,
338        name: String,
339        command: Arc<dyn UserDefinedCommand>,
340    ) -> Result<(), Error> {
341        if self.commands.contains_key(&name)
342            || self.functions.contains_key(&name)
343            || self.type_info.get_prims(&name).is_some()
344        {
345            return Err(Error::CommandAlreadyExists(name, span!()));
346        }
347        self.commands.insert(name.clone(), command);
348        self.parser.add_user_defined(name)?;
349        Ok(())
350    }
351
352    /// Push a snapshot of the e-graph into the stack.
353    ///
354    /// See [`EGraph::pop`].
355    pub fn push(&mut self) {
356        let prev_prev: Option<Box<Self>> = self.pushed_egraph.take();
357        let mut prev = self.clone();
358        prev.pushed_egraph = prev_prev;
359        self.pushed_egraph = Some(Box::new(prev));
360    }
361
362    /// Pop the current egraph off the stack, replacing
363    /// it with the previously pushed egraph.
364    /// It preserves the run report and messages from the popped
365    /// egraph.
366    pub fn pop(&mut self) -> Result<(), Error> {
367        match self.pushed_egraph.take() {
368            Some(e) => {
369                // Copy the overall report from the popped egraph
370                let overall_run_report = self.overall_run_report.clone();
371                *self = *e;
372                self.overall_run_report = overall_run_report;
373                Ok(())
374            }
375            None => Err(Error::Pop(span!())),
376        }
377    }
378
379    fn translate_expr_to_mergefn(
380        &self,
381        expr: &ResolvedExpr,
382    ) -> Result<egglog_bridge::MergeFn, Error> {
383        match expr {
384            GenericExpr::Lit(_, literal) => {
385                let val = literal_to_value(&self.backend, literal);
386                Ok(egglog_bridge::MergeFn::Const(val))
387            }
388            GenericExpr::Var(span, resolved_var) => match resolved_var.name.as_str() {
389                "old" => Ok(egglog_bridge::MergeFn::Old),
390                "new" => Ok(egglog_bridge::MergeFn::New),
391                // NB: type-checking should already catch unbound variables here.
392                _ => Err(TypeError::Unbound(resolved_var.name.clone(), span.clone()).into()),
393            },
394            GenericExpr::Call(_, ResolvedCall::Func(f), args) => {
395                let translated_args = args
396                    .iter()
397                    .map(|arg| self.translate_expr_to_mergefn(arg))
398                    .collect::<Result<Vec<_>, _>>()?;
399                Ok(egglog_bridge::MergeFn::Function(
400                    self.functions[&f.name].backend_id,
401                    translated_args,
402                ))
403            }
404            GenericExpr::Call(_, ResolvedCall::Primitive(p), args) => {
405                let translated_args = args
406                    .iter()
407                    .map(|arg| self.translate_expr_to_mergefn(arg))
408                    .collect::<Result<Vec<_>, _>>()?;
409                Ok(egglog_bridge::MergeFn::Primitive(
410                    p.primitive.1,
411                    translated_args,
412                ))
413            }
414        }
415    }
416
417    fn declare_function(&mut self, decl: &ResolvedFunctionDecl) -> Result<(), Error> {
418        let get_sort = |name: &String| match self.type_info.get_sort_by_name(name) {
419            Some(sort) => Ok(sort.clone()),
420            None => Err(Error::TypeError(TypeError::UndefinedSort(
421                name.to_owned(),
422                decl.span.clone(),
423            ))),
424        };
425
426        let input = decl
427            .schema
428            .input
429            .iter()
430            .map(get_sort)
431            .collect::<Result<Vec<_>, _>>()?;
432        let output = get_sort(&decl.schema.output)?;
433
434        let can_subsume = match decl.subtype {
435            FunctionSubtype::Constructor => true,
436            FunctionSubtype::Relation => true,
437            FunctionSubtype::Custom => false,
438        };
439
440        use egglog_bridge::{DefaultVal, MergeFn};
441        let backend_id = self.backend.add_table(egglog_bridge::FunctionConfig {
442            schema: input
443                .iter()
444                .chain([&output])
445                .map(|sort| sort.column_ty(&self.backend))
446                .collect(),
447            default: match decl.subtype {
448                FunctionSubtype::Constructor => DefaultVal::FreshId,
449                FunctionSubtype::Custom => DefaultVal::Fail,
450                FunctionSubtype::Relation => DefaultVal::Const(self.backend.base_values().get(())),
451            },
452            merge: match decl.subtype {
453                FunctionSubtype::Constructor => MergeFn::UnionId,
454                FunctionSubtype::Relation => MergeFn::AssertEq,
455                FunctionSubtype::Custom => match &decl.merge {
456                    None => MergeFn::AssertEq,
457                    Some(expr) => self.translate_expr_to_mergefn(expr)?,
458                },
459            },
460            name: decl.name.to_string(),
461            can_subsume,
462        });
463
464        let function = Function {
465            decl: decl.clone(),
466            schema: ResolvedSchema { input, output },
467            can_subsume,
468            backend_id,
469        };
470
471        let old = self.functions.insert(decl.name.clone(), function);
472        if old.is_some() {
473            panic!(
474                "Typechecking should have caught function already bound: {}",
475                decl.name
476            );
477        }
478
479        Ok(())
480    }
481
482    /// Extract rows of a table using the default cost model with name sym
483    /// The `include_output` parameter controls whether the output column is always extracted
484    /// For functions, the output column is usually useful
485    /// For constructors and relations, the output column can be ignored
486    pub fn function_to_dag(
487        &self,
488        sym: &str,
489        n: usize,
490        include_output: bool,
491    ) -> Result<(Vec<Term>, Option<Vec<Term>>, TermDag), Error> {
492        let func = self
493            .functions
494            .get(sym)
495            .ok_or(TypeError::UnboundFunction(sym.to_owned(), span!()))?;
496        let mut rootsorts = func.schema.input.clone();
497        if include_output {
498            rootsorts.push(func.schema.output.clone());
499        }
500        let extractor = Extractor::compute_costs_from_rootsorts(
501            Some(rootsorts),
502            self,
503            TreeAdditiveCostModel::default(),
504        );
505
506        let mut termdag = TermDag::default();
507        let mut inputs: Vec<Term> = Vec::new();
508        let mut output: Option<Vec<Term>> = if include_output {
509            Some(Vec::new())
510        } else {
511            None
512        };
513
514        let extract_row = |row: egglog_bridge::FunctionRow| {
515            if inputs.len() < n {
516                // include subsumed rows
517                let mut children: Vec<Term> = Vec::new();
518                for (value, sort) in row.vals.iter().zip(&func.schema.input) {
519                    let (_, term) = extractor
520                        .extract_best_with_sort(self, &mut termdag, *value, sort.clone())
521                        .unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
522                    children.push(term);
523                }
524                inputs.push(termdag.app(sym.to_owned(), children));
525                if include_output {
526                    let value = row.vals[func.schema.input.len()];
527                    let sort = &func.schema.output;
528                    let (_, term) = extractor
529                        .extract_best_with_sort(self, &mut termdag, value, sort.clone())
530                        .unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
531                    output.as_mut().unwrap().push(term);
532                }
533                true
534            } else {
535                false
536            }
537        };
538
539        self.backend.for_each_while(func.backend_id, extract_row);
540
541        Ok((inputs, output, termdag))
542    }
543
544    /// Print up to `n` the tuples in a given function.
545    /// Print all tuples if `n` is not provided.
546    pub fn print_function(
547        &mut self,
548        sym: &str,
549        n: Option<usize>,
550        file: Option<File>,
551        mode: PrintFunctionMode,
552    ) -> Result<Option<CommandOutput>, Error> {
553        let n = match n {
554            Some(n) => {
555                log::info!("Printing up to {n} tuples of function {sym} as {mode}");
556                n
557            }
558            None => {
559                log::info!("Printing all tuples of function {sym} as {mode}");
560                usize::MAX
561            }
562        };
563
564        let (terms, outputs, termdag) = self.function_to_dag(sym, n, true)?;
565        let f = self
566            .functions
567            .get(sym)
568            // function_to_dag should have checked this
569            .unwrap();
570        let terms_and_outputs: Vec<_> = terms.into_iter().zip(outputs.unwrap()).collect();
571        let output = CommandOutput::PrintFunction(f.clone(), termdag, terms_and_outputs, mode);
572        match file {
573            Some(mut file) => {
574                log::info!("Writing output to file");
575                file.write_all(output.to_string().as_bytes())
576                    .expect("Error writing to file");
577                Ok(None)
578            }
579            None => Ok(Some(output)),
580        }
581    }
582
583    /// Print the size of a function. If no function name is provided,
584    /// print the size of all functions in "name: len" pairs.
585    pub fn print_size(&mut self, sym: Option<&str>) -> Result<CommandOutput, Error> {
586        if let Some(sym) = sym {
587            let f = self
588                .functions
589                .get(sym)
590                .ok_or(TypeError::UnboundFunction(sym.to_owned(), span!()))?;
591            let size = self.backend.table_size(f.backend_id);
592            log::info!("Function {} has size {}", sym, size);
593            Ok(CommandOutput::PrintFunctionSize(size))
594        } else {
595            // Print size of all functions
596            let mut lens = self
597                .functions
598                .iter()
599                .map(|(sym, f)| (sym.clone(), self.backend.table_size(f.backend_id)))
600                .collect::<Vec<_>>();
601
602            // Function name's alphabetical order
603            lens.sort_by_key(|(name, _)| name.clone());
604            if log_enabled!(Level::Info) {
605                for (sym, len) in &lens {
606                    log::info!("Function {} has size {}", sym, len);
607                }
608            }
609            Ok(CommandOutput::PrintAllFunctionsSize(lens))
610        }
611    }
612
613    // returns whether the egraph was updated
614    fn run_schedule(&mut self, sched: &ResolvedSchedule) -> Result<RunReport, Error> {
615        match sched {
616            ResolvedSchedule::Run(span, config) => self.run_rules(span, config),
617            ResolvedSchedule::Repeat(_span, limit, sched) => {
618                let mut report = RunReport::default();
619                for _i in 0..*limit {
620                    let rec = self.run_schedule(sched)?;
621                    let updated = rec.updated;
622                    report.union(rec);
623                    if !updated {
624                        break;
625                    }
626                }
627                Ok(report)
628            }
629            ResolvedSchedule::Saturate(_span, sched) => {
630                let mut report = RunReport::default();
631                loop {
632                    let rec = self.run_schedule(sched)?;
633                    let updated = rec.updated;
634                    report.union(rec);
635                    if !updated {
636                        break;
637                    }
638                }
639                Ok(report)
640            }
641            ResolvedSchedule::Sequence(_span, scheds) => {
642                let mut report = RunReport::default();
643                for sched in scheds {
644                    report.union(self.run_schedule(sched)?);
645                }
646                Ok(report)
647            }
648        }
649    }
650
651    /// Extract a value to a [`TermDag`] and [`Term`] in the [`TermDag`] using the default cost model.
652    /// See also [`EGraph::extract_value_with_cost_model`] for more control.
653    pub fn extract_value(
654        &self,
655        sort: &ArcSort,
656        value: Value,
657    ) -> Result<(TermDag, Term, DefaultCost), Error> {
658        self.extract_value_with_cost_model(sort, value, TreeAdditiveCostModel::default())
659    }
660
661    /// Extract a value to a [`TermDag`] and [`Term`] in the [`TermDag`].
662    /// Note that the `TermDag` may contain a superset of the nodes in the `Term`.
663    /// See also [`EGraph::extract_value_to_string`] for convenience.
664    pub fn extract_value_with_cost_model<CM: CostModel<DefaultCost> + 'static>(
665        &self,
666        sort: &ArcSort,
667        value: Value,
668        cost_model: CM,
669    ) -> Result<(TermDag, Term, DefaultCost), Error> {
670        let extractor =
671            Extractor::compute_costs_from_rootsorts(Some(vec![sort.clone()]), self, cost_model);
672        let mut termdag = TermDag::default();
673        let (cost, term) = extractor.extract_best(self, &mut termdag, value).unwrap();
674        Ok((termdag, term, cost))
675    }
676
677    /// Extract a value to a string for printing.
678    /// See also [`EGraph::extract_value`] for more control.
679    pub fn extract_value_to_string(
680        &self,
681        sort: &ArcSort,
682        value: Value,
683    ) -> Result<(String, DefaultCost), Error> {
684        let (termdag, term, cost) = self.extract_value(sort, value)?;
685        Ok((termdag.to_string(&term), cost))
686    }
687
688    fn run_rules(&mut self, span: &Span, config: &ResolvedRunConfig) -> Result<RunReport, Error> {
689        let mut report: RunReport = Default::default();
690
691        let GenericRunConfig { ruleset, until } = config;
692
693        if let Some(facts) = until {
694            if self.check_facts(span, facts).is_ok() {
695                log::info!(
696                    "Breaking early because of facts:\n {}!",
697                    ListDisplay(facts, "\n")
698                );
699                return Ok(report);
700            }
701        }
702
703        let subreport = self.step_rules(ruleset)?;
704        report.union(subreport);
705
706        if log_enabled!(Level::Debug) {
707            log::debug!("database size: {}", self.num_tuples());
708        }
709
710        Ok(report)
711    }
712
713    /// Runs a ruleset for an iteration.
714    ///
715    /// This applies every match it finds (under semi-naive).
716    /// See [`EGraph::step_rules_with_scheduler`] for more fine-grained control.
717    ///
718    /// This will return an error if an egglog primitive returns None in an action.
719    pub fn step_rules(&mut self, ruleset: &str) -> Result<RunReport, Error> {
720        fn collect_rule_ids(
721            ruleset: &str,
722            rulesets: &IndexMap<String, Ruleset>,
723            ids: &mut Vec<egglog_bridge::RuleId>,
724        ) {
725            match &rulesets[ruleset] {
726                Ruleset::Rules(rules) => {
727                    for (_, id) in rules.values() {
728                        ids.push(*id);
729                    }
730                }
731                Ruleset::Combined(sub_rulesets) => {
732                    for sub_ruleset in sub_rulesets {
733                        collect_rule_ids(sub_ruleset, rulesets, ids);
734                    }
735                }
736            }
737        }
738
739        let mut rule_ids = Vec::new();
740        collect_rule_ids(ruleset, &self.rulesets, &mut rule_ids);
741
742        let iteration_report = self
743            .backend
744            .run_rules(&rule_ids)
745            .map_err(|e| Error::BackendError(e.to_string()))?;
746
747        Ok(RunReport::singleton(ruleset, iteration_report))
748    }
749
750    fn add_rule(&mut self, rule: ast::ResolvedRule) -> Result<String, Error> {
751        let core_rule =
752            rule.to_canonicalized_core_rule(&self.type_info, &mut self.parser.symbol_gen)?;
753        let (query, actions) = (&core_rule.body, &core_rule.head);
754
755        let rule_id = {
756            let mut translator = BackendRule::new(
757                self.backend.new_rule(&rule.name, self.seminaive),
758                &self.functions,
759                &self.type_info,
760            );
761            translator.query(query, false);
762            translator.actions(actions)?;
763            translator.build()
764        };
765
766        if let Some(rules) = self.rulesets.get_mut(&rule.ruleset) {
767            match rules {
768                Ruleset::Rules(rules) => {
769                    match rules.entry(rule.name.clone()) {
770                        indexmap::map::Entry::Occupied(_) => {
771                            let name = rule.name;
772                            panic!("Rule '{name}' was already present")
773                        }
774                        indexmap::map::Entry::Vacant(e) => e.insert((core_rule, rule_id)),
775                    };
776                    Ok(rule.name)
777                }
778                Ruleset::Combined(_) => Err(Error::CombinedRulesetError(rule.ruleset, rule.span)),
779            }
780        } else {
781            Err(Error::NoSuchRuleset(rule.ruleset, rule.span))
782        }
783    }
784
785    fn eval_actions(&mut self, actions: &ResolvedActions) -> Result<(), Error> {
786        let (actions, _) = actions.to_core_actions(
787            &self.type_info,
788            &mut Default::default(),
789            &mut self.parser.symbol_gen,
790        )?;
791
792        let mut translator = BackendRule::new(
793            self.backend.new_rule("eval_actions", false),
794            &self.functions,
795            &self.type_info,
796        );
797        translator.actions(&actions)?;
798        let id = translator.build();
799        let result = self.backend.run_rules(&[id]);
800        self.backend.free_rule(id);
801
802        match result {
803            Ok(_) => Ok(()),
804            Err(e) => Err(Error::BackendError(e.to_string())),
805        }
806    }
807
808    /// Evaluates an expression, returns the sort of the expression and the evaluation result.
809    pub fn eval_expr(&mut self, expr: &Expr) -> Result<(ArcSort, Value), Error> {
810        let span = expr.span();
811        let command = Command::Action(Action::Expr(span.clone(), expr.clone()));
812        let resolved_commands = self.process_command(command)?;
813        assert_eq!(resolved_commands.len(), 1);
814        let resolved_command = resolved_commands.into_iter().next().unwrap();
815        let resolved_expr = match resolved_command {
816            ResolvedNCommand::CoreAction(ResolvedAction::Expr(_, resolved_expr)) => resolved_expr,
817            _ => unreachable!(),
818        };
819        let sort = resolved_expr.output_type();
820        let value = self.eval_resolved_expr(span, &resolved_expr)?;
821        Ok((sort, value))
822    }
823
824    fn eval_resolved_expr(&mut self, span: Span, expr: &ResolvedExpr) -> Result<Value, Error> {
825        let unit_id = self.backend.base_values().get_ty::<()>();
826        let unit_val = self.backend.base_values().get(());
827
828        let result: egglog_bridge::SideChannel<Value> = Default::default();
829        let result_ref = result.clone();
830        let ext_id = self
831            .backend
832            .register_external_func(make_external_func(move |_es, vals| {
833                debug_assert!(vals.len() == 1);
834                *result_ref.lock().unwrap() = Some(vals[0]);
835                Some(unit_val)
836            }));
837
838        let mut translator = BackendRule::new(
839            self.backend.new_rule("eval_resolved_expr", false),
840            &self.functions,
841            &self.type_info,
842        );
843
844        let result_var = ResolvedVar {
845            name: self.parser.symbol_gen.fresh("eval_resolved_expr"),
846            sort: expr.output_type(),
847            is_global_ref: false,
848        };
849        let actions = ResolvedActions::singleton(ResolvedAction::Let(
850            span.clone(),
851            result_var.clone(),
852            expr.clone(),
853        ));
854        let actions = actions
855            .to_core_actions(
856                &self.type_info,
857                &mut Default::default(),
858                &mut self.parser.symbol_gen,
859            )?
860            .0;
861        translator.actions(&actions)?;
862
863        let arg = translator.entry(&ResolvedAtomTerm::Var(span.clone(), result_var));
864        translator.rb.call_external_func(
865            ext_id,
866            &[arg],
867            egglog_bridge::ColumnTy::Base(unit_id),
868            || "this function will never panic".to_string(),
869        );
870
871        let id = translator.build();
872        let rule_result = self.backend.run_rules(&[id]);
873        self.backend.free_rule(id);
874        self.backend.free_external_func(ext_id);
875        let _ = rule_result.map_err(|e| {
876            Error::BackendError(format!("Failed to evaluate expression '{}': {}", expr, e))
877        })?;
878
879        let result = result.lock().unwrap().unwrap();
880        Ok(result)
881    }
882
883    fn add_combined_ruleset(&mut self, name: String, rulesets: Vec<String>) {
884        match self.rulesets.entry(name.clone()) {
885            Entry::Occupied(_) => panic!("Ruleset '{name}' was already present"),
886            Entry::Vacant(e) => e.insert(Ruleset::Combined(rulesets)),
887        };
888    }
889
890    fn add_ruleset(&mut self, name: String) {
891        match self.rulesets.entry(name.clone()) {
892            Entry::Occupied(_) => panic!("Ruleset '{name}' was already present"),
893            Entry::Vacant(e) => e.insert(Ruleset::Rules(Default::default())),
894        };
895    }
896
897    fn check_facts(&mut self, span: &Span, facts: &[ResolvedFact]) -> Result<(), Error> {
898        let fresh_name = self.parser.symbol_gen.fresh("check_facts");
899        let fresh_ruleset = self.parser.symbol_gen.fresh("check_facts_ruleset");
900        let rule = ast::ResolvedRule {
901            span: span.clone(),
902            head: ResolvedActions::default(),
903            body: facts.to_vec(),
904            name: fresh_name.clone(),
905            ruleset: fresh_ruleset.clone(),
906        };
907        let core_rule =
908            rule.to_canonicalized_core_rule(&self.type_info, &mut self.parser.symbol_gen)?;
909        let query = core_rule.body;
910
911        let ext_sc = egglog_bridge::SideChannel::default();
912        let ext_sc_ref = ext_sc.clone();
913        let ext_id = self
914            .backend
915            .register_external_func(make_external_func(move |_, _| {
916                *ext_sc_ref.lock().unwrap() = Some(());
917                Some(Value::new_const(0))
918            }));
919
920        let mut translator = BackendRule::new(
921            self.backend.new_rule("check_facts", false),
922            &self.functions,
923            &self.type_info,
924        );
925        translator.query(&query, true);
926        translator
927            .rb
928            .call_external_func(ext_id, &[], egglog_bridge::ColumnTy::Id, || {
929                "this function will never panic".to_string()
930            });
931        let id = translator.build();
932        let _ = self.backend.run_rules(&[id]).unwrap();
933        self.backend.free_rule(id);
934
935        self.backend.free_external_func(ext_id);
936
937        let ext_sc_val = ext_sc.lock().unwrap().take();
938        let matched = matches!(ext_sc_val, Some(()));
939
940        if !matched {
941            Err(Error::CheckError(
942                facts.iter().map(|f| f.clone().make_unresolved()).collect(),
943                span.clone(),
944            ))
945        } else {
946            Ok(())
947        }
948    }
949
950    fn run_command(&mut self, command: ResolvedNCommand) -> Result<Option<CommandOutput>, Error> {
951        match command {
952            // Sorts are already declared during typechecking
953            ResolvedNCommand::Sort(_span, name, _presort_and_args) => {
954                log::info!("Declared sort {}.", name)
955            }
956            ResolvedNCommand::Function(fdecl) => {
957                self.declare_function(&fdecl)?;
958                log::info!("Declared {} {}.", fdecl.subtype, fdecl.name)
959            }
960            ResolvedNCommand::AddRuleset(_span, name) => {
961                self.add_ruleset(name.clone());
962                log::info!("Declared ruleset {name}.");
963            }
964            ResolvedNCommand::UnstableCombinedRuleset(_span, name, others) => {
965                self.add_combined_ruleset(name.clone(), others);
966                log::info!("Declared ruleset {name}.");
967            }
968            ResolvedNCommand::NormRule { rule } => {
969                let name = rule.name.clone();
970                self.add_rule(rule)?;
971                log::info!("Declared rule {name}.")
972            }
973            ResolvedNCommand::RunSchedule(sched) => {
974                let report = self.run_schedule(&sched)?;
975                log::info!("Ran schedule {}.", sched);
976                log::info!("Report: {}", report);
977                self.overall_run_report.union(report.clone());
978                return Ok(Some(CommandOutput::RunSchedule(report)));
979            }
980            ResolvedNCommand::PrintOverallStatistics(span, file) => match file {
981                None => {
982                    log::info!("Printed overall statistics");
983                    return Ok(Some(CommandOutput::OverallStatistics(
984                        self.overall_run_report.clone(),
985                    )));
986                }
987                Some(path) => {
988                    let mut file = std::fs::File::create(&path)
989                        .map_err(|e| Error::IoError(path.clone().into(), e, span.clone()))?;
990                    log::info!("Printed overall statistics to json file {}", path);
991
992                    serde_json::to_writer(&mut file, &self.overall_run_report)
993                        .expect("error serializing to json");
994                }
995            },
996            ResolvedNCommand::Check(span, facts) => {
997                self.check_facts(&span, &facts)?;
998                log::info!("Checked fact {:?}.", facts);
999            }
1000            ResolvedNCommand::CoreAction(action) => match &action {
1001                ResolvedAction::Let(_, name, contents) => {
1002                    panic!("Globals should have been desugared away: {name} = {contents}")
1003                }
1004                _ => {
1005                    self.eval_actions(&ResolvedActions::new(vec![action.clone()]))?;
1006                }
1007            },
1008            ResolvedNCommand::Extract(span, expr, variants) => {
1009                let sort = expr.output_type();
1010
1011                let x = self.eval_resolved_expr(span.clone(), &expr)?;
1012                let n = self.eval_resolved_expr(span, &variants)?;
1013                let n: i64 = self.backend.base_values().unwrap(n);
1014
1015                let mut termdag = TermDag::default();
1016
1017                let extractor = Extractor::compute_costs_from_rootsorts(
1018                    Some(vec![sort]),
1019                    self,
1020                    TreeAdditiveCostModel::default(),
1021                );
1022                return if n == 0 {
1023                    if let Some((cost, term)) = extractor.extract_best(self, &mut termdag, x) {
1024                        // dont turn termdag into a string if we have messages disabled for performance reasons
1025                        if log_enabled!(Level::Info) {
1026                            log::info!("extracted with cost {cost}: {}", termdag.to_string(&term));
1027                        }
1028                        Ok(Some(CommandOutput::ExtractBest(termdag, cost, term)))
1029                    } else {
1030                        Err(Error::ExtractError(
1031                            "Unable to find any valid extraction (likely due to subsume or delete)"
1032                                .to_string(),
1033                        ))
1034                    }
1035                } else {
1036                    if n < 0 {
1037                        panic!("Cannot extract negative number of variants");
1038                    }
1039                    let terms: Vec<Term> = extractor
1040                        .extract_variants(self, &mut termdag, x, n as usize)
1041                        .iter()
1042                        .map(|e| e.1.clone())
1043                        .collect();
1044                    if log_enabled!(Level::Info) {
1045                        let expr_str = expr.to_string();
1046                        log::info!("extracted {} variants for {expr_str}", terms.len());
1047                    }
1048                    Ok(Some(CommandOutput::ExtractVariants(termdag, terms)))
1049                };
1050            }
1051            ResolvedNCommand::Push(n) => {
1052                (0..n).for_each(|_| self.push());
1053                log::info!("Pushed {n} levels.")
1054            }
1055            ResolvedNCommand::Pop(span, n) => {
1056                for _ in 0..n {
1057                    self.pop().map_err(|err| {
1058                        if let Error::Pop(_) = err {
1059                            Error::Pop(span.clone())
1060                        } else {
1061                            err
1062                        }
1063                    })?;
1064                }
1065                log::info!("Popped {n} levels.")
1066            }
1067            ResolvedNCommand::PrintFunction(span, f, n, file, mode) => {
1068                let file = file
1069                    .map(|file| {
1070                        std::fs::File::create(&file)
1071                            .map_err(|e| Error::IoError(file.into(), e, span.clone()))
1072                    })
1073                    .transpose()?;
1074                return self.print_function(&f, n, file, mode).map_err(|e| match e {
1075                    Error::TypeError(TypeError::UnboundFunction(f, _)) => {
1076                        Error::TypeError(TypeError::UnboundFunction(f, span.clone()))
1077                    }
1078                    // This case is currently impossible
1079                    _ => e,
1080                });
1081            }
1082            ResolvedNCommand::PrintSize(span, f) => {
1083                let res = self.print_size(f.as_deref()).map_err(|e| match e {
1084                    Error::TypeError(TypeError::UnboundFunction(f, _)) => {
1085                        Error::TypeError(TypeError::UnboundFunction(f, span.clone()))
1086                    }
1087                    // This case is currently impossible
1088                    _ => e,
1089                })?;
1090                return Ok(Some(res));
1091            }
1092            ResolvedNCommand::Fail(span, c) => {
1093                let result = self.run_command(*c);
1094                if let Err(e) = result {
1095                    log::info!("Command failed as expected: {e}");
1096                } else {
1097                    return Err(Error::ExpectFail(span));
1098                }
1099            }
1100            ResolvedNCommand::Input {
1101                span: _,
1102                name,
1103                file,
1104            } => {
1105                self.input_file(&name, file)?;
1106            }
1107            ResolvedNCommand::Output { span, file, exprs } => {
1108                let mut filename = self.fact_directory.clone().unwrap_or_default();
1109                filename.push(file.as_str());
1110                // append to file
1111                let mut f = File::options()
1112                    .append(true)
1113                    .create(true)
1114                    .open(&filename)
1115                    .map_err(|e| Error::IoError(filename.clone(), e, span.clone()))?;
1116
1117                let extractor = Extractor::compute_costs_from_rootsorts(
1118                    None,
1119                    self,
1120                    TreeAdditiveCostModel::default(),
1121                );
1122                let mut termdag: TermDag = Default::default();
1123
1124                use std::io::Write;
1125                for expr in exprs {
1126                    let value = self.eval_resolved_expr(span.clone(), &expr)?;
1127                    let expr_type = expr.output_type();
1128
1129                    let term = extractor
1130                        .extract_best_with_sort(self, &mut termdag, value, expr_type)
1131                        .unwrap()
1132                        .1;
1133                    writeln!(f, "{}", termdag.to_string(&term))
1134                        .map_err(|e| Error::IoError(filename.clone(), e, span.clone()))?;
1135                }
1136
1137                log::info!("Output to '{filename:?}'.")
1138            }
1139            ResolvedNCommand::UserDefined(_span, name, exprs) => {
1140                let command = self.commands.swap_remove(&name).unwrap_or_else(|| {
1141                    panic!("Unrecognized user-defined command: {}", name);
1142                });
1143                let res = command.update(self, &exprs);
1144                self.commands.insert(name, command);
1145                return res;
1146            }
1147        };
1148
1149        Ok(None)
1150    }
1151
1152    fn input_file(&mut self, func_name: &str, file: String) -> Result<(), Error> {
1153        let function_type = self
1154            .type_info
1155            .get_func_type(func_name)
1156            .unwrap_or_else(|| panic!("Unrecognized function name {}", func_name));
1157        let func = self.functions.get_mut(func_name).unwrap();
1158
1159        let mut filename = self.fact_directory.clone().unwrap_or_default();
1160        filename.push(file.as_str());
1161
1162        // check that the function uses supported types
1163
1164        for t in &func.schema.input {
1165            match t.name() {
1166                "i64" | "f64" | "String" => {}
1167                s => panic!("Unsupported type {} for input", s),
1168            }
1169        }
1170
1171        if function_type.subtype != FunctionSubtype::Constructor {
1172            match func.schema.output.name() {
1173                "i64" | "String" | "Unit" => {}
1174                s => panic!("Unsupported type {} for input", s),
1175            }
1176        }
1177
1178        log::info!("Opening file '{:?}'...", filename);
1179        let mut f = File::open(filename).unwrap();
1180        let mut contents = String::new();
1181        f.read_to_string(&mut contents).unwrap();
1182
1183        // Can also do a row-major Vec<Value>
1184        let mut parsed_contents: Vec<Vec<Value>> = Vec::with_capacity(contents.lines().count());
1185
1186        let mut row_schema = func.schema.input.clone();
1187        if function_type.subtype == FunctionSubtype::Custom {
1188            row_schema.push(func.schema.output.clone());
1189        }
1190
1191        log::debug!("{:?}", row_schema);
1192
1193        let unit_val = self.backend.base_values().get(());
1194
1195        for line in contents.lines() {
1196            let mut it = line.split('\t').map(|s| s.trim());
1197
1198            let mut row: Vec<Value> = Vec::with_capacity(row_schema.len());
1199
1200            for sort in row_schema.iter() {
1201                if let Some(raw) = it.next() {
1202                    let val = match sort.name() {
1203                        "i64" => {
1204                            if let Ok(i) = raw.parse::<i64>() {
1205                                self.backend.base_values().get(i)
1206                            } else {
1207                                return Err(Error::InputFileFormatError(file));
1208                            }
1209                        }
1210                        "f64" => {
1211                            if let Ok(f) = raw.parse::<f64>() {
1212                                self.backend
1213                                    .base_values()
1214                                    .get::<F>(core_relations::Boxed::new(f.into()))
1215                            } else {
1216                                return Err(Error::InputFileFormatError(file));
1217                            }
1218                        }
1219                        "String" => self.backend.base_values().get::<S>(raw.to_string().into()),
1220                        "Unit" => unit_val,
1221                        _ => panic!("Unreachable"),
1222                    };
1223                    row.push(val);
1224                } else {
1225                    break;
1226                }
1227            }
1228
1229            if row.is_empty() {
1230                continue;
1231            }
1232
1233            if row.len() != row_schema.len() || it.next().is_some() {
1234                return Err(Error::InputFileFormatError(file));
1235            }
1236
1237            parsed_contents.push(row);
1238        }
1239
1240        log::debug!("Successfully loaded file.");
1241
1242        let num_facts = parsed_contents.len();
1243
1244        let mut table_action = egglog_bridge::TableAction::new(&self.backend, func.backend_id);
1245
1246        if function_type.subtype != FunctionSubtype::Constructor {
1247            self.backend.with_execution_state(|es| {
1248                for row in parsed_contents.iter() {
1249                    table_action.insert(es, row.iter().copied());
1250                }
1251                Some(unit_val)
1252            });
1253        } else {
1254            self.backend.with_execution_state(|es| {
1255                for row in parsed_contents.iter() {
1256                    table_action.lookup(es, row);
1257                }
1258                Some(unit_val)
1259            });
1260        }
1261
1262        self.backend.flush_updates();
1263
1264        log::info!("Read {num_facts} facts into {func_name} from '{file}'.");
1265        Ok(())
1266    }
1267
1268    fn process_command(&mut self, command: Command) -> Result<Vec<ResolvedNCommand>, Error> {
1269        let mut program = self.resolve_command(command)?;
1270
1271        program = remove_globals::remove_globals(program, &mut self.parser.symbol_gen);
1272        for command in &program {
1273            self.names.check_shadowing(command)?;
1274        }
1275
1276        Ok(program)
1277    }
1278
1279    fn resolve_command(&mut self, command: Command) -> Result<Vec<ResolvedNCommand>, Error> {
1280        let program = desugar::desugar_program(vec![command], &mut self.parser, self.seminaive)?;
1281        Ok(self.typecheck_program(&program)?)
1282    }
1283
1284    /// Run a program, represented as an AST.
1285    /// Return a list of messages.
1286    pub fn run_program(&mut self, program: Vec<Command>) -> Result<Vec<CommandOutput>, Error> {
1287        let mut outputs = Vec::new();
1288        for command in program {
1289            // Important to process each command individually
1290            // because push and pop create new scopes
1291            for processed in self.process_command(command)? {
1292                let result = self.run_command(processed)?;
1293                if let Some(output) = result {
1294                    outputs.push(output);
1295                }
1296            }
1297        }
1298
1299        Ok(outputs)
1300    }
1301
1302    pub fn resugar_program(
1303        &mut self,
1304        filename: Option<String>,
1305        input: &str,
1306    ) -> Result<Vec<String>, Error> {
1307        let parsed = self.parser.get_program_from_string(filename, input)?;
1308        let mut outputs = Vec::new();
1309        for command in parsed {
1310            for processed in self.resolve_command(command)? {
1311                // When re-suggaring, we still need to run scope-related commands (Push/Pop) to make
1312                // the program well-scoped.
1313                if let GenericNCommand::Push(..) | GenericNCommand::Pop(..) = &processed {
1314                    self.run_command(processed.clone())?;
1315                }
1316                outputs.push(processed.to_command().to_string());
1317            }
1318        }
1319        Ok(outputs)
1320    }
1321
1322    /// Takes a source program `input`, parses it, runs it, and returns a list of messages.
1323    ///
1324    /// `filename` is an optional argument to indicate the source of
1325    /// the program for error reporting. If `filename` is `None`,
1326    /// a default name will be used.
1327    pub fn parse_and_run_program(
1328        &mut self,
1329        filename: Option<String>,
1330        input: &str,
1331    ) -> Result<Vec<CommandOutput>, Error> {
1332        let parsed = self.parser.get_program_from_string(filename, input)?;
1333        self.run_program(parsed)
1334    }
1335
1336    /// Get the number of tuples in the database.
1337    ///
1338    pub fn num_tuples(&self) -> usize {
1339        self.functions
1340            .values()
1341            .map(|f| self.backend.table_size(f.backend_id))
1342            .sum()
1343    }
1344
1345    /// Returns a sort based on the type.
1346    pub fn get_sort<S: Sort>(&self) -> Arc<S> {
1347        self.type_info.get_sort()
1348    }
1349
1350    /// Returns a sort that satisfies the type and predicate.
1351    pub fn get_sort_by<S: Sort>(&self, f: impl Fn(&Arc<S>) -> bool) -> Arc<S> {
1352        self.type_info.get_sort_by(f)
1353    }
1354
1355    /// Returns all sorts based on the type.
1356    pub fn get_sorts<S: Sort>(&self) -> Vec<Arc<S>> {
1357        self.type_info.get_sorts()
1358    }
1359
1360    /// Returns all sorts that satisfy the type and predicate.
1361    pub fn get_sorts_by<S: Sort>(&self, f: impl Fn(&Arc<S>) -> bool) -> Vec<Arc<S>> {
1362        self.type_info.get_sorts_by(f)
1363    }
1364
1365    /// Returns a sort based on the predicate.
1366    pub fn get_arcsort_by(&self, f: impl Fn(&ArcSort) -> bool) -> ArcSort {
1367        self.type_info.get_arcsort_by(f)
1368    }
1369
1370    /// Returns all sorts that satisfy the predicate.
1371    pub fn get_arcsorts_by(&self, f: impl Fn(&ArcSort) -> bool) -> Vec<ArcSort> {
1372        self.type_info.get_arcsorts_by(f)
1373    }
1374
1375    /// Returns the sort with the given name if it exists.
1376    pub fn get_sort_by_name(&self, sym: &str) -> Option<&ArcSort> {
1377        self.type_info.get_sort_by_name(sym)
1378    }
1379
1380    /// Gets the overall run report and returns it.
1381    pub fn get_overall_run_report(&self) -> &RunReport {
1382        &self.overall_run_report
1383    }
1384
1385    /// Convert from an egglog value to a Rust type.
1386    pub fn value_to_base<T: BaseValue>(&self, x: Value) -> T {
1387        self.backend.base_values().unwrap::<T>(x)
1388    }
1389
1390    /// Convert from a Rust type to an egglog value.
1391    pub fn base_to_value<T: BaseValue>(&self, x: T) -> Value {
1392        self.backend.base_values().get::<T>(x)
1393    }
1394
1395    /// Convert from an egglog value to a reference of a Rust container type.
1396    ///
1397    /// Returns `None` if the value cannot be converted to the requested container type.
1398    ///
1399    /// Warning: The return type of this function may contain lock guards.
1400    /// Attempts to modify the contents of the containers database may deadlock if the given guard has not been dropped.
1401    pub fn value_to_container<T: ContainerValue>(
1402        &self,
1403        x: Value,
1404    ) -> Option<impl Deref<Target = T>> {
1405        self.backend.container_values().get_val::<T>(x)
1406    }
1407
1408    /// Convert from a Rust container type to an egglog value.
1409    pub fn container_to_value<T: ContainerValue>(&mut self, x: T) -> Value {
1410        self.backend.with_execution_state(|state| {
1411            self.backend.container_values().register_val::<T>(x, state)
1412        })
1413    }
1414
1415    /// Get the size of a function in the e-graph.
1416    ///
1417    /// `panics` if the function does not exist.
1418    pub fn get_size(&self, func: &str) -> usize {
1419        let function_id = self.functions.get(func).unwrap().backend_id;
1420        self.backend.table_size(function_id)
1421    }
1422
1423    /// Lookup a tuple in afunction in the e-graph.
1424    ///
1425    /// Returns `None` if the tuple does not exist.
1426    /// `panics` if the function does not exist.
1427    pub fn lookup_function(&self, name: &str, key: &[Value]) -> Option<Value> {
1428        let func = self.functions.get(name).unwrap().backend_id;
1429        self.backend.lookup_id(func, key)
1430    }
1431
1432    /// Get a function by name.
1433    ///
1434    /// Returns `None` if the function does not exist.
1435    pub fn get_function(&self, name: &str) -> Option<&Function> {
1436        self.functions.get(name)
1437    }
1438
1439    pub fn set_report_level(&mut self, level: ReportLevel) {
1440        self.backend.set_report_level(level);
1441    }
1442
1443    /// A basic method for dumping the state of the database to `log::info!`.
1444    ///
1445    /// For large tables, this is unlikely to give particularly useful output.
1446    pub fn dump_debug_info(&self) {
1447        self.backend.dump_debug_info();
1448    }
1449
1450    /// Get the canonical representation for `val` based on type.
1451    pub fn get_canonical_value(&self, val: Value, sort: &ArcSort) -> Value {
1452        self.backend
1453            .get_canon_repr(val, sort.column_ty(&self.backend))
1454    }
1455}
1456
1457struct BackendRule<'a> {
1458    rb: egglog_bridge::RuleBuilder<'a>,
1459    entries: HashMap<core::ResolvedAtomTerm, QueryEntry>,
1460    functions: &'a IndexMap<String, Function>,
1461    type_info: &'a TypeInfo,
1462}
1463
1464impl<'a> BackendRule<'a> {
1465    fn new(
1466        rb: egglog_bridge::RuleBuilder<'a>,
1467        functions: &'a IndexMap<String, Function>,
1468        type_info: &'a TypeInfo,
1469    ) -> BackendRule<'a> {
1470        BackendRule {
1471            rb,
1472            functions,
1473            type_info,
1474            entries: Default::default(),
1475        }
1476    }
1477
1478    fn entry(&mut self, x: &core::ResolvedAtomTerm) -> QueryEntry {
1479        self.entries
1480            .entry(x.clone())
1481            .or_insert_with(|| match x {
1482                core::GenericAtomTerm::Var(_, v) => self
1483                    .rb
1484                    .new_var_named(v.sort.column_ty(self.rb.egraph()), &v.name),
1485                core::GenericAtomTerm::Literal(_, l) => literal_to_entry(self.rb.egraph(), l),
1486                core::GenericAtomTerm::Global(..) => {
1487                    panic!("Globals should have been desugared")
1488                }
1489            })
1490            .clone()
1491    }
1492
1493    fn func(&self, f: &typechecking::FuncType) -> egglog_bridge::FunctionId {
1494        self.functions[&f.name].backend_id
1495    }
1496
1497    fn prim(
1498        &mut self,
1499        prim: &core::SpecializedPrimitive,
1500        args: &[core::ResolvedAtomTerm],
1501    ) -> (ExternalFunctionId, Vec<QueryEntry>, ColumnTy) {
1502        let mut qe_args = self.args(args);
1503
1504        if prim.primitive.0.name() == "unstable-fn" {
1505            let core::ResolvedAtomTerm::Literal(_, Literal::String(ref name)) = args[0] else {
1506                panic!("expected string literal after `unstable-fn`")
1507            };
1508            let id = if let Some(f) = self.type_info.get_func_type(name) {
1509                ResolvedFunctionId::Lookup(egglog_bridge::TableAction::new(
1510                    self.rb.egraph(),
1511                    self.func(f),
1512                ))
1513            } else if let Some(possible) = self.type_info.get_prims(name) {
1514                let mut ps: Vec<_> = possible.iter().collect();
1515                ps.retain(|p| {
1516                    self.type_info
1517                        .get_sorts::<FunctionSort>()
1518                        .into_iter()
1519                        .any(|f| {
1520                            let types: Vec<_> = prim
1521                                .input
1522                                .iter()
1523                                .skip(1)
1524                                .chain(f.inputs())
1525                                .chain([&f.output()])
1526                                .cloned()
1527                                .collect();
1528                            p.accept(&types, self.type_info)
1529                        })
1530                });
1531                assert!(ps.len() == 1, "options for {name}: {ps:?}");
1532                ResolvedFunctionId::Prim(ps.into_iter().next().unwrap().1)
1533            } else {
1534                panic!("no callable for {name}");
1535            };
1536            let do_rebuild = prim
1537                .input
1538                .iter()
1539                .skip(1)
1540                .map(|s| s.is_eq_sort() || s.is_eq_container_sort())
1541                .collect();
1542
1543            qe_args[0] = self.rb.egraph().base_value_constant(ResolvedFunction {
1544                id,
1545                do_rebuild,
1546                name: name.clone(),
1547            });
1548        }
1549
1550        (
1551            prim.primitive.1,
1552            qe_args,
1553            prim.output.column_ty(self.rb.egraph()),
1554        )
1555    }
1556
1557    fn args<'b>(
1558        &mut self,
1559        args: impl IntoIterator<Item = &'b core::ResolvedAtomTerm>,
1560    ) -> Vec<QueryEntry> {
1561        args.into_iter().map(|x| self.entry(x)).collect()
1562    }
1563
1564    fn query(&mut self, query: &core::Query<ResolvedCall, ResolvedVar>, include_subsumed: bool) {
1565        for atom in &query.atoms {
1566            match &atom.head {
1567                ResolvedCall::Func(f) => {
1568                    let f = self.func(f);
1569                    let args = self.args(&atom.args);
1570                    let is_subsumed = match include_subsumed {
1571                        true => None,
1572                        false => Some(false),
1573                    };
1574                    self.rb.query_table(f, &args, is_subsumed).unwrap();
1575                }
1576                ResolvedCall::Primitive(p) => {
1577                    let (p, args, ty) = self.prim(p, &atom.args);
1578                    self.rb.query_prim(p, &args, ty).unwrap()
1579                }
1580            }
1581        }
1582    }
1583
1584    fn actions(&mut self, actions: &core::ResolvedCoreActions) -> Result<(), Error> {
1585        for action in &actions.0 {
1586            match action {
1587                core::GenericCoreAction::Let(span, v, f, args) => {
1588                    let v = core::GenericAtomTerm::Var(span.clone(), v.clone());
1589                    let y = match f {
1590                        ResolvedCall::Func(f) => {
1591                            let name = f.name.clone();
1592                            let f = self.func(f);
1593                            let args = self.args(args);
1594                            let span = span.clone();
1595                            self.rb.lookup(f, &args, move || {
1596                                format!("{span}: lookup of function {name} failed")
1597                            })
1598                        }
1599                        ResolvedCall::Primitive(p) => {
1600                            let name = p.primitive.0.name().to_owned();
1601                            let (p, args, ty) = self.prim(p, args);
1602                            let span = span.clone();
1603                            self.rb.call_external_func(p, &args, ty, move || {
1604                                format!("{span}: call of primitive {name} failed")
1605                            })
1606                        }
1607                    };
1608                    self.entries.insert(v, y.into());
1609                }
1610                core::GenericCoreAction::LetAtomTerm(span, v, x) => {
1611                    let v = core::GenericAtomTerm::Var(span.clone(), v.clone());
1612                    let x = self.entry(x);
1613                    self.entries.insert(v, x);
1614                }
1615                core::GenericCoreAction::Set(_, f, xs, y) => match f {
1616                    ResolvedCall::Primitive(..) => panic!("runtime primitive set!"),
1617                    ResolvedCall::Func(f) => {
1618                        let f = self.func(f);
1619                        let args = self.args(xs.iter().chain([y]));
1620                        self.rb.set(f, &args)
1621                    }
1622                },
1623                core::GenericCoreAction::Change(span, change, f, args) => match f {
1624                    ResolvedCall::Primitive(..) => panic!("runtime primitive change!"),
1625                    ResolvedCall::Func(f) => {
1626                        let name = f.name.clone();
1627                        let can_subsume = self.functions[&f.name].can_subsume;
1628                        let f = self.func(f);
1629                        let args = self.args(args);
1630                        match change {
1631                            Change::Delete => self.rb.remove(f, &args),
1632                            Change::Subsume if can_subsume => self.rb.subsume(f, &args),
1633                            Change::Subsume => {
1634                                return Err(Error::SubsumeMergeError(name, span.clone()));
1635                            }
1636                        }
1637                    }
1638                },
1639                core::GenericCoreAction::Union(_, x, y) => {
1640                    let x = self.entry(x);
1641                    let y = self.entry(y);
1642                    self.rb.union(x, y)
1643                }
1644                core::GenericCoreAction::Panic(_, message) => self.rb.panic(message.clone()),
1645            }
1646        }
1647        Ok(())
1648    }
1649
1650    fn build(self) -> egglog_bridge::RuleId {
1651        self.rb.build()
1652    }
1653}
1654
1655fn literal_to_entry(egraph: &egglog_bridge::EGraph, l: &Literal) -> QueryEntry {
1656    match l {
1657        Literal::Int(x) => egraph.base_value_constant::<i64>(*x),
1658        Literal::Float(x) => egraph.base_value_constant::<sort::F>(x.into()),
1659        Literal::String(x) => egraph.base_value_constant::<sort::S>(sort::S::new(x.clone())),
1660        Literal::Bool(x) => egraph.base_value_constant::<bool>(*x),
1661        Literal::Unit => egraph.base_value_constant::<()>(()),
1662    }
1663}
1664
1665fn literal_to_value(egraph: &egglog_bridge::EGraph, l: &Literal) -> Value {
1666    match l {
1667        Literal::Int(x) => egraph.base_values().get::<i64>(*x),
1668        Literal::Float(x) => egraph.base_values().get::<sort::F>(x.into()),
1669        Literal::String(x) => egraph.base_values().get::<sort::S>(sort::S::new(x.clone())),
1670        Literal::Bool(x) => egraph.base_values().get::<bool>(*x),
1671        Literal::Unit => egraph.base_values().get::<()>(()),
1672    }
1673}
1674
1675#[derive(Debug, Error)]
1676pub enum Error {
1677    #[error(transparent)]
1678    ParseError(#[from] ParseError),
1679    #[error(transparent)]
1680    NotFoundError(#[from] NotFoundError),
1681    #[error(transparent)]
1682    TypeError(#[from] TypeError),
1683    #[error("Errors:\n{}", ListDisplay(.0, "\n"))]
1684    TypeErrors(Vec<TypeError>),
1685    #[error("{}\nCheck failed: \n{}", .1, ListDisplay(.0, "\n"))]
1686    CheckError(Vec<Fact>, Span),
1687    #[error("{1}\nNo such ruleset: {0}")]
1688    NoSuchRuleset(String, Span),
1689    #[error(
1690        "{1}\nAttempted to add a rule to combined ruleset {0}. Combined rulesets may only depend on other rulesets."
1691    )]
1692    CombinedRulesetError(String, Span),
1693    #[error("{0}")]
1694    BackendError(String),
1695    #[error("{0}\nTried to pop too much")]
1696    Pop(Span),
1697    #[error("{0}\nCommand should have failed.")]
1698    ExpectFail(Span),
1699    #[error("{2}\nIO error: {0}: {1}")]
1700    IoError(PathBuf, std::io::Error, Span),
1701    #[error("{1}\nCannot subsume function with merge: {0}")]
1702    SubsumeMergeError(String, Span),
1703    #[error("extraction failure: {:?}", .0)]
1704    ExtractError(String),
1705    #[error("{1}\n{2}\nShadowing is not allowed, but found {0}")]
1706    Shadowing(String, Span, Span),
1707    #[error("{1}\nCommand already exists: {0}")]
1708    CommandAlreadyExists(String, Span),
1709    #[error("Incorrect format in file '{0}'.")]
1710    InputFileFormatError(String),
1711}
1712
1713#[cfg(test)]
1714mod tests {
1715    use crate::constraint::SimpleTypeConstraint;
1716    use crate::sort::*;
1717    use crate::*;
1718
1719    #[derive(Clone)]
1720    struct InnerProduct {
1721        vec: ArcSort,
1722    }
1723
1724    impl Primitive for InnerProduct {
1725        fn name(&self) -> &str {
1726            "inner-product"
1727        }
1728
1729        fn get_type_constraints(&self, span: &Span) -> Box<dyn crate::constraint::TypeConstraint> {
1730            SimpleTypeConstraint::new(
1731                self.name(),
1732                vec![self.vec.clone(), self.vec.clone(), I64Sort.to_arcsort()],
1733                span.clone(),
1734            )
1735            .into_box()
1736        }
1737
1738        fn apply(&self, exec_state: &mut ExecutionState<'_>, args: &[Value]) -> Option<Value> {
1739            let mut sum = 0;
1740            let vec1 = exec_state
1741                .container_values()
1742                .get_val::<VecContainer>(args[0])
1743                .unwrap();
1744            let vec2 = exec_state
1745                .container_values()
1746                .get_val::<VecContainer>(args[1])
1747                .unwrap();
1748            assert_eq!(vec1.data.len(), vec2.data.len());
1749            for (a, b) in vec1.data.iter().zip(vec2.data.iter()) {
1750                let a = exec_state.base_values().unwrap::<i64>(*a);
1751                let b = exec_state.base_values().unwrap::<i64>(*b);
1752                sum += a * b;
1753            }
1754            Some(exec_state.base_values().get::<i64>(sum))
1755        }
1756    }
1757
1758    #[test]
1759    fn test_user_defined_primitive() {
1760        let mut egraph = EGraph::default();
1761        egraph
1762            .parse_and_run_program(None, "(sort IntVec (Vec i64))")
1763            .unwrap();
1764
1765        let int_vec_sort = egraph.get_arcsort_by(|s| {
1766            s.value_type() == Some(std::any::TypeId::of::<VecContainer>())
1767                && s.inner_sorts()[0].name() == I64Sort.name()
1768        });
1769
1770        egraph.add_primitive(InnerProduct { vec: int_vec_sort });
1771
1772        egraph
1773            .parse_and_run_program(
1774                None,
1775                "
1776                (let a (vec-of 1 2 3 4 5 6))
1777                (let b (vec-of 6 5 4 3 2 1))
1778                (check (= (inner-product a b) 56))
1779            ",
1780            )
1781            .unwrap();
1782    }
1783
1784    // Test that an `EGraph` is `Send` & `Sync`
1785    #[test]
1786    fn test_egraph_send_sync() {
1787        fn is_send<T: Send>(_t: &T) -> bool {
1788            true
1789        }
1790        fn is_sync<T: Sync>(_t: &T) -> bool {
1791            true
1792        }
1793        let egraph = EGraph::default();
1794        assert!(is_send(&egraph) && is_sync(&egraph));
1795    }
1796
1797    fn get_function(egraph: &EGraph, name: &str) -> Function {
1798        egraph.functions.get(name).unwrap().clone()
1799    }
1800
1801    fn get_value(egraph: &EGraph, name: &str) -> Value {
1802        let mut out = None;
1803        let id = get_function(egraph, name).backend_id;
1804        egraph.backend.for_each(id, |row| out = Some(row.vals[0]));
1805        out.unwrap()
1806    }
1807
1808    #[test]
1809    fn test_subsumed_unextractable_rebuild_arg() {
1810        // Tests that a term stays unextractable even after a rebuild after a union would change the value of one of its args
1811        let mut egraph = EGraph::default();
1812
1813        egraph
1814            .parse_and_run_program(
1815                None,
1816                r#"
1817                (datatype Math)
1818                (constructor container (Math) Math)
1819                (constructor exp () Math :cost 100)
1820                (constructor cheap () Math)
1821                (constructor cheap-1 () Math)
1822                ; we make the container cheap so that it will be extracted if possible, but then we mark it as subsumed
1823                ; so the (exp) expr should be extracted instead
1824                (let res (container (cheap)))
1825                (union res (exp))
1826                (cheap)
1827                (cheap-1)
1828                (subsume (container (cheap)))
1829                "#,
1830            ).unwrap();
1831        // At this point (cheap) and (cheap-1) should have different values, because they aren't unioned
1832        let orig_cheap_value = get_value(&egraph, "cheap");
1833        let orig_cheap_1_value = get_value(&egraph, "cheap-1");
1834        assert_ne!(orig_cheap_value, orig_cheap_1_value);
1835        // Then we can union them
1836        egraph
1837            .parse_and_run_program(
1838                None,
1839                r#"
1840                (union (cheap-1) (cheap))
1841                "#,
1842            )
1843            .unwrap();
1844        // And verify that their values are now the same and different from the original (cheap) value.
1845        let new_cheap_value = get_value(&egraph, "cheap");
1846        let new_cheap_1_value = get_value(&egraph, "cheap-1");
1847        assert_eq!(new_cheap_value, new_cheap_1_value);
1848        assert!(new_cheap_value != orig_cheap_value || new_cheap_1_value != orig_cheap_1_value);
1849        // Now verify that if we extract, it still respects the unextractable, even though it's a different values now
1850        let outputs = egraph
1851            .parse_and_run_program(
1852                None,
1853                r#"
1854                (extract res)
1855                "#,
1856            )
1857            .unwrap();
1858        assert_eq!(outputs[0].to_string(), "(exp)\n");
1859    }
1860
1861    #[test]
1862    fn test_subsumed_unextractable_rebuild_self() {
1863        // Tests that a term stays unextractable even after a rebuild after a union change its output value.
1864        let mut egraph = EGraph::default();
1865
1866        egraph
1867            .parse_and_run_program(
1868                None,
1869                r#"
1870                (datatype Math)
1871                (constructor container (Math) Math)
1872                (constructor exp () Math :cost 100)
1873                (constructor cheap () Math)
1874                (exp)
1875                (let x (cheap))
1876                (subsume (cheap))
1877                "#,
1878            )
1879            .unwrap();
1880
1881        let orig_cheap_value = get_value(&egraph, "cheap");
1882        // Then we can union them
1883        egraph
1884            .parse_and_run_program(
1885                None,
1886                r#"
1887                (union (exp) x)
1888                "#,
1889            )
1890            .unwrap();
1891        // And verify that the cheap value is now different
1892        let new_cheap_value = get_value(&egraph, "cheap");
1893        assert_ne!(new_cheap_value, orig_cheap_value);
1894
1895        // Now verify that if we extract, it still respects the subsumption, even though it's a different values now
1896        let res = egraph
1897            .parse_and_run_program(
1898                None,
1899                r#"
1900                (extract x)
1901                "#,
1902            )
1903            .unwrap();
1904        assert_eq!(res[0].to_string(), "(exp)\n");
1905    }
1906}