1use std::{
12 fmt::Debug,
13 hash::Hash,
14 iter, mem,
15 ops::{Index, IndexMut},
16 sync::{Arc, Mutex},
17};
18
19use crate::core_relations::{
20 BaseValue, BaseValueId, BaseValues, ColumnId, Constraint, ContainerValue, ContainerValues,
21 CounterId, Database, DisplacedTable, ExecutionState, ExternalFunction, ExternalFunctionId,
22 MergeVal, Offset, PlanStrategy, SortedWritesTable, TableId, TaggedRowBuffer, Value,
23 WrappedTable,
24};
25use crate::numeric_id::{DenseIdMap, DenseIdMapWithReuse, NumericId, define_id};
26use egglog_core_relations as core_relations;
27use egglog_numeric_id as numeric_id;
28use egglog_reports::{IterationReport, ReportLevel, RuleSetReport};
29use hashbrown::HashMap;
30use indexmap::IndexSet;
31use log::info;
32use once_cell::sync::Lazy;
33use smallvec::SmallVec;
34use web_time::{Duration, Instant};
35
36pub mod macros;
37pub(crate) mod rule;
38#[cfg(test)]
39mod tests;
40
41pub use rule::{Function, QueryEntry, RuleBuilder};
42use thiserror::Error;
43
44#[derive(Clone)]
53pub struct ActionRegistry {
54 table_actions: hashbrown::HashMap<String, TableAction>,
55 union_action: UnionAction,
56 default_panic_id: ExternalFunctionId,
57}
58
59impl ActionRegistry {
60 pub(crate) fn new(union_action: UnionAction, default_panic_id: ExternalFunctionId) -> Self {
61 Self {
62 table_actions: hashbrown::HashMap::new(),
63 union_action,
64 default_panic_id,
65 }
66 }
67
68 pub(crate) fn register_table(&mut self, name: String, action: TableAction) {
69 self.table_actions.insert(name, action);
70 }
71
72 pub fn lookup_table(&self, name: &str) -> Option<&TableAction> {
75 self.table_actions.get(name)
76 }
77
78 pub fn table_sizes(&self, state: &ExecutionState) -> Vec<(&str, usize)> {
80 self.table_actions
81 .iter()
82 .map(|(name, action)| (name.as_str(), action.row_count(state)))
83 .collect()
84 }
85
86 pub fn union_action(&self) -> &UnionAction {
88 &self.union_action
89 }
90
91 pub fn default_panic_id(&self) -> ExternalFunctionId {
94 self.default_panic_id
95 }
96}
97
98#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
99pub enum ColumnTy {
100 Id,
101 Base(BaseValueId),
102}
103
104define_id!(pub RuleId, u32, "An egglog-style rule");
105define_id!(pub FunctionId, u32, "An id representing an egglog function");
106define_id!(pub(crate) Timestamp, u32, "An abstract timestamp used to track execution of egglog rules");
107impl Timestamp {
108 fn to_value(self) -> Value {
109 Value::new(self.rep())
110 }
111}
112
113#[derive(Clone)]
115pub struct EGraph {
116 db: Database,
117 uf_table: TableId,
118 id_counter: CounterId,
119 timestamp_counter: CounterId,
120 rules: DenseIdMapWithReuse<RuleId, RuleInfo>,
121 funcs: DenseIdMap<FunctionId, FunctionInfo>,
122 panic_message: SideChannel<String>,
123 panic_funcs: HashMap<String, ExternalFunctionId>,
129 report_level: ReportLevel,
130 action_registry: Arc<std::sync::RwLock<ActionRegistry>>,
136}
137
138pub type Result<T> = std::result::Result<T, anyhow::Error>;
139
140impl Default for EGraph {
141 fn default() -> Self {
142 let mut db = Database::new();
143 let uf_table = db.add_table_named(
144 DisplacedTable::default(),
145 "$uf".into(),
146 iter::empty(),
147 iter::empty(),
148 );
149 let id_counter = db.add_counter();
150 let ts_counter = db.add_counter();
151 db.inc_counter(ts_counter);
153
154 let panic_message: SideChannel<String> = Default::default();
159 let mut panic_funcs: HashMap<String, ExternalFunctionId> = Default::default();
160 let default_panic_msg = "primitive panicked".to_string();
161 let default_panic_id = db.add_external_function(Box::new(Panic(
162 default_panic_msg.clone(),
163 panic_message.clone(),
164 )));
165 panic_funcs.insert(default_panic_msg, default_panic_id);
166
167 let union_action = UnionAction {
168 table: uf_table,
169 timestamp: ts_counter,
170 };
171 let action_registry = Arc::new(std::sync::RwLock::new(ActionRegistry::new(
172 union_action,
173 default_panic_id,
174 )));
175
176 Self {
177 db,
178 uf_table,
179 id_counter,
180 timestamp_counter: ts_counter,
181 rules: Default::default(),
182 funcs: Default::default(),
183 panic_message,
184 panic_funcs,
185 report_level: Default::default(),
186 action_registry,
187 }
188 }
189}
190
191pub struct FunctionConfig {
193 pub schema: Vec<ColumnTy>,
195 pub default: DefaultVal,
197 pub merge: MergeFn,
199 pub name: String,
201 pub can_subsume: bool,
203}
204
205impl EGraph {
206 fn next_ts(&self) -> Timestamp {
207 Timestamp::from_usize(self.db.read_counter(self.timestamp_counter))
208 }
209
210 fn inc_ts(&mut self) {
211 self.db.inc_counter(self.timestamp_counter);
212 }
213
214 pub fn base_values_mut(&mut self) -> &mut BaseValues {
217 self.db.base_values_mut()
218 }
219
220 pub fn container_values_mut(&mut self) -> &mut ContainerValues {
223 self.db.container_values_mut()
224 }
225
226 pub fn container_values(&self) -> &ContainerValues {
228 self.db.container_values()
229 }
230
231 pub fn get_container_value<C: ContainerValue>(&mut self, val: C) -> Value {
233 self.register_container_ty::<C>();
234 self.db
235 .with_execution_state(|state| state.clone().container_values().register_val(val, state))
236 }
237
238 pub fn register_container_ty<C: ContainerValue>(&mut self) {
243 let uf_table = self.uf_table;
244 let ts_counter = self.timestamp_counter;
245 self.db.container_values_mut().register_type::<C>(
246 self.id_counter,
247 move |state, old, new| {
248 if old != new {
249 let next_ts = Value::from_usize(state.read_counter(ts_counter));
250 state.stage_insert(uf_table, &[old, new, next_ts]);
251 std::cmp::min(old, new)
252 } else {
253 old
254 }
255 },
256 );
257 }
258
259 pub fn base_values(&self) -> &BaseValues {
261 self.db.base_values()
262 }
263
264 pub fn base_value_constant<T>(&self, x: T) -> QueryEntry
266 where
267 T: BaseValue,
268 {
269 QueryEntry::Const {
270 val: self.base_values().get(x),
271 ty: ColumnTy::Base(self.base_values().get_ty::<T>()),
272 }
273 }
274
275 pub fn register_external_func(
287 &mut self,
288 func: Box<dyn ExternalFunction + 'static>,
289 ) -> ExternalFunctionId {
290 self.db.add_external_function(func)
291 }
292
293 pub fn free_external_func(&mut self, func: ExternalFunctionId) {
294 self.db.free_external_function(func)
295 }
296
297 pub fn fresh_id(&mut self) -> Value {
299 Value::from_usize(self.db.inc_counter(self.id_counter))
300 }
301
302 fn get_canon_in_uf(&self, val: Value) -> Value {
306 let table = self.db.get_table(self.uf_table);
307 let row = table.get_row(&[val]);
308 row.map(|row| row.vals[1]).unwrap_or(val)
309 }
310
311 pub fn get_canon_repr(&self, val: Value, ty: ColumnTy) -> Value {
316 match ty {
317 ColumnTy::Id => self.get_canon_in_uf(val),
318 ColumnTy::Base(_) => val,
319 }
320 }
321
322 pub fn add_values(&mut self, values: impl IntoIterator<Item = (FunctionId, Vec<Value>)>) {
331 let mut extended_row = Vec::<Value>::new();
332 let mut bufs = DenseIdMap::default();
333 for (func, row) in values.into_iter() {
334 let table_info = &self.funcs[func];
335 let schema_math = SchemaMath {
336 subsume: table_info.can_subsume,
337 func_cols: table_info.schema.len(),
338 };
339 let table_id = table_info.table;
340 extended_row.extend_from_slice(&row);
341 schema_math.write_table_row(
342 &mut extended_row,
343 RowVals {
344 timestamp: self.next_ts().to_value(),
345 subsume: schema_math.subsume.then_some(NOT_SUBSUMED),
346 ret_val: None, },
348 );
349 let buf = bufs.get_or_insert(table_id, || self.db.new_buffer(table_id));
350 buf.stage_insert(&extended_row);
351 extended_row.clear();
352 }
353 mem::drop(bufs);
355 self.flush_updates();
356 }
357
358 pub fn add_term(&mut self, func: FunctionId, inputs: &[Value]) -> Value {
364 let info = &self.funcs[func];
365 let schema_math = SchemaMath {
366 subsume: info.can_subsume,
367 func_cols: info.schema.len(),
368 };
369 let mut extended_row = Vec::new();
370 extended_row.extend_from_slice(inputs);
371 let res = self.fresh_id();
372 schema_math.write_table_row(
373 &mut extended_row,
374 RowVals {
375 timestamp: self.next_ts().to_value(),
376 ret_val: Some(res),
377 subsume: schema_math.subsume.then_some(NOT_SUBSUMED),
378 },
379 );
380 extended_row[schema_math.ret_val_col()] = res;
381 let table_id = self.funcs[func].table;
382 self.db.new_buffer(table_id).stage_insert(&extended_row);
383 self.flush_updates();
384 self.get_canon_in_uf(res)
385 }
386
387 pub fn lookup_id(&self, func: FunctionId, key: &[Value]) -> Option<Value> {
390 let info = &self.funcs[func];
391 let schema_math = SchemaMath {
392 subsume: info.can_subsume,
393 func_cols: info.schema.len(),
394 };
395 let table_id = info.table;
396 let table = self.db.get_table(table_id);
397 let row = table.get_row(key)?;
398 Some(row.vals[schema_math.ret_val_col()])
399 }
400
401 pub fn approx_table_size(&self, table: FunctionId) -> usize {
402 self.db.estimate_size(self.funcs[table].table, None)
403 }
404
405 pub fn table_size(&self, table: FunctionId) -> usize {
406 self.db.get_table(self.funcs[table].table).len()
407 }
408
409 pub fn clear_table(&mut self, func: FunctionId) {
422 let table_id = self.funcs[func].table;
423 self.db.clear_table(table_id);
424 }
425
426 pub fn for_each(&self, table: FunctionId, mut f: impl FnMut(ScanEntry<'_>)) {
430 self.for_each_while(table, |row| {
431 f(row);
432 true
433 });
434 }
435
436 pub fn for_each_while(&self, table: FunctionId, mut f: impl FnMut(ScanEntry<'_>) -> bool) {
439 let info = &self.funcs[table];
440 let table = self.funcs[table].table;
441 let schema_math = SchemaMath {
442 subsume: info.can_subsume,
443 func_cols: info.schema.len(),
444 };
445 let imp = self.db.get_table(table);
446 let all = imp.all();
447 let mut cur = Offset::new(0);
448 let mut buf = TaggedRowBuffer::new(imp.spec().arity());
449 macro_rules! drain_buf {
454 ($buf:expr) => {
455 for (_, row) in $buf.non_stale() {
456 let subsumed =
457 schema_math.subsume && row[schema_math.subsume_col()] == SUBSUMED;
458 if !f(ScanEntry {
459 vals: &row[0..schema_math.func_cols],
460 subsumed,
461 }) {
462 return;
463 }
464 }
465 $buf.clear();
466 };
467 }
468 while let Some(next) = imp.scan_bounded(all.as_ref(), cur, 32, &mut buf) {
469 drain_buf!(buf);
470 cur = next;
471 }
472 drain_buf!(buf);
473 }
474
475 pub fn dump_debug_info(&self) {
479 info!("=== View Tables ===");
480 for (id, info) in self.funcs.iter() {
481 let table = self.db.get_table(info.table);
482 self.scan_table(table, |row| {
483 info!(
484 "View Table {name} / {id:?} / {table:?}: {row:?}",
485 name = info.name,
486 table = info.table
487 )
488 });
489 }
490 }
491
492 fn scan_table(&self, table: &WrappedTable, mut f: impl FnMut(&[Value])) {
494 const BATCH_SIZE: usize = 128;
495 let all = table.all();
496 let mut cur = Offset::new(0);
497 let mut out = TaggedRowBuffer::new(table.spec().arity());
498 while let Some(next) = table.scan_bounded(all.as_ref(), cur, BATCH_SIZE, &mut out) {
499 out.non_stale().for_each(|(_, row)| f(row));
500 out.clear();
501 cur = next;
502 }
503 out.non_stale().for_each(|(_, row)| f(row));
504 }
505
506 pub fn add_table(&mut self, config: FunctionConfig) -> FunctionId {
508 let FunctionConfig {
509 schema,
510 default,
511 merge,
512 name,
513 can_subsume,
514 } = config;
515 assert!(
516 !schema.is_empty(),
517 "must have at least one column in schema"
518 );
519 let to_rebuild: Vec<ColumnId> = schema
520 .iter()
521 .enumerate()
522 .filter(|(_, ty)| matches!(ty, ColumnTy::Id))
523 .map(|(i, _)| ColumnId::from_usize(i))
524 .collect();
525 let schema_math = SchemaMath {
526 subsume: can_subsume,
527 func_cols: schema.len(),
528 };
529 let n_args = schema_math.num_keys();
530 let n_cols = schema_math.table_columns();
531 let next_func_id = self.funcs.next_id();
532 let mut read_deps = IndexSet::<TableId>::new();
533 let mut write_deps = IndexSet::<TableId>::new();
534 merge.fill_deps(self, &mut read_deps, &mut write_deps);
535 let merge_fn = merge.to_callback(schema_math, &name, self);
536 let table = SortedWritesTable::new(
537 n_args,
538 n_cols,
539 Some(ColumnId::from_usize(schema.len())),
540 to_rebuild,
541 merge_fn,
542 );
543 let name: Arc<str> = name.into();
544 let table_id = self.db.add_table_named(
545 table,
546 name.clone(),
547 read_deps.iter().copied(),
548 write_deps.iter().copied(),
549 );
550
551 let res = self.funcs.push(FunctionInfo {
552 table: table_id,
553 schema: schema.clone(),
554 incremental_rebuild_rules: Default::default(),
555 nonincremental_rebuild_rule: RuleId::new(!0),
556 default_val: default,
557 can_subsume,
558 name,
559 });
560 debug_assert_eq!(res, next_func_id);
561 let incremental_rebuild_rules = self.incremental_rebuild_rules(res, &schema);
562 let nonincremental_rebuild_rule = self.nonincremental_rebuild(res, &schema);
563 let info = &mut self.funcs[res];
564 info.incremental_rebuild_rules = incremental_rebuild_rules;
565 info.nonincremental_rebuild_rule = nonincremental_rebuild_rule;
566 let action = TableAction::new(self, res);
567 let table_name = self.funcs[res].name.to_string();
568 self.action_registry
569 .write()
570 .unwrap()
571 .register_table(table_name, action);
572 res
573 }
574
575 pub fn action_registry(&self) -> &Arc<std::sync::RwLock<ActionRegistry>> {
581 &self.action_registry
582 }
583
584 pub fn run_rules(&mut self, rules: &[RuleId]) -> Result<IterationReport> {
588 self.run_rules_inner(rules)
589 }
590
591 fn run_rules_inner(&mut self, rules: &[RuleId]) -> Result<IterationReport> {
592 let ts = self.next_ts();
593
594 let uf_size_before = self.db.get_table(self.uf_table).len();
595 let rule_set_report =
596 run_rules_impl(&mut self.db, &mut self.rules, rules, ts, self.report_level)?;
597 if let Some(message) = self.panic_message.lock().unwrap().take() {
598 return Err(PanicError(message).into());
599 }
600
601 let mut iteration_report = IterationReport {
602 rule_set_report,
603 rebuild_time: Duration::ZERO,
604 };
605 let uf_size_after = self.db.get_table(self.uf_table).len();
606 if uf_size_before == uf_size_after {
607 self.inc_ts();
612 return Ok(iteration_report);
613 }
614
615 let rebuild_timer = Instant::now();
616 self.rebuild()?;
617 iteration_report.rebuild_time = rebuild_timer.elapsed();
618
619 if let Some(message) = self.panic_message.lock().unwrap().take() {
620 return Err(PanicError(message).into());
621 }
622
623 Ok(iteration_report)
624 }
625
626 fn rebuild(&mut self) -> Result<()> {
627 let do_parallel = rayon::current_num_threads() > 1;
628 if self.db.get_table(self.uf_table).rebuilder(&[]).is_some() {
629 let mut tables = Vec::with_capacity(self.funcs.next_id().index());
631 for (_, func) in self.funcs.iter() {
632 tables.push(func.table);
633 }
634 loop {
635 let container_rebuild = self.db.rebuild_containers(self.uf_table);
667 let next_ts = self.next_ts().to_value();
668 let table_rebuild = self.db.apply_rebuild(self.uf_table, &tables, next_ts);
669 let dirty_ids: Vec<Value> = container_rebuild.dirty_ids().iter().copied().collect();
673 let refreshed_rows = self
674 .db
675 .refresh_rows_for_values(&tables, &dirty_ids, next_ts);
676 self.inc_ts();
677 if !table_rebuild && !refreshed_rows && !container_rebuild.changed() {
678 break;
679 }
680 }
681 return Ok(());
682 }
683 if do_parallel {
684 return self.rebuild_parallel();
685 }
686 let start = Instant::now();
687
688 let mut changed = true;
690 while changed {
691 changed = false;
692 self.inc_ts();
695 let ts = self.next_ts();
696 for (_, info) in self.funcs.iter_mut() {
697 let last_rebuilt_at = self.rules[info.nonincremental_rebuild_rule].last_run_at;
698 let table_size = self.db.estimate_size(info.table, None);
699 let uf_size = self.db.estimate_size(
700 self.uf_table,
701 Some(Constraint::GeConst {
702 col: ColumnId::new(2),
703 val: last_rebuilt_at.to_value(),
704 }),
705 );
706 if incremental_rebuild(uf_size, table_size, false) {
707 marker_incremental_rebuild(|| -> Result<()> {
708 for rule in &info.incremental_rebuild_rules {
713 changed |= run_rules_impl(
714 &mut self.db,
715 &mut self.rules,
716 &[*rule],
717 ts,
718 ReportLevel::TimeOnly,
719 )?
720 .changed;
721 }
722 self.rules[info.nonincremental_rebuild_rule].last_run_at = ts;
724 Ok(())
725 })?;
726 } else {
727 marker_nonincremental_rebuild(|| -> Result<()> {
728 changed |= run_rules_impl(
729 &mut self.db,
730 &mut self.rules,
731 &[info.nonincremental_rebuild_rule],
732 ts,
733 ReportLevel::TimeOnly,
734 )?
735 .changed;
736 for rule in &info.incremental_rebuild_rules {
737 self.rules[*rule].last_run_at = ts;
738 }
739 Ok(())
740 })?;
741 }
742 }
743 }
744 log::info!("rebuild took {:?}", start.elapsed());
745 Ok(())
746 }
747
748 fn rebuild_parallel(&mut self) -> Result<()> {
753 let start = Instant::now();
754 #[derive(Default)]
755 struct RebuildState {
756 nonincremental: Vec<FunctionId>,
757 incremental: DenseIdMap<usize, SmallVec<[FunctionId; 2]>>,
758 }
759
760 impl RebuildState {
761 fn clear(&mut self) {
762 self.nonincremental.clear();
763 self.incremental.iter_mut().for_each(|(_, v)| v.clear());
764 }
765 }
766
767 let mut changed = true;
768 let mut state = RebuildState::default();
769 let mut scratch = Vec::new();
770 while changed {
771 changed = false;
772 state.clear();
773 self.inc_ts();
774 for (func, info) in self.funcs.iter_mut() {
777 let last_rebuilt_at = self.rules[info.nonincremental_rebuild_rule].last_run_at;
778 let table_size = self.db.estimate_size(info.table, None);
779 let uf_size = self.db.estimate_size(
780 self.uf_table,
781 Some(Constraint::GeConst {
782 col: ColumnId::new(2),
783 val: last_rebuilt_at.to_value(),
784 }),
785 );
786 if incremental_rebuild(uf_size, table_size, true) {
787 for (i, _) in info.incremental_rebuild_rules.iter().enumerate() {
788 state.incremental.get_or_default(i).push(func);
789 }
790 } else {
791 state.nonincremental.push(func);
792 }
793 }
794 let ts = self.next_ts();
795 for func in state.nonincremental.iter().copied() {
796 scratch.push(self.funcs[func].nonincremental_rebuild_rule);
797 for rule in &self.funcs[func].incremental_rebuild_rules {
798 self.rules[*rule].last_run_at = ts;
799 }
800 }
801 changed |= run_rules_impl(
802 &mut self.db,
803 &mut self.rules,
804 &scratch,
805 ts,
806 ReportLevel::TimeOnly,
807 )?
808 .changed;
809 scratch.clear();
810 let ts = self.next_ts();
811 for (i, funcs) in state.incremental.iter() {
812 for func in funcs.iter().copied() {
813 let info = &mut self.funcs[func];
814 scratch.push(info.incremental_rebuild_rules[i]);
815 self.rules[info.nonincremental_rebuild_rule].last_run_at = ts;
816 }
817 changed |= run_rules_impl(
818 &mut self.db,
819 &mut self.rules,
820 &scratch,
821 ts,
822 ReportLevel::TimeOnly,
823 )?
824 .changed;
825 scratch.clear();
826 }
827 }
828 log::info!("rebuild took {:?}", start.elapsed());
829 Ok(())
830 }
831
832 fn incremental_rebuild_rules(&mut self, table: FunctionId, schema: &[ColumnTy]) -> Vec<RuleId> {
833 schema
834 .iter()
835 .enumerate()
836 .filter_map(|(i, ty)| match ty {
837 ColumnTy::Id => {
838 Some(self.incremental_rebuild_rule(table, schema, ColumnId::from_usize(i)))
839 }
840 ColumnTy::Base(_) => None,
841 })
842 .collect()
843 }
844
845 fn incremental_rebuild_rule(
846 &mut self,
847 table: FunctionId,
848 schema: &[ColumnTy],
849 col: ColumnId,
850 ) -> RuleId {
851 let subsume = self.funcs[table].can_subsume;
852 let table_id = self.funcs[table].table;
853 let uf_table = self.uf_table;
854 let mut rb = self.new_rule(&format!("incremental rebuild {table:?}, {col:?}"), true);
856 rb.set_plan_strategy(PlanStrategy::MinCover);
857 let mut vars = Vec::<QueryEntry>::with_capacity(schema.len());
858 for ty in schema {
859 vars.push(rb.new_var(*ty).into());
860 }
861 let canon_val: QueryEntry = rb.new_var(ColumnTy::Id).into();
862 let subsume_var = subsume.then(|| rb.new_var(ColumnTy::Id));
863 rb.add_atom_with_timestamp_and_func(
864 table_id,
865 Some(table),
866 subsume_var.clone().map(QueryEntry::from),
867 &vars,
868 );
869 rb.add_atom_with_timestamp_and_func(
870 uf_table,
871 None,
872 None,
873 &[vars[col.index()].clone(), canon_val.clone()],
874 );
875 rb.set_focus(1); let mut canon = Vec::<QueryEntry>::with_capacity(schema.len());
879 for (i, (var, ty)) in vars.iter().zip(schema.iter()).enumerate() {
880 canon.push(if i == col.index() {
881 canon_val.clone()
882 } else if let ColumnTy::Id = ty {
883 rb.lookup_uf(var.clone()).unwrap().into()
884 } else {
885 var.clone()
886 })
887 }
888
889 rb.rebuild_row(table, &vars, &canon, subsume_var);
891 rb.build()
892 }
893
894 fn nonincremental_rebuild(&mut self, table: FunctionId, schema: &[ColumnTy]) -> RuleId {
895 let can_subsume = self.funcs[table].can_subsume;
896 let table_id = self.funcs[table].table;
897 let mut rb = self.new_rule(&format!("nonincremental rebuild {table:?}"), false);
898 rb.set_plan_strategy(PlanStrategy::MinCover);
899 let mut vars = Vec::<QueryEntry>::with_capacity(schema.len());
900 for ty in schema {
901 vars.push(rb.new_var(*ty).into());
902 }
903 let subsume_var = can_subsume.then(|| rb.new_var(ColumnTy::Id));
904 rb.add_atom_with_timestamp_and_func(
905 table_id,
906 Some(table),
907 subsume_var.clone().map(QueryEntry::from),
908 &vars,
909 );
910 let mut lhs = SmallVec::<[QueryEntry; 4]>::new();
911 let mut rhs = SmallVec::<[QueryEntry; 4]>::new();
912 let mut canon = Vec::<QueryEntry>::with_capacity(schema.len());
913 for (var, ty) in vars.iter().zip(schema.iter()) {
914 canon.push(if let ColumnTy::Id = ty {
915 lhs.push(var.clone());
916 let canon_var = QueryEntry::from(rb.lookup_uf(var.clone()).unwrap());
917 rhs.push(canon_var.clone());
918 canon_var
919 } else {
920 var.clone()
921 })
922 }
923 rb.check_for_update(&lhs, &rhs).unwrap();
924 rb.rebuild_row(table, &vars, &canon, subsume_var);
925 rb.build()
926 }
927
928 pub fn with_execution_state<R>(&self, f: impl FnOnce(&mut ExecutionState<'_>) -> R) -> R {
944 self.db.with_execution_state(f)
945 }
946
947 pub fn with_execution_state_tracked<R>(
952 &self,
953 f: impl FnOnce(&mut ExecutionState<'_>) -> R,
954 ) -> (R, bool) {
955 self.db.with_execution_state_tracked(f)
956 }
957
958 pub fn flush_updates(&mut self) -> bool {
961 let uf_size_before = self.db.get_table(self.uf_table).len();
962 let updated = self.db.merge_all();
963 self.inc_ts();
964 let uf_size_after = self.db.get_table(self.uf_table).len();
965 if uf_size_before != uf_size_after {
966 self.rebuild().unwrap();
969 }
970 updated
971 }
972
973 pub fn set_report_level(&mut self, level: ReportLevel) {
974 self.report_level = level;
975 }
976}
977
978#[derive(Clone)]
979struct RuleInfo {
980 last_run_at: Timestamp,
981 query: rule::Query,
982 cached_plan: Option<CachedPlanInfo>,
983 desc: Arc<str>,
984}
985
986#[derive(Clone)]
987struct CachedPlanInfo {
988 plan: Arc<core_relations::CachedPlan>,
989 atom_mapping: Vec<core_relations::AtomId>,
992}
993
994#[derive(Clone)]
995struct FunctionInfo {
996 table: TableId,
997 schema: Vec<ColumnTy>,
998 incremental_rebuild_rules: Vec<RuleId>,
999 nonincremental_rebuild_rule: RuleId,
1000 default_val: DefaultVal,
1001 can_subsume: bool,
1002 name: Arc<str>,
1003}
1004
1005impl FunctionInfo {
1006 fn ret_ty(&self) -> ColumnTy {
1007 self.schema.last().copied().unwrap()
1008 }
1009}
1010
1011#[derive(Copy, Clone)]
1013pub enum DefaultVal {
1014 FreshId,
1016 Fail,
1018 Const(Value),
1020}
1021
1022pub enum MergeFn {
1024 AssertEq,
1026 UnionId,
1028 Primitive(ExternalFunctionId, Vec<MergeFn>),
1031 Function(FunctionId, Vec<MergeFn>),
1034 Old,
1036 New,
1038 Const(Value),
1042}
1043
1044impl MergeFn {
1045 fn fill_deps(
1046 &self,
1047 egraph: &EGraph,
1048 read_deps: &mut IndexSet<TableId>,
1049 write_deps: &mut IndexSet<TableId>,
1050 ) {
1051 use MergeFn::*;
1052 match self {
1053 Primitive(_, args) => {
1054 args.iter()
1055 .for_each(|arg| arg.fill_deps(egraph, read_deps, write_deps));
1056 write_deps.insert(egraph.uf_table);
1057 }
1058 Function(func, args) => {
1059 read_deps.insert(egraph.funcs[*func].table);
1060 write_deps.insert(egraph.funcs[*func].table);
1061 args.iter()
1062 .for_each(|arg| arg.fill_deps(egraph, read_deps, write_deps));
1063 }
1064 UnionId => {
1065 write_deps.insert(egraph.uf_table);
1066 }
1067 AssertEq | Old | New | Const(..) => {}
1068 }
1069 }
1070
1071 fn to_callback(
1072 &self,
1073 schema_math: SchemaMath,
1074 function_name: &str,
1075 egraph: &mut EGraph,
1076 ) -> Box<core_relations::MergeFn> {
1077 let resolved = self.resolve(function_name, egraph);
1078
1079 Box::new(move |state, cur, new, out| {
1080 let timestamp = new[schema_math.ts_col()];
1081
1082 let mut changed = false;
1083
1084 let ret_val = {
1085 let cur = cur[schema_math.ret_val_col()];
1086 let new = new[schema_math.ret_val_col()];
1087 let out = resolved.run(state, cur, new, timestamp);
1088 changed |= cur != out;
1089 out
1090 };
1091
1092 let subsume = schema_math.subsume.then(|| {
1093 let cur = cur[schema_math.subsume_col()];
1094 let new = new[schema_math.subsume_col()];
1095 let out = combine_subsumed(cur, new);
1096 changed |= cur != out;
1097 out
1098 });
1099 if changed {
1100 out.extend_from_slice(new);
1101 schema_math.write_table_row(
1102 out,
1103 RowVals {
1104 timestamp,
1105 subsume,
1106 ret_val: Some(ret_val),
1107 },
1108 );
1109 }
1110
1111 changed
1112 })
1113 }
1114
1115 fn resolve(&self, function_name: &str, egraph: &mut EGraph) -> ResolvedMergeFn {
1116 match self {
1117 MergeFn::Const(v) => ResolvedMergeFn::Const(*v),
1118 MergeFn::Old => ResolvedMergeFn::Old,
1119 MergeFn::New => ResolvedMergeFn::New,
1120 MergeFn::AssertEq => ResolvedMergeFn::AssertEq {
1121 panic: egraph.new_panic(format!(
1122 "Illegal merge attempted for function {function_name}"
1123 )),
1124 },
1125 MergeFn::UnionId => ResolvedMergeFn::UnionId {
1126 uf_table: egraph.uf_table,
1127 },
1128 MergeFn::Primitive(prim, args) => ResolvedMergeFn::Primitive {
1133 prim: *prim,
1134 args: args
1135 .iter()
1136 .map(|arg| arg.resolve(function_name, egraph))
1137 .collect::<Vec<_>>(),
1138 panic: egraph.new_panic(format!(
1139 "Merge function for {function_name} primitive call failed"
1140 )),
1141 },
1142 MergeFn::Function(func, args) => {
1143 let func_info = &egraph.funcs[*func];
1144 assert_eq!(
1145 func_info.schema.len(),
1146 args.len() + 1,
1147 "Merge function for {function_name} must match function arity for {}",
1148 func_info.name
1149 );
1150 ResolvedMergeFn::Function {
1151 func: TableAction::new(egraph, *func),
1152 panic: egraph.new_panic(format!(
1153 "Lookup on {} failed in the merge function for {function_name}",
1154 func_info.name
1155 )),
1156 args: args
1157 .iter()
1158 .map(|arg| arg.resolve(function_name, egraph))
1159 .collect::<Vec<_>>(),
1160 }
1161 }
1162 }
1163 }
1164}
1165
1166enum ResolvedMergeFn {
1171 Const(Value),
1172 Old,
1173 New,
1174 AssertEq {
1175 panic: ExternalFunctionId,
1176 },
1177 UnionId {
1178 uf_table: TableId,
1179 },
1180 Primitive {
1181 prim: ExternalFunctionId,
1182 args: Vec<ResolvedMergeFn>,
1183 panic: ExternalFunctionId,
1184 },
1185 Function {
1186 func: TableAction,
1187 args: Vec<ResolvedMergeFn>,
1188 panic: ExternalFunctionId,
1189 },
1190}
1191
1192impl ResolvedMergeFn {
1193 fn run(&self, state: &mut ExecutionState, cur: Value, new: Value, ts: Value) -> Value {
1194 match self {
1195 ResolvedMergeFn::Const(v) => *v,
1196 ResolvedMergeFn::Old => cur,
1197 ResolvedMergeFn::New => new,
1198 ResolvedMergeFn::AssertEq { panic } => {
1199 if cur != new {
1200 let res = state.call_external_func(*panic, &[]);
1201 assert_eq!(res, None);
1202 }
1203 cur
1204 }
1205 ResolvedMergeFn::UnionId { uf_table } => {
1206 if cur != new {
1207 state.stage_insert(*uf_table, &[cur, new, ts]);
1208 std::cmp::min(cur, new)
1211 } else {
1212 cur
1213 }
1214 }
1215 ResolvedMergeFn::Primitive { prim, args, panic } => {
1220 let args = args
1221 .iter()
1222 .map(|arg| arg.run(state, cur, new, ts))
1223 .collect::<Vec<_>>();
1224
1225 match state.call_external_func(*prim, &args) {
1226 Some(result) => result,
1227 None => {
1228 let res = state.call_external_func(*panic, &[]);
1229 assert_eq!(res, None);
1230 cur
1231 }
1232 }
1233 }
1234 ResolvedMergeFn::Function { func, args, panic } => {
1235 if cur == new {
1237 return cur;
1238 }
1239
1240 let args = args
1241 .iter()
1242 .map(|arg| arg.run(state, cur, new, ts))
1243 .collect::<Vec<_>>();
1244
1245 func.lookup_or_insert(state, &args).unwrap_or_else(|| {
1251 let res = state.call_external_func(*panic, &[]);
1252 assert_eq!(res, None);
1253 cur
1254 })
1255 }
1256 }
1257 }
1258}
1259
1260#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
1265pub enum TableKind {
1266 Function,
1267 Constructor,
1268}
1269
1270#[derive(Debug, PartialEq, Eq, Hash, Clone)]
1274pub struct TableAction {
1275 table: TableId,
1276 table_math: SchemaMath,
1277 default: Option<MergeVal>,
1278 timestamp: CounterId,
1279 kind: TableKind,
1280}
1281
1282impl TableAction {
1283 pub fn new(egraph: &EGraph, func: FunctionId) -> TableAction {
1286 let func_info = &egraph.funcs[func];
1287 let kind = match &func_info.default_val {
1288 DefaultVal::FreshId => TableKind::Constructor,
1289 DefaultVal::Fail | DefaultVal::Const(_) => TableKind::Function,
1290 };
1291 TableAction {
1292 table: func_info.table,
1293 table_math: SchemaMath {
1294 func_cols: func_info.schema.len(),
1295 subsume: func_info.can_subsume,
1296 },
1297 default: match &func_info.default_val {
1298 DefaultVal::FreshId => Some(MergeVal::Counter(egraph.id_counter)),
1299 DefaultVal::Fail => None,
1300 DefaultVal::Const(val) => Some(MergeVal::Constant(*val)),
1301 },
1302 timestamp: egraph.timestamp_counter,
1303 kind,
1304 }
1305 }
1306
1307 pub fn kind(&self) -> TableKind {
1310 self.kind
1311 }
1312
1313 pub fn input_arity(&self) -> usize {
1316 self.table_math.func_cols - 1
1317 }
1318
1319 pub fn lookup(&self, state: &ExecutionState, key: &[Value]) -> Option<Value> {
1326 state
1327 .get_table(self.table)
1328 .get_row(key)
1329 .map(|row| row.vals[self.table_math.ret_val_col()])
1330 }
1331
1332 pub fn row_count(&self, state: &ExecutionState) -> usize {
1334 state.get_table(self.table).len()
1335 }
1336
1337 pub fn for_each(&self, state: &ExecutionState, mut f: impl FnMut(ScanEntry<'_>)) {
1342 self.for_each_while(state, |row| {
1343 f(row);
1344 true
1345 });
1346 }
1347
1348 pub fn for_each_while(&self, state: &ExecutionState, mut f: impl FnMut(ScanEntry<'_>) -> bool) {
1351 let schema_math = self.table_math;
1352 let imp = state.get_table(self.table);
1353 let all = imp.all();
1354 let mut cur = Offset::new(0);
1355 let mut buf = TaggedRowBuffer::new(imp.spec().arity());
1356 macro_rules! drain_buf {
1357 ($buf:expr) => {
1358 for (_, row) in $buf.non_stale() {
1359 let subsumed =
1360 schema_math.subsume && row[schema_math.subsume_col()] == SUBSUMED;
1361 if !f(ScanEntry {
1362 vals: &row[0..schema_math.func_cols],
1363 subsumed,
1364 }) {
1365 return;
1366 }
1367 }
1368 $buf.clear();
1369 };
1370 }
1371 while let Some(next) = imp.scan_bounded(all.as_ref(), cur, 32, &mut buf) {
1372 drain_buf!(buf);
1373 cur = next;
1374 }
1375 drain_buf!(buf);
1376 }
1377
1378 pub fn lookup_or_insert(&self, state: &mut ExecutionState, key: &[Value]) -> Option<Value> {
1386 match self.default {
1387 Some(default) => {
1388 let timestamp =
1389 MergeVal::Constant(Value::from_usize(state.read_counter(self.timestamp)));
1390 let mut merge_vals = SmallVec::<[MergeVal; 3]>::new();
1391 SchemaMath {
1392 func_cols: 1,
1393 ..self.table_math
1394 }
1395 .write_table_row(
1396 &mut merge_vals,
1397 RowVals {
1398 timestamp,
1399 subsume: self
1400 .table_math
1401 .subsume
1402 .then_some(MergeVal::Constant(NOT_SUBSUMED)),
1403 ret_val: Some(default),
1404 },
1405 );
1406 Some(
1407 state.predict_val(self.table, key, merge_vals.iter().copied())
1408 [self.table_math.ret_val_col()],
1409 )
1410 }
1411 None => self.lookup(state, key),
1412 }
1413 }
1414
1415 pub fn insert(&self, state: &mut ExecutionState, row: impl Iterator<Item = Value>) {
1417 let ts = Value::from_usize(state.read_counter(self.timestamp));
1418 let mut scratch = row.collect::<SmallVec<[_; 8]>>();
1419 self.table_math.write_table_row(
1420 &mut scratch,
1421 RowVals {
1422 timestamp: ts,
1423 subsume: self.table_math.subsume.then_some(NOT_SUBSUMED),
1424 ret_val: None,
1425 },
1426 );
1427 state.stage_insert(self.table, &scratch);
1428 }
1429
1430 pub fn remove(&self, state: &mut ExecutionState, key: &[Value]) {
1432 state.stage_remove(self.table, key);
1433 }
1434
1435 pub fn subsume(&self, state: &mut ExecutionState, key: impl Iterator<Item = Value>) {
1437 let ts = Value::from_usize(state.read_counter(self.timestamp));
1438 let mut scratch = key.collect::<SmallVec<[_; 8]>>();
1439
1440 let ret_val = self.lookup(state, &scratch).expect("subsume lookup failed");
1441
1442 self.table_math.write_table_row(
1443 &mut scratch,
1444 RowVals {
1445 timestamp: ts,
1446 subsume: Some(SUBSUMED),
1447 ret_val: Some(ret_val),
1448 },
1449 );
1450 state.stage_insert(self.table, &scratch);
1451 }
1452}
1453
1454#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1456pub struct UnionAction {
1457 table: TableId,
1458 timestamp: CounterId,
1459}
1460
1461impl UnionAction {
1462 pub fn new(egraph: &EGraph) -> UnionAction {
1465 UnionAction {
1466 table: egraph.uf_table,
1467 timestamp: egraph.timestamp_counter,
1468 }
1469 }
1470
1471 pub fn union(&self, state: &mut ExecutionState, x: Value, y: Value) {
1473 let ts = Value::from_usize(state.read_counter(self.timestamp));
1474 state.stage_insert(self.table, &[x, y, ts]);
1475 }
1476}
1477
1478fn run_rules_impl(
1479 db: &mut Database,
1480 rule_info: &mut DenseIdMapWithReuse<RuleId, RuleInfo>,
1481 rules: &[RuleId],
1482 next_ts: Timestamp,
1483 report_level: ReportLevel,
1484) -> Result<RuleSetReport> {
1485 for rule in rules {
1486 let info = &mut rule_info[*rule];
1487 if info.cached_plan.is_none() {
1488 info.cached_plan = Some(info.query.build_cached_plan(db, &info.desc)?);
1489 }
1490 }
1491 let mut rsb = db.new_rule_set();
1492 for rule in rules {
1493 let info = &mut rule_info[*rule];
1494 let cached_plan = info.cached_plan.as_ref().unwrap();
1495 info.query
1496 .add_rules_from_cached(&mut rsb, info.last_run_at, cached_plan);
1497 info.last_run_at = next_ts;
1498 }
1499 let ruleset = rsb.build();
1500 Ok(db.run_rule_set(&ruleset, report_level))
1501}
1502
1503#[inline(never)]
1507fn marker_incremental_rebuild<R>(f: impl FnOnce() -> R) -> R {
1508 f()
1509}
1510
1511#[inline(never)]
1512fn marker_nonincremental_rebuild<R>(f: impl FnOnce() -> R) -> R {
1513 f()
1514}
1515
1516pub type SideChannel<T> = Arc<Mutex<Option<T>>>;
1519
1520struct LazyPanic<F>(Arc<Lazy<String, F>>, SideChannel<String>);
1536
1537impl<F: FnOnce() -> String + Send> ExternalFunction for LazyPanic<F> {
1538 fn invoke(&self, state: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
1539 assert!(args.is_empty());
1540 state.trigger_early_stop();
1541 let mut guard = self.1.lock().unwrap();
1542 if guard.is_none() {
1543 *guard = Some(Lazy::force(&self.0).clone());
1544 }
1545 None
1546 }
1547}
1548
1549impl<F> Clone for LazyPanic<F> {
1550 fn clone(&self) -> Self {
1551 LazyPanic(self.0.clone(), self.1.clone())
1552 }
1553}
1554
1555#[derive(Clone)]
1560struct Panic(String, SideChannel<String>);
1561
1562impl EGraph {
1563 pub fn new_panic(&mut self, message: String) -> ExternalFunctionId {
1565 *self
1566 .panic_funcs
1567 .entry(message.to_string())
1568 .or_insert_with(|| {
1569 let panic = Panic(message, self.panic_message.clone());
1570 self.db.add_external_function(Box::new(panic))
1571 })
1572 }
1573
1574 pub fn new_panic_lazy(
1575 &mut self,
1576 message: impl FnOnce() -> String + Send + 'static,
1577 ) -> ExternalFunctionId {
1578 let lazy = Lazy::new(message);
1579 let panic = LazyPanic(Arc::new(lazy), self.panic_message.clone());
1580 self.db.add_external_function(Box::new(panic))
1581 }
1582}
1583
1584impl ExternalFunction for Panic {
1585 fn invoke(&self, state: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
1586 assert!(args.is_empty());
1588
1589 state.trigger_early_stop();
1590 let mut guard = self.1.lock().unwrap();
1591 if guard.is_none() {
1592 *guard = Some(self.0.clone());
1593 }
1594 None
1595 }
1596}
1597
1598fn incremental_rebuild(uf_size: usize, table_size: usize, parallel: bool) -> bool {
1601 if parallel {
1602 uf_size <= (table_size / 16)
1603 } else {
1604 uf_size <= (table_size / 8)
1605 }
1606}
1607
1608pub(crate) const SUBSUMED: Value = Value::new_const(1);
1609pub(crate) const NOT_SUBSUMED: Value = Value::new_const(0);
1610fn combine_subsumed(v1: Value, v2: Value) -> Value {
1611 std::cmp::max(v1, v2)
1612}
1613
1614#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1625struct SchemaMath {
1626 subsume: bool,
1628 func_cols: usize,
1630}
1631
1632struct RowVals<T> {
1638 timestamp: T,
1640 subsume: Option<T>,
1642 ret_val: Option<T>,
1645}
1646
1647#[derive(Clone, Debug)]
1654pub struct ScanEntry<'a> {
1655 pub vals: &'a [Value],
1656 pub subsumed: bool,
1657}
1658
1659impl SchemaMath {
1660 fn write_table_row<T: Clone>(
1661 &self,
1662 row: &mut impl HasResizeWith<T>,
1663 RowVals {
1664 timestamp,
1665 subsume,
1666 ret_val,
1667 }: RowVals<T>,
1668 ) {
1669 row.resize_with(self.table_columns(), || timestamp.clone());
1670 row[self.ts_col()] = timestamp;
1671 if let Some(ret_val) = ret_val {
1672 row[self.ret_val_col()] = ret_val;
1673 }
1674 if let Some(subsume) = subsume {
1675 row[self.subsume_col()] = subsume;
1676 } else {
1677 assert!(
1678 !self.subsume,
1679 "subsume flag must be provided if subsumption is enabled"
1680 );
1681 }
1682 }
1683
1684 fn num_keys(&self) -> usize {
1685 self.func_cols - 1
1686 }
1687
1688 fn table_columns(&self) -> usize {
1689 self.func_cols + 1 + if self.subsume { 1 } else { 0 }
1690 }
1691
1692 fn ret_val_col(&self) -> usize {
1693 self.func_cols - 1
1694 }
1695
1696 fn ts_col(&self) -> usize {
1697 self.func_cols
1698 }
1699
1700 #[track_caller]
1701 fn subsume_col(&self) -> usize {
1702 assert!(self.subsume);
1703 self.func_cols + 1
1704 }
1705}
1706
1707#[derive(Error, Debug)]
1708#[error("Panic: {0}")]
1709struct PanicError(String);
1710
1711trait HasResizeWith<T>:
1714 AsMut<[T]> + AsRef<[T]> + Index<usize, Output = T> + IndexMut<usize, Output = T>
1715{
1716 fn resize_with<F>(&mut self, new_size: usize, f: F)
1717 where
1718 F: FnMut() -> T;
1719}
1720
1721impl<T> HasResizeWith<T> for Vec<T> {
1722 fn resize_with<F>(&mut self, new_size: usize, f: F)
1723 where
1724 F: FnMut() -> T,
1725 {
1726 self.resize_with(new_size, f);
1727 }
1728}
1729
1730impl<T, A: smallvec::Array<Item = T>> HasResizeWith<T> for SmallVec<A> {
1731 fn resize_with<F>(&mut self, new_size: usize, f: F)
1732 where
1733 F: FnMut() -> T,
1734 {
1735 self.resize_with(new_size, f);
1736 }
1737}