1use crate::*;
10use std::any::{Any, TypeId};
11
12pub use egglog::ast::{Action, Fact, Facts, GenericActions, RustSpan, Span};
14pub use egglog::sort::{BigIntSort, BigRatSort, BoolSort, F64Sort, I64Sort, StringSort, UnitSort};
15pub use egglog::{CommandMacro, CommandMacroRegistry};
16pub use egglog::{EGraph, span};
17pub use egglog::{action, actions, datatype, expr, fact, facts, sort, vars};
18
19pub trait LiteralConvertible: Sized {
22 fn to_literal(self) -> egglog_ast::generic_ast::Literal;
23 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self>;
24}
25
26impl LiteralConvertible for i64 {
27 fn to_literal(self) -> egglog_ast::generic_ast::Literal {
28 egglog_ast::generic_ast::Literal::Int(self)
29 }
30 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
31 match lit {
32 egglog_ast::generic_ast::Literal::Int(i) => Some(*i),
33 _ => None,
34 }
35 }
36}
37
38impl LiteralConvertible for bool {
39 fn to_literal(self) -> egglog_ast::generic_ast::Literal {
40 egglog_ast::generic_ast::Literal::Bool(self)
41 }
42 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
43 match lit {
44 egglog_ast::generic_ast::Literal::Bool(b) => Some(*b),
45 _ => None,
46 }
47 }
48}
49
50impl LiteralConvertible for ordered_float::OrderedFloat<f64> {
51 fn to_literal(self) -> egglog_ast::generic_ast::Literal {
52 egglog_ast::generic_ast::Literal::Float(self)
53 }
54 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
55 match lit {
56 egglog_ast::generic_ast::Literal::Float(f) => Some(*f),
57 _ => None,
58 }
59 }
60}
61
62impl LiteralConvertible for egglog::sort::F {
63 fn to_literal(self) -> egglog_ast::generic_ast::Literal {
64 egglog_ast::generic_ast::Literal::Float(self.0)
65 }
66 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
67 match lit {
68 egglog_ast::generic_ast::Literal::Float(f) => Some(egglog::sort::F::from(*f)),
69 _ => None,
70 }
71 }
72}
73
74impl LiteralConvertible for egglog::sort::S {
75 fn to_literal(self) -> egglog_ast::generic_ast::Literal {
76 egglog_ast::generic_ast::Literal::String(self.0)
77 }
78 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
79 match lit {
80 egglog_ast::generic_ast::Literal::String(s) => Some(egglog::sort::S::new(s.clone())),
81 _ => None,
82 }
83 }
84}
85
86impl LiteralConvertible for () {
87 fn to_literal(self) -> egglog_ast::generic_ast::Literal {
88 egglog_ast::generic_ast::Literal::Unit
89 }
90 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
91 match lit {
92 egglog_ast::generic_ast::Literal::Unit => Some(()),
93 _ => None,
94 }
95 }
96}
97
98pub mod exprs {
99 use super::*;
100
101 pub fn var(name: &str) -> Expr {
103 Expr::Var(span!(), name.to_owned())
104 }
105
106 pub fn int(value: i64) -> Expr {
108 Expr::Lit(span!(), Literal::Int(value))
109 }
110
111 pub fn float(value: f64) -> Expr {
113 Expr::Lit(span!(), Literal::Float(value.into()))
114 }
115
116 pub fn string(value: &str) -> Expr {
118 Expr::Lit(span!(), Literal::String(value.to_owned()))
119 }
120
121 pub fn unit() -> Expr {
123 Expr::Lit(span!(), Literal::Unit)
124 }
125
126 pub fn bool(value: bool) -> Expr {
128 Expr::Lit(span!(), Literal::Bool(value))
129 }
130
131 pub fn call(f: &str, xs: Vec<Expr>) -> Expr {
133 Expr::Call(span!(), f.to_owned(), xs)
134 }
135}
136
137pub fn add_ruleset(egraph: &mut EGraph, ruleset: &str) -> Result<Vec<CommandOutput>, Error> {
139 egraph.run_program(vec![Command::AddRuleset(span!(), ruleset.to_owned())])
140}
141
142pub fn run_ruleset(egraph: &mut EGraph, ruleset: &str) -> Result<Vec<CommandOutput>, Error> {
144 egraph.run_program(vec![Command::RunSchedule(Schedule::Run(
145 span!(),
146 RunConfig {
147 ruleset: ruleset.to_owned(),
148 until: None,
149 },
150 ))])
151}
152
153#[macro_export]
154macro_rules! sort {
155 (BigInt) => {
156 BigIntSort.to_arcsort()
157 };
158 (BigRat) => {
159 BigRatSort.to_arcsort()
160 };
161 (bool) => {
162 BoolSort.to_arcsort()
163 };
164 (f64) => {
165 F64Sort.to_arcsort()
166 };
167 (i64) => {
168 I64Sort.to_arcsort()
169 };
170 (String) => {
171 StringSort.to_arcsort()
172 };
173 (Unit) => {
174 UnitSort.to_arcsort()
175 };
176 ($t:expr) => {
177 $t
178 };
179}
180
181#[macro_export]
182macro_rules! vars {
183 [$($x:ident : $t:tt),* $(,)?] => {
184 &[$((stringify!($x), sort!($t))),*]
185 };
186}
187
188#[macro_export]
189macro_rules! expr {
190 ((unquote $unquoted:expr)) => { $unquoted };
191 (($func:tt $($arg:tt)*)) => { exprs::call(stringify!($func), vec![$(expr!($arg)),*]) };
192 ($value:literal) => { exprs::int($value) };
193 ($quoted:tt) => { exprs::var(stringify!($quoted)) };
194}
195
196#[macro_export]
197macro_rules! fact {
198 ((= $($arg:tt)*)) => { Fact::Eq(span!(), $(expr!($arg)),*) };
199 ($a:tt) => { Fact::Fact(expr!($a)) };
200}
201
202#[macro_export]
203macro_rules! facts {
204 ($($tree:tt)*) => { Facts(vec![$(fact!($tree)),*]) };
205}
206
207#[macro_export]
208macro_rules! action {
209 ((let $name:ident $value:tt)) => {
210 Action::Let(span!(), String::from(stringify!($name)), expr!($value))
211 };
212 ((set ($f:ident $($x:tt)*) $value:tt)) => {
213 Action::Set(span!(), String::from(stringify!($f)), vec![$(expr!($x)),*], expr!($value))
214 };
215 ((delete ($f:ident $($x:tt)*))) => {
216 Action::Change(span!(), Change::Delete, String::from(stringify!($f)), vec![$(expr!($x)),*])
217 };
218 ((subsume ($f:ident $($x:tt)*))) => {
219 Action::Change(span!(), Change::Subsume, String::from(stringify!($f)), vec![$(expr!($x)),*])
220 };
221 ((union $x:tt $y:tt)) => {
222 Action::Union(span!(), expr!($x), expr!($y))
223 };
224 ((panic $message:literal)) => {
225 Action::Panic(span!(), $message.to_owned())
226 };
227 ($x:tt) => {
228 Action::Expr(span!(), expr!($x))
229 };
230}
231
232#[macro_export]
233macro_rules! actions {
234 ($($tree:tt)*) => { GenericActions(vec![$(action!($tree)),*]) };
235}
236
237pub fn rule(
304 egraph: &mut EGraph,
305 ruleset: &str,
306 facts: Facts<String, String>,
307 actions: Actions,
308) -> Result<Vec<CommandOutput>, Error> {
309 let mut rule = Rule {
310 span: span!(),
311 head: actions,
312 body: facts.0,
313 name: "".into(),
314 ruleset: ruleset.into(),
315 };
316
317 rule.name = format!("{rule:?}");
318
319 egraph.run_program(vec![Command::Rule { rule }])
320}
321
322pub struct RustRuleContext<'a, 'b> {
325 exec_state: &'a mut ExecutionState<'b>,
326 union_action: egglog_bridge::UnionAction,
327 table_actions: HashMap<String, egglog_bridge::TableAction>,
328 panic_id: ExternalFunctionId,
329}
330
331impl RustRuleContext<'_, '_> {
332 pub fn value_to_base<T: BaseValue>(&self, x: Value) -> T {
334 self.exec_state.base_values().unwrap::<T>(x)
335 }
336
337 pub fn value_to_container<T: ContainerValue>(
341 &mut self,
342 x: Value,
343 ) -> Option<impl Deref<Target = T>> {
344 self.exec_state.container_values().get_val::<T>(x)
345 }
346
347 pub fn base_to_value<T: BaseValue>(&self, x: T) -> Value {
349 self.exec_state.base_values().get::<T>(x)
350 }
351
352 pub fn container_to_value<T: ContainerValue>(&mut self, x: T) -> Value {
354 self.exec_state
355 .container_values()
356 .register_val::<T>(x, self.exec_state)
357 }
358
359 fn get_table_action(&self, table: &str) -> egglog_bridge::TableAction {
360 self.table_actions[table].clone()
361 }
362
363 pub fn lookup(&mut self, table: &str, key: &[Value]) -> Option<Value> {
366 self.get_table_action(table).lookup(self.exec_state, key)
367 }
368
369 pub fn union(&mut self, x: Value, y: Value) {
372 self.union_action.union(self.exec_state, x, y)
373 }
374
375 pub fn insert(&mut self, table: &str, row: impl Iterator<Item = Value>) {
378 self.get_table_action(table).insert(self.exec_state, row)
379 }
380
381 pub fn remove(&mut self, table: &str, key: &[Value]) {
384 self.get_table_action(table).remove(self.exec_state, key)
385 }
386
387 pub fn subsume(&mut self, table: &str, key: &[Value]) {
390 self.get_table_action(table)
391 .subsume(self.exec_state, key.iter().copied())
392 }
393
394 pub fn panic(&mut self) -> Option<()> {
399 self.exec_state.call_external_func(self.panic_id, &[]);
400 None
401 }
402}
403
404#[derive(Clone)]
405struct RustRuleRhs<F: Fn(&mut RustRuleContext, &[Value]) -> Option<()>> {
406 name: String,
407 inputs: Vec<ArcSort>,
408 union_action: egglog_bridge::UnionAction,
409 table_actions: HashMap<String, egglog_bridge::TableAction>,
410 panic_id: ExternalFunctionId,
411 func: F,
412}
413
414impl<F: Fn(&mut RustRuleContext, &[Value]) -> Option<()>> Primitive for RustRuleRhs<F> {
415 fn name(&self) -> &str {
416 &self.name
417 }
418
419 fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
420 let sorts: Vec<_> = self
421 .inputs
422 .iter()
423 .chain(once(&UnitSort.to_arcsort()))
424 .cloned()
425 .collect();
426 SimpleTypeConstraint::new(self.name(), sorts, span.clone()).into_box()
427 }
428
429 fn apply(&self, exec_state: &mut ExecutionState, values: &[Value]) -> Option<Value> {
430 let mut context = RustRuleContext {
431 exec_state,
432 union_action: self.union_action,
433 table_actions: self.table_actions.clone(),
434 panic_id: self.panic_id,
435 };
436 (self.func)(&mut context, values)?;
437 Some(exec_state.base_values().get(()))
438 }
439}
440
441pub fn rust_rule(
519 egraph: &mut EGraph,
520 rule_name: &str,
521 ruleset: &str,
522 vars: &[(&str, ArcSort)],
523 facts: Facts<String, String>,
524 func: impl Fn(&mut RustRuleContext, &[Value]) -> Option<()> + Clone + Send + Sync + 'static,
525) -> Result<Vec<CommandOutput>, Error> {
526 let prim_name = egraph.parser.symbol_gen.fresh("rust_rule_prim");
527 let panic_id = egraph.backend.new_panic(format!("{prim_name}_panic"));
528 egraph.add_primitive(RustRuleRhs {
529 name: prim_name.clone(),
530 inputs: vars.iter().map(|(_, s)| s.clone()).collect(),
531 union_action: egglog_bridge::UnionAction::new(&egraph.backend),
532 table_actions: egraph
533 .functions
534 .iter()
535 .map(|(k, v)| {
536 (
537 k.clone(),
538 egglog_bridge::TableAction::new(&egraph.backend, v.backend_id),
539 )
540 })
541 .collect(),
542 panic_id,
543 func,
544 });
545
546 let rule = Rule {
547 span: span!(),
548 head: GenericActions(vec![GenericAction::Expr(
549 span!(),
550 exprs::call(
551 &prim_name,
552 vars.iter().map(|(v, _)| exprs::var(v)).collect(),
553 ),
554 )]),
555 body: facts.0,
556 name: egraph.parser.symbol_gen.fresh(rule_name),
557 ruleset: ruleset.into(),
558 };
559
560 egraph.run_program(vec![Command::Rule { rule }])
561}
562
563pub struct QueryResult {
565 rows: usize,
566 cols: usize,
567 data: Vec<Value>,
568}
569
570impl QueryResult {
571 pub fn iter(&self) -> impl Iterator<Item = &[Value]> {
575 assert!(self.cols > 0, "no vars; use `any_matches` instead");
576 assert!(self.data.len() % self.cols == 0);
577 self.data.chunks_exact(self.cols)
578 }
579
580 pub fn any_matches(&self) -> bool {
582 self.rows > 0
583 }
584}
585
586pub fn query(
624 egraph: &mut EGraph,
625 vars: &[(&str, ArcSort)],
626 facts: Facts<String, String>,
627) -> Result<QueryResult, Error> {
628 use std::sync::{Arc, Mutex};
629
630 let results = Arc::new(Mutex::new(QueryResult {
631 rows: 0,
632 cols: vars.len(),
633 data: Vec::new(),
634 }));
635 let results_weak = Arc::downgrade(&results);
636
637 let ruleset = egraph.parser.symbol_gen.fresh("query_ruleset");
638 add_ruleset(egraph, &ruleset)?;
639
640 rust_rule(egraph, "query", &ruleset, vars, facts, move |_, values| {
641 let arc = results_weak.upgrade().unwrap();
642 let mut results = arc.lock().unwrap();
643 results.rows += 1;
644 results.data.extend(values);
645 Some(())
646 })?;
647
648 run_ruleset(egraph, &ruleset)?;
649
650 let ruleset = egraph.rulesets.swap_remove(&ruleset).unwrap();
651
652 let Ruleset::Rules(rules) = ruleset else {
653 unreachable!()
654 };
655 assert_eq!(rules.len(), 1);
656 let rule = rules.into_iter().next().unwrap().1;
657 egraph.backend.free_rule(rule.1);
658
659 let Some(mutex) = Arc::into_inner(results) else {
660 panic!("results_weak.upgrade() was not dropped");
661 };
662 Ok(mutex.into_inner().unwrap())
663}
664
665pub fn add_sort(egraph: &mut EGraph, name: &str) -> Result<Vec<CommandOutput>, Error> {
667 egraph.run_program(vec![Command::Sort(span!(), name.to_owned(), None)])
668}
669
670pub fn add_function(
672 egraph: &mut EGraph,
673 name: &str,
674 schema: Schema,
675 merge: Option<GenericExpr<String, String>>,
676) -> Result<Vec<CommandOutput>, Error> {
677 egraph.run_program(vec![Command::Function {
678 span: span!(),
679 name: name.to_owned(),
680 schema,
681 merge,
682 }])
683}
684
685pub fn add_constructor(
687 egraph: &mut EGraph,
688 name: &str,
689 schema: Schema,
690 cost: Option<DefaultCost>,
691 unextractable: bool,
692) -> Result<Vec<CommandOutput>, Error> {
693 egraph.run_program(vec![Command::Constructor {
694 span: span!(),
695 name: name.to_owned(),
696 schema,
697 cost,
698 unextractable,
699 }])
700}
701
702pub fn add_relation(
704 egraph: &mut EGraph,
705 name: &str,
706 inputs: Vec<String>,
707) -> Result<Vec<CommandOutput>, Error> {
708 egraph.run_program(vec![Command::Relation {
709 span: span!(),
710 name: name.to_owned(),
711 inputs,
712 }])
713}
714
715#[macro_export]
717macro_rules! datatype {
718 ($egraph:expr, (datatype $sort:ident $(($name:ident $($args:ident)* $(:cost $cost:expr)?))*)) => {
719 add_sort($egraph, stringify!($sort))?;
720 $(add_constructor(
721 $egraph,
722 stringify!($name),
723 Schema {
724 input: vec![$(stringify!($args).to_owned()),*],
725 output: stringify!($sort).to_owned(),
726 },
727 [$($cost)*].first().copied(),
728 false,
729 )?;)*
730 };
731}
732
733pub trait BaseSort: Any + Send + Sync + Debug {
740 type Base: BaseValue;
741 fn name(&self) -> &str;
742 fn register_primitives(&self, _eg: &mut EGraph) {}
743 fn reconstruct_termdag(&self, _: &BaseValues, _: Value, _: &mut TermDag) -> TermId;
744
745 fn to_arcsort(self) -> ArcSort
746 where
747 Self: Sized,
748 {
749 Arc::new(BaseSortImpl(self))
750 }
751}
752
753#[derive(Debug)]
754struct BaseSortImpl<T: BaseSort>(T);
755
756impl<T: BaseSort> Sort for BaseSortImpl<T> {
757 fn name(&self) -> &str {
758 self.0.name()
759 }
760
761 fn column_ty(&self, backend: &egglog_bridge::EGraph) -> ColumnTy {
762 ColumnTy::Base(backend.base_values().get_ty::<T::Base>())
763 }
764
765 fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
766 backend.base_values_mut().register_type::<T::Base>();
767 }
768
769 fn value_type(&self) -> Option<TypeId> {
770 Some(TypeId::of::<T::Base>())
771 }
772
773 fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
774 self
775 }
776
777 fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
778 self.0.register_primitives(eg)
779 }
780
781 fn reconstruct_termdag_base(
783 &self,
784 base_values: &BaseValues,
785 value: Value,
786 termdag: &mut TermDag,
787 ) -> TermId {
788 self.0.reconstruct_termdag(base_values, value, termdag)
789 }
790}
791
792pub trait ContainerSort: Any + Send + Sync + Debug {
799 type Container: ContainerValue;
800 fn name(&self) -> &str;
801 fn is_eq_container_sort(&self) -> bool;
802 fn inner_sorts(&self) -> Vec<ArcSort>;
803 fn inner_values(&self, _: &ContainerValues, _: Value) -> Vec<(ArcSort, Value)>;
804 fn register_primitives(&self, _eg: &mut EGraph) {}
805 fn reconstruct_termdag(
806 &self,
807 _: &ContainerValues,
808 _: Value,
809 _: &mut TermDag,
810 _: Vec<TermId>,
811 ) -> TermId;
812 fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String;
813
814 fn to_arcsort(self) -> ArcSort
815 where
816 Self: Sized,
817 {
818 Arc::new(ContainerSortImpl(self))
819 }
820}
821
822#[derive(Debug)]
823struct ContainerSortImpl<T: ContainerSort>(T);
824
825impl<T: ContainerSort> Sort for ContainerSortImpl<T> {
826 fn name(&self) -> &str {
827 self.0.name()
828 }
829
830 fn column_ty(&self, _backend: &egglog_bridge::EGraph) -> ColumnTy {
831 ColumnTy::Id
832 }
833
834 fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
835 backend.register_container_ty::<T::Container>();
836 }
837
838 fn value_type(&self) -> Option<TypeId> {
839 Some(TypeId::of::<T::Container>())
840 }
841
842 fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
843 self
844 }
845
846 fn inner_sorts(&self) -> Vec<ArcSort> {
847 self.0.inner_sorts()
848 }
849
850 fn inner_values(
851 &self,
852 container_values: &ContainerValues,
853 value: Value,
854 ) -> Vec<(ArcSort, Value)> {
855 self.0.inner_values(container_values, value)
856 }
857
858 fn is_container_sort(&self) -> bool {
859 true
860 }
861
862 fn is_eq_container_sort(&self) -> bool {
863 self.0.is_eq_container_sort()
864 }
865
866 fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String {
867 self.0.serialized_name(container_values, value)
868 }
869
870 fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
871 self.0.register_primitives(eg);
872 }
873
874 fn reconstruct_termdag_container(
875 &self,
876 container_values: &ContainerValues,
877 value: Value,
878 termdag: &mut TermDag,
879 element_terms: Vec<TermId>,
880 ) -> TermId {
881 self.0
882 .reconstruct_termdag(container_values, value, termdag, element_terms)
883 }
884}
885
886pub fn add_base_sort(
888 egraph: &mut EGraph,
889 base_sort: impl BaseSort,
890 span: Span,
891) -> Result<(), TypeError> {
892 egraph.add_sort(BaseSortImpl(base_sort), span)
893}
894
895pub fn add_container_sort(
896 egraph: &mut EGraph,
897 container_sort: impl ContainerSort,
898 span: Span,
899) -> Result<(), TypeError> {
900 egraph.add_sort(ContainerSortImpl(container_sort), span)
901}
902
903#[cfg(test)]
904mod tests {
905 use super::*;
906
907 fn build_test_database() -> Result<EGraph, Error> {
908 let mut egraph = EGraph::default();
909 egraph.parse_and_run_program(
910 None,
911 "
912(function fib (i64) i64 :no-merge)
913(set (fib 0) 0)
914(set (fib 1) 1)
915(rule (
916 (= f0 (fib x))
917 (= f1 (fib (+ x 1)))
918) (
919 (set (fib (+ x 2)) (+ f0 f1))
920))
921(run 10)
922 ",
923 )?;
924 Ok(egraph)
925 }
926
927 #[test]
928 fn rust_api_query() -> Result<(), Error> {
929 let mut egraph = build_test_database()?;
930
931 let results = query(
932 &mut egraph,
933 vars![x: i64, y: i64],
934 facts![
935 (= (fib x) y)
936 (= y 13)
937 ],
938 )?;
939
940 let x = egraph.backend.base_values().get::<i64>(7);
941 let y = egraph.backend.base_values().get::<i64>(13);
942 assert_eq!(results.data, [x, y]);
943
944 Ok(())
945 }
946
947 #[test]
948 fn rust_api_rule() -> Result<(), Error> {
949 let mut egraph = build_test_database()?;
950
951 let big_number = 20;
952
953 let results = query(
955 &mut egraph,
956 vars![f: i64],
957 facts![(= (fib (unquote exprs::int(big_number))) f)],
958 )?;
959
960 assert!(results.data.is_empty());
961
962 let ruleset = "custom_ruleset";
963 add_ruleset(&mut egraph, ruleset)?;
964
965 rule(
967 &mut egraph,
968 ruleset,
969 facts![
970 (= f0 (fib x))
971 (= f1 (fib (+ x 1)))
972 ],
973 actions![
974 (set (fib (+ x 2)) (+ f0 f1))
975 ],
976 )?;
977
978 for _ in 0..10 {
980 run_ruleset(&mut egraph, ruleset)?;
981 }
982
983 let results = query(
985 &mut egraph,
986 vars![f: i64],
987 facts![(= (fib (unquote exprs::int(big_number))) f)],
988 )?;
989
990 let y = egraph.backend.base_values().get::<i64>(6765);
991 assert_eq!(results.data, [y]);
992
993 Ok(())
994 }
995
996 #[test]
997 fn rust_api_macros() -> Result<(), Error> {
998 let mut egraph = build_test_database()?;
999
1000 datatype!(&mut egraph, (datatype Expr (One) (Two Expr Expr :cost 10)));
1001
1002 let ruleset = "custom_ruleset";
1003 add_ruleset(&mut egraph, ruleset)?;
1004
1005 rule(
1006 &mut egraph,
1007 ruleset,
1008 facts![
1009 (fib 5)
1010 (fib x)
1011 (= f1 (fib (+ x 1)))
1012 (= 3 (unquote exprs::int(1 + 2)))
1013 ],
1014 actions![
1015 (let y (+ x 2))
1016 (set (fib (+ x 2)) (+ f1 f1))
1017 (delete (fib 0))
1018 (subsume (Two (One) (One)))
1019 (union (One) (Two (One) (One)))
1020 (panic "message")
1021 (+ 6 87)
1022 ],
1023 )?;
1024
1025 Ok(())
1026 }
1027
1028 #[test]
1029 fn rust_api_rust_rule() -> Result<(), Error> {
1030 let mut egraph = build_test_database()?;
1031
1032 let big_number = 20;
1033
1034 let results = query(
1036 &mut egraph,
1037 vars![f: i64],
1038 facts![(= (fib (unquote exprs::int(big_number))) f)],
1039 )?;
1040
1041 assert!(results.data.is_empty());
1042
1043 let ruleset = "custom_ruleset";
1044 add_ruleset(&mut egraph, ruleset)?;
1045
1046 rust_rule(
1048 &mut egraph,
1049 "demo_rule",
1050 ruleset,
1051 vars![x: i64, f0: i64, f1: i64],
1052 facts![
1053 (= f0 (fib x))
1054 (= f1 (fib (+ x 1)))
1055 ],
1056 move |ctx, values| {
1057 let [x, f0, f1] = values else { unreachable!() };
1058 let x = ctx.value_to_base::<i64>(*x);
1059 let f0 = ctx.value_to_base::<i64>(*f0);
1060 let f1 = ctx.value_to_base::<i64>(*f1);
1061
1062 let y = ctx.base_to_value::<i64>(x + 2);
1063 let f2 = ctx.base_to_value::<i64>(f0 + f1);
1064 ctx.insert("fib", [y, f2].into_iter());
1065
1066 Some(())
1067 },
1068 )?;
1069
1070 for _ in 0..10 {
1072 run_ruleset(&mut egraph, ruleset)?;
1073 }
1074
1075 let results = query(
1077 &mut egraph,
1078 vars![f: i64],
1079 facts![(= (fib (unquote exprs::int(big_number))) f)],
1080 )?;
1081
1082 let y = egraph.backend.base_values().get::<i64>(6765);
1083 assert_eq!(results.data, [y]);
1084
1085 Ok(())
1086 }
1087}