1use crate::*;
21use std::any::{Any, TypeId};
22
23pub use egglog::ast::{Action, Fact, Facts, GenericActions, RustSpan, Span};
25pub use egglog::sort::{BigIntSort, BigRatSort, BoolSort, F64Sort, I64Sort, StringSort, UnitSort};
26pub use egglog::{CommandMacro, CommandMacroRegistry};
27pub use egglog::{Core, FullState, PureState, Read, ReadState, Write, WriteState};
28pub use egglog::{EGraph, span};
29pub use egglog::{action, actions, datatype, expr, fact, facts, sort, vars};
30
31pub trait LiteralConvertible: Sized {
34 fn to_literal(self) -> egglog_ast::generic_ast::Literal;
35 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self>;
36}
37
38impl LiteralConvertible for i64 {
39 fn to_literal(self) -> egglog_ast::generic_ast::Literal {
40 egglog_ast::generic_ast::Literal::Int(self)
41 }
42 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
43 match lit {
44 egglog_ast::generic_ast::Literal::Int(i) => Some(*i),
45 _ => None,
46 }
47 }
48}
49
50impl LiteralConvertible for bool {
51 fn to_literal(self) -> egglog_ast::generic_ast::Literal {
52 egglog_ast::generic_ast::Literal::Bool(self)
53 }
54 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
55 match lit {
56 egglog_ast::generic_ast::Literal::Bool(b) => Some(*b),
57 _ => None,
58 }
59 }
60}
61
62impl LiteralConvertible for ordered_float::OrderedFloat<f64> {
63 fn to_literal(self) -> egglog_ast::generic_ast::Literal {
64 egglog_ast::generic_ast::Literal::Float(self)
65 }
66 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
67 match lit {
68 egglog_ast::generic_ast::Literal::Float(f) => Some(*f),
69 _ => None,
70 }
71 }
72}
73
74impl LiteralConvertible for egglog::sort::F {
75 fn to_literal(self) -> egglog_ast::generic_ast::Literal {
76 egglog_ast::generic_ast::Literal::Float(self.0)
77 }
78 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
79 match lit {
80 egglog_ast::generic_ast::Literal::Float(f) => Some(egglog::sort::F::from(*f)),
81 _ => None,
82 }
83 }
84}
85
86impl LiteralConvertible for egglog::sort::S {
87 fn to_literal(self) -> egglog_ast::generic_ast::Literal {
88 egglog_ast::generic_ast::Literal::String(self.0)
89 }
90 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
91 match lit {
92 egglog_ast::generic_ast::Literal::String(s) => Some(egglog::sort::S::new(s.clone())),
93 _ => None,
94 }
95 }
96}
97
98impl LiteralConvertible for () {
99 fn to_literal(self) -> egglog_ast::generic_ast::Literal {
100 egglog_ast::generic_ast::Literal::Unit
101 }
102 fn from_literal(lit: &egglog_ast::generic_ast::Literal) -> Option<Self> {
103 match lit {
104 egglog_ast::generic_ast::Literal::Unit => Some(()),
105 _ => None,
106 }
107 }
108}
109
110pub mod exprs {
111 use super::*;
112
113 pub fn var(name: &str) -> Expr {
115 Expr::Var(span!(), name.to_owned())
116 }
117
118 pub fn int(value: i64) -> Expr {
120 Expr::Lit(span!(), Literal::Int(value))
121 }
122
123 pub fn float(value: f64) -> Expr {
125 Expr::Lit(span!(), Literal::Float(value.into()))
126 }
127
128 pub fn string(value: &str) -> Expr {
130 Expr::Lit(span!(), Literal::String(value.to_owned()))
131 }
132
133 pub fn unit() -> Expr {
135 Expr::Lit(span!(), Literal::Unit)
136 }
137
138 pub fn bool(value: bool) -> Expr {
140 Expr::Lit(span!(), Literal::Bool(value))
141 }
142
143 pub fn call(f: &str, xs: Vec<Expr>) -> Expr {
145 Expr::Call(span!(), f.to_owned(), xs)
146 }
147}
148
149pub fn add_ruleset(egraph: &mut EGraph, ruleset: &str) -> Result<Vec<CommandOutput>, Error> {
151 egraph.run_program(vec![Command::AddRuleset(span!(), ruleset.to_owned())])
152}
153
154pub fn run_ruleset(egraph: &mut EGraph, ruleset: &str) -> Result<Vec<CommandOutput>, Error> {
156 egraph.run_program(vec![Command::RunSchedule(Schedule::Run(
157 span!(),
158 RunConfig {
159 ruleset: ruleset.to_owned(),
160 until: None,
161 },
162 ))])
163}
164
165#[macro_export]
166macro_rules! sort {
167 (BigInt) => {
168 BigIntSort.to_arcsort()
169 };
170 (BigRat) => {
171 BigRatSort.to_arcsort()
172 };
173 (bool) => {
174 BoolSort.to_arcsort()
175 };
176 (f64) => {
177 F64Sort.to_arcsort()
178 };
179 (i64) => {
180 I64Sort.to_arcsort()
181 };
182 (String) => {
183 StringSort.to_arcsort()
184 };
185 (Unit) => {
186 UnitSort.to_arcsort()
187 };
188 ($t:expr) => {
189 $t
190 };
191}
192
193#[macro_export]
194macro_rules! vars {
195 [$($x:ident : $t:tt),* $(,)?] => {
196 &[$((stringify!($x), sort!($t))),*]
197 };
198}
199
200#[macro_export]
201macro_rules! expr {
202 ((unquote $unquoted:expr)) => { $unquoted };
203 (($func:tt $($arg:tt)*)) => { exprs::call(stringify!($func), vec![$(expr!($arg)),*]) };
204 ($value:literal) => { exprs::int($value) };
205 ($quoted:tt) => { exprs::var(stringify!($quoted)) };
206}
207
208#[macro_export]
209macro_rules! fact {
210 ((= $($arg:tt)*)) => { Fact::Eq(span!(), $(expr!($arg)),*) };
211 ($a:tt) => { Fact::Fact(expr!($a)) };
212}
213
214#[macro_export]
215macro_rules! facts {
216 ($($tree:tt)*) => { Facts(vec![$(fact!($tree)),*]) };
217}
218
219#[macro_export]
220macro_rules! action {
221 ((let $name:ident $value:tt)) => {
222 Action::Let(span!(), String::from(stringify!($name)), expr!($value))
223 };
224 ((set ($f:ident $($x:tt)*) $value:tt)) => {
225 Action::Set(span!(), String::from(stringify!($f)), vec![$(expr!($x)),*], expr!($value))
226 };
227 ((delete ($f:ident $($x:tt)*))) => {
228 Action::Change(span!(), Change::Delete, String::from(stringify!($f)), vec![$(expr!($x)),*])
229 };
230 ((subsume ($f:ident $($x:tt)*))) => {
231 Action::Change(span!(), Change::Subsume, String::from(stringify!($f)), vec![$(expr!($x)),*])
232 };
233 ((union $x:tt $y:tt)) => {
234 Action::Union(span!(), expr!($x), expr!($y))
235 };
236 ((panic $message:literal)) => {
237 Action::Panic(span!(), $message.to_owned())
238 };
239 ($x:tt) => {
240 Action::Expr(span!(), expr!($x))
241 };
242}
243
244#[macro_export]
245macro_rules! actions {
246 ($($tree:tt)*) => { GenericActions(vec![$(action!($tree)),*]) };
247}
248
249pub fn rule(
316 egraph: &mut EGraph,
317 ruleset: &str,
318 facts: Facts<String, String>,
319 actions: Actions,
320) -> Result<Vec<CommandOutput>, Error> {
321 let rule = Rule {
322 span: span!(),
323 head: actions,
324 body: facts.0,
325 name: "".into(),
326 ruleset: ruleset.into(),
327 naive: false,
328 no_decomp: false,
329 };
330
331 egraph.run_program(vec![Command::Rule { rule }])
332}
333
334#[derive(Clone)]
335struct RustRuleRhs<F>
336where
337 F: for<'a, 'db> Fn(crate::WriteState<'a, 'db>, &[Value]) -> Option<()>
338 + Clone
339 + Send
340 + Sync
341 + 'static,
342{
343 name: String,
344 inputs: Vec<ArcSort>,
345 func: F,
346}
347
348impl<F> Primitive for RustRuleRhs<F>
349where
350 F: for<'a, 'db> Fn(crate::WriteState<'a, 'db>, &[Value]) -> Option<()>
351 + Clone
352 + Send
353 + Sync
354 + 'static,
355{
356 fn name(&self) -> &str {
357 &self.name
358 }
359
360 fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
361 let sorts: Vec<_> = self
362 .inputs
363 .iter()
364 .chain(once(&UnitSort.to_arcsort()))
365 .cloned()
366 .collect();
367 SimpleTypeConstraint::new(self.name(), sorts, span.clone()).into_box()
368 }
369}
370
371impl<F> WritePrim for RustRuleRhs<F>
372where
373 F: for<'a, 'db> Fn(crate::WriteState<'a, 'db>, &[Value]) -> Option<()>
374 + Clone
375 + Send
376 + Sync
377 + 'static,
378{
379 fn apply<'a, 'db>(&self, state: crate::WriteState<'a, 'db>, values: &[Value]) -> Option<Value> {
380 let unit = state.base_values().get(());
381 (self.func)(state, values)?;
382 Some(unit)
383 }
384}
385
386pub fn rust_rule(
464 egraph: &mut EGraph,
465 rule_name: &str,
466 ruleset: &str,
467 vars: &[(&str, ArcSort)],
468 facts: Facts<String, String>,
469 func: impl for<'a, 'db> Fn(crate::WriteState<'a, 'db>, &[Value]) -> Option<()>
470 + Clone
471 + Send
472 + Sync
473 + 'static,
474) -> Result<Vec<CommandOutput>, Error> {
475 let prim_name = egraph.parser.symbol_gen.fresh("rust_rule_prim");
476 egraph.add_write_primitive(
477 RustRuleRhs {
478 name: prim_name.clone(),
479 inputs: vars.iter().map(|(_, s)| s.clone()).collect(),
480 func,
481 },
482 None,
483 );
484
485 let rule = Rule {
486 span: span!(),
487 head: GenericActions(vec![GenericAction::Expr(
488 span!(),
489 exprs::call(
490 &prim_name,
491 vars.iter().map(|(v, _)| exprs::var(v)).collect(),
492 ),
493 )]),
494 body: facts.0,
495 name: egraph.parser.symbol_gen.fresh(rule_name),
496 ruleset: ruleset.into(),
497 naive: false,
498 no_decomp: false,
499 };
500
501 egraph.run_program(vec![Command::Rule { rule }])
502}
503
504#[derive(Clone)]
505struct RustRuleFullRhs<F>
506where
507 F: for<'a, 'db> Fn(crate::FullState<'a, 'db>, &[Value]) -> Option<()>
508 + Clone
509 + Send
510 + Sync
511 + 'static,
512{
513 name: String,
514 inputs: Vec<ArcSort>,
515 func: F,
516}
517
518impl<F> Primitive for RustRuleFullRhs<F>
519where
520 F: for<'a, 'db> Fn(crate::FullState<'a, 'db>, &[Value]) -> Option<()>
521 + Clone
522 + Send
523 + Sync
524 + 'static,
525{
526 fn name(&self) -> &str {
527 &self.name
528 }
529
530 fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
531 let sorts: Vec<_> = self
532 .inputs
533 .iter()
534 .chain(once(&UnitSort.to_arcsort()))
535 .cloned()
536 .collect();
537 SimpleTypeConstraint::new(self.name(), sorts, span.clone()).into_box()
538 }
539}
540
541impl<F> crate::FullPrim for RustRuleFullRhs<F>
542where
543 F: for<'a, 'db> Fn(crate::FullState<'a, 'db>, &[Value]) -> Option<()>
544 + Clone
545 + Send
546 + Sync
547 + 'static,
548{
549 fn apply<'a, 'db>(&self, state: crate::FullState<'a, 'db>, values: &[Value]) -> Option<Value> {
550 let unit = state.base_values().get(());
551 (self.func)(state, values)?;
552 Some(unit)
553 }
554}
555
556pub fn rust_rule_full(
565 egraph: &mut EGraph,
566 rule_name: &str,
567 ruleset: &str,
568 vars: &[(&str, ArcSort)],
569 facts: Facts<String, String>,
570 func: impl for<'a, 'db> Fn(crate::FullState<'a, 'db>, &[Value]) -> Option<()>
571 + Clone
572 + Send
573 + Sync
574 + 'static,
575) -> Result<Vec<CommandOutput>, Error> {
576 let prim_name = egraph.parser.symbol_gen.fresh("rust_rule_full_prim");
577 egraph.add_full_primitive(
578 RustRuleFullRhs {
579 name: prim_name.clone(),
580 inputs: vars.iter().map(|(_, s)| s.clone()).collect(),
581 func,
582 },
583 None,
584 );
585
586 let rule = Rule {
587 span: span!(),
588 head: GenericActions(vec![GenericAction::Expr(
589 span!(),
590 exprs::call(
591 &prim_name,
592 vars.iter().map(|(v, _)| exprs::var(v)).collect(),
593 ),
594 )]),
595 body: facts.0,
596 name: egraph.parser.symbol_gen.fresh(rule_name),
597 ruleset: ruleset.into(),
598 naive: true,
601 no_decomp: false,
602 };
603
604 egraph.run_program(vec![Command::Rule { rule }])
605}
606
607pub struct QueryResult {
609 rows: usize,
610 cols: usize,
611 data: Vec<Value>,
612}
613
614impl QueryResult {
615 pub fn iter(&self) -> impl Iterator<Item = &[Value]> {
619 assert!(self.cols > 0, "no vars; use `any_matches` instead");
620 assert!(self.data.len().is_multiple_of(self.cols));
621 self.data.chunks_exact(self.cols)
622 }
623
624 pub fn any_matches(&self) -> bool {
626 self.rows > 0
627 }
628}
629
630pub fn query(
668 egraph: &mut EGraph,
669 vars: &[(&str, ArcSort)],
670 facts: Facts<String, String>,
671) -> Result<QueryResult, Error> {
672 use std::sync::{Arc, Mutex};
673
674 let results = Arc::new(Mutex::new(QueryResult {
675 rows: 0,
676 cols: vars.len(),
677 data: Vec::new(),
678 }));
679 let results_weak = Arc::downgrade(&results);
680
681 let ruleset = egraph.parser.symbol_gen.fresh("query_ruleset");
682 add_ruleset(egraph, &ruleset)?;
683
684 rust_rule(egraph, "query", &ruleset, vars, facts, move |_, values| {
685 let arc = results_weak.upgrade().unwrap();
686 let mut results = arc.lock().unwrap();
687 results.rows += 1;
688 results.data.extend(values);
689 Some(())
690 })?;
691
692 run_ruleset(egraph, &ruleset)?;
693
694 let ruleset = egraph.rulesets.swap_remove(&ruleset).unwrap();
695
696 let Ruleset::Rules(rules) = ruleset else {
697 unreachable!()
698 };
699 assert_eq!(rules.len(), 1);
700 let rule = rules.into_iter().next().unwrap().1;
701 egraph.backend.free_rule(rule.1);
702
703 let Some(mutex) = Arc::into_inner(results) else {
704 panic!("results_weak.upgrade() was not dropped");
705 };
706 Ok(mutex.into_inner().unwrap())
707}
708
709pub fn add_sort(egraph: &mut EGraph, name: &str) -> Result<Vec<CommandOutput>, Error> {
711 egraph.run_program(vec![Command::Sort {
712 span: span!(),
713 name: name.to_owned(),
714 presort_and_args: None,
715 uf: None,
716 proof_func: None,
717 unionable: true,
718 }])
719}
720
721pub fn add_function(
723 egraph: &mut EGraph,
724 name: &str,
725 schema: Schema,
726 merge: Option<GenericExpr<String, String>>,
727) -> Result<Vec<CommandOutput>, Error> {
728 egraph.run_program(vec![Command::Function {
729 span: span!(),
730 name: name.to_owned(),
731 schema,
732 merge,
733 hidden: false,
734 let_binding: false,
735 term_constructor: None,
736 unextractable: false,
737 }])
738}
739
740pub fn add_constructor(
742 egraph: &mut EGraph,
743 name: &str,
744 schema: Schema,
745 cost: Option<DefaultCost>,
746 unextractable: bool,
747) -> Result<Vec<CommandOutput>, Error> {
748 egraph.run_program(vec![Command::Constructor {
749 span: span!(),
750 name: name.to_owned(),
751 schema,
752 cost,
753 unextractable,
754 hidden: false,
755 let_binding: false,
756 term_constructor: None,
757 }])
758}
759
760pub fn add_relation(
762 egraph: &mut EGraph,
763 name: &str,
764 inputs: Vec<String>,
765) -> Result<Vec<CommandOutput>, Error> {
766 egraph.run_program(vec![Command::Relation {
767 span: span!(),
768 name: name.to_owned(),
769 inputs,
770 }])
771}
772
773#[macro_export]
775macro_rules! datatype {
776 ($egraph:expr, (datatype $sort:ident $(($name:ident $($args:ident)* $(:cost $cost:expr)?))*)) => {
777 add_sort($egraph, stringify!($sort))?;
778 $(add_constructor(
779 $egraph,
780 stringify!($name),
781 Schema {
782 input: vec![$(stringify!($args).to_owned()),*],
783 output: stringify!($sort).to_owned(),
784 },
785 [$($cost)*].first().copied(),
786 false,
787 )?;)*
788 };
789}
790
791pub trait BaseSort: Any + Send + Sync + Debug {
798 type Base: BaseValue;
799 fn name(&self) -> &str;
800 fn register_primitives(&self, _eg: &mut EGraph) {}
801 fn reconstruct_termdag(&self, _: &BaseValues, _: Value, _: &mut TermDag) -> TermId;
802
803 fn to_arcsort(self) -> ArcSort
804 where
805 Self: Sized,
806 {
807 Arc::new(BaseSortImpl(self))
808 }
809}
810
811#[derive(Debug)]
812struct BaseSortImpl<T: BaseSort>(T);
813
814impl<T: BaseSort> Sort for BaseSortImpl<T> {
815 fn name(&self) -> &str {
816 self.0.name()
817 }
818
819 fn column_ty(&self, backend: &egglog_bridge::EGraph) -> ColumnTy {
820 ColumnTy::Base(backend.base_values().get_ty::<T::Base>())
821 }
822
823 fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
824 backend.base_values_mut().register_type::<T::Base>();
825 }
826
827 fn value_type(&self) -> Option<TypeId> {
828 Some(TypeId::of::<T::Base>())
829 }
830
831 fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
832 self
833 }
834
835 fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
836 self.0.register_primitives(eg)
837 }
838
839 fn reconstruct_termdag_base(
841 &self,
842 base_values: &BaseValues,
843 value: Value,
844 termdag: &mut TermDag,
845 ) -> TermId {
846 self.0.reconstruct_termdag(base_values, value, termdag)
847 }
848}
849
850pub trait ContainerSort: Any + Send + Sync + Debug {
857 type Container: ContainerValue;
858 fn name(&self) -> &str;
859 fn is_eq_container_sort(&self) -> bool;
860 fn inner_sorts(&self) -> Vec<ArcSort>;
861 fn inner_values(&self, _: &ContainerValues, _: Value) -> Vec<(ArcSort, Value)>;
862 fn register_primitives(&self, _eg: &mut EGraph) {}
863 fn reconstruct_termdag(
864 &self,
865 _: &ContainerValues,
866 _: Value,
867 _: &mut TermDag,
868 _: Vec<TermId>,
869 ) -> TermId;
870 fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String;
871
872 fn to_arcsort(self) -> ArcSort
873 where
874 Self: Sized,
875 {
876 Arc::new(ContainerSortImpl(self))
877 }
878}
879
880#[derive(Debug)]
881struct ContainerSortImpl<T: ContainerSort>(T);
882
883impl<T: ContainerSort> Sort for ContainerSortImpl<T> {
884 fn name(&self) -> &str {
885 self.0.name()
886 }
887
888 fn column_ty(&self, _backend: &egglog_bridge::EGraph) -> ColumnTy {
889 ColumnTy::Id
890 }
891
892 fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
893 backend.register_container_ty::<T::Container>();
894 }
895
896 fn value_type(&self) -> Option<TypeId> {
897 Some(TypeId::of::<T::Container>())
898 }
899
900 fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
901 self
902 }
903
904 fn inner_sorts(&self) -> Vec<ArcSort> {
905 self.0.inner_sorts()
906 }
907
908 fn inner_values(
909 &self,
910 container_values: &ContainerValues,
911 value: Value,
912 ) -> Vec<(ArcSort, Value)> {
913 self.0.inner_values(container_values, value)
914 }
915
916 fn is_container_sort(&self) -> bool {
917 true
918 }
919
920 fn is_eq_container_sort(&self) -> bool {
921 self.0.is_eq_container_sort()
922 }
923
924 fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String {
925 self.0.serialized_name(container_values, value)
926 }
927
928 fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
929 self.0.register_primitives(eg);
930 }
931
932 fn reconstruct_termdag_container(
933 &self,
934 container_values: &ContainerValues,
935 value: Value,
936 termdag: &mut TermDag,
937 element_terms: Vec<TermId>,
938 ) -> TermId {
939 self.0
940 .reconstruct_termdag(container_values, value, termdag, element_terms)
941 }
942}
943
944pub fn add_base_sort(
946 egraph: &mut EGraph,
947 base_sort: impl BaseSort,
948 span: Span,
949) -> Result<(), TypeError> {
950 egraph.add_sort(BaseSortImpl(base_sort), span)
951}
952
953pub fn add_container_sort(
954 egraph: &mut EGraph,
955 container_sort: impl ContainerSort,
956 span: Span,
957) -> Result<(), TypeError> {
958 egraph.add_sort(ContainerSortImpl(container_sort), span)
959}
960
961#[cfg(test)]
962mod tests {
963 use super::*;
964
965 fn build_test_database() -> Result<EGraph, Error> {
966 let mut egraph = EGraph::default();
967 egraph.parse_and_run_program(
968 None,
969 "
970(function fib (i64) i64 :no-merge)
971(set (fib 0) 0)
972(set (fib 1) 1)
973(rule (
974 (= f0 (fib x))
975 (= f1 (fib (+ x 1)))
976) (
977 (set (fib (+ x 2)) (+ f0 f1))
978))
979(run 10)
980 ",
981 )?;
982 Ok(egraph)
983 }
984
985 #[test]
986 fn rust_api_query() -> Result<(), Error> {
987 let mut egraph = build_test_database()?;
988
989 let results = query(
990 &mut egraph,
991 vars![x: i64, y: i64],
992 facts![
993 (= (fib x) y)
994 (= y 13)
995 ],
996 )?;
997
998 let x = egraph.backend.base_values().get::<i64>(7);
999 let y = egraph.backend.base_values().get::<i64>(13);
1000 assert_eq!(results.data, [x, y]);
1001
1002 Ok(())
1003 }
1004
1005 #[test]
1006 fn rust_api_rule() -> Result<(), Error> {
1007 let mut egraph = build_test_database()?;
1008
1009 let big_number = 20;
1010
1011 let results = query(
1013 &mut egraph,
1014 vars![f: i64],
1015 facts![(= (fib (unquote exprs::int(big_number))) f)],
1016 )?;
1017
1018 assert!(results.data.is_empty());
1019
1020 let ruleset = "custom_ruleset";
1021 add_ruleset(&mut egraph, ruleset)?;
1022
1023 rule(
1025 &mut egraph,
1026 ruleset,
1027 facts![
1028 (= f0 (fib x))
1029 (= f1 (fib (+ x 1)))
1030 ],
1031 actions![
1032 (set (fib (+ x 2)) (+ f0 f1))
1033 ],
1034 )?;
1035
1036 for _ in 0..10 {
1038 run_ruleset(&mut egraph, ruleset)?;
1039 }
1040
1041 let results = query(
1043 &mut egraph,
1044 vars![f: i64],
1045 facts![(= (fib (unquote exprs::int(big_number))) f)],
1046 )?;
1047
1048 let y = egraph.backend.base_values().get::<i64>(6765);
1049 assert_eq!(results.data, [y]);
1050
1051 Ok(())
1052 }
1053
1054 #[test]
1055 fn rust_api_macros() -> Result<(), Error> {
1056 let mut egraph = build_test_database()?;
1057
1058 datatype!(&mut egraph, (datatype Expr (One) (Two Expr Expr :cost 10)));
1059
1060 let ruleset = "custom_ruleset";
1061 add_ruleset(&mut egraph, ruleset)?;
1062
1063 rule(
1064 &mut egraph,
1065 ruleset,
1066 facts![
1067 (fib 5)
1068 (fib x)
1069 (= f1 (fib (+ x 1)))
1070 (= 3 (unquote exprs::int(1 + 2)))
1071 ],
1072 actions![
1073 (let y (+ x 2))
1074 (set (fib (+ x 2)) (+ f1 f1))
1075 (delete (fib 0))
1076 (subsume (Two (One) (One)))
1077 (union (One) (Two (One) (One)))
1078 (panic "message")
1079 (+ 6 87)
1080 ],
1081 )?;
1082
1083 Ok(())
1084 }
1085
1086 #[test]
1087 fn rust_api_rust_rule() -> Result<(), Error> {
1088 let mut egraph = build_test_database()?;
1089
1090 let big_number = 20;
1091
1092 let results = query(
1094 &mut egraph,
1095 vars![f: i64],
1096 facts![(= (fib (unquote exprs::int(big_number))) f)],
1097 )?;
1098
1099 assert!(results.data.is_empty());
1100
1101 let ruleset = "custom_ruleset";
1102 add_ruleset(&mut egraph, ruleset)?;
1103
1104 rust_rule(
1106 &mut egraph,
1107 "demo_rule",
1108 ruleset,
1109 vars![x: i64, f0: i64, f1: i64],
1110 facts![
1111 (= f0 (fib x))
1112 (= f1 (fib (+ x 1)))
1113 ],
1114 move |mut ctx, values| {
1115 let [x, f0, f1] = values else { unreachable!() };
1116 let x = ctx.value_to_base::<i64>(*x);
1117 let f0 = ctx.value_to_base::<i64>(*f0);
1118 let f1 = ctx.value_to_base::<i64>(*f1);
1119
1120 let y = ctx.base_to_value::<i64>(x + 2);
1121 let f2 = ctx.base_to_value::<i64>(f0 + f1);
1122 ctx.insert("fib", [y, f2].into_iter());
1123
1124 Some(())
1125 },
1126 )?;
1127
1128 for _ in 0..10 {
1130 run_ruleset(&mut egraph, ruleset)?;
1131 }
1132
1133 let results = query(
1135 &mut egraph,
1136 vars![f: i64],
1137 facts![(= (fib (unquote exprs::int(big_number))) f)],
1138 )?;
1139
1140 let y = egraph.backend.base_values().get::<i64>(6765);
1141 assert_eq!(results.data, [y]);
1142
1143 Ok(())
1144 }
1145}