egglog/
prelude.rs

1//! This module makes it easier to use `egglog` from Rust.
2//! ```
3//! use egglog::prelude::*;
4//! ```
5//!
6//! Common entry points:
7//!
8//! - [`rule`] / [`rust_rule`] — add rules whose RHS is egglog code or a
9//!   Rust closure (the closure receives an
10//!   [`crate::WriteState`]).
11//! - [`query`] — run a one-shot query and read out matches.
12//! - [`BaseSort`] / [`ContainerSort`] — declare custom sort types.
13//! - [`crate::PurePrim`] / [`crate::WritePrim`] / [`crate::ReadPrim`] /
14//!   [`crate::FullPrim`] — register custom primitives via the matching
15//!   `EGraph::add_*_primitive` method.
16//! - The [`add_primitive!`] / [`add_primitive_with_validator!`] /
17//!   [`add_literal_prim!`] macros (re-exported via `egglog::*`) cover
18//!   the simple "pure native function" case.
19
20use crate::*;
21use std::any::{Any, TypeId};
22
23// Re-exports in `prelude` for convenience.
24pub use egglog::ast::{Action, Fact, Facts, GenericActions, RustSpan, Span};
25pub use egglog::sort::{BigIntSort, BigRatSort, BoolSort, F64Sort, I64Sort, StringSort, UnitSort};
26pub use egglog::{CommandMacro, CommandMacroRegistry};
27pub use egglog::{Core, FullState, PureState, Read, ReadState, Write, WriteState};
28pub use egglog::{EGraph, span};
29pub use egglog::{action, actions, datatype, expr, fact, facts, sort, vars};
30
31/// Trait for types that can be converted to/from Literal for use in validated primitives.
32/// This enables automatic validator generation for literal primitives.
33pub trait LiteralConvertible: Sized {
34    fn to_literal(self) -> egglog_ast::generic_ast::Literal;
35    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self>;
36}
37
38impl LiteralConvertible for i64 {
39    fn to_literal(self) -> egglog_ast::generic_ast::Literal {
40        egglog_ast::generic_ast::Literal::Int(self)
41    }
42    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
43        match lit {
44            egglog_ast::generic_ast::Literal::Int(i) => Some(*i),
45            _ => None,
46        }
47    }
48}
49
50impl LiteralConvertible for bool {
51    fn to_literal(self) -> egglog_ast::generic_ast::Literal {
52        egglog_ast::generic_ast::Literal::Bool(self)
53    }
54    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
55        match lit {
56            egglog_ast::generic_ast::Literal::Bool(b) => Some(*b),
57            _ => None,
58        }
59    }
60}
61
62impl LiteralConvertible for ordered_float::OrderedFloat<f64> {
63    fn to_literal(self) -> egglog_ast::generic_ast::Literal {
64        egglog_ast::generic_ast::Literal::Float(self)
65    }
66    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
67        match lit {
68            egglog_ast::generic_ast::Literal::Float(f) => Some(*f),
69            _ => None,
70        }
71    }
72}
73
74impl LiteralConvertible for egglog::sort::F {
75    fn to_literal(self) -> egglog_ast::generic_ast::Literal {
76        egglog_ast::generic_ast::Literal::Float(self.0)
77    }
78    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
79        match lit {
80            egglog_ast::generic_ast::Literal::Float(f) => Some(egglog::sort::F::from(*f)),
81            _ => None,
82        }
83    }
84}
85
86impl LiteralConvertible for egglog::sort::S {
87    fn to_literal(self) -> egglog_ast::generic_ast::Literal {
88        egglog_ast::generic_ast::Literal::String(self.0)
89    }
90    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
91        match lit {
92            egglog_ast::generic_ast::Literal::String(s) => Some(egglog::sort::S::new(s.clone())),
93            _ => None,
94        }
95    }
96}
97
98impl LiteralConvertible for () {
99    fn to_literal(self) -> egglog_ast::generic_ast::Literal {
100        egglog_ast::generic_ast::Literal::Unit
101    }
102    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
103        match lit {
104            egglog_ast::generic_ast::Literal::Unit => Some(()),
105            _ => None,
106        }
107    }
108}
109
110pub mod exprs {
111    use super::*;
112
113    /// Creates a variable expression.
114    pub fn var(name: &str) -> Expr {
115        Expr::Var(span!(), name.to_owned())
116    }
117
118    /// Creates an integer literal expression.
119    pub fn int(value: i64) -> Expr {
120        Expr::Lit(span!(), Literal::Int(value))
121    }
122
123    /// Creates a float literal expression.
124    pub fn float(value: f64) -> Expr {
125        Expr::Lit(span!(), Literal::Float(value.into()))
126    }
127
128    /// Creates a string literal expression.
129    pub fn string(value: &str) -> Expr {
130        Expr::Lit(span!(), Literal::String(value.to_owned()))
131    }
132
133    /// Creates a unit literal expression.
134    pub fn unit() -> Expr {
135        Expr::Lit(span!(), Literal::Unit)
136    }
137
138    /// Creates a boolean literal expression.
139    pub fn bool(value: bool) -> Expr {
140        Expr::Lit(span!(), Literal::Bool(value))
141    }
142
143    /// Creates a function call expression.
144    pub fn call(f: &str, xs: Vec<Expr>) -> Expr {
145        Expr::Call(span!(), f.to_owned(), xs)
146    }
147}
148
149/// Create a new ruleset.
150pub fn add_ruleset(egraph: &mut EGraph, ruleset: &str) -> Result<Vec<CommandOutput>, Error> {
151    egraph.run_program(vec![Command::AddRuleset(span!(), ruleset.to_owned())])
152}
153
154/// Run one iteration of a ruleset.
155pub fn run_ruleset(egraph: &mut EGraph, ruleset: &str) -> Result<Vec<CommandOutput>, Error> {
156    egraph.run_program(vec![Command::RunSchedule(Schedule::Run(
157        span!(),
158        RunConfig {
159            ruleset: ruleset.to_owned(),
160            until: None,
161        },
162    ))])
163}
164
165#[macro_export]
166macro_rules! sort {
167    (BigInt) => {
168        BigIntSort.to_arcsort()
169    };
170    (BigRat) => {
171        BigRatSort.to_arcsort()
172    };
173    (bool) => {
174        BoolSort.to_arcsort()
175    };
176    (f64) => {
177        F64Sort.to_arcsort()
178    };
179    (i64) => {
180        I64Sort.to_arcsort()
181    };
182    (String) => {
183        StringSort.to_arcsort()
184    };
185    (Unit) => {
186        UnitSort.to_arcsort()
187    };
188    ($t:expr) => {
189        $t
190    };
191}
192
193#[macro_export]
194macro_rules! vars {
195    [$($x:ident : $t:tt),* $(,)?] => {
196        &[$((stringify!($x), sort!($t))),*]
197    };
198}
199
200#[macro_export]
201macro_rules! expr {
202    ((unquote $unquoted:expr)) => { $unquoted };
203    (($func:tt $($arg:tt)*)) => { exprs::call(stringify!($func), vec![$(expr!($arg)),*]) };
204    ($value:literal) => { exprs::int($value) };
205    ($quoted:tt) => { exprs::var(stringify!($quoted)) };
206}
207
208#[macro_export]
209macro_rules! fact {
210    ((= $($arg:tt)*)) => { Fact::Eq(span!(), $(expr!($arg)),*) };
211    ($a:tt) => { Fact::Fact(expr!($a)) };
212}
213
214#[macro_export]
215macro_rules! facts {
216    ($($tree:tt)*) => { Facts(vec![$(fact!($tree)),*]) };
217}
218
219#[macro_export]
220macro_rules! action {
221    ((let $name:ident $value:tt)) => {
222        Action::Let(span!(), String::from(stringify!($name)), expr!($value))
223    };
224    ((set ($f:ident $($x:tt)*) $value:tt)) => {
225        Action::Set(span!(), String::from(stringify!($f)), vec![$(expr!($x)),*], expr!($value))
226    };
227    ((delete ($f:ident $($x:tt)*))) => {
228        Action::Change(span!(), Change::Delete, String::from(stringify!($f)), vec![$(expr!($x)),*])
229    };
230    ((subsume ($f:ident $($x:tt)*))) => {
231        Action::Change(span!(), Change::Subsume, String::from(stringify!($f)), vec![$(expr!($x)),*])
232    };
233    ((union $x:tt $y:tt)) => {
234        Action::Union(span!(), expr!($x), expr!($y))
235    };
236    ((panic $message:literal)) => {
237        Action::Panic(span!(), $message.to_owned())
238    };
239    ($x:tt) => {
240        Action::Expr(span!(), expr!($x))
241    };
242}
243
244#[macro_export]
245macro_rules! actions {
246    ($($tree:tt)*) => { GenericActions(vec![$(action!($tree)),*]) };
247}
248
249/// Add a rule to the e-graph whose right-hand side is made up of actions.
250/// ```
251/// use egglog::prelude::*;
252///
253/// let mut egraph = EGraph::default();
254/// egraph.parse_and_run_program(
255///     None,
256///     "
257/// (function fib (i64) i64 :no-merge)
258/// (set (fib 0) 0)
259/// (set (fib 1) 1)
260/// (rule (
261///     (= f0 (fib x))
262///     (= f1 (fib (+ x 1)))
263/// ) (
264///     (set (fib (+ x 2)) (+ f0 f1))
265/// ))
266/// (run 10)
267///     ",
268/// )?;
269///
270/// let big_number = 20;
271///
272/// // check that `(fib 20)` is not in the e-graph
273/// let results = query(
274///     &mut egraph,
275///     vars![f: i64],
276///     facts![(= (fib (unquote exprs::int(big_number))) f)],
277/// )?;
278///
279/// assert!(results.iter().next().is_none());
280///
281/// let ruleset = "custom_ruleset";
282/// add_ruleset(&mut egraph, ruleset)?;
283///
284/// // add the rule from `build_test_database` to the egraph
285/// rule(
286///     &mut egraph,
287///     ruleset,
288///     facts![
289///         (= f0 (fib x))
290///         (= f1 (fib (+ x 1)))
291///     ],
292///     actions![
293///         (set (fib (+ x 2)) (+ f0 f1))
294///     ],
295/// )?;
296///
297/// // run that rule 10 times
298/// for _ in 0..10 {
299///     run_ruleset(&mut egraph, ruleset)?;
300/// }
301///
302/// // check that `(fib 20)` is now in the e-graph
303/// let results = query(
304///     &mut egraph,
305///     vars![f: i64],
306///     facts![(= (fib (unquote exprs::int(big_number))) f)],
307/// )?;
308///
309/// let y = egraph.base_to_value::<i64>(6765);
310/// let results: Vec<_> = results.iter().collect();
311/// assert_eq!(results, [[y]]);
312///
313/// # Ok::<(), egglog::Error>(())
314/// ```
315pub fn rule(
316    egraph: &mut EGraph,
317    ruleset: &str,
318    facts: Facts<String, String>,
319    actions: Actions,
320) -> Result<Vec<CommandOutput>, Error> {
321    let rule = Rule {
322        span: span!(),
323        head: actions,
324        body: facts.0,
325        name: "".into(),
326        ruleset: ruleset.into(),
327        naive: false,
328        no_decomp: false,
329    };
330
331    egraph.run_program(vec![Command::Rule { rule }])
332}
333
334#[derive(Clone)]
335struct RustRuleRhs<F>
336where
337    F: for<'a, 'db> Fn(crate::WriteState<'a, 'db>, &[Value]) -> Option<()>
338        + Clone
339        + Send
340        + Sync
341        + 'static,
342{
343    name: String,
344    inputs: Vec<ArcSort>,
345    func: F,
346}
347
348impl<F> Primitive for RustRuleRhs<F>
349where
350    F: for<'a, 'db> Fn(crate::WriteState<'a, 'db>, &[Value]) -> Option<()>
351        + Clone
352        + Send
353        + Sync
354        + 'static,
355{
356    fn name(&self) -> &str {
357        &self.name
358    }
359
360    fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
361        let sorts: Vec<_> = self
362            .inputs
363            .iter()
364            .chain(once(&UnitSort.to_arcsort()))
365            .cloned()
366            .collect();
367        SimpleTypeConstraint::new(self.name(), sorts, span.clone()).into_box()
368    }
369}
370
371impl<F> WritePrim for RustRuleRhs<F>
372where
373    F: for<'a, 'db> Fn(crate::WriteState<'a, 'db>, &[Value]) -> Option<()>
374        + Clone
375        + Send
376        + Sync
377        + 'static,
378{
379    fn apply<'a, 'db>(&self, state: crate::WriteState<'a, 'db>, values: &[Value]) -> Option<Value> {
380        let unit = state.base_values().get(());
381        (self.func)(state, values)?;
382        Some(unit)
383    }
384}
385
386/// Add a rule to the e-graph whose right-hand side is a Rust callback.
387/// ```
388/// use egglog::prelude::*;
389///
390/// let mut egraph = EGraph::default();
391/// egraph.parse_and_run_program(
392///     None,
393///     "
394/// (function fib (i64) i64 :no-merge)
395/// (set (fib 0) 0)
396/// (set (fib 1) 1)
397/// (rule (
398///     (= f0 (fib x))
399///     (= f1 (fib (+ x 1)))
400/// ) (
401///     (set (fib (+ x 2)) (+ f0 f1))
402/// ))
403/// (run 10)
404///     ",
405/// )?;
406///
407/// let big_number = 20;
408///
409/// // check that `(fib 20)` is not in the e-graph
410/// let results = query(
411///     &mut egraph,
412///     vars![f: i64],
413///     facts![(= (fib (unquote exprs::int(big_number))) f)],
414/// )?;
415///
416/// assert!(results.iter().next().is_none());
417///
418/// let ruleset = "custom_ruleset";
419/// add_ruleset(&mut egraph, ruleset)?;
420///
421/// // add the rule from `build_test_database` to the egraph
422/// rust_rule(
423///     &mut egraph,
424///     "fib_rule",
425///     ruleset,
426///     vars![x: i64, f0: i64, f1: i64],
427///     facts![
428///         (= f0 (fib x))
429///         (= f1 (fib (+ x 1)))
430///     ],
431///     move |mut ctx, values| {
432///         let [x, f0, f1] = values else { unreachable!() };
433///         let x = ctx.value_to_base::<i64>(*x);
434///         let f0 = ctx.value_to_base::<i64>(*f0);
435///         let f1 = ctx.value_to_base::<i64>(*f1);
436///
437///         let y = ctx.base_to_value::<i64>(x + 2);
438///         let f2 = ctx.base_to_value::<i64>(f0 + f1);
439///         ctx.insert("fib", [y, f2].into_iter());
440///
441///         Some(())
442///     },
443/// )?;
444///
445/// // run that rule 10 times
446/// for _ in 0..10 {
447///     run_ruleset(&mut egraph, ruleset)?;
448/// }
449///
450/// // check that `(fib 20)` is now in the e-graph
451/// let results = query(
452///     &mut egraph,
453///     vars![f: i64],
454///     facts![(= (fib (unquote exprs::int(big_number))) f)],
455/// )?;
456///
457/// let y = egraph.base_to_value::<i64>(6765);
458/// let results: Vec<_> = results.iter().collect();
459/// assert_eq!(results, [[y]]);
460///
461/// # Ok::<(), egglog::Error>(())
462/// ```
463pub fn rust_rule(
464    egraph: &mut EGraph,
465    rule_name: &str,
466    ruleset: &str,
467    vars: &[(&str, ArcSort)],
468    facts: Facts<String, String>,
469    func: impl for<'a, 'db> Fn(crate::WriteState<'a, 'db>, &[Value]) -> Option<()>
470    + Clone
471    + Send
472    + Sync
473    + 'static,
474) -> Result<Vec<CommandOutput>, Error> {
475    let prim_name = egraph.parser.symbol_gen.fresh("rust_rule_prim");
476    egraph.add_write_primitive(
477        RustRuleRhs {
478            name: prim_name.clone(),
479            inputs: vars.iter().map(|(_, s)| s.clone()).collect(),
480            func,
481        },
482        None,
483    );
484
485    let rule = Rule {
486        span: span!(),
487        head: GenericActions(vec![GenericAction::Expr(
488            span!(),
489            exprs::call(
490                &prim_name,
491                vars.iter().map(|(v, _)| exprs::var(v)).collect(),
492            ),
493        )]),
494        body: facts.0,
495        name: egraph.parser.symbol_gen.fresh(rule_name),
496        ruleset: ruleset.into(),
497        naive: false,
498        no_decomp: false,
499    };
500
501    egraph.run_program(vec![Command::Rule { rule }])
502}
503
504#[derive(Clone)]
505struct RustRuleFullRhs<F>
506where
507    F: for<'a, 'db> Fn(crate::FullState<'a, 'db>, &[Value]) -> Option<()>
508        + Clone
509        + Send
510        + Sync
511        + 'static,
512{
513    name: String,
514    inputs: Vec<ArcSort>,
515    func: F,
516}
517
518impl<F> Primitive for RustRuleFullRhs<F>
519where
520    F: for<'a, 'db> Fn(crate::FullState<'a, 'db>, &[Value]) -> Option<()>
521        + Clone
522        + Send
523        + Sync
524        + 'static,
525{
526    fn name(&self) -> &str {
527        &self.name
528    }
529
530    fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
531        let sorts: Vec<_> = self
532            .inputs
533            .iter()
534            .chain(once(&UnitSort.to_arcsort()))
535            .cloned()
536            .collect();
537        SimpleTypeConstraint::new(self.name(), sorts, span.clone()).into_box()
538    }
539}
540
541impl<F> crate::FullPrim for RustRuleFullRhs<F>
542where
543    F: for<'a, 'db> Fn(crate::FullState<'a, 'db>, &[Value]) -> Option<()>
544        + Clone
545        + Send
546        + Sync
547        + 'static,
548{
549    fn apply<'a, 'db>(&self, state: crate::FullState<'a, 'db>, values: &[Value]) -> Option<Value> {
550        let unit = state.base_values().get(());
551        (self.func)(state, values)?;
552        Some(unit)
553    }
554}
555
556/// Like [`rust_rule`], but the action callback receives a [`FullState`]
557/// — it can read tables (`ctx.lookup`) in addition to writing them.
558/// Action callbacks of `FullPrim` are only valid in `Context::Full`,
559/// so this helper marks the generated rule `:naive`: the body matches
560/// against the entire database every iteration instead of using the
561/// seminaive delta. Use this when the action genuinely needs to look
562/// up rows; prefer [`rust_rule`] when the data can be bound via the
563/// matcher in the rule body.
564pub fn rust_rule_full(
565    egraph: &mut EGraph,
566    rule_name: &str,
567    ruleset: &str,
568    vars: &[(&str, ArcSort)],
569    facts: Facts<String, String>,
570    func: impl for<'a, 'db> Fn(crate::FullState<'a, 'db>, &[Value]) -> Option<()>
571    + Clone
572    + Send
573    + Sync
574    + 'static,
575) -> Result<Vec<CommandOutput>, Error> {
576    let prim_name = egraph.parser.symbol_gen.fresh("rust_rule_full_prim");
577    egraph.add_full_primitive(
578        RustRuleFullRhs {
579            name: prim_name.clone(),
580            inputs: vars.iter().map(|(_, s)| s.clone()).collect(),
581            func,
582        },
583        None,
584    );
585
586    let rule = Rule {
587        span: span!(),
588        head: GenericActions(vec![GenericAction::Expr(
589            span!(),
590            exprs::call(
591                &prim_name,
592                vars.iter().map(|(v, _)| exprs::var(v)).collect(),
593            ),
594        )]),
595        body: facts.0,
596        name: egraph.parser.symbol_gen.fresh(rule_name),
597        ruleset: ruleset.into(),
598        // FullPrim action requires `Context::Full`, which is only
599        // available in `:naive` rules.
600        naive: true,
601        no_decomp: false,
602    };
603
604    egraph.run_program(vec![Command::Rule { rule }])
605}
606
607/// The result of a query.
608pub struct QueryResult {
609    rows: usize,
610    cols: usize,
611    data: Vec<Value>,
612}
613
614impl QueryResult {
615    /// Get an iterator over the query results,
616    /// where each match is a `&[Value]` in the same order
617    /// as the `vars` that were passed to `query`.
618    pub fn iter(&self) -> impl Iterator<Item = &[Value]> {
619        assert!(self.cols > 0, "no vars; use `any_matches` instead");
620        assert!(self.data.len().is_multiple_of(self.cols));
621        self.data.chunks_exact(self.cols)
622    }
623
624    /// Check if any matches were returned at all.
625    pub fn any_matches(&self) -> bool {
626        self.rows > 0
627    }
628}
629
630/// Run a query over the database.
631/// ```
632/// use egglog::prelude::*;
633///
634/// let mut egraph = EGraph::default();
635/// egraph.parse_and_run_program(
636///     None,
637///     "
638/// (function fib (i64) i64 :no-merge)
639/// (set (fib 0) 0)
640/// (set (fib 1) 1)
641/// (rule (
642///     (= f0 (fib x))
643///     (= f1 (fib (+ x 1)))
644/// ) (
645///     (set (fib (+ x 2)) (+ f0 f1))
646/// ))
647/// (run 10)
648///     ",
649/// )?;
650///
651/// let results = query(
652///     &mut egraph,
653///     vars![x: i64, y: i64],
654///     facts![
655///         (= (fib x) y)
656///         (= y 13)
657///     ],
658/// )?;
659///
660/// let x = egraph.base_to_value::<i64>(7);
661/// let y = egraph.base_to_value::<i64>(13);
662/// let results: Vec<_> = results.iter().collect();
663/// assert_eq!(results, [[x, y]]);
664///
665/// # Ok::<(), egglog::Error>(())
666/// ```
667pub fn query(
668    egraph: &mut EGraph,
669    vars: &[(&str, ArcSort)],
670    facts: Facts<String, String>,
671) -> Result<QueryResult, Error> {
672    use std::sync::{Arc, Mutex};
673
674    let results = Arc::new(Mutex::new(QueryResult {
675        rows: 0,
676        cols: vars.len(),
677        data: Vec::new(),
678    }));
679    let results_weak = Arc::downgrade(&results);
680
681    let ruleset = egraph.parser.symbol_gen.fresh("query_ruleset");
682    add_ruleset(egraph, &ruleset)?;
683
684    rust_rule(egraph, "query", &ruleset, vars, facts, move |_, values| {
685        let arc = results_weak.upgrade().unwrap();
686        let mut results = arc.lock().unwrap();
687        results.rows += 1;
688        results.data.extend(values);
689        Some(())
690    })?;
691
692    run_ruleset(egraph, &ruleset)?;
693
694    let ruleset = egraph.rulesets.swap_remove(&ruleset).unwrap();
695
696    let Ruleset::Rules(rules) = ruleset else {
697        unreachable!()
698    };
699    assert_eq!(rules.len(), 1);
700    let rule = rules.into_iter().next().unwrap().1;
701    egraph.backend.free_rule(rule.1);
702
703    let Some(mutex) = Arc::into_inner(results) else {
704        panic!("results_weak.upgrade() was not dropped");
705    };
706    Ok(mutex.into_inner().unwrap())
707}
708
709/// Declare a new sort.
710pub fn add_sort(egraph: &mut EGraph, name: &str) -> Result<Vec<CommandOutput>, Error> {
711    egraph.run_program(vec![Command::Sort {
712        span: span!(),
713        name: name.to_owned(),
714        presort_and_args: None,
715        uf: None,
716        proof_func: None,
717        unionable: true,
718    }])
719}
720
721/// Declare a new function table.
722pub fn add_function(
723    egraph: &mut EGraph,
724    name: &str,
725    schema: Schema,
726    merge: Option<GenericExpr<String, String>>,
727) -> Result<Vec<CommandOutput>, Error> {
728    egraph.run_program(vec![Command::Function {
729        span: span!(),
730        name: name.to_owned(),
731        schema,
732        merge,
733        hidden: false,
734        let_binding: false,
735        term_constructor: None,
736        unextractable: false,
737    }])
738}
739
740/// Declare a new constructor table.
741pub fn add_constructor(
742    egraph: &mut EGraph,
743    name: &str,
744    schema: Schema,
745    cost: Option<DefaultCost>,
746    unextractable: bool,
747) -> Result<Vec<CommandOutput>, Error> {
748    egraph.run_program(vec![Command::Constructor {
749        span: span!(),
750        name: name.to_owned(),
751        schema,
752        cost,
753        unextractable,
754        hidden: false,
755        let_binding: false,
756        term_constructor: None,
757    }])
758}
759
760/// Declare a new relation table.
761pub fn add_relation(
762    egraph: &mut EGraph,
763    name: &str,
764    inputs: Vec<String>,
765) -> Result<Vec<CommandOutput>, Error> {
766    egraph.run_program(vec![Command::Relation {
767        span: span!(),
768        name: name.to_owned(),
769        inputs,
770    }])
771}
772
773/// Adds sorts and constructor tables to the database.
774#[macro_export]
775macro_rules! datatype {
776    ($egraph:expr, (datatype $sort:ident $(($name:ident $($args:ident)* $(:cost $cost:expr)?))*)) => {
777        add_sort($egraph, stringify!($sort))?;
778        $(add_constructor(
779            $egraph,
780            stringify!($name),
781            Schema {
782                input: vec![$(stringify!($args).to_owned()),*],
783                output: stringify!($sort).to_owned(),
784            },
785            [$($cost)*].first().copied(),
786            false,
787        )?;)*
788    };
789}
790
791/// A "default" implementation of [`Sort`] for simple types
792/// which just want to put some data in the e-graph. If you
793/// implement this trait, do not implement `Sort` or
794/// `ContainerSort. Use `add_base_sort` to register base
795/// sorts with the `EGraph`. See `Sort` for documentation
796/// of the methods. Do not override `to_arcsort`.
797pub trait BaseSort: Any + Send + Sync + Debug {
798    type Base: BaseValue;
799    fn name(&self) -> &str;
800    fn register_primitives(&self, _eg: &mut EGraph) {}
801    fn reconstruct_termdag(&self, _: &BaseValues, _: Value, _: &mut TermDag) -> TermId;
802
803    fn to_arcsort(self) -> ArcSort
804    where
805        Self: Sized,
806    {
807        Arc::new(BaseSortImpl(self))
808    }
809}
810
811#[derive(Debug)]
812struct BaseSortImpl<T: BaseSort>(T);
813
814impl<T: BaseSort> Sort for BaseSortImpl<T> {
815    fn name(&self) -> &str {
816        self.0.name()
817    }
818
819    fn column_ty(&self, backend: &egglog_bridge::EGraph) -> ColumnTy {
820        ColumnTy::Base(backend.base_values().get_ty::<T::Base>())
821    }
822
823    fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
824        backend.base_values_mut().register_type::<T::Base>();
825    }
826
827    fn value_type(&self) -> Option<TypeId> {
828        Some(TypeId::of::<T::Base>())
829    }
830
831    fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
832        self
833    }
834
835    fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
836        self.0.register_primitives(eg)
837    }
838
839    /// Reconstruct a leaf base value in a TermDag
840    fn reconstruct_termdag_base(
841        &self,
842        base_values: &BaseValues,
843        value: Value,
844        termdag: &mut TermDag,
845    ) -> TermId {
846        self.0.reconstruct_termdag(base_values, value, termdag)
847    }
848}
849
850/// A "default" implementation of [`Sort`] for types which
851/// just want to store a pure data structure in the e-graph.
852/// If you implement this trait, do not implement `Sort` or
853/// `BaseSort`. Use `add_container_sort` to register container
854/// sorts with the `EGraph`. See `Sort` for documentation
855/// of the methods. Do not override `to_arcsort`.
856pub trait ContainerSort: Any + Send + Sync + Debug {
857    type Container: ContainerValue;
858    fn name(&self) -> &str;
859    fn is_eq_container_sort(&self) -> bool;
860    fn inner_sorts(&self) -> Vec<ArcSort>;
861    fn inner_values(&self, _: &ContainerValues, _: Value) -> Vec<(ArcSort, Value)>;
862    fn register_primitives(&self, _eg: &mut EGraph) {}
863    fn reconstruct_termdag(
864        &self,
865        _: &ContainerValues,
866        _: Value,
867        _: &mut TermDag,
868        _: Vec<TermId>,
869    ) -> TermId;
870    fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String;
871
872    fn to_arcsort(self) -> ArcSort
873    where
874        Self: Sized,
875    {
876        Arc::new(ContainerSortImpl(self))
877    }
878}
879
880#[derive(Debug)]
881struct ContainerSortImpl<T: ContainerSort>(T);
882
883impl<T: ContainerSort> Sort for ContainerSortImpl<T> {
884    fn name(&self) -> &str {
885        self.0.name()
886    }
887
888    fn column_ty(&self, _backend: &egglog_bridge::EGraph) -> ColumnTy {
889        ColumnTy::Id
890    }
891
892    fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
893        backend.register_container_ty::<T::Container>();
894    }
895
896    fn value_type(&self) -> Option<TypeId> {
897        Some(TypeId::of::<T::Container>())
898    }
899
900    fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
901        self
902    }
903
904    fn inner_sorts(&self) -> Vec<ArcSort> {
905        self.0.inner_sorts()
906    }
907
908    fn inner_values(
909        &self,
910        container_values: &ContainerValues,
911        value: Value,
912    ) -> Vec<(ArcSort, Value)> {
913        self.0.inner_values(container_values, value)
914    }
915
916    fn is_container_sort(&self) -> bool {
917        true
918    }
919
920    fn is_eq_container_sort(&self) -> bool {
921        self.0.is_eq_container_sort()
922    }
923
924    fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String {
925        self.0.serialized_name(container_values, value)
926    }
927
928    fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
929        self.0.register_primitives(eg);
930    }
931
932    fn reconstruct_termdag_container(
933        &self,
934        container_values: &ContainerValues,
935        value: Value,
936        termdag: &mut TermDag,
937        element_terms: Vec<TermId>,
938    ) -> TermId {
939        self.0
940            .reconstruct_termdag(container_values, value, termdag, element_terms)
941    }
942}
943
944/// Add a [`BaseSort`] to the e-graph
945pub fn add_base_sort(
946    egraph: &mut EGraph,
947    base_sort: impl BaseSort,
948    span: Span,
949) -> Result<(), TypeError> {
950    egraph.add_sort(BaseSortImpl(base_sort), span)
951}
952
953pub fn add_container_sort(
954    egraph: &mut EGraph,
955    container_sort: impl ContainerSort,
956    span: Span,
957) -> Result<(), TypeError> {
958    egraph.add_sort(ContainerSortImpl(container_sort), span)
959}
960
961#[cfg(test)]
962mod tests {
963    use super::*;
964
965    fn build_test_database() -> Result<EGraph, Error> {
966        let mut egraph = EGraph::default();
967        egraph.parse_and_run_program(
968            None,
969            "
970(function fib (i64) i64 :no-merge)
971(set (fib 0) 0)
972(set (fib 1) 1)
973(rule (
974    (= f0 (fib x))
975    (= f1 (fib (+ x 1)))
976) (
977    (set (fib (+ x 2)) (+ f0 f1))
978))
979(run 10)
980        ",
981        )?;
982        Ok(egraph)
983    }
984
985    #[test]
986    fn rust_api_query() -> Result<(), Error> {
987        let mut egraph = build_test_database()?;
988
989        let results = query(
990            &mut egraph,
991            vars![x: i64, y: i64],
992            facts![
993                (= (fib x) y)
994                (= y 13)
995            ],
996        )?;
997
998        let x = egraph.backend.base_values().get::<i64>(7);
999        let y = egraph.backend.base_values().get::<i64>(13);
1000        assert_eq!(results.data, [x, y]);
1001
1002        Ok(())
1003    }
1004
1005    #[test]
1006    fn rust_api_rule() -> Result<(), Error> {
1007        let mut egraph = build_test_database()?;
1008
1009        let big_number = 20;
1010
1011        // check that `(fib 20)` is not in the e-graph
1012        let results = query(
1013            &mut egraph,
1014            vars![f: i64],
1015            facts![(= (fib (unquote exprs::int(big_number))) f)],
1016        )?;
1017
1018        assert!(results.data.is_empty());
1019
1020        let ruleset = "custom_ruleset";
1021        add_ruleset(&mut egraph, ruleset)?;
1022
1023        // add the rule from `build_test_database` to the egraph
1024        rule(
1025            &mut egraph,
1026            ruleset,
1027            facts![
1028                (= f0 (fib x))
1029                (= f1 (fib (+ x 1)))
1030            ],
1031            actions![
1032                (set (fib (+ x 2)) (+ f0 f1))
1033            ],
1034        )?;
1035
1036        // run that rule 10 times
1037        for _ in 0..10 {
1038            run_ruleset(&mut egraph, ruleset)?;
1039        }
1040
1041        // check that `(fib 20)` is now in the e-graph
1042        let results = query(
1043            &mut egraph,
1044            vars![f: i64],
1045            facts![(= (fib (unquote exprs::int(big_number))) f)],
1046        )?;
1047
1048        let y = egraph.backend.base_values().get::<i64>(6765);
1049        assert_eq!(results.data, [y]);
1050
1051        Ok(())
1052    }
1053
1054    #[test]
1055    fn rust_api_macros() -> Result<(), Error> {
1056        let mut egraph = build_test_database()?;
1057
1058        datatype!(&mut egraph, (datatype Expr (One) (Two Expr Expr :cost 10)));
1059
1060        let ruleset = "custom_ruleset";
1061        add_ruleset(&mut egraph, ruleset)?;
1062
1063        rule(
1064            &mut egraph,
1065            ruleset,
1066            facts![
1067                (fib 5)
1068                (fib x)
1069                (= f1 (fib (+ x 1)))
1070                (= 3 (unquote exprs::int(1 + 2)))
1071            ],
1072            actions![
1073                (let y (+ x 2))
1074                (set (fib (+ x 2)) (+ f1 f1))
1075                (delete (fib 0))
1076                (subsume (Two (One) (One)))
1077                (union (One) (Two (One) (One)))
1078                (panic "message")
1079                (+ 6 87)
1080            ],
1081        )?;
1082
1083        Ok(())
1084    }
1085
1086    #[test]
1087    fn rust_api_rust_rule() -> Result<(), Error> {
1088        let mut egraph = build_test_database()?;
1089
1090        let big_number = 20;
1091
1092        // check that `(fib 20)` is not in the e-graph
1093        let results = query(
1094            &mut egraph,
1095            vars![f: i64],
1096            facts![(= (fib (unquote exprs::int(big_number))) f)],
1097        )?;
1098
1099        assert!(results.data.is_empty());
1100
1101        let ruleset = "custom_ruleset";
1102        add_ruleset(&mut egraph, ruleset)?;
1103
1104        // add the rule from `build_test_database` to the egraph
1105        rust_rule(
1106            &mut egraph,
1107            "demo_rule",
1108            ruleset,
1109            vars![x: i64, f0: i64, f1: i64],
1110            facts![
1111                (= f0 (fib x))
1112                (= f1 (fib (+ x 1)))
1113            ],
1114            move |mut ctx, values| {
1115                let [x, f0, f1] = values else { unreachable!() };
1116                let x = ctx.value_to_base::<i64>(*x);
1117                let f0 = ctx.value_to_base::<i64>(*f0);
1118                let f1 = ctx.value_to_base::<i64>(*f1);
1119
1120                let y = ctx.base_to_value::<i64>(x + 2);
1121                let f2 = ctx.base_to_value::<i64>(f0 + f1);
1122                ctx.insert("fib", [y, f2].into_iter());
1123
1124                Some(())
1125            },
1126        )?;
1127
1128        // run that rule 10 times
1129        for _ in 0..10 {
1130            run_ruleset(&mut egraph, ruleset)?;
1131        }
1132
1133        // check that `(fib 20)` is now in the e-graph
1134        let results = query(
1135            &mut egraph,
1136            vars![f: i64],
1137            facts![(= (fib (unquote exprs::int(big_number))) f)],
1138        )?;
1139
1140        let y = egraph.backend.base_values().get::<i64>(6765);
1141        assert_eq!(results.data, [y]);
1142
1143        Ok(())
1144    }
1145}