egglog_bridge/
lib.rs

1//! An implementation of egglog-style queries on top of core-relations.
2//!
3//! This module translates a well-typed egglog-esque query into the abstractions
4//! from the `core-relations` crate. The main higher-level functionality that it
5//! implements are seminaive evaluation, default values, and merge functions.
6//!
7//! This crate is essentially involved in desugaring: it elaborates the encoding
8//! of core egglog functionality, but it does not implement algorithms for
9//! joins, union-finds, etc.
10
11use std::{
12    fmt::Debug,
13    hash::Hash,
14    iter, mem,
15    ops::{Index, IndexMut},
16    sync::{Arc, Mutex},
17};
18
19use crate::core_relations::{
20    BaseValue, BaseValueId, BaseValues, ColumnId, Constraint, ContainerValue, ContainerValues,
21    CounterId, Database, DisplacedTable, ExecutionState, ExternalFunction, ExternalFunctionId,
22    MergeVal, Offset, PlanStrategy, SortedWritesTable, TableId, TaggedRowBuffer, Value,
23    WrappedTable,
24};
25use crate::numeric_id::{DenseIdMap, DenseIdMapWithReuse, NumericId, define_id};
26use egglog_core_relations as core_relations;
27use egglog_numeric_id as numeric_id;
28use egglog_reports::{IterationReport, ReportLevel, RuleSetReport};
29use hashbrown::HashMap;
30use indexmap::IndexSet;
31use log::info;
32use once_cell::sync::Lazy;
33use smallvec::SmallVec;
34use web_time::{Duration, Instant};
35
36pub mod macros;
37pub(crate) mod rule;
38#[cfg(test)]
39mod tests;
40
41pub use rule::{Function, QueryEntry, RuleBuilder};
42use thiserror::Error;
43
44/// A live registry of action handles for use by typed primitives.
45///
46/// Maps table name to [`TableAction`] (plus the shared [`UnionAction`]
47/// and the default-panic external-function id) and is owned by the
48/// bridge `EGraph`. The state wrappers (`PureState`/`ReadState`/
49/// `WriteState`/`FullState`) live in the `egglog` crate; they read
50/// from this registry at invoke time to back name-indexed action
51/// methods. Held by the bridge `EGraph` inside an `Arc<RwLock<_>>`.
52#[derive(Clone)]
53pub struct ActionRegistry {
54    table_actions: hashbrown::HashMap<String, TableAction>,
55    union_action: UnionAction,
56    default_panic_id: ExternalFunctionId,
57}
58
59impl ActionRegistry {
60    pub(crate) fn new(union_action: UnionAction, default_panic_id: ExternalFunctionId) -> Self {
61        Self {
62            table_actions: hashbrown::HashMap::new(),
63            union_action,
64            default_panic_id,
65        }
66    }
67
68    pub(crate) fn register_table(&mut self, name: String, action: TableAction) {
69        self.table_actions.insert(name, action);
70    }
71
72    /// Look up the [`TableAction`] for a table by name, or `None` if
73    /// no table with that name has been registered.
74    pub fn lookup_table(&self, name: &str) -> Option<&TableAction> {
75        self.table_actions.get(name)
76    }
77
78    /// Snapshot the registered table names and their current row counts.
79    pub fn table_sizes(&self, state: &ExecutionState) -> Vec<(&str, usize)> {
80        self.table_actions
81            .iter()
82            .map(|(name, action)| (name.as_str(), action.row_count(state)))
83            .collect()
84    }
85
86    /// The shared [`UnionAction`] for this EGraph's union-find.
87    pub fn union_action(&self) -> &UnionAction {
88        &self.union_action
89    }
90
91    /// The default panic external function id, used by the egglog
92    /// crate's `ActionView::panic`.
93    pub fn default_panic_id(&self) -> ExternalFunctionId {
94        self.default_panic_id
95    }
96}
97
98#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
99pub enum ColumnTy {
100    Id,
101    Base(BaseValueId),
102}
103
104define_id!(pub RuleId, u32, "An egglog-style rule");
105define_id!(pub FunctionId, u32, "An id representing an egglog function");
106define_id!(pub(crate) Timestamp, u32, "An abstract timestamp used to track execution of egglog rules");
107impl Timestamp {
108    fn to_value(self) -> Value {
109        Value::new(self.rep())
110    }
111}
112
113/// The state associated with an egglog program.
114#[derive(Clone)]
115pub struct EGraph {
116    db: Database,
117    uf_table: TableId,
118    id_counter: CounterId,
119    timestamp_counter: CounterId,
120    rules: DenseIdMapWithReuse<RuleId, RuleInfo>,
121    funcs: DenseIdMap<FunctionId, FunctionInfo>,
122    panic_message: SideChannel<String>,
123    /// This is a cache of all the different panic messages that we may use while executing rules
124    /// against the EGraph. Oftentimes, these messages are generated dynamically: keeping this map
125    /// around allows us to cache external function ids with repeat panic messages and they can
126    /// also serve as a debugging tool in the case that the number of panic messages grows without
127    /// bound.
128    panic_funcs: HashMap<String, ExternalFunctionId>,
129    report_level: ReportLevel,
130    /// Live registry of name-indexed action handles. Shared (via
131    /// `Arc<RwLock<_>>`) with state wrappers and primitive callbacks
132    /// in the egglog crate so name-indexed action methods on
133    /// [`WriteState`] / [`FullState`] can resolve table actions at
134    /// invoke time. Mutated in place from [`add_table`](EGraph::add_table).
135    action_registry: Arc<std::sync::RwLock<ActionRegistry>>,
136}
137
138pub type Result<T> = std::result::Result<T, anyhow::Error>;
139
140impl Default for EGraph {
141    fn default() -> Self {
142        let mut db = Database::new();
143        let uf_table = db.add_table_named(
144            DisplacedTable::default(),
145            "$uf".into(),
146            iter::empty(),
147            iter::empty(),
148        );
149        let id_counter = db.add_counter();
150        let ts_counter = db.add_counter();
151        // Start the timestamp counter at 1.
152        db.inc_counter(ts_counter);
153
154        // Register a default panic external function so the typed
155        // state wrappers' `panic()` method has an id to call. This
156        // also seeds `panic_funcs` so a later `new_panic` with the
157        // same message reuses the id.
158        let panic_message: SideChannel<String> = Default::default();
159        let mut panic_funcs: HashMap<String, ExternalFunctionId> = Default::default();
160        let default_panic_msg = "primitive panicked".to_string();
161        let default_panic_id = db.add_external_function(Box::new(Panic(
162            default_panic_msg.clone(),
163            panic_message.clone(),
164        )));
165        panic_funcs.insert(default_panic_msg, default_panic_id);
166
167        let union_action = UnionAction {
168            table: uf_table,
169            timestamp: ts_counter,
170        };
171        let action_registry = Arc::new(std::sync::RwLock::new(ActionRegistry::new(
172            union_action,
173            default_panic_id,
174        )));
175
176        Self {
177            db,
178            uf_table,
179            id_counter,
180            timestamp_counter: ts_counter,
181            rules: Default::default(),
182            funcs: Default::default(),
183            panic_message,
184            panic_funcs,
185            report_level: Default::default(),
186            action_registry,
187        }
188    }
189}
190
191/// Properties of a function added to an [`EGraph`].
192pub struct FunctionConfig {
193    /// The function's schema. The last column in the schema is the return type.
194    pub schema: Vec<ColumnTy>,
195    /// The behavior of the function when lookups are made on keys not currently present.
196    pub default: DefaultVal,
197    /// How to resolve FD conflicts for the function.
198    pub merge: MergeFn,
199    /// The function's name
200    pub name: String,
201    /// Whether or not subsumption is enabled for this function.
202    pub can_subsume: bool,
203}
204
205impl EGraph {
206    fn next_ts(&self) -> Timestamp {
207        Timestamp::from_usize(self.db.read_counter(self.timestamp_counter))
208    }
209
210    fn inc_ts(&mut self) {
211        self.db.inc_counter(self.timestamp_counter);
212    }
213
214    /// Get a mutable reference to the underlying table of base values for this
215    /// `EGraph`.
216    pub fn base_values_mut(&mut self) -> &mut BaseValues {
217        self.db.base_values_mut()
218    }
219
220    /// Get a mutable reference to the underlying table of containers for this
221    /// `EGraph`.
222    pub fn container_values_mut(&mut self) -> &mut ContainerValues {
223        self.db.container_values_mut()
224    }
225
226    /// Get a reference to the underlying table of containers for this `EGraph`.
227    pub fn container_values(&self) -> &ContainerValues {
228        self.db.container_values()
229    }
230
231    /// Intern the given container value into the EGraph.
232    pub fn get_container_value<C: ContainerValue>(&mut self, val: C) -> Value {
233        self.register_container_ty::<C>();
234        self.db
235            .with_execution_state(|state| state.clone().container_values().register_val(val, state))
236    }
237
238    /// Register the given [`ContainerValue`] type with this EGraph.
239    ///
240    /// The given container will use the EGraph's union-find to manage rebuilding and the merging
241    /// of containers with a common id.
242    pub fn register_container_ty<C: ContainerValue>(&mut self) {
243        let uf_table = self.uf_table;
244        let ts_counter = self.timestamp_counter;
245        self.db.container_values_mut().register_type::<C>(
246            self.id_counter,
247            move |state, old, new| {
248                if old != new {
249                    let next_ts = Value::from_usize(state.read_counter(ts_counter));
250                    state.stage_insert(uf_table, &[old, new, next_ts]);
251                    std::cmp::min(old, new)
252                } else {
253                    old
254                }
255            },
256        );
257    }
258
259    /// Get a reference to the underlying table of base values for this `EGraph`.
260    pub fn base_values(&self) -> &BaseValues {
261        self.db.base_values()
262    }
263
264    /// Create a [`QueryEntry`] for a base value.
265    pub fn base_value_constant<T>(&self, x: T) -> QueryEntry
266    where
267        T: BaseValue,
268    {
269        QueryEntry::Const {
270            val: self.base_values().get(x),
271            ty: ColumnTy::Base(self.base_values().get_ty::<T>()),
272        }
273    }
274
275    /// Register a low-level external function. The callback receives a
276    /// raw `&mut ExecutionState`.
277    ///
278    /// # Seminaive-safety trust boundary
279    ///
280    /// Like [`EGraph::with_execution_state`], this is a raw escape —
281    /// the registered function has unrestricted access and is not
282    /// tracked by the per-context validity system. Prefer building
283    /// primitives via the higher-level `egglog::Primitive` /
284    /// `egglog::EGraph::add_primitive` API, which enforces #772's
285    /// seminaive-safety contract.
286    pub fn register_external_func(
287        &mut self,
288        func: Box<dyn ExternalFunction + 'static>,
289    ) -> ExternalFunctionId {
290        self.db.add_external_function(func)
291    }
292
293    pub fn free_external_func(&mut self, func: ExternalFunctionId) {
294        self.db.free_external_function(func)
295    }
296
297    /// Generate a fresh id.
298    pub fn fresh_id(&mut self) -> Value {
299        Value::from_usize(self.db.inc_counter(self.id_counter))
300    }
301
302    /// Look up the canonical value for `val` in the union-find.
303    ///
304    /// If the value has never been inserted into the union-find, `val` is returned.
305    fn get_canon_in_uf(&self, val: Value) -> Value {
306        let table = self.db.get_table(self.uf_table);
307        let row = table.get_row(&[val]);
308        row.map(|row| row.vals[1]).unwrap_or(val)
309    }
310
311    /// Get the canonical representation for `val` based on type.
312    ///
313    /// For [`ColumnTy::Id`], it looks up the union find; otherwise,
314    /// it returns the value itself.
315    pub fn get_canon_repr(&self, val: Value, ty: ColumnTy) -> Value {
316        match ty {
317            ColumnTy::Id => self.get_canon_in_uf(val),
318            ColumnTy::Base(_) => val,
319        }
320    }
321
322    /// Load the given values into the database.
323    ///
324    /// # Panics
325    /// This method panics if the values do not match the arity of the function.
326    ///
327    /// NB: this is not an efficient interface for bulk loading. We should add
328    /// one that allows us to pass through a series of RowBuffers before
329    /// incrementing the timestamp.
330    pub fn add_values(&mut self, values: impl IntoIterator<Item = (FunctionId, Vec<Value>)>) {
331        let mut extended_row = Vec::<Value>::new();
332        let mut bufs = DenseIdMap::default();
333        for (func, row) in values.into_iter() {
334            let table_info = &self.funcs[func];
335            let schema_math = SchemaMath {
336                subsume: table_info.can_subsume,
337                func_cols: table_info.schema.len(),
338            };
339            let table_id = table_info.table;
340            extended_row.extend_from_slice(&row);
341            schema_math.write_table_row(
342                &mut extended_row,
343                RowVals {
344                    timestamp: self.next_ts().to_value(),
345                    subsume: schema_math.subsume.then_some(NOT_SUBSUMED),
346                    ret_val: None, // already filled in.
347                },
348            );
349            let buf = bufs.get_or_insert(table_id, || self.db.new_buffer(table_id));
350            buf.stage_insert(&extended_row);
351            extended_row.clear();
352        }
353        // Flush the buffers.
354        mem::drop(bufs);
355        self.flush_updates();
356    }
357
358    /// A term-oriented means of adding data to the database: hand back a "term
359    /// id" for the given function and keys for the function.
360    ///
361    /// # Panics
362    /// This method panics if the values do not match the arity of the function.
363    pub fn add_term(&mut self, func: FunctionId, inputs: &[Value]) -> Value {
364        let info = &self.funcs[func];
365        let schema_math = SchemaMath {
366            subsume: info.can_subsume,
367            func_cols: info.schema.len(),
368        };
369        let mut extended_row = Vec::new();
370        extended_row.extend_from_slice(inputs);
371        let res = self.fresh_id();
372        schema_math.write_table_row(
373            &mut extended_row,
374            RowVals {
375                timestamp: self.next_ts().to_value(),
376                ret_val: Some(res),
377                subsume: schema_math.subsume.then_some(NOT_SUBSUMED),
378            },
379        );
380        extended_row[schema_math.ret_val_col()] = res;
381        let table_id = self.funcs[func].table;
382        self.db.new_buffer(table_id).stage_insert(&extended_row);
383        self.flush_updates();
384        self.get_canon_in_uf(res)
385    }
386
387    /// Lookup the id associated with a function `func` and the given arguments
388    /// (`key`).
389    pub fn lookup_id(&self, func: FunctionId, key: &[Value]) -> Option<Value> {
390        let info = &self.funcs[func];
391        let schema_math = SchemaMath {
392            subsume: info.can_subsume,
393            func_cols: info.schema.len(),
394        };
395        let table_id = info.table;
396        let table = self.db.get_table(table_id);
397        let row = table.get_row(key)?;
398        Some(row.vals[schema_math.ret_val_col()])
399    }
400
401    pub fn approx_table_size(&self, table: FunctionId) -> usize {
402        self.db.estimate_size(self.funcs[table].table, None)
403    }
404
405    pub fn table_size(&self, table: FunctionId) -> usize {
406        self.db.get_table(self.funcs[table].table).len()
407    }
408
409    /// Remove every row from the given function's backing table.
410    ///
411    /// This is the bulk counterpart to staging a `remove` for every key in the
412    /// table: the underlying `Database::clear_table` drops the row buffer in
413    /// O(1)-in-row-count time and bumps the table's major generation, which
414    /// lazily invalidates any cached subsets or indexes a later reader might
415    /// consult. Any rows staged for this table by an in-flight
416    /// `MutationBuffer` are dropped along with the table contents.
417    ///
418    /// Callers that have staged inserts/removes for *other* tables that they
419    /// want flushed first should call [`EGraph::flush_updates`] before
420    /// clearing.
421    pub fn clear_table(&mut self, func: FunctionId) {
422        let table_id = self.funcs[func].table;
423        self.db.clear_table(table_id);
424    }
425
426    /// Read the contents of the given function.
427    ///
428    /// The callback `f` is called with each row and its subsumption status.
429    pub fn for_each(&self, table: FunctionId, mut f: impl FnMut(ScanEntry<'_>)) {
430        self.for_each_while(table, |row| {
431            f(row);
432            true
433        });
434    }
435
436    /// Iterate over the rows of a function table, calling `f` on each row. If `f` returns `false`
437    /// the function returns early and stops reading rows from the table.
438    pub fn for_each_while(&self, table: FunctionId, mut f: impl FnMut(ScanEntry<'_>) -> bool) {
439        let info = &self.funcs[table];
440        let table = self.funcs[table].table;
441        let schema_math = SchemaMath {
442            subsume: info.can_subsume,
443            func_cols: info.schema.len(),
444        };
445        let imp = self.db.get_table(table);
446        let all = imp.all();
447        let mut cur = Offset::new(0);
448        let mut buf = TaggedRowBuffer::new(imp.spec().arity());
449        // This somewhat awkward iteration strategy is forced on us by the `scan_bounded` API. We
450        // should look into ways to avoid this cludge where the loop body effectively must be
451        // repeated at the end. The obvious and idiomatic ways to do this all require
452        // `dyn`-compatibility on `Table` or dynamic dispatch per row.
453        macro_rules! drain_buf {
454            ($buf:expr) => {
455                for (_, row) in $buf.non_stale() {
456                    let subsumed =
457                        schema_math.subsume && row[schema_math.subsume_col()] == SUBSUMED;
458                    if !f(ScanEntry {
459                        vals: &row[0..schema_math.func_cols],
460                        subsumed,
461                    }) {
462                        return;
463                    }
464                }
465                $buf.clear();
466            };
467        }
468        while let Some(next) = imp.scan_bounded(all.as_ref(), cur, 32, &mut buf) {
469            drain_buf!(buf);
470            cur = next;
471        }
472        drain_buf!(buf);
473    }
474
475    /// A basic method for dumping the state of the database to `log::info!`.
476    ///
477    /// For large tables, this is unlikely to give particularly useful output.
478    pub fn dump_debug_info(&self) {
479        info!("=== View Tables ===");
480        for (id, info) in self.funcs.iter() {
481            let table = self.db.get_table(info.table);
482            self.scan_table(table, |row| {
483                info!(
484                    "View Table {name} / {id:?} / {table:?}: {row:?}",
485                    name = info.name,
486                    table = info.table
487                )
488            });
489        }
490    }
491
492    /// A helper for scanning the entries in a table.
493    fn scan_table(&self, table: &WrappedTable, mut f: impl FnMut(&[Value])) {
494        const BATCH_SIZE: usize = 128;
495        let all = table.all();
496        let mut cur = Offset::new(0);
497        let mut out = TaggedRowBuffer::new(table.spec().arity());
498        while let Some(next) = table.scan_bounded(all.as_ref(), cur, BATCH_SIZE, &mut out) {
499            out.non_stale().for_each(|(_, row)| f(row));
500            out.clear();
501            cur = next;
502        }
503        out.non_stale().for_each(|(_, row)| f(row));
504    }
505
506    /// Register a function in this EGraph.
507    pub fn add_table(&mut self, config: FunctionConfig) -> FunctionId {
508        let FunctionConfig {
509            schema,
510            default,
511            merge,
512            name,
513            can_subsume,
514        } = config;
515        assert!(
516            !schema.is_empty(),
517            "must have at least one column in schema"
518        );
519        let to_rebuild: Vec<ColumnId> = schema
520            .iter()
521            .enumerate()
522            .filter(|(_, ty)| matches!(ty, ColumnTy::Id))
523            .map(|(i, _)| ColumnId::from_usize(i))
524            .collect();
525        let schema_math = SchemaMath {
526            subsume: can_subsume,
527            func_cols: schema.len(),
528        };
529        let n_args = schema_math.num_keys();
530        let n_cols = schema_math.table_columns();
531        let next_func_id = self.funcs.next_id();
532        let mut read_deps = IndexSet::<TableId>::new();
533        let mut write_deps = IndexSet::<TableId>::new();
534        merge.fill_deps(self, &mut read_deps, &mut write_deps);
535        let merge_fn = merge.to_callback(schema_math, &name, self);
536        let table = SortedWritesTable::new(
537            n_args,
538            n_cols,
539            Some(ColumnId::from_usize(schema.len())),
540            to_rebuild,
541            merge_fn,
542        );
543        let name: Arc<str> = name.into();
544        let table_id = self.db.add_table_named(
545            table,
546            name.clone(),
547            read_deps.iter().copied(),
548            write_deps.iter().copied(),
549        );
550
551        let res = self.funcs.push(FunctionInfo {
552            table: table_id,
553            schema: schema.clone(),
554            incremental_rebuild_rules: Default::default(),
555            nonincremental_rebuild_rule: RuleId::new(!0),
556            default_val: default,
557            can_subsume,
558            name,
559        });
560        debug_assert_eq!(res, next_func_id);
561        let incremental_rebuild_rules = self.incremental_rebuild_rules(res, &schema);
562        let nonincremental_rebuild_rule = self.nonincremental_rebuild(res, &schema);
563        let info = &mut self.funcs[res];
564        info.incremental_rebuild_rules = incremental_rebuild_rules;
565        info.nonincremental_rebuild_rule = nonincremental_rebuild_rule;
566        let action = TableAction::new(self, res);
567        let table_name = self.funcs[res].name.to_string();
568        self.action_registry
569            .write()
570            .unwrap()
571            .register_table(table_name, action);
572        res
573    }
574
575    /// A handle to the live [`ActionRegistry`] for this EGraph.
576    /// The handle is shared (`Arc<RwLock<_>>`); cloning the outer
577    /// `Arc` does not duplicate the underlying registry. Used by the
578    /// egglog crate's primitive machinery to thread the registry into
579    /// state wrappers at invoke time.
580    pub fn action_registry(&self) -> &Arc<std::sync::RwLock<ActionRegistry>> {
581        &self.action_registry
582    }
583
584    /// Run the given rules, returning whether the database changed.
585    ///
586    /// If the given rules are malformed, this method can return an error.
587    pub fn run_rules(&mut self, rules: &[RuleId]) -> Result<IterationReport> {
588        self.run_rules_inner(rules)
589    }
590
591    fn run_rules_inner(&mut self, rules: &[RuleId]) -> Result<IterationReport> {
592        let ts = self.next_ts();
593
594        let uf_size_before = self.db.get_table(self.uf_table).len();
595        let rule_set_report =
596            run_rules_impl(&mut self.db, &mut self.rules, rules, ts, self.report_level)?;
597        if let Some(message) = self.panic_message.lock().unwrap().take() {
598            return Err(PanicError(message).into());
599        }
600
601        let mut iteration_report = IterationReport {
602            rule_set_report,
603            rebuild_time: Duration::ZERO,
604        };
605        let uf_size_after = self.db.get_table(self.uf_table).len();
606        if uf_size_before == uf_size_after {
607            // No new unions: skip the full rebuild but still advance the
608            // timestamp so that seminaive evaluation sees a fresh epoch.
609            // Rebuilding is only necessary when new unions have been made because ids may need to be updated.
610            // Adding terms doesn't necessarily touch the union-find, only doing a union between existing ids does.
611            self.inc_ts();
612            return Ok(iteration_report);
613        }
614
615        let rebuild_timer = Instant::now();
616        self.rebuild()?;
617        iteration_report.rebuild_time = rebuild_timer.elapsed();
618
619        if let Some(message) = self.panic_message.lock().unwrap().take() {
620            return Err(PanicError(message).into());
621        }
622
623        Ok(iteration_report)
624    }
625
626    fn rebuild(&mut self) -> Result<()> {
627        let do_parallel = rayon::current_num_threads() > 1;
628        if self.db.get_table(self.uf_table).rebuilder(&[]).is_some() {
629            // The UF implementation supports "native"  rebuilding.
630            let mut tables = Vec::with_capacity(self.funcs.next_id().index());
631            for (_, func) in self.funcs.iter() {
632                tables.push(func.table);
633            }
634            loop {
635                // Order matters here: we need to rebuild containers first and then rebuild the
636                // tables. Why?
637                //
638                // Say we have a sort that can map to and from a vector containing only itself:
639                // (sort X)
640                // (function to-vec (X) (Vec X) :no-merge)
641                // (constructor from-vec (Vec X) X)
642                // (constructor Num (i64) X)
643                // (constructor Add (X X) X)
644                //
645                // Along with rules:
646                // (rule ((= x (Num i))) ((set (to-vec x) (vec-of x))))
647                // (rule ((= x (Add i j))) ((set (to-vec x) (vec-of x))))
648                // (rule ((= x (from-vec v))) ((set (to-vec x) v))
649                // (rewrite (Add (Num i) (Num j)) (Num (+ i j)))
650                //
651                // These rules, while redundant, should be safe. However, if we rebuild tables
652                // before containers some schedules can cause us to violate the `:no-merge`
653                // directive, which asserts that all values written for a key are equal.
654                //
655                // Suppose we start off with x1=(Num 1), x2=(Num 3), and x3=(Add (Num 1) (Num 2)) as
656                // expressions, with `to-vec` and `from-vec` entries for all three expressions.
657                // We'll call (to-vec xi) vi for all i.
658                //
659                // Now suppose we run the `rewrite` above: now, x3 = x2. But v3 will only equal v2
660                // _after_ we rebuild the `Vec` container. That means that if we rebuild `to-vec`
661                // we will collapse the the rows for x3 and x2, but then fail to merge v3 and v2
662                // because they are not (yet) equal.
663                //
664                // Rebuilding containers first will find that v3 and v2 are equal, and the rest of
665                // the rules can proceed.
666                let container_rebuild = self.db.rebuild_containers(self.uf_table);
667                let next_ts = self.next_ts().to_value();
668                let table_rebuild = self.db.apply_rebuild(self.uf_table, &tables, next_ts);
669                // Container rebuild can make a parent row newly matchable without
670                // changing the row's stored id. Re-timestamp those parents so
671                // seminaive sees the newly enabled match on the next pass.
672                let dirty_ids: Vec<Value> = container_rebuild.dirty_ids().iter().copied().collect();
673                let refreshed_rows = self
674                    .db
675                    .refresh_rows_for_values(&tables, &dirty_ids, next_ts);
676                self.inc_ts();
677                if !table_rebuild && !refreshed_rows && !container_rebuild.changed() {
678                    break;
679                }
680            }
681            return Ok(());
682        }
683        if do_parallel {
684            return self.rebuild_parallel();
685        }
686        let start = Instant::now();
687
688        // The database changed. Rebuild. New entries should land after the given rules.
689        let mut changed = true;
690        while changed {
691            changed = false;
692            // We need to iterate rebuilding to a fixed point. Future scans
693            // should look only at the latest updates.
694            self.inc_ts();
695            let ts = self.next_ts();
696            for (_, info) in self.funcs.iter_mut() {
697                let last_rebuilt_at = self.rules[info.nonincremental_rebuild_rule].last_run_at;
698                let table_size = self.db.estimate_size(info.table, None);
699                let uf_size = self.db.estimate_size(
700                    self.uf_table,
701                    Some(Constraint::GeConst {
702                        col: ColumnId::new(2),
703                        val: last_rebuilt_at.to_value(),
704                    }),
705                );
706                if incremental_rebuild(uf_size, table_size, false) {
707                    marker_incremental_rebuild(|| -> Result<()> {
708                        // Run each of the incremental rules serially.
709                        //
710                        // This is to avoid recanonicalizing the same row multiple
711                        // times.
712                        for rule in &info.incremental_rebuild_rules {
713                            changed |= run_rules_impl(
714                                &mut self.db,
715                                &mut self.rules,
716                                &[*rule],
717                                ts,
718                                ReportLevel::TimeOnly,
719                            )?
720                            .changed;
721                        }
722                        // Reset the rule we did not run. These two should be equivalent.
723                        self.rules[info.nonincremental_rebuild_rule].last_run_at = ts;
724                        Ok(())
725                    })?;
726                } else {
727                    marker_nonincremental_rebuild(|| -> Result<()> {
728                        changed |= run_rules_impl(
729                            &mut self.db,
730                            &mut self.rules,
731                            &[info.nonincremental_rebuild_rule],
732                            ts,
733                            ReportLevel::TimeOnly,
734                        )?
735                        .changed;
736                        for rule in &info.incremental_rebuild_rules {
737                            self.rules[*rule].last_run_at = ts;
738                        }
739                        Ok(())
740                    })?;
741                }
742            }
743        }
744        log::info!("rebuild took {:?}", start.elapsed());
745        Ok(())
746    }
747
748    /// A variant of `rebuild` that attempts to combine rebuild rules into
749    /// larger rulesets to increase parallelism. This kind of preprocessing can
750    /// slow processing down in a single-threaded setting, so it is only used
751    /// when the number of active threads is greater than 1.
752    fn rebuild_parallel(&mut self) -> Result<()> {
753        let start = Instant::now();
754        #[derive(Default)]
755        struct RebuildState {
756            nonincremental: Vec<FunctionId>,
757            incremental: DenseIdMap<usize, SmallVec<[FunctionId; 2]>>,
758        }
759
760        impl RebuildState {
761            fn clear(&mut self) {
762                self.nonincremental.clear();
763                self.incremental.iter_mut().for_each(|(_, v)| v.clear());
764            }
765        }
766
767        let mut changed = true;
768        let mut state = RebuildState::default();
769        let mut scratch = Vec::new();
770        while changed {
771            changed = false;
772            state.clear();
773            self.inc_ts();
774            // First, figure out which functions will be rebuilt nonincrementally,
775            // vs. incrementally. Group them together.
776            for (func, info) in self.funcs.iter_mut() {
777                let last_rebuilt_at = self.rules[info.nonincremental_rebuild_rule].last_run_at;
778                let table_size = self.db.estimate_size(info.table, None);
779                let uf_size = self.db.estimate_size(
780                    self.uf_table,
781                    Some(Constraint::GeConst {
782                        col: ColumnId::new(2),
783                        val: last_rebuilt_at.to_value(),
784                    }),
785                );
786                if incremental_rebuild(uf_size, table_size, true) {
787                    for (i, _) in info.incremental_rebuild_rules.iter().enumerate() {
788                        state.incremental.get_or_default(i).push(func);
789                    }
790                } else {
791                    state.nonincremental.push(func);
792                }
793            }
794            let ts = self.next_ts();
795            for func in state.nonincremental.iter().copied() {
796                scratch.push(self.funcs[func].nonincremental_rebuild_rule);
797                for rule in &self.funcs[func].incremental_rebuild_rules {
798                    self.rules[*rule].last_run_at = ts;
799                }
800            }
801            changed |= run_rules_impl(
802                &mut self.db,
803                &mut self.rules,
804                &scratch,
805                ts,
806                ReportLevel::TimeOnly,
807            )?
808            .changed;
809            scratch.clear();
810            let ts = self.next_ts();
811            for (i, funcs) in state.incremental.iter() {
812                for func in funcs.iter().copied() {
813                    let info = &mut self.funcs[func];
814                    scratch.push(info.incremental_rebuild_rules[i]);
815                    self.rules[info.nonincremental_rebuild_rule].last_run_at = ts;
816                }
817                changed |= run_rules_impl(
818                    &mut self.db,
819                    &mut self.rules,
820                    &scratch,
821                    ts,
822                    ReportLevel::TimeOnly,
823                )?
824                .changed;
825                scratch.clear();
826            }
827        }
828        log::info!("rebuild took {:?}", start.elapsed());
829        Ok(())
830    }
831
832    fn incremental_rebuild_rules(&mut self, table: FunctionId, schema: &[ColumnTy]) -> Vec<RuleId> {
833        schema
834            .iter()
835            .enumerate()
836            .filter_map(|(i, ty)| match ty {
837                ColumnTy::Id => {
838                    Some(self.incremental_rebuild_rule(table, schema, ColumnId::from_usize(i)))
839                }
840                ColumnTy::Base(_) => None,
841            })
842            .collect()
843    }
844
845    fn incremental_rebuild_rule(
846        &mut self,
847        table: FunctionId,
848        schema: &[ColumnTy],
849        col: ColumnId,
850    ) -> RuleId {
851        let subsume = self.funcs[table].can_subsume;
852        let table_id = self.funcs[table].table;
853        let uf_table = self.uf_table;
854        // Two atoms, one binding a whole tuple, one binding a displaced column
855        let mut rb = self.new_rule(&format!("incremental rebuild {table:?}, {col:?}"), true);
856        rb.set_plan_strategy(PlanStrategy::MinCover);
857        let mut vars = Vec::<QueryEntry>::with_capacity(schema.len());
858        for ty in schema {
859            vars.push(rb.new_var(*ty).into());
860        }
861        let canon_val: QueryEntry = rb.new_var(ColumnTy::Id).into();
862        let subsume_var = subsume.then(|| rb.new_var(ColumnTy::Id));
863        rb.add_atom_with_timestamp_and_func(
864            table_id,
865            Some(table),
866            subsume_var.clone().map(QueryEntry::from),
867            &vars,
868        );
869        rb.add_atom_with_timestamp_and_func(
870            uf_table,
871            None,
872            None,
873            &[vars[col.index()].clone(), canon_val.clone()],
874        );
875        rb.set_focus(1); // Set the uf atom as the sole focus.
876
877        // Now canonicalize the entire row.
878        let mut canon = Vec::<QueryEntry>::with_capacity(schema.len());
879        for (i, (var, ty)) in vars.iter().zip(schema.iter()).enumerate() {
880            canon.push(if i == col.index() {
881                canon_val.clone()
882            } else if let ColumnTy::Id = ty {
883                rb.lookup_uf(var.clone()).unwrap().into()
884            } else {
885                var.clone()
886            })
887        }
888
889        // Remove the old row and insert the new one.
890        rb.rebuild_row(table, &vars, &canon, subsume_var);
891        rb.build()
892    }
893
894    fn nonincremental_rebuild(&mut self, table: FunctionId, schema: &[ColumnTy]) -> RuleId {
895        let can_subsume = self.funcs[table].can_subsume;
896        let table_id = self.funcs[table].table;
897        let mut rb = self.new_rule(&format!("nonincremental rebuild {table:?}"), false);
898        rb.set_plan_strategy(PlanStrategy::MinCover);
899        let mut vars = Vec::<QueryEntry>::with_capacity(schema.len());
900        for ty in schema {
901            vars.push(rb.new_var(*ty).into());
902        }
903        let subsume_var = can_subsume.then(|| rb.new_var(ColumnTy::Id));
904        rb.add_atom_with_timestamp_and_func(
905            table_id,
906            Some(table),
907            subsume_var.clone().map(QueryEntry::from),
908            &vars,
909        );
910        let mut lhs = SmallVec::<[QueryEntry; 4]>::new();
911        let mut rhs = SmallVec::<[QueryEntry; 4]>::new();
912        let mut canon = Vec::<QueryEntry>::with_capacity(schema.len());
913        for (var, ty) in vars.iter().zip(schema.iter()) {
914            canon.push(if let ColumnTy::Id = ty {
915                lhs.push(var.clone());
916                let canon_var = QueryEntry::from(rb.lookup_uf(var.clone()).unwrap());
917                rhs.push(canon_var.clone());
918                canon_var
919            } else {
920                var.clone()
921            })
922        }
923        rb.check_for_update(&lhs, &rhs).unwrap();
924        rb.rebuild_row(table, &vars, &canon, subsume_var);
925        rb.build()
926    }
927
928    /// Gives the user a handle to the underlying ExecutionState. Useful for staging updates
929    /// to the database.
930    ///
931    /// The staged updates are not immediately reflected in the EGraph, so you may want to
932    /// manually flush the updates using [`EGraph::flush_updates`].
933    ///
934    /// # Seminaive-safety trust boundary
935    ///
936    /// This method hands out a raw `&mut ExecutionState`, which bypasses
937    /// the egglog crate's `Read` / `Write` capability wrappers
938    /// (`PureState`, `WriteState`, `ReadState`, `FullState`) that
939    /// enforce #772's seminaive-safety model. Treat it as top-level
940    /// / global-action context: appropriate for one-shot database
941    /// manipulation from outside any rule, not for use inside
942    /// primitive implementations.
943    pub fn with_execution_state<R>(&self, f: impl FnOnce(&mut ExecutionState<'_>) -> R) -> R {
944        self.db.with_execution_state(f)
945    }
946
947    /// Like [`EGraph::with_execution_state`], but also reports whether `f`
948    /// staged any mutation. A read-only closure leaves the flag `false`, so
949    /// callers can skip a [`EGraph::flush_updates`] that would otherwise be a
950    /// no-op merge plus a spurious timestamp bump.
951    pub fn with_execution_state_tracked<R>(
952        &self,
953        f: impl FnOnce(&mut ExecutionState<'_>) -> R,
954    ) -> (R, bool) {
955        self.db.with_execution_state_tracked(f)
956    }
957
958    /// Flush the pending update buffers to the EGraph.
959    /// Returns `true` if the database is updated.
960    pub fn flush_updates(&mut self) -> bool {
961        let uf_size_before = self.db.get_table(self.uf_table).len();
962        let updated = self.db.merge_all();
963        self.inc_ts();
964        let uf_size_after = self.db.get_table(self.uf_table).len();
965        if uf_size_before != uf_size_after {
966            // Rebuilding is only necessary when new unions have been made because ids may need to be updated.
967            // Adding terms doesn't necessarily touch the union-find, only doing a union between existing ids does.
968            self.rebuild().unwrap();
969        }
970        updated
971    }
972
973    pub fn set_report_level(&mut self, level: ReportLevel) {
974        self.report_level = level;
975    }
976}
977
978#[derive(Clone)]
979struct RuleInfo {
980    last_run_at: Timestamp,
981    query: rule::Query,
982    cached_plan: Option<CachedPlanInfo>,
983    desc: Arc<str>,
984}
985
986#[derive(Clone)]
987struct CachedPlanInfo {
988    plan: Arc<core_relations::CachedPlan>,
989    /// A mapping from index into a [`rule::Query`]'s atoms to the atoms in the underlying cached
990    /// plan.
991    atom_mapping: Vec<core_relations::AtomId>,
992}
993
994#[derive(Clone)]
995struct FunctionInfo {
996    table: TableId,
997    schema: Vec<ColumnTy>,
998    incremental_rebuild_rules: Vec<RuleId>,
999    nonincremental_rebuild_rule: RuleId,
1000    default_val: DefaultVal,
1001    can_subsume: bool,
1002    name: Arc<str>,
1003}
1004
1005impl FunctionInfo {
1006    fn ret_ty(&self) -> ColumnTy {
1007        self.schema.last().copied().unwrap()
1008    }
1009}
1010
1011/// How defaults are computed for the given function.
1012#[derive(Copy, Clone)]
1013pub enum DefaultVal {
1014    /// Generate a fresh UF id.
1015    FreshId,
1016    /// Cause an egglog-level panic if a lookup fails.
1017    Fail,
1018    /// Insert a constant of some kind.
1019    Const(Value),
1020}
1021
1022/// How to resolve FD conflicts for a table.
1023pub enum MergeFn {
1024    /// Panic if the old and new values don't match.
1025    AssertEq,
1026    /// Use congruence to resolve FD conflicts.
1027    UnionId,
1028    /// The output of a merge is determined by applying the given ExternalFunction to the result
1029    /// of the argument merge functions.
1030    Primitive(ExternalFunctionId, Vec<MergeFn>),
1031    /// The output of a merge is determined by looking up the value for the given function and the
1032    /// given arguments in the egraph.
1033    Function(FunctionId, Vec<MergeFn>),
1034    /// Always return the old value for the given function.
1035    Old,
1036    /// Always return the new value for the given function.
1037    New,
1038    /// Always overwrite the new value for the given function with a constant. This is more useful
1039    /// as a "base case" in a more complicated merge function (e.g. one that clamps a value between
1040    /// 1 and 100) than it is as a standalone merge function.
1041    Const(Value),
1042}
1043
1044impl MergeFn {
1045    fn fill_deps(
1046        &self,
1047        egraph: &EGraph,
1048        read_deps: &mut IndexSet<TableId>,
1049        write_deps: &mut IndexSet<TableId>,
1050    ) {
1051        use MergeFn::*;
1052        match self {
1053            Primitive(_, args) => {
1054                args.iter()
1055                    .for_each(|arg| arg.fill_deps(egraph, read_deps, write_deps));
1056                write_deps.insert(egraph.uf_table);
1057            }
1058            Function(func, args) => {
1059                read_deps.insert(egraph.funcs[*func].table);
1060                write_deps.insert(egraph.funcs[*func].table);
1061                args.iter()
1062                    .for_each(|arg| arg.fill_deps(egraph, read_deps, write_deps));
1063            }
1064            UnionId => {
1065                write_deps.insert(egraph.uf_table);
1066            }
1067            AssertEq | Old | New | Const(..) => {}
1068        }
1069    }
1070
1071    fn to_callback(
1072        &self,
1073        schema_math: SchemaMath,
1074        function_name: &str,
1075        egraph: &mut EGraph,
1076    ) -> Box<core_relations::MergeFn> {
1077        let resolved = self.resolve(function_name, egraph);
1078
1079        Box::new(move |state, cur, new, out| {
1080            let timestamp = new[schema_math.ts_col()];
1081
1082            let mut changed = false;
1083
1084            let ret_val = {
1085                let cur = cur[schema_math.ret_val_col()];
1086                let new = new[schema_math.ret_val_col()];
1087                let out = resolved.run(state, cur, new, timestamp);
1088                changed |= cur != out;
1089                out
1090            };
1091
1092            let subsume = schema_math.subsume.then(|| {
1093                let cur = cur[schema_math.subsume_col()];
1094                let new = new[schema_math.subsume_col()];
1095                let out = combine_subsumed(cur, new);
1096                changed |= cur != out;
1097                out
1098            });
1099            if changed {
1100                out.extend_from_slice(new);
1101                schema_math.write_table_row(
1102                    out,
1103                    RowVals {
1104                        timestamp,
1105                        subsume,
1106                        ret_val: Some(ret_val),
1107                    },
1108                );
1109            }
1110
1111            changed
1112        })
1113    }
1114
1115    fn resolve(&self, function_name: &str, egraph: &mut EGraph) -> ResolvedMergeFn {
1116        match self {
1117            MergeFn::Const(v) => ResolvedMergeFn::Const(*v),
1118            MergeFn::Old => ResolvedMergeFn::Old,
1119            MergeFn::New => ResolvedMergeFn::New,
1120            MergeFn::AssertEq => ResolvedMergeFn::AssertEq {
1121                panic: egraph.new_panic(format!(
1122                    "Illegal merge attempted for function {function_name}"
1123                )),
1124            },
1125            MergeFn::UnionId => ResolvedMergeFn::UnionId {
1126                uf_table: egraph.uf_table,
1127            },
1128            // NB: The primitive and function-based merge functions heap allocate a single callback
1129            // for each layer of nesting. This introduces a bit of overhead, particularly for cases
1130            // that look like `(f old new)` or `(f new old)`. We could special-case common cases in
1131            // this function if that overhead shows up.
1132            MergeFn::Primitive(prim, args) => ResolvedMergeFn::Primitive {
1133                prim: *prim,
1134                args: args
1135                    .iter()
1136                    .map(|arg| arg.resolve(function_name, egraph))
1137                    .collect::<Vec<_>>(),
1138                panic: egraph.new_panic(format!(
1139                    "Merge function for {function_name} primitive call failed"
1140                )),
1141            },
1142            MergeFn::Function(func, args) => {
1143                let func_info = &egraph.funcs[*func];
1144                assert_eq!(
1145                    func_info.schema.len(),
1146                    args.len() + 1,
1147                    "Merge function for {function_name} must match function arity for {}",
1148                    func_info.name
1149                );
1150                ResolvedMergeFn::Function {
1151                    func: TableAction::new(egraph, *func),
1152                    panic: egraph.new_panic(format!(
1153                        "Lookup on {} failed in the merge function for {function_name}",
1154                        func_info.name
1155                    )),
1156                    args: args
1157                        .iter()
1158                        .map(|arg| arg.resolve(function_name, egraph))
1159                        .collect::<Vec<_>>(),
1160                }
1161            }
1162        }
1163    }
1164}
1165
1166/// This enum is taking the place of a
1167/// `Box<dyn Fn(&mut ExecutionState, Value, Value, Value) -> Value + Send + Sync>`
1168/// to avoid extra boxes. It stores the data needed to run a `MergeFn` without
1169/// holding onto any references, so it can be `move`d inside the `core_relations::MergeFn`.
1170enum ResolvedMergeFn {
1171    Const(Value),
1172    Old,
1173    New,
1174    AssertEq {
1175        panic: ExternalFunctionId,
1176    },
1177    UnionId {
1178        uf_table: TableId,
1179    },
1180    Primitive {
1181        prim: ExternalFunctionId,
1182        args: Vec<ResolvedMergeFn>,
1183        panic: ExternalFunctionId,
1184    },
1185    Function {
1186        func: TableAction,
1187        args: Vec<ResolvedMergeFn>,
1188        panic: ExternalFunctionId,
1189    },
1190}
1191
1192impl ResolvedMergeFn {
1193    fn run(&self, state: &mut ExecutionState, cur: Value, new: Value, ts: Value) -> Value {
1194        match self {
1195            ResolvedMergeFn::Const(v) => *v,
1196            ResolvedMergeFn::Old => cur,
1197            ResolvedMergeFn::New => new,
1198            ResolvedMergeFn::AssertEq { panic } => {
1199                if cur != new {
1200                    let res = state.call_external_func(*panic, &[]);
1201                    assert_eq!(res, None);
1202                }
1203                cur
1204            }
1205            ResolvedMergeFn::UnionId { uf_table } => {
1206                if cur != new {
1207                    state.stage_insert(*uf_table, &[cur, new, ts]);
1208                    // We pick the minimum when unioning. This matches the original egglog
1209                    // behavior. THIS MUST MATCH THE UNION-FIND IMPLEMENTATION!
1210                    std::cmp::min(cur, new)
1211                } else {
1212                    cur
1213                }
1214            }
1215            // NB: The primitive and function-based merge functions heap allocate a single callback
1216            // for each layer of nesting. This introduces a bit of overhead, particularly for cases
1217            // that look like `(f old new)` or `(f new old)`. We could special-case common cases in
1218            // this function if that overhead shows up.
1219            ResolvedMergeFn::Primitive { prim, args, panic } => {
1220                let args = args
1221                    .iter()
1222                    .map(|arg| arg.run(state, cur, new, ts))
1223                    .collect::<Vec<_>>();
1224
1225                match state.call_external_func(*prim, &args) {
1226                    Some(result) => result,
1227                    None => {
1228                        let res = state.call_external_func(*panic, &[]);
1229                        assert_eq!(res, None);
1230                        cur
1231                    }
1232                }
1233            }
1234            ResolvedMergeFn::Function { func, args, panic } => {
1235                // see github.com/egraphs-good/egglog/pull/287
1236                if cur == new {
1237                    return cur;
1238                }
1239
1240                let args = args
1241                    .iter()
1242                    .map(|arg| arg.run(state, cur, new, ts))
1243                    .collect::<Vec<_>>();
1244
1245                // Merge functions dispatch to another function that may be
1246                // a constructor (mint fresh id on miss) or a custom function
1247                // (return `None` → panic). `lookup_or_insert` preserves
1248                // both behaviors; the pure-read `lookup` would skip
1249                // constructor minting.
1250                func.lookup_or_insert(state, &args).unwrap_or_else(|| {
1251                    let res = state.call_external_func(*panic, &[]);
1252                    assert_eq!(res, None);
1253                    cur
1254                })
1255            }
1256        }
1257    }
1258}
1259
1260/// Coarse classification of a table — `Constructor` mints a fresh
1261/// eclass id when a row is missed; `Function` does not. Mirrors the
1262/// `FunctionSubtype` split on the egglog side without dragging that
1263/// type into the bridge crate.
1264#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
1265pub enum TableKind {
1266    Function,
1267    Constructor,
1268}
1269
1270/// This is an intern-able struct that holds all the data needed
1271/// to do table operations with an [`ExecutionState`], assuming
1272/// that the [`FunctionId`] for the table is known ahead of time.
1273#[derive(Debug, PartialEq, Eq, Hash, Clone)]
1274pub struct TableAction {
1275    table: TableId,
1276    table_math: SchemaMath,
1277    default: Option<MergeVal>,
1278    timestamp: CounterId,
1279    kind: TableKind,
1280}
1281
1282impl TableAction {
1283    /// Create a new `TableAction` to be used later.
1284    /// This requires access to the `egglog_bridge::EGraph`.
1285    pub fn new(egraph: &EGraph, func: FunctionId) -> TableAction {
1286        let func_info = &egraph.funcs[func];
1287        let kind = match &func_info.default_val {
1288            DefaultVal::FreshId => TableKind::Constructor,
1289            DefaultVal::Fail | DefaultVal::Const(_) => TableKind::Function,
1290        };
1291        TableAction {
1292            table: func_info.table,
1293            table_math: SchemaMath {
1294                func_cols: func_info.schema.len(),
1295                subsume: func_info.can_subsume,
1296            },
1297            default: match &func_info.default_val {
1298                DefaultVal::FreshId => Some(MergeVal::Counter(egraph.id_counter)),
1299                DefaultVal::Fail => None,
1300                DefaultVal::Const(val) => Some(MergeVal::Constant(*val)),
1301            },
1302            timestamp: egraph.timestamp_counter,
1303            kind,
1304        }
1305    }
1306
1307    /// Whether this table is a `Function` (no auto-insert) or a
1308    /// `Constructor` (mints a fresh eclass id on miss).
1309    pub fn kind(&self) -> TableKind {
1310        self.kind
1311    }
1312
1313    /// Number of input columns (schema minus the trailing output
1314    /// column).
1315    pub fn input_arity(&self) -> usize {
1316        self.table_math.func_cols - 1
1317    }
1318
1319    /// Look up a row and return its return-value column, or `None` if the
1320    /// key is not present. **This is a pure read**: it never inserts a row,
1321    /// regardless of the table's configured [`DefaultVal`].
1322    ///
1323    /// For the lookup-or-insert behavior that mints fresh eclass IDs for
1324    /// constructors, use [`TableAction::lookup_or_insert`].
1325    pub fn lookup(&self, state: &ExecutionState, key: &[Value]) -> Option<Value> {
1326        state
1327            .get_table(self.table)
1328            .get_row(key)
1329            .map(|row| row.vals[self.table_math.ret_val_col()])
1330    }
1331
1332    /// Return the current number of rows in this table.
1333    pub fn row_count(&self, state: &ExecutionState) -> usize {
1334        state.get_table(self.table).len()
1335    }
1336
1337    /// Iterate this table's rows, calling `f` on each function row.
1338    /// Mirrors [`EGraph::for_each`] but reaches the table through an
1339    /// [`ExecutionState`] — so it's callable from primitive bodies via
1340    /// the typed `Read`-style API.
1341    pub fn for_each(&self, state: &ExecutionState, mut f: impl FnMut(ScanEntry<'_>)) {
1342        self.for_each_while(state, |row| {
1343            f(row);
1344            true
1345        });
1346    }
1347
1348    /// Like [`TableAction::for_each`], but stops as soon as `f`
1349    /// returns `false`.
1350    pub fn for_each_while(&self, state: &ExecutionState, mut f: impl FnMut(ScanEntry<'_>) -> bool) {
1351        let schema_math = self.table_math;
1352        let imp = state.get_table(self.table);
1353        let all = imp.all();
1354        let mut cur = Offset::new(0);
1355        let mut buf = TaggedRowBuffer::new(imp.spec().arity());
1356        macro_rules! drain_buf {
1357            ($buf:expr) => {
1358                for (_, row) in $buf.non_stale() {
1359                    let subsumed =
1360                        schema_math.subsume && row[schema_math.subsume_col()] == SUBSUMED;
1361                    if !f(ScanEntry {
1362                        vals: &row[0..schema_math.func_cols],
1363                        subsumed,
1364                    }) {
1365                        return;
1366                    }
1367                }
1368                $buf.clear();
1369            };
1370        }
1371        while let Some(next) = imp.scan_bounded(all.as_ref(), cur, 32, &mut buf) {
1372            drain_buf!(buf);
1373            cur = next;
1374        }
1375        drain_buf!(buf);
1376    }
1377
1378    /// Look up a row, inserting the configured default value if absent.
1379    /// For constructor tables this mints a fresh eclass ID; for custom
1380    /// functions (no default) this behaves identically to
1381    /// [`TableAction::lookup`].
1382    ///
1383    /// This is a write operation — only safe in action contexts. See
1384    /// issue #772.
1385    pub fn lookup_or_insert(&self, state: &mut ExecutionState, key: &[Value]) -> Option<Value> {
1386        match self.default {
1387            Some(default) => {
1388                let timestamp =
1389                    MergeVal::Constant(Value::from_usize(state.read_counter(self.timestamp)));
1390                let mut merge_vals = SmallVec::<[MergeVal; 3]>::new();
1391                SchemaMath {
1392                    func_cols: 1,
1393                    ..self.table_math
1394                }
1395                .write_table_row(
1396                    &mut merge_vals,
1397                    RowVals {
1398                        timestamp,
1399                        subsume: self
1400                            .table_math
1401                            .subsume
1402                            .then_some(MergeVal::Constant(NOT_SUBSUMED)),
1403                        ret_val: Some(default),
1404                    },
1405                );
1406                Some(
1407                    state.predict_val(self.table, key, merge_vals.iter().copied())
1408                        [self.table_math.ret_val_col()],
1409                )
1410            }
1411            None => self.lookup(state, key),
1412        }
1413    }
1414
1415    /// Insert a row into this table.
1416    pub fn insert(&self, state: &mut ExecutionState, row: impl Iterator<Item = Value>) {
1417        let ts = Value::from_usize(state.read_counter(self.timestamp));
1418        let mut scratch = row.collect::<SmallVec<[_; 8]>>();
1419        self.table_math.write_table_row(
1420            &mut scratch,
1421            RowVals {
1422                timestamp: ts,
1423                subsume: self.table_math.subsume.then_some(NOT_SUBSUMED),
1424                ret_val: None,
1425            },
1426        );
1427        state.stage_insert(self.table, &scratch);
1428    }
1429
1430    /// Delete a row from this table.
1431    pub fn remove(&self, state: &mut ExecutionState, key: &[Value]) {
1432        state.stage_remove(self.table, key);
1433    }
1434
1435    /// Subsume a row in this table.
1436    pub fn subsume(&self, state: &mut ExecutionState, key: impl Iterator<Item = Value>) {
1437        let ts = Value::from_usize(state.read_counter(self.timestamp));
1438        let mut scratch = key.collect::<SmallVec<[_; 8]>>();
1439
1440        let ret_val = self.lookup(state, &scratch).expect("subsume lookup failed");
1441
1442        self.table_math.write_table_row(
1443            &mut scratch,
1444            RowVals {
1445                timestamp: ts,
1446                subsume: Some(SUBSUMED),
1447                ret_val: Some(ret_val),
1448            },
1449        );
1450        state.stage_insert(self.table, &scratch);
1451    }
1452}
1453
1454/// A variant of `TableAction` for the union-find.
1455#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1456pub struct UnionAction {
1457    table: TableId,
1458    timestamp: CounterId,
1459}
1460
1461impl UnionAction {
1462    /// Create a new `UnionAction` to be used later.
1463    /// This requires access to the `egglog_bridge::EGraph`.
1464    pub fn new(egraph: &EGraph) -> UnionAction {
1465        UnionAction {
1466            table: egraph.uf_table,
1467            timestamp: egraph.timestamp_counter,
1468        }
1469    }
1470
1471    /// Union two values.
1472    pub fn union(&self, state: &mut ExecutionState, x: Value, y: Value) {
1473        let ts = Value::from_usize(state.read_counter(self.timestamp));
1474        state.stage_insert(self.table, &[x, y, ts]);
1475    }
1476}
1477
1478fn run_rules_impl(
1479    db: &mut Database,
1480    rule_info: &mut DenseIdMapWithReuse<RuleId, RuleInfo>,
1481    rules: &[RuleId],
1482    next_ts: Timestamp,
1483    report_level: ReportLevel,
1484) -> Result<RuleSetReport> {
1485    for rule in rules {
1486        let info = &mut rule_info[*rule];
1487        if info.cached_plan.is_none() {
1488            info.cached_plan = Some(info.query.build_cached_plan(db, &info.desc)?);
1489        }
1490    }
1491    let mut rsb = db.new_rule_set();
1492    for rule in rules {
1493        let info = &mut rule_info[*rule];
1494        let cached_plan = info.cached_plan.as_ref().unwrap();
1495        info.query
1496            .add_rules_from_cached(&mut rsb, info.last_run_at, cached_plan);
1497        info.last_run_at = next_ts;
1498    }
1499    let ruleset = rsb.build();
1500    Ok(db.run_rule_set(&ruleset, report_level))
1501}
1502
1503// These markers are just used to make it easy to distinguish time spent in
1504// incremental vs. nonincremental rebuilds in time-based profiles.
1505
1506#[inline(never)]
1507fn marker_incremental_rebuild<R>(f: impl FnOnce() -> R) -> R {
1508    f()
1509}
1510
1511#[inline(never)]
1512fn marker_nonincremental_rebuild<R>(f: impl FnOnce() -> R) -> R {
1513    f()
1514}
1515
1516/// A useful type definition for external functions that need to pass data
1517/// to outside code, such as `Panic`.
1518pub type SideChannel<T> = Arc<Mutex<Option<T>>>;
1519
1520/// An external function used to grab a value out of the database matching a
1521/// particular query.
1522//
1523// TODO: once we have parallelism wired in, we'll want to replace this with a
1524// more efficient solution (e.g. one based on crossbeam or arcswap).
1525/// This is a variant on [`Panic`] that avoids eager construction of the panic message.
1526///
1527/// The main thing this is used for is to avoid constructing the panic message ahead of time during
1528/// a call to [`RuleBuilder::call_external_func`]; these panic messages are often quite rare and
1529/// may never need to be constructed at all. Furthermore, a closure to produce the panic message in
1530/// most cases need only close over a few cheap-to-clone values.
1531///
1532/// The downside of this, and why we do not use it everywhere, is that there's no natural "key"
1533/// that we can use to cache duplicate panic messages. We would need a more complex API to support
1534/// both and fully replace our use of `Panic`.
1535struct LazyPanic<F>(Arc<Lazy<String, F>>, SideChannel<String>);
1536
1537impl<F: FnOnce() -> String + Send> ExternalFunction for LazyPanic<F> {
1538    fn invoke(&self, state: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
1539        assert!(args.is_empty());
1540        state.trigger_early_stop();
1541        let mut guard = self.1.lock().unwrap();
1542        if guard.is_none() {
1543            *guard = Some(Lazy::force(&self.0).clone());
1544        }
1545        None
1546    }
1547}
1548
1549impl<F> Clone for LazyPanic<F> {
1550    fn clone(&self) -> Self {
1551        LazyPanic(self.0.clone(), self.1.clone())
1552    }
1553}
1554
1555/// An external function used to store a message when a panic occurs.
1556//
1557// TODO: once we have parallelism wired in, we'll want to replace this with a
1558// more efficient solution (e.g. one based on crossbeam or arcswap).
1559#[derive(Clone)]
1560struct Panic(String, SideChannel<String>);
1561
1562impl EGraph {
1563    /// Create a new `ExternalFunction` that panics with the given message.
1564    pub fn new_panic(&mut self, message: String) -> ExternalFunctionId {
1565        *self
1566            .panic_funcs
1567            .entry(message.to_string())
1568            .or_insert_with(|| {
1569                let panic = Panic(message, self.panic_message.clone());
1570                self.db.add_external_function(Box::new(panic))
1571            })
1572    }
1573
1574    pub fn new_panic_lazy(
1575        &mut self,
1576        message: impl FnOnce() -> String + Send + 'static,
1577    ) -> ExternalFunctionId {
1578        let lazy = Lazy::new(message);
1579        let panic = LazyPanic(Arc::new(lazy), self.panic_message.clone());
1580        self.db.add_external_function(Box::new(panic))
1581    }
1582}
1583
1584impl ExternalFunction for Panic {
1585    fn invoke(&self, state: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
1586        // TODO (egglog feature): change this to support interpolating panic messages
1587        assert!(args.is_empty());
1588
1589        state.trigger_early_stop();
1590        let mut guard = self.1.lock().unwrap();
1591        if guard.is_none() {
1592            *guard = Some(self.0.clone());
1593        }
1594        None
1595    }
1596}
1597
1598/// Heuristic for deciding whether to do an incremental or nonincremental
1599/// rebuild for a given table.
1600fn incremental_rebuild(uf_size: usize, table_size: usize, parallel: bool) -> bool {
1601    if parallel {
1602        uf_size <= (table_size / 16)
1603    } else {
1604        uf_size <= (table_size / 8)
1605    }
1606}
1607
1608pub(crate) const SUBSUMED: Value = Value::new_const(1);
1609pub(crate) const NOT_SUBSUMED: Value = Value::new_const(0);
1610fn combine_subsumed(v1: Value, v2: Value) -> Value {
1611    std::cmp::max(v1, v2)
1612}
1613
1614/// A struct helping with some calculations of where some information is stored at the
1615/// core-relations Table level for a given function.
1616///
1617/// Functions can have multiple "output columns" in the underlying core-relations layer depending
1618/// on whether different features are enabled. Roughly, tables are laid out as:
1619///
1620/// > `[key0, ..., keyn, return value, timestamp, subsume?]`
1621///
1622/// Where there are `n+1` key columns and columns marked with a question mark are optional,
1623/// depending on the egraph and table-level configuration.
1624#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1625struct SchemaMath {
1626    /// Whether or not the table is enabled for subsumption.
1627    subsume: bool,
1628    /// The number of columns in the function (including the return value).
1629    func_cols: usize,
1630}
1631
1632/// A struct containing possible non-key portions of a table row. To be used with
1633/// [`SchemaMath::write_table_row`].
1634///
1635/// This is the write side (building a row); [`ScanEntry`] is the
1636/// read side (a row yielded by a table scan).
1637struct RowVals<T> {
1638    /// The timestamp for the row.
1639    timestamp: T,
1640    /// The subsumption tag for the row. Only relevant if the table has subsumption enabled.
1641    subsume: Option<T>,
1642    /// The return value of the row. Return values are mandatory but callers may have already
1643    /// filled it in.
1644    ret_val: Option<T>,
1645}
1646
1647/// A raw row yielded by a table scan; `vals` includes the trailing
1648/// output/eclass column.
1649///
1650/// Public so the `egglog` crate can consume it, but **do not re-export
1651/// it from `egglog`'s public API** — the user-facing row types are
1652/// `egglog::FunctionEntry` and `egglog::Enode`.
1653#[derive(Clone, Debug)]
1654pub struct ScanEntry<'a> {
1655    pub vals: &'a [Value],
1656    pub subsumed: bool,
1657}
1658
1659impl SchemaMath {
1660    fn write_table_row<T: Clone>(
1661        &self,
1662        row: &mut impl HasResizeWith<T>,
1663        RowVals {
1664            timestamp,
1665            subsume,
1666            ret_val,
1667        }: RowVals<T>,
1668    ) {
1669        row.resize_with(self.table_columns(), || timestamp.clone());
1670        row[self.ts_col()] = timestamp;
1671        if let Some(ret_val) = ret_val {
1672            row[self.ret_val_col()] = ret_val;
1673        }
1674        if let Some(subsume) = subsume {
1675            row[self.subsume_col()] = subsume;
1676        } else {
1677            assert!(
1678                !self.subsume,
1679                "subsume flag must be provided if subsumption is enabled"
1680            );
1681        }
1682    }
1683
1684    fn num_keys(&self) -> usize {
1685        self.func_cols - 1
1686    }
1687
1688    fn table_columns(&self) -> usize {
1689        self.func_cols + 1 /* timestamp */ + if self.subsume { 1 } else { 0 }
1690    }
1691
1692    fn ret_val_col(&self) -> usize {
1693        self.func_cols - 1
1694    }
1695
1696    fn ts_col(&self) -> usize {
1697        self.func_cols
1698    }
1699
1700    #[track_caller]
1701    fn subsume_col(&self) -> usize {
1702        assert!(self.subsume);
1703        self.func_cols + 1
1704    }
1705}
1706
1707#[derive(Error, Debug)]
1708#[error("Panic: {0}")]
1709struct PanicError(String);
1710
1711/// Basic ad-hoc polymorphism around `resize_with` in order to get [`SchemaMath::write_table_row`]
1712/// to work with both `Vec` and `SmallVec`.
1713trait HasResizeWith<T>:
1714    AsMut<[T]> + AsRef<[T]> + Index<usize, Output = T> + IndexMut<usize, Output = T>
1715{
1716    fn resize_with<F>(&mut self, new_size: usize, f: F)
1717    where
1718        F: FnMut() -> T;
1719}
1720
1721impl<T> HasResizeWith<T> for Vec<T> {
1722    fn resize_with<F>(&mut self, new_size: usize, f: F)
1723    where
1724        F: FnMut() -> T,
1725    {
1726        self.resize_with(new_size, f);
1727    }
1728}
1729
1730impl<T, A: smallvec::Array<Item = T>> HasResizeWith<T> for SmallVec<A> {
1731    fn resize_with<F>(&mut self, new_size: usize, f: F)
1732    where
1733        F: FnMut() -> T,
1734    {
1735        self.resize_with(new_size, f);
1736    }
1737}