egglog/
prelude.rs

1//! This module makes it easier to use `egglog` from Rust.
2//! It is intended to be imported fully.
3//! ```
4//! use egglog::prelude::*;
5//! ```
6//! See also [`rule`], [`rust_rule`], [`query`], [`BaseSort`],
7//! and [`ContainerSort`].
8
9use crate::*;
10use std::any::{Any, TypeId};
11
12// Re-exports in `prelude` for convenience.
13pub use egglog::ast::{Action, Fact, Facts, GenericActions, RustSpan, Span};
14pub use egglog::sort::{BigIntSort, BigRatSort, BoolSort, F64Sort, I64Sort, StringSort, UnitSort};
15pub use egglog::{EGraph, span};
16pub use egglog::{action, actions, datatype, expr, fact, facts, sort, vars};
17
18pub mod exprs {
19    use super::*;
20
21    /// Creates a variable expression.
22    pub fn var(name: &str) -> Expr {
23        Expr::Var(span!(), name.to_owned())
24    }
25
26    /// Creates an integer literal expression.
27    pub fn int(value: i64) -> Expr {
28        Expr::Lit(span!(), Literal::Int(value))
29    }
30
31    /// Creates a float literal expression.
32    pub fn float(value: f64) -> Expr {
33        Expr::Lit(span!(), Literal::Float(value.into()))
34    }
35
36    /// Creates a string literal expression.
37    pub fn string(value: &str) -> Expr {
38        Expr::Lit(span!(), Literal::String(value.to_owned()))
39    }
40
41    /// Creates a unit literal expression.
42    pub fn unit() -> Expr {
43        Expr::Lit(span!(), Literal::Unit)
44    }
45
46    /// Creates a boolean literal expression.
47    pub fn bool(value: bool) -> Expr {
48        Expr::Lit(span!(), Literal::Bool(value))
49    }
50
51    /// Creates a function call expression.
52    pub fn call(f: &str, xs: Vec<Expr>) -> Expr {
53        Expr::Call(span!(), f.to_owned(), xs)
54    }
55}
56
57/// Create a new ruleset.
58pub fn add_ruleset(egraph: &mut EGraph, ruleset: &str) -> Result<Vec<CommandOutput>, Error> {
59    egraph.run_program(vec![Command::AddRuleset(span!(), ruleset.to_owned())])
60}
61
62/// Run one iteration of a ruleset.
63pub fn run_ruleset(egraph: &mut EGraph, ruleset: &str) -> Result<Vec<CommandOutput>, Error> {
64    egraph.run_program(vec![Command::RunSchedule(Schedule::Run(
65        span!(),
66        RunConfig {
67            ruleset: ruleset.to_owned(),
68            until: None,
69        },
70    ))])
71}
72
73#[macro_export]
74macro_rules! sort {
75    (BigInt) => {
76        BigIntSort.to_arcsort()
77    };
78    (BigRat) => {
79        BigRatSort.to_arcsort()
80    };
81    (bool) => {
82        BoolSort.to_arcsort()
83    };
84    (f64) => {
85        F64Sort.to_arcsort()
86    };
87    (i64) => {
88        I64Sort.to_arcsort()
89    };
90    (String) => {
91        StringSort.to_arcsort()
92    };
93    (Unit) => {
94        UnitSort.to_arcsort()
95    };
96    ($t:expr) => {
97        $t
98    };
99}
100
101#[macro_export]
102macro_rules! vars {
103    [$($x:ident : $t:tt),* $(,)?] => {
104        &[$((stringify!($x), sort!($t))),*]
105    };
106}
107
108#[macro_export]
109macro_rules! expr {
110    ((unquote $unquoted:expr)) => { $unquoted };
111    (($func:tt $($arg:tt)*)) => { exprs::call(stringify!($func), vec![$(expr!($arg)),*]) };
112    ($value:literal) => { exprs::int($value) };
113    ($quoted:tt) => { exprs::var(stringify!($quoted)) };
114}
115
116#[macro_export]
117macro_rules! fact {
118    ((= $($arg:tt)*)) => { Fact::Eq(span!(), $(expr!($arg)),*) };
119    ($a:tt) => { Fact::Fact(expr!($a)) };
120}
121
122#[macro_export]
123macro_rules! facts {
124    ($($tree:tt)*) => { Facts(vec![$(fact!($tree)),*]) };
125}
126
127#[macro_export]
128macro_rules! action {
129    ((let $name:ident $value:tt)) => {
130        Action::Let(span!(), String::from(stringify!($name)), expr!($value))
131    };
132    ((set ($f:ident $($x:tt)*) $value:tt)) => {
133        Action::Set(span!(), String::from(stringify!($f)), vec![$(expr!($x)),*], expr!($value))
134    };
135    ((delete ($f:ident $($x:tt)*))) => {
136        Action::Change(span!(), Change::Delete, String::from(stringify!($f)), vec![$(expr!($x)),*])
137    };
138    ((subsume ($f:ident $($x:tt)*))) => {
139        Action::Change(span!(), Change::Subsume, String::from(stringify!($f)), vec![$(expr!($x)),*])
140    };
141    ((union $x:tt $y:tt)) => {
142        Action::Union(span!(), expr!($x), expr!($y))
143    };
144    ((panic $message:literal)) => {
145        Action::Panic(span!(), $message.to_owned())
146    };
147    ($x:tt) => {
148        Action::Expr(span!(), expr!($x))
149    };
150}
151
152#[macro_export]
153macro_rules! actions {
154    ($($tree:tt)*) => { GenericActions(vec![$(action!($tree)),*]) };
155}
156
157/// Add a rule to the e-graph whose right-hand side is made up of actions.
158/// ```
159/// use egglog::prelude::*;
160///
161/// let mut egraph = EGraph::default();
162/// egraph.parse_and_run_program(
163///     None,
164///     "
165/// (function fib (i64) i64 :no-merge)
166/// (set (fib 0) 0)
167/// (set (fib 1) 1)
168/// (rule (
169///     (= f0 (fib x))
170///     (= f1 (fib (+ x 1)))
171/// ) (
172///     (set (fib (+ x 2)) (+ f0 f1))
173/// ))
174/// (run 10)
175///     ",
176/// )?;
177///
178/// let big_number = 20;
179///
180/// // check that `(fib 20)` is not in the e-graph
181/// let results = query(
182///     &mut egraph,
183///     vars![f: i64],
184///     facts![(= (fib (unquote exprs::int(big_number))) f)],
185/// )?;
186///
187/// assert!(results.iter().next().is_none());
188///
189/// let ruleset = "custom_ruleset";
190/// add_ruleset(&mut egraph, ruleset)?;
191///
192/// // add the rule from `build_test_database` to the egraph
193/// rule(
194///     &mut egraph,
195///     ruleset,
196///     facts![
197///         (= f0 (fib x))
198///         (= f1 (fib (+ x 1)))
199///     ],
200///     actions![
201///         (set (fib (+ x 2)) (+ f0 f1))
202///     ],
203/// )?;
204///
205/// // run that rule 10 times
206/// for _ in 0..10 {
207///     run_ruleset(&mut egraph, ruleset)?;
208/// }
209///
210/// // check that `(fib 20)` is now in the e-graph
211/// let results = query(
212///     &mut egraph,
213///     vars![f: i64],
214///     facts![(= (fib (unquote exprs::int(big_number))) f)],
215/// )?;
216///
217/// let y = egraph.base_to_value::<i64>(6765);
218/// let results: Vec<_> = results.iter().collect();
219/// assert_eq!(results, [[y]]);
220///
221/// # Ok::<(), egglog::Error>(())
222/// ```
223pub fn rule(
224    egraph: &mut EGraph,
225    ruleset: &str,
226    facts: Facts<String, String>,
227    actions: Actions,
228) -> Result<Vec<CommandOutput>, Error> {
229    let mut rule = Rule {
230        span: span!(),
231        head: actions,
232        body: facts.0,
233        name: "".into(),
234        ruleset: ruleset.into(),
235    };
236
237    rule.name = format!("{rule:?}");
238
239    egraph.run_program(vec![Command::Rule { rule }])
240}
241
242/// A wrapper around an `ExecutionState` for rules that are written in Rust.
243/// See the [`rust_rule`] documentation for an example of how to use this.
244pub struct RustRuleContext<'a, 'b> {
245    exec_state: &'a mut ExecutionState<'b>,
246    union_action: egglog_bridge::UnionAction,
247    table_actions: HashMap<String, egglog_bridge::TableAction>,
248    panic_id: ExternalFunctionId,
249}
250
251impl RustRuleContext<'_, '_> {
252    /// Convert from an egglog value to a Rust type.
253    pub fn value_to_base<T: BaseValue>(&self, x: Value) -> T {
254        self.exec_state.base_values().unwrap::<T>(x)
255    }
256
257    /// Convert from an egglog value to reference of Rust container type.
258    ///
259    /// See [`EGraph::value_to_container`].
260    pub fn value_to_container<T: ContainerValue>(
261        &mut self,
262        x: Value,
263    ) -> Option<impl Deref<Target = T>> {
264        self.exec_state.container_values().get_val::<T>(x)
265    }
266
267    /// Convert from a Rust type to an egglog value.
268    pub fn base_to_value<T: BaseValue>(&self, x: T) -> Value {
269        self.exec_state.base_values().get::<T>(x)
270    }
271
272    /// Convert from a Rust container type to an egglog value.
273    pub fn container_to_value<T: ContainerValue>(&mut self, x: T) -> Value {
274        self.exec_state
275            .container_values()
276            .register_val::<T>(x, self.exec_state)
277    }
278
279    fn get_table_action(&self, table: &str) -> egglog_bridge::TableAction {
280        self.table_actions[table].clone()
281    }
282
283    /// Do a table lookup. This is potentially a mutable operation!
284    /// For more information, see `egglog_bridge::TableAction::lookup`.
285    pub fn lookup(&mut self, table: &str, key: &[Value]) -> Option<Value> {
286        self.get_table_action(table).lookup(self.exec_state, key)
287    }
288
289    /// Union two values in the e-graph.
290    /// For more information, see `egglog_bridge::UnionAction::union`.
291    pub fn union(&mut self, x: Value, y: Value) {
292        self.union_action.union(self.exec_state, x, y)
293    }
294
295    /// Insert a row into a table.
296    /// For more information, see `egglog_bridge::TableAction::insert`.
297    pub fn insert(&mut self, table: &str, row: impl Iterator<Item = Value>) {
298        self.get_table_action(table).insert(self.exec_state, row)
299    }
300
301    /// Remove a row from a table.
302    /// For more information, see `egglog_bridge::TableAction::remove`.
303    pub fn remove(&mut self, table: &str, key: &[Value]) {
304        self.get_table_action(table).remove(self.exec_state, key)
305    }
306
307    /// Subsume a row in a table.
308    /// For more information, see `egglog_bridge::TableAction::subsume`.
309    pub fn subsume(&mut self, table: &str, key: &[Value]) {
310        self.get_table_action(table)
311            .subsume(self.exec_state, key.iter().copied())
312    }
313
314    /// Panic.
315    /// You should also return `None` from your callback if you call
316    /// this function, which this function hopefully makes easier by
317    /// always returning `None` so that you can use `?`.
318    pub fn panic(&mut self) -> Option<()> {
319        self.exec_state.call_external_func(self.panic_id, &[]);
320        None
321    }
322}
323
324#[derive(Clone)]
325struct RustRuleRhs<F: Fn(&mut RustRuleContext, &[Value]) -> Option<()>> {
326    name: String,
327    inputs: Vec<ArcSort>,
328    union_action: egglog_bridge::UnionAction,
329    table_actions: HashMap<String, egglog_bridge::TableAction>,
330    panic_id: ExternalFunctionId,
331    func: F,
332}
333
334impl<F: Fn(&mut RustRuleContext, &[Value]) -> Option<()>> Primitive for RustRuleRhs<F> {
335    fn name(&self) -> &str {
336        &self.name
337    }
338
339    fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
340        let sorts: Vec<_> = self
341            .inputs
342            .iter()
343            .chain(once(&UnitSort.to_arcsort()))
344            .cloned()
345            .collect();
346        SimpleTypeConstraint::new(self.name(), sorts, span.clone()).into_box()
347    }
348
349    fn apply(&self, exec_state: &mut ExecutionState, values: &[Value]) -> Option<Value> {
350        let mut context = RustRuleContext {
351            exec_state,
352            union_action: self.union_action,
353            table_actions: self.table_actions.clone(),
354            panic_id: self.panic_id,
355        };
356        (self.func)(&mut context, values)?;
357        Some(exec_state.base_values().get(()))
358    }
359}
360
361/// Add a rule to the e-graph whose right-hand side is a Rust callback.
362/// ```
363/// use egglog::prelude::*;
364///
365/// let mut egraph = EGraph::default();
366/// egraph.parse_and_run_program(
367///     None,
368///     "
369/// (function fib (i64) i64 :no-merge)
370/// (set (fib 0) 0)
371/// (set (fib 1) 1)
372/// (rule (
373///     (= f0 (fib x))
374///     (= f1 (fib (+ x 1)))
375/// ) (
376///     (set (fib (+ x 2)) (+ f0 f1))
377/// ))
378/// (run 10)
379///     ",
380/// )?;
381///
382/// let big_number = 20;
383///
384/// // check that `(fib 20)` is not in the e-graph
385/// let results = query(
386///     &mut egraph,
387///     vars![f: i64],
388///     facts![(= (fib (unquote exprs::int(big_number))) f)],
389/// )?;
390///
391/// assert!(results.iter().next().is_none());
392///
393/// let ruleset = "custom_ruleset";
394/// add_ruleset(&mut egraph, ruleset)?;
395///
396/// // add the rule from `build_test_database` to the egraph
397/// rust_rule(
398///     &mut egraph,
399///     "fib_rule",
400///     ruleset,
401///     vars![x: i64, f0: i64, f1: i64],
402///     facts![
403///         (= f0 (fib x))
404///         (= f1 (fib (+ x 1)))
405///     ],
406///     move |ctx, values| {
407///         let [x, f0, f1] = values else { unreachable!() };
408///         let x = ctx.value_to_base::<i64>(*x);
409///         let f0 = ctx.value_to_base::<i64>(*f0);
410///         let f1 = ctx.value_to_base::<i64>(*f1);
411///
412///         let y = ctx.base_to_value::<i64>(x + 2);
413///         let f2 = ctx.base_to_value::<i64>(f0 + f1);
414///         ctx.insert("fib", [y, f2].into_iter());
415///
416///         Some(())
417///     },
418/// )?;
419///
420/// // run that rule 10 times
421/// for _ in 0..10 {
422///     run_ruleset(&mut egraph, ruleset)?;
423/// }
424///
425/// // check that `(fib 20)` is now in the e-graph
426/// let results = query(
427///     &mut egraph,
428///     vars![f: i64],
429///     facts![(= (fib (unquote exprs::int(big_number))) f)],
430/// )?;
431///
432/// let y = egraph.base_to_value::<i64>(6765);
433/// let results: Vec<_> = results.iter().collect();
434/// assert_eq!(results, [[y]]);
435///
436/// # Ok::<(), egglog::Error>(())
437/// ```
438pub fn rust_rule(
439    egraph: &mut EGraph,
440    rule_name: &str,
441    ruleset: &str,
442    vars: &[(&str, ArcSort)],
443    facts: Facts<String, String>,
444    func: impl Fn(&mut RustRuleContext, &[Value]) -> Option<()> + Clone + Send + Sync + 'static,
445) -> Result<Vec<CommandOutput>, Error> {
446    let prim_name = egraph.parser.symbol_gen.fresh("rust_rule_prim");
447    let panic_id = egraph.backend.new_panic(format!("{prim_name}_panic"));
448    egraph.add_primitive(RustRuleRhs {
449        name: prim_name.clone(),
450        inputs: vars.iter().map(|(_, s)| s.clone()).collect(),
451        union_action: egglog_bridge::UnionAction::new(&egraph.backend),
452        table_actions: egraph
453            .functions
454            .iter()
455            .map(|(k, v)| {
456                (
457                    k.clone(),
458                    egglog_bridge::TableAction::new(&egraph.backend, v.backend_id),
459                )
460            })
461            .collect(),
462        panic_id,
463        func,
464    });
465
466    let rule = Rule {
467        span: span!(),
468        head: GenericActions(vec![GenericAction::Expr(
469            span!(),
470            exprs::call(
471                &prim_name,
472                vars.iter().map(|(v, _)| exprs::var(v)).collect(),
473            ),
474        )]),
475        body: facts.0,
476        name: egraph.parser.symbol_gen.fresh(rule_name),
477        ruleset: ruleset.into(),
478    };
479
480    egraph.run_program(vec![Command::Rule { rule }])
481}
482
483/// The result of a query.
484pub struct QueryResult {
485    rows: usize,
486    cols: usize,
487    data: Vec<Value>,
488}
489
490impl QueryResult {
491    /// Get an iterator over the query results,
492    /// where each match is a `&[Value]` in the same order
493    /// as the `vars` that were passed to `query`.
494    pub fn iter(&self) -> impl Iterator<Item = &[Value]> {
495        assert!(self.cols > 0, "no vars; use `any_matches` instead");
496        assert!(self.data.len() % self.cols == 0);
497        self.data.chunks_exact(self.cols)
498    }
499
500    /// Check if any matches were returned at all.
501    pub fn any_matches(&self) -> bool {
502        self.rows > 0
503    }
504}
505
506/// Run a query over the database.
507/// ```
508/// use egglog::prelude::*;
509///
510/// let mut egraph = EGraph::default();
511/// egraph.parse_and_run_program(
512///     None,
513///     "
514/// (function fib (i64) i64 :no-merge)
515/// (set (fib 0) 0)
516/// (set (fib 1) 1)
517/// (rule (
518///     (= f0 (fib x))
519///     (= f1 (fib (+ x 1)))
520/// ) (
521///     (set (fib (+ x 2)) (+ f0 f1))
522/// ))
523/// (run 10)
524///     ",
525/// )?;
526///
527/// let results = query(
528///     &mut egraph,
529///     vars![x: i64, y: i64],
530///     facts![
531///         (= (fib x) y)
532///         (= y 13)
533///     ],
534/// )?;
535///
536/// let x = egraph.base_to_value::<i64>(7);
537/// let y = egraph.base_to_value::<i64>(13);
538/// let results: Vec<_> = results.iter().collect();
539/// assert_eq!(results, [[x, y]]);
540///
541/// # Ok::<(), egglog::Error>(())
542/// ```
543pub fn query(
544    egraph: &mut EGraph,
545    vars: &[(&str, ArcSort)],
546    facts: Facts<String, String>,
547) -> Result<QueryResult, Error> {
548    use std::sync::{Arc, Mutex};
549
550    let results = Arc::new(Mutex::new(QueryResult {
551        rows: 0,
552        cols: vars.len(),
553        data: Vec::new(),
554    }));
555    let results_weak = Arc::downgrade(&results);
556
557    let ruleset = egraph.parser.symbol_gen.fresh("query_ruleset");
558    add_ruleset(egraph, &ruleset)?;
559
560    rust_rule(egraph, "query", &ruleset, vars, facts, move |_, values| {
561        let arc = results_weak.upgrade().unwrap();
562        let mut results = arc.lock().unwrap();
563        results.rows += 1;
564        results.data.extend(values);
565        Some(())
566    })?;
567
568    run_ruleset(egraph, &ruleset)?;
569
570    let ruleset = egraph.rulesets.swap_remove(&ruleset).unwrap();
571
572    let Ruleset::Rules(rules) = ruleset else {
573        unreachable!()
574    };
575    assert_eq!(rules.len(), 1);
576    let rule = rules.into_iter().next().unwrap().1;
577    egraph.backend.free_rule(rule.1);
578
579    let Some(mutex) = Arc::into_inner(results) else {
580        panic!("results_weak.upgrade() was not dropped");
581    };
582    Ok(mutex.into_inner().unwrap())
583}
584
585/// Declare a new sort.
586pub fn add_sort(egraph: &mut EGraph, name: &str) -> Result<Vec<CommandOutput>, Error> {
587    egraph.run_program(vec![Command::Sort(span!(), name.to_owned(), None)])
588}
589
590/// Declare a new function table.
591pub fn add_function(
592    egraph: &mut EGraph,
593    name: &str,
594    schema: Schema,
595    merge: Option<GenericExpr<String, String>>,
596) -> Result<Vec<CommandOutput>, Error> {
597    egraph.run_program(vec![Command::Function {
598        span: span!(),
599        name: name.to_owned(),
600        schema,
601        merge,
602    }])
603}
604
605/// Declare a new constructor table.
606pub fn add_constructor(
607    egraph: &mut EGraph,
608    name: &str,
609    schema: Schema,
610    cost: Option<DefaultCost>,
611    unextractable: bool,
612) -> Result<Vec<CommandOutput>, Error> {
613    egraph.run_program(vec![Command::Constructor {
614        span: span!(),
615        name: name.to_owned(),
616        schema,
617        cost,
618        unextractable,
619    }])
620}
621
622/// Declare a new relation table.
623pub fn add_relation(
624    egraph: &mut EGraph,
625    name: &str,
626    inputs: Vec<String>,
627) -> Result<Vec<CommandOutput>, Error> {
628    egraph.run_program(vec![Command::Relation {
629        span: span!(),
630        name: name.to_owned(),
631        inputs,
632    }])
633}
634
635/// Adds sorts and constructor tables to the database.
636#[macro_export]
637macro_rules! datatype {
638    ($egraph:expr, (datatype $sort:ident $(($name:ident $($args:ident)* $(:cost $cost:expr)?))*)) => {
639        add_sort($egraph, stringify!($sort))?;
640        $(add_constructor(
641            $egraph,
642            stringify!($name),
643            Schema {
644                input: vec![$(stringify!($args).to_owned()),*],
645                output: stringify!($sort).to_owned(),
646            },
647            [$($cost)*].first().copied(),
648            false,
649        )?;)*
650    };
651}
652
653/// A "default" implementation of [`Sort`] for simple types
654/// which just want to put some data in the e-graph. If you
655/// implement this trait, do not implement `Sort` or
656/// `ContainerSort. Use `add_base_sort` to register base
657/// sorts with the `EGraph`. See `Sort` for documentation
658/// of the methods. Do not override `to_arcsort`.
659pub trait BaseSort: Any + Send + Sync + Debug {
660    type Base: BaseValue;
661    fn name(&self) -> &str;
662    fn register_primitives(&self, _eg: &mut EGraph) {}
663    fn reconstruct_termdag(&self, _: &BaseValues, _: Value, _: &mut TermDag) -> Term;
664
665    fn to_arcsort(self) -> ArcSort
666    where
667        Self: Sized,
668    {
669        Arc::new(BaseSortImpl(self))
670    }
671}
672
673#[derive(Debug)]
674struct BaseSortImpl<T: BaseSort>(T);
675
676impl<T: BaseSort> Sort for BaseSortImpl<T> {
677    fn name(&self) -> &str {
678        self.0.name()
679    }
680
681    fn column_ty(&self, backend: &egglog_bridge::EGraph) -> ColumnTy {
682        ColumnTy::Base(backend.base_values().get_ty::<T::Base>())
683    }
684
685    fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
686        backend.base_values_mut().register_type::<T::Base>();
687    }
688
689    fn value_type(&self) -> Option<TypeId> {
690        Some(TypeId::of::<T::Base>())
691    }
692
693    fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
694        self
695    }
696
697    fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
698        self.0.register_primitives(eg)
699    }
700
701    /// Reconstruct a leaf base value in a TermDag
702    fn reconstruct_termdag_base(
703        &self,
704        base_values: &BaseValues,
705        value: Value,
706        termdag: &mut TermDag,
707    ) -> Term {
708        self.0.reconstruct_termdag(base_values, value, termdag)
709    }
710}
711
712/// A "default" implementation of [`Sort`] for types which
713/// just want to store a pure data structure in the e-graph.
714/// If you implement this trait, do not implement `Sort` or
715/// `BaseSort`. Use `add_container_sort` to register container
716/// sorts with the `EGraph`. See `Sort` for documentation
717/// of the methods. Do not override `to_arcsort`.
718pub trait ContainerSort: Any + Send + Sync + Debug {
719    type Container: ContainerValue;
720    fn name(&self) -> &str;
721    fn is_eq_container_sort(&self) -> bool;
722    fn inner_sorts(&self) -> Vec<ArcSort>;
723    fn inner_values(&self, _: &ContainerValues, _: Value) -> Vec<(ArcSort, Value)>;
724    fn register_primitives(&self, _eg: &mut EGraph) {}
725    fn reconstruct_termdag(
726        &self,
727        _: &ContainerValues,
728        _: Value,
729        _: &mut TermDag,
730        _: Vec<Term>,
731    ) -> Term;
732    fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String;
733
734    fn to_arcsort(self) -> ArcSort
735    where
736        Self: Sized,
737    {
738        Arc::new(ContainerSortImpl(self))
739    }
740}
741
742#[derive(Debug)]
743struct ContainerSortImpl<T: ContainerSort>(T);
744
745impl<T: ContainerSort> Sort for ContainerSortImpl<T> {
746    fn name(&self) -> &str {
747        self.0.name()
748    }
749
750    fn column_ty(&self, _backend: &egglog_bridge::EGraph) -> ColumnTy {
751        ColumnTy::Id
752    }
753
754    fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
755        backend.register_container_ty::<T::Container>();
756    }
757
758    fn value_type(&self) -> Option<TypeId> {
759        Some(TypeId::of::<T::Container>())
760    }
761
762    fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
763        self
764    }
765
766    fn inner_sorts(&self) -> Vec<ArcSort> {
767        self.0.inner_sorts()
768    }
769
770    fn inner_values(
771        &self,
772        container_values: &ContainerValues,
773        value: Value,
774    ) -> Vec<(ArcSort, Value)> {
775        self.0.inner_values(container_values, value)
776    }
777
778    fn is_container_sort(&self) -> bool {
779        true
780    }
781
782    fn is_eq_container_sort(&self) -> bool {
783        self.0.is_eq_container_sort()
784    }
785
786    fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String {
787        self.0.serialized_name(container_values, value)
788    }
789
790    fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
791        self.0.register_primitives(eg);
792    }
793
794    fn reconstruct_termdag_container(
795        &self,
796        container_values: &ContainerValues,
797        value: Value,
798        termdag: &mut TermDag,
799        element_terms: Vec<Term>,
800    ) -> Term {
801        self.0
802            .reconstruct_termdag(container_values, value, termdag, element_terms)
803    }
804}
805
806/// Add a [`BaseSort`] to the e-graph
807pub fn add_base_sort(
808    egraph: &mut EGraph,
809    base_sort: impl BaseSort,
810    span: Span,
811) -> Result<(), TypeError> {
812    egraph.add_sort(BaseSortImpl(base_sort), span)
813}
814
815pub fn add_container_sort(
816    egraph: &mut EGraph,
817    container_sort: impl ContainerSort,
818    span: Span,
819) -> Result<(), TypeError> {
820    egraph.add_sort(ContainerSortImpl(container_sort), span)
821}
822
823#[cfg(test)]
824mod tests {
825    use super::*;
826
827    fn build_test_database() -> Result<EGraph, Error> {
828        let mut egraph = EGraph::default();
829        egraph.parse_and_run_program(
830            None,
831            "
832(function fib (i64) i64 :no-merge)
833(set (fib 0) 0)
834(set (fib 1) 1)
835(rule (
836    (= f0 (fib x))
837    (= f1 (fib (+ x 1)))
838) (
839    (set (fib (+ x 2)) (+ f0 f1))
840))
841(run 10)
842        ",
843        )?;
844        Ok(egraph)
845    }
846
847    #[test]
848    fn rust_api_query() -> Result<(), Error> {
849        let mut egraph = build_test_database()?;
850
851        let results = query(
852            &mut egraph,
853            vars![x: i64, y: i64],
854            facts![
855                (= (fib x) y)
856                (= y 13)
857            ],
858        )?;
859
860        let x = egraph.backend.base_values().get::<i64>(7);
861        let y = egraph.backend.base_values().get::<i64>(13);
862        assert_eq!(results.data, [x, y]);
863
864        Ok(())
865    }
866
867    #[test]
868    fn rust_api_rule() -> Result<(), Error> {
869        let mut egraph = build_test_database()?;
870
871        let big_number = 20;
872
873        // check that `(fib 20)` is not in the e-graph
874        let results = query(
875            &mut egraph,
876            vars![f: i64],
877            facts![(= (fib (unquote exprs::int(big_number))) f)],
878        )?;
879
880        assert!(results.data.is_empty());
881
882        let ruleset = "custom_ruleset";
883        add_ruleset(&mut egraph, ruleset)?;
884
885        // add the rule from `build_test_database` to the egraph
886        rule(
887            &mut egraph,
888            ruleset,
889            facts![
890                (= f0 (fib x))
891                (= f1 (fib (+ x 1)))
892            ],
893            actions![
894                (set (fib (+ x 2)) (+ f0 f1))
895            ],
896        )?;
897
898        // run that rule 10 times
899        for _ in 0..10 {
900            run_ruleset(&mut egraph, ruleset)?;
901        }
902
903        // check that `(fib 20)` is now in the e-graph
904        let results = query(
905            &mut egraph,
906            vars![f: i64],
907            facts![(= (fib (unquote exprs::int(big_number))) f)],
908        )?;
909
910        let y = egraph.backend.base_values().get::<i64>(6765);
911        assert_eq!(results.data, [y]);
912
913        Ok(())
914    }
915
916    #[test]
917    fn rust_api_macros() -> Result<(), Error> {
918        let mut egraph = build_test_database()?;
919
920        datatype!(&mut egraph, (datatype Expr (One) (Two Expr Expr :cost 10)));
921
922        let ruleset = "custom_ruleset";
923        add_ruleset(&mut egraph, ruleset)?;
924
925        rule(
926            &mut egraph,
927            ruleset,
928            facts![
929                (fib 5)
930                (fib x)
931                (= f1 (fib (+ x 1)))
932                (= 3 (unquote exprs::int(1 + 2)))
933            ],
934            actions![
935                (let y (+ x 2))
936                (set (fib (+ x 2)) (+ f1 f1))
937                (delete (fib 0))
938                (subsume (Two (One) (One)))
939                (union (One) (Two (One) (One)))
940                (panic "message")
941                (+ 6 87)
942            ],
943        )?;
944
945        Ok(())
946    }
947
948    #[test]
949    fn rust_api_rust_rule() -> Result<(), Error> {
950        let mut egraph = build_test_database()?;
951
952        let big_number = 20;
953
954        // check that `(fib 20)` is not in the e-graph
955        let results = query(
956            &mut egraph,
957            vars![f: i64],
958            facts![(= (fib (unquote exprs::int(big_number))) f)],
959        )?;
960
961        assert!(results.data.is_empty());
962
963        let ruleset = "custom_ruleset";
964        add_ruleset(&mut egraph, ruleset)?;
965
966        // add the rule from `build_test_database` to the egraph
967        rust_rule(
968            &mut egraph,
969            "demo_rule",
970            ruleset,
971            vars![x: i64, f0: i64, f1: i64],
972            facts![
973                (= f0 (fib x))
974                (= f1 (fib (+ x 1)))
975            ],
976            move |ctx, values| {
977                let [x, f0, f1] = values else { unreachable!() };
978                let x = ctx.value_to_base::<i64>(*x);
979                let f0 = ctx.value_to_base::<i64>(*f0);
980                let f1 = ctx.value_to_base::<i64>(*f1);
981
982                let y = ctx.base_to_value::<i64>(x + 2);
983                let f2 = ctx.base_to_value::<i64>(f0 + f1);
984                ctx.insert("fib", [y, f2].into_iter());
985
986                Some(())
987            },
988        )?;
989
990        // run that rule 10 times
991        for _ in 0..10 {
992            run_ruleset(&mut egraph, ruleset)?;
993        }
994
995        // check that `(fib 20)` is now in the e-graph
996        let results = query(
997            &mut egraph,
998            vars![f: i64],
999            facts![(= (fib (unquote exprs::int(big_number))) f)],
1000        )?;
1001
1002        let y = egraph.backend.base_values().get::<i64>(6765);
1003        assert_eq!(results.data, [y]);
1004
1005        Ok(())
1006    }
1007}