1use crate::*;
10use std::any::{Any, TypeId};
11
12pub use egglog::ast::{Action, Fact, Facts, GenericActions, RustSpan, Span};
14pub use egglog::sort::{BigIntSort, BigRatSort, BoolSort, F64Sort, I64Sort, StringSort, UnitSort};
15pub use egglog::{EGraph, span};
16pub use egglog::{action, actions, datatype, expr, fact, facts, sort, vars};
17
18pub mod exprs {
19 use super::*;
20
21 pub fn var(name: &str) -> Expr {
23 Expr::Var(span!(), name.to_owned())
24 }
25
26 pub fn int(value: i64) -> Expr {
28 Expr::Lit(span!(), Literal::Int(value))
29 }
30
31 pub fn float(value: f64) -> Expr {
33 Expr::Lit(span!(), Literal::Float(value.into()))
34 }
35
36 pub fn string(value: &str) -> Expr {
38 Expr::Lit(span!(), Literal::String(value.to_owned()))
39 }
40
41 pub fn unit() -> Expr {
43 Expr::Lit(span!(), Literal::Unit)
44 }
45
46 pub fn bool(value: bool) -> Expr {
48 Expr::Lit(span!(), Literal::Bool(value))
49 }
50
51 pub fn call(f: &str, xs: Vec<Expr>) -> Expr {
53 Expr::Call(span!(), f.to_owned(), xs)
54 }
55}
56
57pub fn add_ruleset(egraph: &mut EGraph, ruleset: &str) -> Result<Vec<CommandOutput>, Error> {
59 egraph.run_program(vec![Command::AddRuleset(span!(), ruleset.to_owned())])
60}
61
62pub fn run_ruleset(egraph: &mut EGraph, ruleset: &str) -> Result<Vec<CommandOutput>, Error> {
64 egraph.run_program(vec![Command::RunSchedule(Schedule::Run(
65 span!(),
66 RunConfig {
67 ruleset: ruleset.to_owned(),
68 until: None,
69 },
70 ))])
71}
72
73#[macro_export]
74macro_rules! sort {
75 (BigInt) => {
76 BigIntSort.to_arcsort()
77 };
78 (BigRat) => {
79 BigRatSort.to_arcsort()
80 };
81 (bool) => {
82 BoolSort.to_arcsort()
83 };
84 (f64) => {
85 F64Sort.to_arcsort()
86 };
87 (i64) => {
88 I64Sort.to_arcsort()
89 };
90 (String) => {
91 StringSort.to_arcsort()
92 };
93 (Unit) => {
94 UnitSort.to_arcsort()
95 };
96 ($t:expr) => {
97 $t
98 };
99}
100
101#[macro_export]
102macro_rules! vars {
103 [$($x:ident : $t:tt),* $(,)?] => {
104 &[$((stringify!($x), sort!($t))),*]
105 };
106}
107
108#[macro_export]
109macro_rules! expr {
110 ((unquote $unquoted:expr)) => { $unquoted };
111 (($func:tt $($arg:tt)*)) => { exprs::call(stringify!($func), vec![$(expr!($arg)),*]) };
112 ($value:literal) => { exprs::int($value) };
113 ($quoted:tt) => { exprs::var(stringify!($quoted)) };
114}
115
116#[macro_export]
117macro_rules! fact {
118 ((= $($arg:tt)*)) => { Fact::Eq(span!(), $(expr!($arg)),*) };
119 ($a:tt) => { Fact::Fact(expr!($a)) };
120}
121
122#[macro_export]
123macro_rules! facts {
124 ($($tree:tt)*) => { Facts(vec![$(fact!($tree)),*]) };
125}
126
127#[macro_export]
128macro_rules! action {
129 ((let $name:ident $value:tt)) => {
130 Action::Let(span!(), String::from(stringify!($name)), expr!($value))
131 };
132 ((set ($f:ident $($x:tt)*) $value:tt)) => {
133 Action::Set(span!(), String::from(stringify!($f)), vec![$(expr!($x)),*], expr!($value))
134 };
135 ((delete ($f:ident $($x:tt)*))) => {
136 Action::Change(span!(), Change::Delete, String::from(stringify!($f)), vec![$(expr!($x)),*])
137 };
138 ((subsume ($f:ident $($x:tt)*))) => {
139 Action::Change(span!(), Change::Subsume, String::from(stringify!($f)), vec![$(expr!($x)),*])
140 };
141 ((union $x:tt $y:tt)) => {
142 Action::Union(span!(), expr!($x), expr!($y))
143 };
144 ((panic $message:literal)) => {
145 Action::Panic(span!(), $message.to_owned())
146 };
147 ($x:tt) => {
148 Action::Expr(span!(), expr!($x))
149 };
150}
151
152#[macro_export]
153macro_rules! actions {
154 ($($tree:tt)*) => { GenericActions(vec![$(action!($tree)),*]) };
155}
156
157pub fn rule(
224 egraph: &mut EGraph,
225 ruleset: &str,
226 facts: Facts<String, String>,
227 actions: Actions,
228) -> Result<Vec<CommandOutput>, Error> {
229 let mut rule = Rule {
230 span: span!(),
231 head: actions,
232 body: facts.0,
233 name: "".into(),
234 ruleset: ruleset.into(),
235 };
236
237 rule.name = format!("{rule:?}");
238
239 egraph.run_program(vec![Command::Rule { rule }])
240}
241
242pub struct RustRuleContext<'a, 'b> {
245 exec_state: &'a mut ExecutionState<'b>,
246 union_action: egglog_bridge::UnionAction,
247 table_actions: HashMap<String, egglog_bridge::TableAction>,
248 panic_id: ExternalFunctionId,
249}
250
251impl RustRuleContext<'_, '_> {
252 pub fn value_to_base<T: BaseValue>(&self, x: Value) -> T {
254 self.exec_state.base_values().unwrap::<T>(x)
255 }
256
257 pub fn value_to_container<T: ContainerValue>(
261 &mut self,
262 x: Value,
263 ) -> Option<impl Deref<Target = T>> {
264 self.exec_state.container_values().get_val::<T>(x)
265 }
266
267 pub fn base_to_value<T: BaseValue>(&self, x: T) -> Value {
269 self.exec_state.base_values().get::<T>(x)
270 }
271
272 pub fn container_to_value<T: ContainerValue>(&mut self, x: T) -> Value {
274 self.exec_state
275 .container_values()
276 .register_val::<T>(x, self.exec_state)
277 }
278
279 fn get_table_action(&self, table: &str) -> egglog_bridge::TableAction {
280 self.table_actions[table].clone()
281 }
282
283 pub fn lookup(&mut self, table: &str, key: &[Value]) -> Option<Value> {
286 self.get_table_action(table).lookup(self.exec_state, key)
287 }
288
289 pub fn union(&mut self, x: Value, y: Value) {
292 self.union_action.union(self.exec_state, x, y)
293 }
294
295 pub fn insert(&mut self, table: &str, row: impl Iterator<Item = Value>) {
298 self.get_table_action(table).insert(self.exec_state, row)
299 }
300
301 pub fn remove(&mut self, table: &str, key: &[Value]) {
304 self.get_table_action(table).remove(self.exec_state, key)
305 }
306
307 pub fn subsume(&mut self, table: &str, key: &[Value]) {
310 self.get_table_action(table)
311 .subsume(self.exec_state, key.iter().copied())
312 }
313
314 pub fn panic(&mut self) -> Option<()> {
319 self.exec_state.call_external_func(self.panic_id, &[]);
320 None
321 }
322}
323
324#[derive(Clone)]
325struct RustRuleRhs<F: Fn(&mut RustRuleContext, &[Value]) -> Option<()>> {
326 name: String,
327 inputs: Vec<ArcSort>,
328 union_action: egglog_bridge::UnionAction,
329 table_actions: HashMap<String, egglog_bridge::TableAction>,
330 panic_id: ExternalFunctionId,
331 func: F,
332}
333
334impl<F: Fn(&mut RustRuleContext, &[Value]) -> Option<()>> Primitive for RustRuleRhs<F> {
335 fn name(&self) -> &str {
336 &self.name
337 }
338
339 fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
340 let sorts: Vec<_> = self
341 .inputs
342 .iter()
343 .chain(once(&UnitSort.to_arcsort()))
344 .cloned()
345 .collect();
346 SimpleTypeConstraint::new(self.name(), sorts, span.clone()).into_box()
347 }
348
349 fn apply(&self, exec_state: &mut ExecutionState, values: &[Value]) -> Option<Value> {
350 let mut context = RustRuleContext {
351 exec_state,
352 union_action: self.union_action,
353 table_actions: self.table_actions.clone(),
354 panic_id: self.panic_id,
355 };
356 (self.func)(&mut context, values)?;
357 Some(exec_state.base_values().get(()))
358 }
359}
360
361pub fn rust_rule(
439 egraph: &mut EGraph,
440 rule_name: &str,
441 ruleset: &str,
442 vars: &[(&str, ArcSort)],
443 facts: Facts<String, String>,
444 func: impl Fn(&mut RustRuleContext, &[Value]) -> Option<()> + Clone + Send + Sync + 'static,
445) -> Result<Vec<CommandOutput>, Error> {
446 let prim_name = egraph.parser.symbol_gen.fresh("rust_rule_prim");
447 let panic_id = egraph.backend.new_panic(format!("{prim_name}_panic"));
448 egraph.add_primitive(RustRuleRhs {
449 name: prim_name.clone(),
450 inputs: vars.iter().map(|(_, s)| s.clone()).collect(),
451 union_action: egglog_bridge::UnionAction::new(&egraph.backend),
452 table_actions: egraph
453 .functions
454 .iter()
455 .map(|(k, v)| {
456 (
457 k.clone(),
458 egglog_bridge::TableAction::new(&egraph.backend, v.backend_id),
459 )
460 })
461 .collect(),
462 panic_id,
463 func,
464 });
465
466 let rule = Rule {
467 span: span!(),
468 head: GenericActions(vec![GenericAction::Expr(
469 span!(),
470 exprs::call(
471 &prim_name,
472 vars.iter().map(|(v, _)| exprs::var(v)).collect(),
473 ),
474 )]),
475 body: facts.0,
476 name: egraph.parser.symbol_gen.fresh(rule_name),
477 ruleset: ruleset.into(),
478 };
479
480 egraph.run_program(vec![Command::Rule { rule }])
481}
482
483pub struct QueryResult {
485 rows: usize,
486 cols: usize,
487 data: Vec<Value>,
488}
489
490impl QueryResult {
491 pub fn iter(&self) -> impl Iterator<Item = &[Value]> {
495 assert!(self.cols > 0, "no vars; use `any_matches` instead");
496 assert!(self.data.len() % self.cols == 0);
497 self.data.chunks_exact(self.cols)
498 }
499
500 pub fn any_matches(&self) -> bool {
502 self.rows > 0
503 }
504}
505
506pub fn query(
544 egraph: &mut EGraph,
545 vars: &[(&str, ArcSort)],
546 facts: Facts<String, String>,
547) -> Result<QueryResult, Error> {
548 use std::sync::{Arc, Mutex};
549
550 let results = Arc::new(Mutex::new(QueryResult {
551 rows: 0,
552 cols: vars.len(),
553 data: Vec::new(),
554 }));
555 let results_weak = Arc::downgrade(&results);
556
557 let ruleset = egraph.parser.symbol_gen.fresh("query_ruleset");
558 add_ruleset(egraph, &ruleset)?;
559
560 rust_rule(egraph, "query", &ruleset, vars, facts, move |_, values| {
561 let arc = results_weak.upgrade().unwrap();
562 let mut results = arc.lock().unwrap();
563 results.rows += 1;
564 results.data.extend(values);
565 Some(())
566 })?;
567
568 run_ruleset(egraph, &ruleset)?;
569
570 let ruleset = egraph.rulesets.swap_remove(&ruleset).unwrap();
571
572 let Ruleset::Rules(rules) = ruleset else {
573 unreachable!()
574 };
575 assert_eq!(rules.len(), 1);
576 let rule = rules.into_iter().next().unwrap().1;
577 egraph.backend.free_rule(rule.1);
578
579 let Some(mutex) = Arc::into_inner(results) else {
580 panic!("results_weak.upgrade() was not dropped");
581 };
582 Ok(mutex.into_inner().unwrap())
583}
584
585pub fn add_sort(egraph: &mut EGraph, name: &str) -> Result<Vec<CommandOutput>, Error> {
587 egraph.run_program(vec![Command::Sort(span!(), name.to_owned(), None)])
588}
589
590pub fn add_function(
592 egraph: &mut EGraph,
593 name: &str,
594 schema: Schema,
595 merge: Option<GenericExpr<String, String>>,
596) -> Result<Vec<CommandOutput>, Error> {
597 egraph.run_program(vec![Command::Function {
598 span: span!(),
599 name: name.to_owned(),
600 schema,
601 merge,
602 }])
603}
604
605pub fn add_constructor(
607 egraph: &mut EGraph,
608 name: &str,
609 schema: Schema,
610 cost: Option<DefaultCost>,
611 unextractable: bool,
612) -> Result<Vec<CommandOutput>, Error> {
613 egraph.run_program(vec![Command::Constructor {
614 span: span!(),
615 name: name.to_owned(),
616 schema,
617 cost,
618 unextractable,
619 }])
620}
621
622pub fn add_relation(
624 egraph: &mut EGraph,
625 name: &str,
626 inputs: Vec<String>,
627) -> Result<Vec<CommandOutput>, Error> {
628 egraph.run_program(vec![Command::Relation {
629 span: span!(),
630 name: name.to_owned(),
631 inputs,
632 }])
633}
634
635#[macro_export]
637macro_rules! datatype {
638 ($egraph:expr, (datatype $sort:ident $(($name:ident $($args:ident)* $(:cost $cost:expr)?))*)) => {
639 add_sort($egraph, stringify!($sort))?;
640 $(add_constructor(
641 $egraph,
642 stringify!($name),
643 Schema {
644 input: vec![$(stringify!($args).to_owned()),*],
645 output: stringify!($sort).to_owned(),
646 },
647 [$($cost)*].first().copied(),
648 false,
649 )?;)*
650 };
651}
652
653pub trait BaseSort: Any + Send + Sync + Debug {
660 type Base: BaseValue;
661 fn name(&self) -> &str;
662 fn register_primitives(&self, _eg: &mut EGraph) {}
663 fn reconstruct_termdag(&self, _: &BaseValues, _: Value, _: &mut TermDag) -> Term;
664
665 fn to_arcsort(self) -> ArcSort
666 where
667 Self: Sized,
668 {
669 Arc::new(BaseSortImpl(self))
670 }
671}
672
673#[derive(Debug)]
674struct BaseSortImpl<T: BaseSort>(T);
675
676impl<T: BaseSort> Sort for BaseSortImpl<T> {
677 fn name(&self) -> &str {
678 self.0.name()
679 }
680
681 fn column_ty(&self, backend: &egglog_bridge::EGraph) -> ColumnTy {
682 ColumnTy::Base(backend.base_values().get_ty::<T::Base>())
683 }
684
685 fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
686 backend.base_values_mut().register_type::<T::Base>();
687 }
688
689 fn value_type(&self) -> Option<TypeId> {
690 Some(TypeId::of::<T::Base>())
691 }
692
693 fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
694 self
695 }
696
697 fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
698 self.0.register_primitives(eg)
699 }
700
701 fn reconstruct_termdag_base(
703 &self,
704 base_values: &BaseValues,
705 value: Value,
706 termdag: &mut TermDag,
707 ) -> Term {
708 self.0.reconstruct_termdag(base_values, value, termdag)
709 }
710}
711
712pub trait ContainerSort: Any + Send + Sync + Debug {
719 type Container: ContainerValue;
720 fn name(&self) -> &str;
721 fn is_eq_container_sort(&self) -> bool;
722 fn inner_sorts(&self) -> Vec<ArcSort>;
723 fn inner_values(&self, _: &ContainerValues, _: Value) -> Vec<(ArcSort, Value)>;
724 fn register_primitives(&self, _eg: &mut EGraph) {}
725 fn reconstruct_termdag(
726 &self,
727 _: &ContainerValues,
728 _: Value,
729 _: &mut TermDag,
730 _: Vec<Term>,
731 ) -> Term;
732 fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String;
733
734 fn to_arcsort(self) -> ArcSort
735 where
736 Self: Sized,
737 {
738 Arc::new(ContainerSortImpl(self))
739 }
740}
741
742#[derive(Debug)]
743struct ContainerSortImpl<T: ContainerSort>(T);
744
745impl<T: ContainerSort> Sort for ContainerSortImpl<T> {
746 fn name(&self) -> &str {
747 self.0.name()
748 }
749
750 fn column_ty(&self, _backend: &egglog_bridge::EGraph) -> ColumnTy {
751 ColumnTy::Id
752 }
753
754 fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
755 backend.register_container_ty::<T::Container>();
756 }
757
758 fn value_type(&self) -> Option<TypeId> {
759 Some(TypeId::of::<T::Container>())
760 }
761
762 fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
763 self
764 }
765
766 fn inner_sorts(&self) -> Vec<ArcSort> {
767 self.0.inner_sorts()
768 }
769
770 fn inner_values(
771 &self,
772 container_values: &ContainerValues,
773 value: Value,
774 ) -> Vec<(ArcSort, Value)> {
775 self.0.inner_values(container_values, value)
776 }
777
778 fn is_container_sort(&self) -> bool {
779 true
780 }
781
782 fn is_eq_container_sort(&self) -> bool {
783 self.0.is_eq_container_sort()
784 }
785
786 fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String {
787 self.0.serialized_name(container_values, value)
788 }
789
790 fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
791 self.0.register_primitives(eg);
792 }
793
794 fn reconstruct_termdag_container(
795 &self,
796 container_values: &ContainerValues,
797 value: Value,
798 termdag: &mut TermDag,
799 element_terms: Vec<Term>,
800 ) -> Term {
801 self.0
802 .reconstruct_termdag(container_values, value, termdag, element_terms)
803 }
804}
805
806pub fn add_base_sort(
808 egraph: &mut EGraph,
809 base_sort: impl BaseSort,
810 span: Span,
811) -> Result<(), TypeError> {
812 egraph.add_sort(BaseSortImpl(base_sort), span)
813}
814
815pub fn add_container_sort(
816 egraph: &mut EGraph,
817 container_sort: impl ContainerSort,
818 span: Span,
819) -> Result<(), TypeError> {
820 egraph.add_sort(ContainerSortImpl(container_sort), span)
821}
822
823#[cfg(test)]
824mod tests {
825 use super::*;
826
827 fn build_test_database() -> Result<EGraph, Error> {
828 let mut egraph = EGraph::default();
829 egraph.parse_and_run_program(
830 None,
831 "
832(function fib (i64) i64 :no-merge)
833(set (fib 0) 0)
834(set (fib 1) 1)
835(rule (
836 (= f0 (fib x))
837 (= f1 (fib (+ x 1)))
838) (
839 (set (fib (+ x 2)) (+ f0 f1))
840))
841(run 10)
842 ",
843 )?;
844 Ok(egraph)
845 }
846
847 #[test]
848 fn rust_api_query() -> Result<(), Error> {
849 let mut egraph = build_test_database()?;
850
851 let results = query(
852 &mut egraph,
853 vars![x: i64, y: i64],
854 facts![
855 (= (fib x) y)
856 (= y 13)
857 ],
858 )?;
859
860 let x = egraph.backend.base_values().get::<i64>(7);
861 let y = egraph.backend.base_values().get::<i64>(13);
862 assert_eq!(results.data, [x, y]);
863
864 Ok(())
865 }
866
867 #[test]
868 fn rust_api_rule() -> Result<(), Error> {
869 let mut egraph = build_test_database()?;
870
871 let big_number = 20;
872
873 let results = query(
875 &mut egraph,
876 vars![f: i64],
877 facts![(= (fib (unquote exprs::int(big_number))) f)],
878 )?;
879
880 assert!(results.data.is_empty());
881
882 let ruleset = "custom_ruleset";
883 add_ruleset(&mut egraph, ruleset)?;
884
885 rule(
887 &mut egraph,
888 ruleset,
889 facts![
890 (= f0 (fib x))
891 (= f1 (fib (+ x 1)))
892 ],
893 actions![
894 (set (fib (+ x 2)) (+ f0 f1))
895 ],
896 )?;
897
898 for _ in 0..10 {
900 run_ruleset(&mut egraph, ruleset)?;
901 }
902
903 let results = query(
905 &mut egraph,
906 vars![f: i64],
907 facts![(= (fib (unquote exprs::int(big_number))) f)],
908 )?;
909
910 let y = egraph.backend.base_values().get::<i64>(6765);
911 assert_eq!(results.data, [y]);
912
913 Ok(())
914 }
915
916 #[test]
917 fn rust_api_macros() -> Result<(), Error> {
918 let mut egraph = build_test_database()?;
919
920 datatype!(&mut egraph, (datatype Expr (One) (Two Expr Expr :cost 10)));
921
922 let ruleset = "custom_ruleset";
923 add_ruleset(&mut egraph, ruleset)?;
924
925 rule(
926 &mut egraph,
927 ruleset,
928 facts![
929 (fib 5)
930 (fib x)
931 (= f1 (fib (+ x 1)))
932 (= 3 (unquote exprs::int(1 + 2)))
933 ],
934 actions![
935 (let y (+ x 2))
936 (set (fib (+ x 2)) (+ f1 f1))
937 (delete (fib 0))
938 (subsume (Two (One) (One)))
939 (union (One) (Two (One) (One)))
940 (panic "message")
941 (+ 6 87)
942 ],
943 )?;
944
945 Ok(())
946 }
947
948 #[test]
949 fn rust_api_rust_rule() -> Result<(), Error> {
950 let mut egraph = build_test_database()?;
951
952 let big_number = 20;
953
954 let results = query(
956 &mut egraph,
957 vars![f: i64],
958 facts![(= (fib (unquote exprs::int(big_number))) f)],
959 )?;
960
961 assert!(results.data.is_empty());
962
963 let ruleset = "custom_ruleset";
964 add_ruleset(&mut egraph, ruleset)?;
965
966 rust_rule(
968 &mut egraph,
969 "demo_rule",
970 ruleset,
971 vars![x: i64, f0: i64, f1: i64],
972 facts![
973 (= f0 (fib x))
974 (= f1 (fib (+ x 1)))
975 ],
976 move |ctx, values| {
977 let [x, f0, f1] = values else { unreachable!() };
978 let x = ctx.value_to_base::<i64>(*x);
979 let f0 = ctx.value_to_base::<i64>(*f0);
980 let f1 = ctx.value_to_base::<i64>(*f1);
981
982 let y = ctx.base_to_value::<i64>(x + 2);
983 let f2 = ctx.base_to_value::<i64>(f0 + f1);
984 ctx.insert("fib", [y, f2].into_iter());
985
986 Some(())
987 },
988 )?;
989
990 for _ in 0..10 {
992 run_ruleset(&mut egraph, ruleset)?;
993 }
994
995 let results = query(
997 &mut egraph,
998 vars![f: i64],
999 facts![(= (fib (unquote exprs::int(big_number))) f)],
1000 )?;
1001
1002 let y = egraph.backend.base_values().get::<i64>(6765);
1003 assert_eq!(results.data, [y]);
1004
1005 Ok(())
1006 }
1007}