egglog/
prelude.rs

1//! This module makes it easier to use `egglog` from Rust.
2//! It is intended to be imported fully.
3//! ```
4//! use egglog::prelude::*;
5//! ```
6//! See also [`rule`], [`rust_rule`], [`query`], [`BaseSort`],
7//! and [`ContainerSort`].
8
9use crate::*;
10use std::any::{Any, TypeId};
11
12// Re-exports in `prelude` for convenience.
13pub use egglog::ast::{Action, Fact, Facts, GenericActions, RustSpan, Span};
14pub use egglog::sort::{BigIntSort, BigRatSort, BoolSort, F64Sort, I64Sort, StringSort, UnitSort};
15pub use egglog::{CommandMacro, CommandMacroRegistry};
16pub use egglog::{EGraph, span};
17pub use egglog::{action, actions, datatype, expr, fact, facts, sort, vars};
18
19/// Trait for types that can be converted to/from Literal for use in validated primitives.
20/// This enables automatic validator generation for literal primitives.
21pub trait LiteralConvertible: Sized {
22    fn to_literal(self) -> egglog_ast::generic_ast::Literal;
23    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self>;
24}
25
26impl LiteralConvertible for i64 {
27    fn to_literal(self) -> egglog_ast::generic_ast::Literal {
28        egglog_ast::generic_ast::Literal::Int(self)
29    }
30    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
31        match lit {
32            egglog_ast::generic_ast::Literal::Int(i) => Some(*i),
33            _ => None,
34        }
35    }
36}
37
38impl LiteralConvertible for bool {
39    fn to_literal(self) -> egglog_ast::generic_ast::Literal {
40        egglog_ast::generic_ast::Literal::Bool(self)
41    }
42    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
43        match lit {
44            egglog_ast::generic_ast::Literal::Bool(b) => Some(*b),
45            _ => None,
46        }
47    }
48}
49
50impl LiteralConvertible for ordered_float::OrderedFloat<f64> {
51    fn to_literal(self) -> egglog_ast::generic_ast::Literal {
52        egglog_ast::generic_ast::Literal::Float(self)
53    }
54    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
55        match lit {
56            egglog_ast::generic_ast::Literal::Float(f) => Some(*f),
57            _ => None,
58        }
59    }
60}
61
62impl LiteralConvertible for egglog::sort::F {
63    fn to_literal(self) -> egglog_ast::generic_ast::Literal {
64        egglog_ast::generic_ast::Literal::Float(self.0)
65    }
66    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
67        match lit {
68            egglog_ast::generic_ast::Literal::Float(f) => Some(egglog::sort::F::from(*f)),
69            _ => None,
70        }
71    }
72}
73
74impl LiteralConvertible for egglog::sort::S {
75    fn to_literal(self) -> egglog_ast::generic_ast::Literal {
76        egglog_ast::generic_ast::Literal::String(self.0)
77    }
78    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
79        match lit {
80            egglog_ast::generic_ast::Literal::String(s) => Some(egglog::sort::S::new(s.clone())),
81            _ => None,
82        }
83    }
84}
85
86impl LiteralConvertible for () {
87    fn to_literal(self) -> egglog_ast::generic_ast::Literal {
88        egglog_ast::generic_ast::Literal::Unit
89    }
90    fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
91        match lit {
92            egglog_ast::generic_ast::Literal::Unit => Some(()),
93            _ => None,
94        }
95    }
96}
97
98pub mod exprs {
99    use super::*;
100
101    /// Creates a variable expression.
102    pub fn var(name: &str) -> Expr {
103        Expr::Var(span!(), name.to_owned())
104    }
105
106    /// Creates an integer literal expression.
107    pub fn int(value: i64) -> Expr {
108        Expr::Lit(span!(), Literal::Int(value))
109    }
110
111    /// Creates a float literal expression.
112    pub fn float(value: f64) -> Expr {
113        Expr::Lit(span!(), Literal::Float(value.into()))
114    }
115
116    /// Creates a string literal expression.
117    pub fn string(value: &str) -> Expr {
118        Expr::Lit(span!(), Literal::String(value.to_owned()))
119    }
120
121    /// Creates a unit literal expression.
122    pub fn unit() -> Expr {
123        Expr::Lit(span!(), Literal::Unit)
124    }
125
126    /// Creates a boolean literal expression.
127    pub fn bool(value: bool) -> Expr {
128        Expr::Lit(span!(), Literal::Bool(value))
129    }
130
131    /// Creates a function call expression.
132    pub fn call(f: &str, xs: Vec<Expr>) -> Expr {
133        Expr::Call(span!(), f.to_owned(), xs)
134    }
135}
136
137/// Create a new ruleset.
138pub fn add_ruleset(egraph: &mut EGraph, ruleset: &str) -> Result<Vec<CommandOutput>, Error> {
139    egraph.run_program(vec![Command::AddRuleset(span!(), ruleset.to_owned())])
140}
141
142/// Run one iteration of a ruleset.
143pub fn run_ruleset(egraph: &mut EGraph, ruleset: &str) -> Result<Vec<CommandOutput>, Error> {
144    egraph.run_program(vec![Command::RunSchedule(Schedule::Run(
145        span!(),
146        RunConfig {
147            ruleset: ruleset.to_owned(),
148            until: None,
149        },
150    ))])
151}
152
153#[macro_export]
154macro_rules! sort {
155    (BigInt) => {
156        BigIntSort.to_arcsort()
157    };
158    (BigRat) => {
159        BigRatSort.to_arcsort()
160    };
161    (bool) => {
162        BoolSort.to_arcsort()
163    };
164    (f64) => {
165        F64Sort.to_arcsort()
166    };
167    (i64) => {
168        I64Sort.to_arcsort()
169    };
170    (String) => {
171        StringSort.to_arcsort()
172    };
173    (Unit) => {
174        UnitSort.to_arcsort()
175    };
176    ($t:expr) => {
177        $t
178    };
179}
180
181#[macro_export]
182macro_rules! vars {
183    [$($x:ident : $t:tt),* $(,)?] => {
184        &[$((stringify!($x), sort!($t))),*]
185    };
186}
187
188#[macro_export]
189macro_rules! expr {
190    ((unquote $unquoted:expr)) => { $unquoted };
191    (($func:tt $($arg:tt)*)) => { exprs::call(stringify!($func), vec![$(expr!($arg)),*]) };
192    ($value:literal) => { exprs::int($value) };
193    ($quoted:tt) => { exprs::var(stringify!($quoted)) };
194}
195
196#[macro_export]
197macro_rules! fact {
198    ((= $($arg:tt)*)) => { Fact::Eq(span!(), $(expr!($arg)),*) };
199    ($a:tt) => { Fact::Fact(expr!($a)) };
200}
201
202#[macro_export]
203macro_rules! facts {
204    ($($tree:tt)*) => { Facts(vec![$(fact!($tree)),*]) };
205}
206
207#[macro_export]
208macro_rules! action {
209    ((let $name:ident $value:tt)) => {
210        Action::Let(span!(), String::from(stringify!($name)), expr!($value))
211    };
212    ((set ($f:ident $($x:tt)*) $value:tt)) => {
213        Action::Set(span!(), String::from(stringify!($f)), vec![$(expr!($x)),*], expr!($value))
214    };
215    ((delete ($f:ident $($x:tt)*))) => {
216        Action::Change(span!(), Change::Delete, String::from(stringify!($f)), vec![$(expr!($x)),*])
217    };
218    ((subsume ($f:ident $($x:tt)*))) => {
219        Action::Change(span!(), Change::Subsume, String::from(stringify!($f)), vec![$(expr!($x)),*])
220    };
221    ((union $x:tt $y:tt)) => {
222        Action::Union(span!(), expr!($x), expr!($y))
223    };
224    ((panic $message:literal)) => {
225        Action::Panic(span!(), $message.to_owned())
226    };
227    ($x:tt) => {
228        Action::Expr(span!(), expr!($x))
229    };
230}
231
232#[macro_export]
233macro_rules! actions {
234    ($($tree:tt)*) => { GenericActions(vec![$(action!($tree)),*]) };
235}
236
237/// Add a rule to the e-graph whose right-hand side is made up of actions.
238/// ```
239/// use egglog::prelude::*;
240///
241/// let mut egraph = EGraph::default();
242/// egraph.parse_and_run_program(
243///     None,
244///     "
245/// (function fib (i64) i64 :no-merge)
246/// (set (fib 0) 0)
247/// (set (fib 1) 1)
248/// (rule (
249///     (= f0 (fib x))
250///     (= f1 (fib (+ x 1)))
251/// ) (
252///     (set (fib (+ x 2)) (+ f0 f1))
253/// ))
254/// (run 10)
255///     ",
256/// )?;
257///
258/// let big_number = 20;
259///
260/// // check that `(fib 20)` is not in the e-graph
261/// let results = query(
262///     &mut egraph,
263///     vars![f: i64],
264///     facts![(= (fib (unquote exprs::int(big_number))) f)],
265/// )?;
266///
267/// assert!(results.iter().next().is_none());
268///
269/// let ruleset = "custom_ruleset";
270/// add_ruleset(&mut egraph, ruleset)?;
271///
272/// // add the rule from `build_test_database` to the egraph
273/// rule(
274///     &mut egraph,
275///     ruleset,
276///     facts![
277///         (= f0 (fib x))
278///         (= f1 (fib (+ x 1)))
279///     ],
280///     actions![
281///         (set (fib (+ x 2)) (+ f0 f1))
282///     ],
283/// )?;
284///
285/// // run that rule 10 times
286/// for _ in 0..10 {
287///     run_ruleset(&mut egraph, ruleset)?;
288/// }
289///
290/// // check that `(fib 20)` is now in the e-graph
291/// let results = query(
292///     &mut egraph,
293///     vars![f: i64],
294///     facts![(= (fib (unquote exprs::int(big_number))) f)],
295/// )?;
296///
297/// let y = egraph.base_to_value::<i64>(6765);
298/// let results: Vec<_> = results.iter().collect();
299/// assert_eq!(results, [[y]]);
300///
301/// # Ok::<(), egglog::Error>(())
302/// ```
303pub fn rule(
304    egraph: &mut EGraph,
305    ruleset: &str,
306    facts: Facts<String, String>,
307    actions: Actions,
308) -> Result<Vec<CommandOutput>, Error> {
309    let mut rule = Rule {
310        span: span!(),
311        head: actions,
312        body: facts.0,
313        name: "".into(),
314        ruleset: ruleset.into(),
315    };
316
317    rule.name = format!("{rule:?}");
318
319    egraph.run_program(vec![Command::Rule { rule }])
320}
321
322/// A wrapper around an `ExecutionState` for rules that are written in Rust.
323/// See the [`rust_rule`] documentation for an example of how to use this.
324pub struct RustRuleContext<'a, 'b> {
325    exec_state: &'a mut ExecutionState<'b>,
326    union_action: egglog_bridge::UnionAction,
327    table_actions: HashMap<String, egglog_bridge::TableAction>,
328    panic_id: ExternalFunctionId,
329}
330
331impl RustRuleContext<'_, '_> {
332    /// Convert from an egglog value to a Rust type.
333    pub fn value_to_base<T: BaseValue>(&self, x: Value) -> T {
334        self.exec_state.base_values().unwrap::<T>(x)
335    }
336
337    /// Convert from an egglog value to reference of Rust container type.
338    ///
339    /// See [`EGraph::value_to_container`].
340    pub fn value_to_container<T: ContainerValue>(
341        &mut self,
342        x: Value,
343    ) -> Option<impl Deref<Target = T>> {
344        self.exec_state.container_values().get_val::<T>(x)
345    }
346
347    /// Convert from a Rust type to an egglog value.
348    pub fn base_to_value<T: BaseValue>(&self, x: T) -> Value {
349        self.exec_state.base_values().get::<T>(x)
350    }
351
352    /// Convert from a Rust container type to an egglog value.
353    pub fn container_to_value<T: ContainerValue>(&mut self, x: T) -> Value {
354        self.exec_state
355            .container_values()
356            .register_val::<T>(x, self.exec_state)
357    }
358
359    fn get_table_action(&self, table: &str) -> egglog_bridge::TableAction {
360        self.table_actions[table].clone()
361    }
362
363    /// Do a table lookup. This is potentially a mutable operation!
364    /// For more information, see `egglog_bridge::TableAction::lookup`.
365    pub fn lookup(&mut self, table: &str, key: &[Value]) -> Option<Value> {
366        self.get_table_action(table).lookup(self.exec_state, key)
367    }
368
369    /// Union two values in the e-graph.
370    /// For more information, see `egglog_bridge::UnionAction::union`.
371    pub fn union(&mut self, x: Value, y: Value) {
372        self.union_action.union(self.exec_state, x, y)
373    }
374
375    /// Insert a row into a table.
376    /// For more information, see `egglog_bridge::TableAction::insert`.
377    pub fn insert(&mut self, table: &str, row: impl Iterator<Item = Value>) {
378        self.get_table_action(table).insert(self.exec_state, row)
379    }
380
381    /// Remove a row from a table.
382    /// For more information, see `egglog_bridge::TableAction::remove`.
383    pub fn remove(&mut self, table: &str, key: &[Value]) {
384        self.get_table_action(table).remove(self.exec_state, key)
385    }
386
387    /// Subsume a row in a table.
388    /// For more information, see `egglog_bridge::TableAction::subsume`.
389    pub fn subsume(&mut self, table: &str, key: &[Value]) {
390        self.get_table_action(table)
391            .subsume(self.exec_state, key.iter().copied())
392    }
393
394    /// Panic.
395    /// You should also return `None` from your callback if you call
396    /// this function, which this function hopefully makes easier by
397    /// always returning `None` so that you can use `?`.
398    pub fn panic(&mut self) -> Option<()> {
399        self.exec_state.call_external_func(self.panic_id, &[]);
400        None
401    }
402}
403
404#[derive(Clone)]
405struct RustRuleRhs<F: Fn(&mut RustRuleContext, &[Value]) -> Option<()>> {
406    name: String,
407    inputs: Vec<ArcSort>,
408    union_action: egglog_bridge::UnionAction,
409    table_actions: HashMap<String, egglog_bridge::TableAction>,
410    panic_id: ExternalFunctionId,
411    func: F,
412}
413
414impl<F: Fn(&mut RustRuleContext, &[Value]) -> Option<()>> Primitive for RustRuleRhs<F> {
415    fn name(&self) -> &str {
416        &self.name
417    }
418
419    fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
420        let sorts: Vec<_> = self
421            .inputs
422            .iter()
423            .chain(once(&UnitSort.to_arcsort()))
424            .cloned()
425            .collect();
426        SimpleTypeConstraint::new(self.name(), sorts, span.clone()).into_box()
427    }
428
429    fn apply(&self, exec_state: &mut ExecutionState, values: &[Value]) -> Option<Value> {
430        let mut context = RustRuleContext {
431            exec_state,
432            union_action: self.union_action,
433            table_actions: self.table_actions.clone(),
434            panic_id: self.panic_id,
435        };
436        (self.func)(&mut context, values)?;
437        Some(exec_state.base_values().get(()))
438    }
439}
440
441/// Add a rule to the e-graph whose right-hand side is a Rust callback.
442/// ```
443/// use egglog::prelude::*;
444///
445/// let mut egraph = EGraph::default();
446/// egraph.parse_and_run_program(
447///     None,
448///     "
449/// (function fib (i64) i64 :no-merge)
450/// (set (fib 0) 0)
451/// (set (fib 1) 1)
452/// (rule (
453///     (= f0 (fib x))
454///     (= f1 (fib (+ x 1)))
455/// ) (
456///     (set (fib (+ x 2)) (+ f0 f1))
457/// ))
458/// (run 10)
459///     ",
460/// )?;
461///
462/// let big_number = 20;
463///
464/// // check that `(fib 20)` is not in the e-graph
465/// let results = query(
466///     &mut egraph,
467///     vars![f: i64],
468///     facts![(= (fib (unquote exprs::int(big_number))) f)],
469/// )?;
470///
471/// assert!(results.iter().next().is_none());
472///
473/// let ruleset = "custom_ruleset";
474/// add_ruleset(&mut egraph, ruleset)?;
475///
476/// // add the rule from `build_test_database` to the egraph
477/// rust_rule(
478///     &mut egraph,
479///     "fib_rule",
480///     ruleset,
481///     vars![x: i64, f0: i64, f1: i64],
482///     facts![
483///         (= f0 (fib x))
484///         (= f1 (fib (+ x 1)))
485///     ],
486///     move |ctx, values| {
487///         let [x, f0, f1] = values else { unreachable!() };
488///         let x = ctx.value_to_base::<i64>(*x);
489///         let f0 = ctx.value_to_base::<i64>(*f0);
490///         let f1 = ctx.value_to_base::<i64>(*f1);
491///
492///         let y = ctx.base_to_value::<i64>(x + 2);
493///         let f2 = ctx.base_to_value::<i64>(f0 + f1);
494///         ctx.insert("fib", [y, f2].into_iter());
495///
496///         Some(())
497///     },
498/// )?;
499///
500/// // run that rule 10 times
501/// for _ in 0..10 {
502///     run_ruleset(&mut egraph, ruleset)?;
503/// }
504///
505/// // check that `(fib 20)` is now in the e-graph
506/// let results = query(
507///     &mut egraph,
508///     vars![f: i64],
509///     facts![(= (fib (unquote exprs::int(big_number))) f)],
510/// )?;
511///
512/// let y = egraph.base_to_value::<i64>(6765);
513/// let results: Vec<_> = results.iter().collect();
514/// assert_eq!(results, [[y]]);
515///
516/// # Ok::<(), egglog::Error>(())
517/// ```
518pub fn rust_rule(
519    egraph: &mut EGraph,
520    rule_name: &str,
521    ruleset: &str,
522    vars: &[(&str, ArcSort)],
523    facts: Facts<String, String>,
524    func: impl Fn(&mut RustRuleContext, &[Value]) -> Option<()> + Clone + Send + Sync + 'static,
525) -> Result<Vec<CommandOutput>, Error> {
526    let prim_name = egraph.parser.symbol_gen.fresh("rust_rule_prim");
527    let panic_id = egraph.backend.new_panic(format!("{prim_name}_panic"));
528    egraph.add_primitive(RustRuleRhs {
529        name: prim_name.clone(),
530        inputs: vars.iter().map(|(_, s)| s.clone()).collect(),
531        union_action: egglog_bridge::UnionAction::new(&egraph.backend),
532        table_actions: egraph
533            .functions
534            .iter()
535            .map(|(k, v)| {
536                (
537                    k.clone(),
538                    egglog_bridge::TableAction::new(&egraph.backend, v.backend_id),
539                )
540            })
541            .collect(),
542        panic_id,
543        func,
544    });
545
546    let rule = Rule {
547        span: span!(),
548        head: GenericActions(vec![GenericAction::Expr(
549            span!(),
550            exprs::call(
551                &prim_name,
552                vars.iter().map(|(v, _)| exprs::var(v)).collect(),
553            ),
554        )]),
555        body: facts.0,
556        name: egraph.parser.symbol_gen.fresh(rule_name),
557        ruleset: ruleset.into(),
558    };
559
560    egraph.run_program(vec![Command::Rule { rule }])
561}
562
563/// The result of a query.
564pub struct QueryResult {
565    rows: usize,
566    cols: usize,
567    data: Vec<Value>,
568}
569
570impl QueryResult {
571    /// Get an iterator over the query results,
572    /// where each match is a `&[Value]` in the same order
573    /// as the `vars` that were passed to `query`.
574    pub fn iter(&self) -> impl Iterator<Item = &[Value]> {
575        assert!(self.cols > 0, "no vars; use `any_matches` instead");
576        assert!(self.data.len() % self.cols == 0);
577        self.data.chunks_exact(self.cols)
578    }
579
580    /// Check if any matches were returned at all.
581    pub fn any_matches(&self) -> bool {
582        self.rows > 0
583    }
584}
585
586/// Run a query over the database.
587/// ```
588/// use egglog::prelude::*;
589///
590/// let mut egraph = EGraph::default();
591/// egraph.parse_and_run_program(
592///     None,
593///     "
594/// (function fib (i64) i64 :no-merge)
595/// (set (fib 0) 0)
596/// (set (fib 1) 1)
597/// (rule (
598///     (= f0 (fib x))
599///     (= f1 (fib (+ x 1)))
600/// ) (
601///     (set (fib (+ x 2)) (+ f0 f1))
602/// ))
603/// (run 10)
604///     ",
605/// )?;
606///
607/// let results = query(
608///     &mut egraph,
609///     vars![x: i64, y: i64],
610///     facts![
611///         (= (fib x) y)
612///         (= y 13)
613///     ],
614/// )?;
615///
616/// let x = egraph.base_to_value::<i64>(7);
617/// let y = egraph.base_to_value::<i64>(13);
618/// let results: Vec<_> = results.iter().collect();
619/// assert_eq!(results, [[x, y]]);
620///
621/// # Ok::<(), egglog::Error>(())
622/// ```
623pub fn query(
624    egraph: &mut EGraph,
625    vars: &[(&str, ArcSort)],
626    facts: Facts<String, String>,
627) -> Result<QueryResult, Error> {
628    use std::sync::{Arc, Mutex};
629
630    let results = Arc::new(Mutex::new(QueryResult {
631        rows: 0,
632        cols: vars.len(),
633        data: Vec::new(),
634    }));
635    let results_weak = Arc::downgrade(&results);
636
637    let ruleset = egraph.parser.symbol_gen.fresh("query_ruleset");
638    add_ruleset(egraph, &ruleset)?;
639
640    rust_rule(egraph, "query", &ruleset, vars, facts, move |_, values| {
641        let arc = results_weak.upgrade().unwrap();
642        let mut results = arc.lock().unwrap();
643        results.rows += 1;
644        results.data.extend(values);
645        Some(())
646    })?;
647
648    run_ruleset(egraph, &ruleset)?;
649
650    let ruleset = egraph.rulesets.swap_remove(&ruleset).unwrap();
651
652    let Ruleset::Rules(rules) = ruleset else {
653        unreachable!()
654    };
655    assert_eq!(rules.len(), 1);
656    let rule = rules.into_iter().next().unwrap().1;
657    egraph.backend.free_rule(rule.1);
658
659    let Some(mutex) = Arc::into_inner(results) else {
660        panic!("results_weak.upgrade() was not dropped");
661    };
662    Ok(mutex.into_inner().unwrap())
663}
664
665/// Declare a new sort.
666pub fn add_sort(egraph: &mut EGraph, name: &str) -> Result<Vec<CommandOutput>, Error> {
667    egraph.run_program(vec![Command::Sort(span!(), name.to_owned(), None)])
668}
669
670/// Declare a new function table.
671pub fn add_function(
672    egraph: &mut EGraph,
673    name: &str,
674    schema: Schema,
675    merge: Option<GenericExpr<String, String>>,
676) -> Result<Vec<CommandOutput>, Error> {
677    egraph.run_program(vec![Command::Function {
678        span: span!(),
679        name: name.to_owned(),
680        schema,
681        merge,
682    }])
683}
684
685/// Declare a new constructor table.
686pub fn add_constructor(
687    egraph: &mut EGraph,
688    name: &str,
689    schema: Schema,
690    cost: Option<DefaultCost>,
691    unextractable: bool,
692) -> Result<Vec<CommandOutput>, Error> {
693    egraph.run_program(vec![Command::Constructor {
694        span: span!(),
695        name: name.to_owned(),
696        schema,
697        cost,
698        unextractable,
699    }])
700}
701
702/// Declare a new relation table.
703pub fn add_relation(
704    egraph: &mut EGraph,
705    name: &str,
706    inputs: Vec<String>,
707) -> Result<Vec<CommandOutput>, Error> {
708    egraph.run_program(vec![Command::Relation {
709        span: span!(),
710        name: name.to_owned(),
711        inputs,
712    }])
713}
714
715/// Adds sorts and constructor tables to the database.
716#[macro_export]
717macro_rules! datatype {
718    ($egraph:expr, (datatype $sort:ident $(($name:ident $($args:ident)* $(:cost $cost:expr)?))*)) => {
719        add_sort($egraph, stringify!($sort))?;
720        $(add_constructor(
721            $egraph,
722            stringify!($name),
723            Schema {
724                input: vec![$(stringify!($args).to_owned()),*],
725                output: stringify!($sort).to_owned(),
726            },
727            [$($cost)*].first().copied(),
728            false,
729        )?;)*
730    };
731}
732
733/// A "default" implementation of [`Sort`] for simple types
734/// which just want to put some data in the e-graph. If you
735/// implement this trait, do not implement `Sort` or
736/// `ContainerSort. Use `add_base_sort` to register base
737/// sorts with the `EGraph`. See `Sort` for documentation
738/// of the methods. Do not override `to_arcsort`.
739pub trait BaseSort: Any + Send + Sync + Debug {
740    type Base: BaseValue;
741    fn name(&self) -> &str;
742    fn register_primitives(&self, _eg: &mut EGraph) {}
743    fn reconstruct_termdag(&self, _: &BaseValues, _: Value, _: &mut TermDag) -> TermId;
744
745    fn to_arcsort(self) -> ArcSort
746    where
747        Self: Sized,
748    {
749        Arc::new(BaseSortImpl(self))
750    }
751}
752
753#[derive(Debug)]
754struct BaseSortImpl<T: BaseSort>(T);
755
756impl<T: BaseSort> Sort for BaseSortImpl<T> {
757    fn name(&self) -> &str {
758        self.0.name()
759    }
760
761    fn column_ty(&self, backend: &egglog_bridge::EGraph) -> ColumnTy {
762        ColumnTy::Base(backend.base_values().get_ty::<T::Base>())
763    }
764
765    fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
766        backend.base_values_mut().register_type::<T::Base>();
767    }
768
769    fn value_type(&self) -> Option<TypeId> {
770        Some(TypeId::of::<T::Base>())
771    }
772
773    fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
774        self
775    }
776
777    fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
778        self.0.register_primitives(eg)
779    }
780
781    /// Reconstruct a leaf base value in a TermDag
782    fn reconstruct_termdag_base(
783        &self,
784        base_values: &BaseValues,
785        value: Value,
786        termdag: &mut TermDag,
787    ) -> TermId {
788        self.0.reconstruct_termdag(base_values, value, termdag)
789    }
790}
791
792/// A "default" implementation of [`Sort`] for types which
793/// just want to store a pure data structure in the e-graph.
794/// If you implement this trait, do not implement `Sort` or
795/// `BaseSort`. Use `add_container_sort` to register container
796/// sorts with the `EGraph`. See `Sort` for documentation
797/// of the methods. Do not override `to_arcsort`.
798pub trait ContainerSort: Any + Send + Sync + Debug {
799    type Container: ContainerValue;
800    fn name(&self) -> &str;
801    fn is_eq_container_sort(&self) -> bool;
802    fn inner_sorts(&self) -> Vec<ArcSort>;
803    fn inner_values(&self, _: &ContainerValues, _: Value) -> Vec<(ArcSort, Value)>;
804    fn register_primitives(&self, _eg: &mut EGraph) {}
805    fn reconstruct_termdag(
806        &self,
807        _: &ContainerValues,
808        _: Value,
809        _: &mut TermDag,
810        _: Vec<TermId>,
811    ) -> TermId;
812    fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String;
813
814    fn to_arcsort(self) -> ArcSort
815    where
816        Self: Sized,
817    {
818        Arc::new(ContainerSortImpl(self))
819    }
820}
821
822#[derive(Debug)]
823struct ContainerSortImpl<T: ContainerSort>(T);
824
825impl<T: ContainerSort> Sort for ContainerSortImpl<T> {
826    fn name(&self) -> &str {
827        self.0.name()
828    }
829
830    fn column_ty(&self, _backend: &egglog_bridge::EGraph) -> ColumnTy {
831        ColumnTy::Id
832    }
833
834    fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
835        backend.register_container_ty::<T::Container>();
836    }
837
838    fn value_type(&self) -> Option<TypeId> {
839        Some(TypeId::of::<T::Container>())
840    }
841
842    fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
843        self
844    }
845
846    fn inner_sorts(&self) -> Vec<ArcSort> {
847        self.0.inner_sorts()
848    }
849
850    fn inner_values(
851        &self,
852        container_values: &ContainerValues,
853        value: Value,
854    ) -> Vec<(ArcSort, Value)> {
855        self.0.inner_values(container_values, value)
856    }
857
858    fn is_container_sort(&self) -> bool {
859        true
860    }
861
862    fn is_eq_container_sort(&self) -> bool {
863        self.0.is_eq_container_sort()
864    }
865
866    fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String {
867        self.0.serialized_name(container_values, value)
868    }
869
870    fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
871        self.0.register_primitives(eg);
872    }
873
874    fn reconstruct_termdag_container(
875        &self,
876        container_values: &ContainerValues,
877        value: Value,
878        termdag: &mut TermDag,
879        element_terms: Vec<TermId>,
880    ) -> TermId {
881        self.0
882            .reconstruct_termdag(container_values, value, termdag, element_terms)
883    }
884}
885
886/// Add a [`BaseSort`] to the e-graph
887pub fn add_base_sort(
888    egraph: &mut EGraph,
889    base_sort: impl BaseSort,
890    span: Span,
891) -> Result<(), TypeError> {
892    egraph.add_sort(BaseSortImpl(base_sort), span)
893}
894
895pub fn add_container_sort(
896    egraph: &mut EGraph,
897    container_sort: impl ContainerSort,
898    span: Span,
899) -> Result<(), TypeError> {
900    egraph.add_sort(ContainerSortImpl(container_sort), span)
901}
902
903#[cfg(test)]
904mod tests {
905    use super::*;
906
907    fn build_test_database() -> Result<EGraph, Error> {
908        let mut egraph = EGraph::default();
909        egraph.parse_and_run_program(
910            None,
911            "
912(function fib (i64) i64 :no-merge)
913(set (fib 0) 0)
914(set (fib 1) 1)
915(rule (
916    (= f0 (fib x))
917    (= f1 (fib (+ x 1)))
918) (
919    (set (fib (+ x 2)) (+ f0 f1))
920))
921(run 10)
922        ",
923        )?;
924        Ok(egraph)
925    }
926
927    #[test]
928    fn rust_api_query() -> Result<(), Error> {
929        let mut egraph = build_test_database()?;
930
931        let results = query(
932            &mut egraph,
933            vars![x: i64, y: i64],
934            facts![
935                (= (fib x) y)
936                (= y 13)
937            ],
938        )?;
939
940        let x = egraph.backend.base_values().get::<i64>(7);
941        let y = egraph.backend.base_values().get::<i64>(13);
942        assert_eq!(results.data, [x, y]);
943
944        Ok(())
945    }
946
947    #[test]
948    fn rust_api_rule() -> Result<(), Error> {
949        let mut egraph = build_test_database()?;
950
951        let big_number = 20;
952
953        // check that `(fib 20)` is not in the e-graph
954        let results = query(
955            &mut egraph,
956            vars![f: i64],
957            facts![(= (fib (unquote exprs::int(big_number))) f)],
958        )?;
959
960        assert!(results.data.is_empty());
961
962        let ruleset = "custom_ruleset";
963        add_ruleset(&mut egraph, ruleset)?;
964
965        // add the rule from `build_test_database` to the egraph
966        rule(
967            &mut egraph,
968            ruleset,
969            facts![
970                (= f0 (fib x))
971                (= f1 (fib (+ x 1)))
972            ],
973            actions![
974                (set (fib (+ x 2)) (+ f0 f1))
975            ],
976        )?;
977
978        // run that rule 10 times
979        for _ in 0..10 {
980            run_ruleset(&mut egraph, ruleset)?;
981        }
982
983        // check that `(fib 20)` is now in the e-graph
984        let results = query(
985            &mut egraph,
986            vars![f: i64],
987            facts![(= (fib (unquote exprs::int(big_number))) f)],
988        )?;
989
990        let y = egraph.backend.base_values().get::<i64>(6765);
991        assert_eq!(results.data, [y]);
992
993        Ok(())
994    }
995
996    #[test]
997    fn rust_api_macros() -> Result<(), Error> {
998        let mut egraph = build_test_database()?;
999
1000        datatype!(&mut egraph, (datatype Expr (One) (Two Expr Expr :cost 10)));
1001
1002        let ruleset = "custom_ruleset";
1003        add_ruleset(&mut egraph, ruleset)?;
1004
1005        rule(
1006            &mut egraph,
1007            ruleset,
1008            facts![
1009                (fib 5)
1010                (fib x)
1011                (= f1 (fib (+ x 1)))
1012                (= 3 (unquote exprs::int(1 + 2)))
1013            ],
1014            actions![
1015                (let y (+ x 2))
1016                (set (fib (+ x 2)) (+ f1 f1))
1017                (delete (fib 0))
1018                (subsume (Two (One) (One)))
1019                (union (One) (Two (One) (One)))
1020                (panic "message")
1021                (+ 6 87)
1022            ],
1023        )?;
1024
1025        Ok(())
1026    }
1027
1028    #[test]
1029    fn rust_api_rust_rule() -> Result<(), Error> {
1030        let mut egraph = build_test_database()?;
1031
1032        let big_number = 20;
1033
1034        // check that `(fib 20)` is not in the e-graph
1035        let results = query(
1036            &mut egraph,
1037            vars![f: i64],
1038            facts![(= (fib (unquote exprs::int(big_number))) f)],
1039        )?;
1040
1041        assert!(results.data.is_empty());
1042
1043        let ruleset = "custom_ruleset";
1044        add_ruleset(&mut egraph, ruleset)?;
1045
1046        // add the rule from `build_test_database` to the egraph
1047        rust_rule(
1048            &mut egraph,
1049            "demo_rule",
1050            ruleset,
1051            vars![x: i64, f0: i64, f1: i64],
1052            facts![
1053                (= f0 (fib x))
1054                (= f1 (fib (+ x 1)))
1055            ],
1056            move |ctx, values| {
1057                let [x, f0, f1] = values else { unreachable!() };
1058                let x = ctx.value_to_base::<i64>(*x);
1059                let f0 = ctx.value_to_base::<i64>(*f0);
1060                let f1 = ctx.value_to_base::<i64>(*f1);
1061
1062                let y = ctx.base_to_value::<i64>(x + 2);
1063                let f2 = ctx.base_to_value::<i64>(f0 + f1);
1064                ctx.insert("fib", [y, f2].into_iter());
1065
1066                Some(())
1067            },
1068        )?;
1069
1070        // run that rule 10 times
1071        for _ in 0..10 {
1072            run_ruleset(&mut egraph, ruleset)?;
1073        }
1074
1075        // check that `(fib 20)` is now in the e-graph
1076        let results = query(
1077            &mut egraph,
1078            vars![f: i64],
1079            facts![(= (fib (unquote exprs::int(big_number))) f)],
1080        )?;
1081
1082        let y = egraph.backend.base_values().get::<i64>(6765);
1083        assert_eq!(results.data, [y]);
1084
1085        Ok(())
1086    }
1087}