egglog_bridge/
rule.rs

1//! APIs for building egglog rules.
2//!
3//! Egglog rules are ultimately just (sets of) `core-relations` rules
4//! parameterized by a range of timestamps used as constraints during seminaive
5//! evaluation.
6
7use std::sync::Arc;
8
9use crate::core_relations;
10use crate::core_relations::{
11    ColumnId, Constraint, CounterId, ExternalFunctionId, PlanStrategy, QueryBuilder,
12    RuleBuilder as CoreRuleBuilder, RuleSetBuilder, TableId, Value, WriteVal,
13};
14use crate::numeric_id::{DenseIdMap, NumericId, define_id};
15use anyhow::Context;
16use hashbrown::HashSet;
17use log::debug;
18use smallvec::SmallVec;
19use thiserror::Error;
20
21use crate::{CachedPlanInfo, NOT_SUBSUMED, RowVals, SUBSUMED, SchemaMath};
22use crate::{ColumnTy, DefaultVal, EGraph, FunctionId, Result, RuleId, RuleInfo, Timestamp};
23
24define_id!(pub VariableId, u32, "A variable in an egglog query");
25define_id!(pub AtomId, u32, "an atom in an egglog query");
26pub(crate) type DstVar = core_relations::QueryEntry;
27
28impl VariableId {
29    fn to_var(self) -> Variable {
30        Variable {
31            id: self,
32            name: None,
33        }
34    }
35}
36
37#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct Variable {
39    pub id: VariableId,
40    pub name: Option<Box<str>>,
41}
42
43#[derive(Debug, Error)]
44enum RuleBuilderError {
45    #[error("type mismatch: expected {expected:?}, got {got:?}")]
46    TypeMismatch { expected: ColumnTy, got: ColumnTy },
47    #[error("arity mismatch: expected {expected:?}, got {got:?}")]
48    ArityMismatch { expected: usize, got: usize },
49}
50
51#[derive(Clone)]
52struct VarInfo {
53    ty: ColumnTy,
54    name: Option<Box<str>>,
55}
56
57#[derive(Clone, Debug, PartialEq, Eq, Hash)]
58pub enum QueryEntry {
59    Var(Variable),
60    Const {
61        val: Value,
62        // Constants can have a type plumbed through, particularly if they
63        // correspond to a base value constant in egglog.
64        ty: ColumnTy,
65    },
66}
67
68impl From<Variable> for QueryEntry {
69    fn from(var: Variable) -> Self {
70        QueryEntry::Var(var)
71    }
72}
73
74#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
75pub enum Function {
76    Table(FunctionId),
77    Prim(ExternalFunctionId),
78}
79
80impl From<FunctionId> for Function {
81    fn from(f: FunctionId) -> Self {
82        Function::Table(f)
83    }
84}
85
86impl From<ExternalFunctionId> for Function {
87    fn from(f: ExternalFunctionId) -> Self {
88        Function::Prim(f)
89    }
90}
91
92trait Brc:
93    Fn(&mut Bindings, &mut CoreRuleBuilder) -> Result<()> + dyn_clone::DynClone + Send + Sync
94{
95}
96impl<T: Fn(&mut Bindings, &mut CoreRuleBuilder) -> Result<()> + Clone + Send + Sync> Brc for T {}
97dyn_clone::clone_trait_object!(Brc);
98type BuildRuleCallback = Box<dyn Brc>;
99
100#[derive(Clone)]
101pub(crate) struct Query {
102    uf_table: TableId,
103    id_counter: CounterId,
104    ts_counter: CounterId,
105    rule_id: RuleId,
106    vars: DenseIdMap<VariableId, VarInfo>,
107    atoms: Vec<(TableId, Vec<QueryEntry>, SchemaMath)>,
108    /// The builders for queries in this module essentially wrap the lower-level
109    /// builders from the `core_relations` crate. A single egglog rule can turn
110    /// into N core-relations rules. The code is structured by constructing a
111    /// series of callbacks that will iteratively build up a low-level rule that
112    /// looks like the high-level rule, passing along an environment that keeps
113    /// track of the mappings between low and high-level variables.
114    add_rule: Vec<BuildRuleCallback>,
115    /// If set, execute a single rule (rather than O(atoms.len()) rules) during
116    /// seminaive, with the given atom as the focus.
117    sole_focus: Option<usize>,
118    seminaive: bool,
119    plan_strategy: PlanStrategy,
120    /// If `true`, skip tree-decomposition during query planning. See
121    /// [`core_relations::QueryBuilder::set_no_decomp`].
122    no_decomp: bool,
123}
124
125pub struct RuleBuilder<'a> {
126    egraph: &'a mut EGraph,
127    desc: Arc<str>,
128    query: Query,
129}
130
131impl EGraph {
132    /// Add a rewrite rule for this [`EGraph`] using a [`RuleBuilder`].
133    /// If you aren't sure, use `egraph.new_rule("", true)`.
134    pub fn new_rule(&mut self, desc: &str, seminaive: bool) -> RuleBuilder<'_> {
135        let uf_table = self.uf_table;
136        let id_counter = self.id_counter;
137        let ts_counter = self.timestamp_counter;
138        let rule_id = self.rules.reserve_slot();
139        RuleBuilder {
140            egraph: self,
141            desc: Arc::from(desc),
142            query: Query {
143                uf_table,
144                id_counter,
145                ts_counter,
146                rule_id,
147                seminaive,
148                sole_focus: None,
149                vars: Default::default(),
150                atoms: Default::default(),
151                add_rule: Default::default(),
152                plan_strategy: Default::default(),
153                no_decomp: false,
154            },
155        }
156    }
157
158    /// Remove a rewrite rule from this [`EGraph`].
159    pub fn free_rule(&mut self, id: RuleId) {
160        self.rules.take(id);
161    }
162}
163
164impl RuleBuilder<'_> {
165    fn add_callback(&mut self, cb: impl Brc + 'static) {
166        self.query.add_rule.push(Box::new(cb));
167    }
168
169    /// Access the underlying egraph within the builder.
170    pub fn egraph(&self) -> &EGraph {
171        self.egraph
172    }
173
174    /// Register a runtime panic with a custom message and return its
175    /// id. When called via [`call_external_func`], the panic writes
176    /// the message to the egraph's panic side channel and triggers
177    /// early stop, so `run_rules` returns an `Err` carrying the
178    /// message rather than the calling thread unwinding.
179    ///
180    /// [`call_external_func`]: Self::call_external_func
181    pub fn new_panic(&mut self, message: String) -> crate::ExternalFunctionId {
182        self.egraph.new_panic(message)
183    }
184
185    pub(crate) fn set_plan_strategy(&mut self, strategy: PlanStrategy) {
186        self.query.plan_strategy = strategy;
187    }
188
189    /// If `true`, the query planner will skip tree-decomposition for
190    /// this rule. Mirrors
191    /// [`core_relations::QueryBuilder::set_no_decomp`]; set from the
192    /// `:no-decomp` rule option or the egglog `--no-decomp` CLI flag.
193    pub fn set_no_decomp(&mut self, no_decomp: bool) {
194        self.query.no_decomp = no_decomp;
195    }
196
197    /// Get the canonical value of an id in the union-find. An internal-only
198    /// routine used to implement rebuilding.
199    ///
200    /// Note, calling this with a non-Id entry can cause errors at rule runtime
201    /// (The derived rules will not compile).
202    pub(crate) fn lookup_uf(&mut self, entry: QueryEntry) -> Result<Variable> {
203        let res = self.new_var(ColumnTy::Id);
204        let uf_table = self.query.uf_table;
205        self.assert_has_ty(&entry, ColumnTy::Id)
206            .context("lookup_uf: ")?;
207        self.add_callback(move |inner, rb| {
208            let entry = inner.convert(&entry);
209            let res_inner = rb.lookup_with_default(uf_table, &[entry], entry, ColumnId::new(1))?;
210            inner.mapping.insert(res.id, res_inner.into());
211            Ok(())
212        });
213        Ok(res)
214    }
215
216    /// A low-level routine used in rebuilding. Halts execution if `lhs` and
217    /// `rhs` are equal (pointwise).
218    ///
219    /// Note, calling this with invalid arguments (e.g. different lengths for
220    /// `lhs` and `rhs`) can cause errors at rule runtime.
221    pub(crate) fn check_for_update(
222        &mut self,
223        lhs: &[QueryEntry],
224        rhs: &[QueryEntry],
225    ) -> Result<()> {
226        let lhs = SmallVec::<[QueryEntry; 4]>::from_iter(lhs.iter().cloned());
227        let rhs = SmallVec::<[QueryEntry; 4]>::from_iter(rhs.iter().cloned());
228        if lhs.len() != rhs.len() {
229            return Err(RuleBuilderError::ArityMismatch {
230                expected: lhs.len(),
231                got: rhs.len(),
232            }
233            .into());
234        }
235        lhs.iter().zip(rhs.iter()).try_for_each(|(l, r)| {
236            self.assert_same_ty(l, r).with_context(|| {
237                format!("check_for_update: {lhs:?} and {rhs:?}, mismatch between {l:?} and {r:?}")
238            })
239        })?;
240
241        self.add_callback(move |inner, rb| {
242            let lhs = inner.convert_all(&lhs);
243            let rhs = inner.convert_all(&rhs);
244            rb.assert_any_ne(&lhs, &rhs).context("check_for_update")
245        });
246        Ok(())
247    }
248
249    fn assert_same_ty(
250        &self,
251        l: &QueryEntry,
252        r: &QueryEntry,
253    ) -> std::result::Result<(), RuleBuilderError> {
254        match (l, r) {
255            (
256                QueryEntry::Var(Variable { id: v1, .. }),
257                QueryEntry::Var(Variable { id: v2, .. }),
258            ) => {
259                let ty1 = self.query.vars[*v1].ty;
260                let ty2 = self.query.vars[*v2].ty;
261                if ty1 != ty2 {
262                    return Err(RuleBuilderError::TypeMismatch {
263                        expected: ty1,
264                        got: ty2,
265                    });
266                }
267            }
268            // constants can be untyped
269            (QueryEntry::Const { .. }, QueryEntry::Const { .. })
270            | (QueryEntry::Var { .. }, QueryEntry::Const { .. })
271            | (QueryEntry::Const { .. }, QueryEntry::Var { .. }) => {}
272        }
273        Ok(())
274    }
275
276    fn assert_has_ty(
277        &self,
278        entry: &QueryEntry,
279        ty: ColumnTy,
280    ) -> std::result::Result<(), RuleBuilderError> {
281        if let QueryEntry::Var(Variable { id: v, .. }) = entry {
282            let var_ty = self.query.vars[*v].ty;
283            if var_ty != ty {
284                return Err(RuleBuilderError::TypeMismatch {
285                    expected: var_ty,
286                    got: ty,
287                });
288            }
289        }
290        Ok(())
291    }
292
293    /// Register the given rule with the egraph.
294    pub fn build(mut self) -> RuleId {
295        if self.query.atoms.len() == 1 {
296            self.query.plan_strategy = PlanStrategy::MinCover;
297        }
298        let res = self.query.rule_id;
299        let info = RuleInfo {
300            last_run_at: Timestamp::new(0),
301            query: self.query,
302            cached_plan: None,
303            desc: self.desc,
304        };
305        debug!("created rule {res:?} / {}", info.desc);
306        self.egraph.rules.insert(res, info);
307        res
308    }
309
310    pub(crate) fn set_focus(&mut self, focus: usize) {
311        self.query.sole_focus = Some(focus);
312    }
313
314    /// Bind a new variable of the given type in the query.
315    pub fn new_var(&mut self, ty: ColumnTy) -> Variable {
316        let res = self.query.vars.next_id();
317        let var = Variable {
318            id: res,
319            name: None,
320        };
321        self.query.vars.push(VarInfo { ty, name: None });
322        var
323    }
324
325    /// Bind a new variable of the given type in the query.
326    ///
327    /// This method attaches the given name to the [`QueryEntry`], which can
328    /// make debugging easier.
329    pub fn new_var_named(&mut self, ty: ColumnTy, name: &str) -> QueryEntry {
330        let id = self.query.vars.next_id();
331        let var = Variable {
332            id,
333            name: Some(name.into()),
334        };
335        self.query.vars.push(VarInfo {
336            ty,
337            name: Some(name.into()),
338        });
339        QueryEntry::Var(var)
340    }
341
342    /// A low-level way to add an atom to a query.
343    ///
344    /// The atom is added directly to `table`. If `func` is supplied, then metadata about the
345    /// function is used for schema validation. If `subsume_entry` is
346    /// supplied and the supplied function is enabled for subsumption, then the given
347    /// [`QueryEntry`] is used to populated the subsumption column for the table. This allows
348    /// higher-level routines to constrain the subsumption column or use it for other purposes.
349    pub(crate) fn add_atom_with_timestamp_and_func(
350        &mut self,
351        table: TableId,
352        func: Option<FunctionId>,
353        subsume_entry: Option<QueryEntry>,
354        entries: &[QueryEntry],
355    ) -> AtomId {
356        let mut atom = entries.to_vec();
357        let schema_math = if let Some(func) = func {
358            let info = &self.egraph.funcs[func];
359            assert_eq!(info.schema.len(), entries.len());
360            SchemaMath {
361                subsume: info.can_subsume,
362                func_cols: info.schema.len(),
363            }
364        } else {
365            SchemaMath {
366                subsume: subsume_entry.is_some(),
367                func_cols: entries.len(),
368            }
369        };
370        schema_math.write_table_row(
371            &mut atom,
372            RowVals {
373                timestamp: self.new_var(ColumnTy::Id).into(),
374                subsume: if schema_math.subsume {
375                    Some(subsume_entry.unwrap_or_else(|| self.new_var(ColumnTy::Id).into()))
376                } else {
377                    None
378                },
379                ret_val: None,
380            },
381        );
382        let res = AtomId::from_usize(self.query.atoms.len());
383        self.query.atoms.push((table, atom, schema_math));
384        res
385    }
386
387    pub fn call_external_func(
388        &mut self,
389        func: ExternalFunctionId,
390        args: &[QueryEntry],
391        ret_ty: ColumnTy,
392        panic_msg: impl FnOnce() -> String + 'static + Send,
393    ) -> Variable {
394        let args = args.to_vec();
395        let res = self.new_var(ret_ty);
396        // External functions that fail on the RHS of a rule should cause a panic.
397        let panic_fn = self.egraph.new_panic_lazy(panic_msg);
398        self.query.add_rule.push(Box::new(move |inner, rb| {
399            let args = inner.convert_all(&args);
400            let var = rb.call_external_with_fallback(func, &args, panic_fn, &[])?;
401            inner.mapping.insert(res.id, var.into());
402            Ok(())
403        }));
404        res
405    }
406
407    /// Add the given table atom to query. As elsewhere in the crate, the last
408    /// argument is the "return value" of the function. Can also optionally
409    /// check the subsumption bit.
410    pub fn query_table(
411        &mut self,
412        func: FunctionId,
413        entries: &[QueryEntry],
414        is_subsumed: Option<bool>,
415    ) -> Result<AtomId> {
416        let info = &self.egraph.funcs[func];
417        let schema = &info.schema;
418        if schema.len() != entries.len() {
419            return Err(anyhow::Error::from(RuleBuilderError::ArityMismatch {
420                expected: schema.len(),
421                got: entries.len(),
422            }))
423            .with_context(|| format!("query_table: mismatch between {entries:?} and {schema:?}"));
424        }
425        entries
426            .iter()
427            .zip(schema.iter())
428            .try_for_each(|(entry, ty)| {
429                self.assert_has_ty(entry, *ty)
430                    .with_context(|| format!("query_table: mismatch between {entry:?} and {ty:?}"))
431            })?;
432        Ok(self.add_atom_with_timestamp_and_func(
433            info.table,
434            Some(func),
435            is_subsumed.map(|b| QueryEntry::Const {
436                val: match b {
437                    true => SUBSUMED,
438                    false => NOT_SUBSUMED,
439                },
440                ty: ColumnTy::Id,
441            }),
442            entries,
443        ))
444    }
445
446    /// Add the given primitive atom to query. As elsewhere in the crate, the last
447    /// argument is the "return value" of the function.
448    pub fn query_prim(
449        &mut self,
450        func: ExternalFunctionId,
451        entries: &[QueryEntry],
452        // NB: not clear if we still need this now that proof checker is in a separate crate.
453        _ret_ty: ColumnTy,
454    ) -> Result<()> {
455        let entries = entries.to_vec();
456        self.query.add_rule.push(Box::new(move |inner, rb| {
457            let mut dst_vars = inner.convert_all(&entries);
458            let expected = dst_vars.pop().expect("must specify a return value");
459            let var = rb.call_external(func, &dst_vars)?;
460            match entries.last().unwrap() {
461                QueryEntry::Var(Variable { id, .. }) if !inner.grounded.contains(id) => {
462                    inner.mapping.insert(*id, var.into());
463                    inner.grounded.insert(*id);
464                }
465                _ => rb.assert_eq(var.into(), expected),
466            }
467            Ok(())
468        }));
469        Ok(())
470    }
471
472    /// Subsume the given entry in `func`.
473    ///
474    /// `entries` should match the number of keys to the function.
475    pub fn subsume(&mut self, func: FunctionId, entries: &[QueryEntry]) {
476        // First, insert a subsumed value if the tuple is new.
477        let ret = self.lookup_with_subsumed(
478            func,
479            entries,
480            QueryEntry::Const {
481                val: SUBSUMED,
482                ty: ColumnTy::Id,
483            },
484            || "subsumed a nonextestent row!".to_string(),
485        );
486        let info = &self.egraph.funcs[func];
487        let schema_math = SchemaMath {
488            subsume: info.can_subsume,
489            func_cols: info.schema.len(),
490        };
491        assert!(info.can_subsume);
492        assert_eq!(entries.len() + 1, info.schema.len());
493        let entries = entries.to_vec();
494        let table = info.table;
495
496        let ret: QueryEntry = ret.into();
497        self.add_callback(move |inner, rb| {
498            // Then, add a tuple subsuming the entry, but only if the entry isn't already subsumed.
499            // Look up the current subsume value.
500            let mut dst_entries = inner.convert_all(&entries);
501            let cur_subsume_val = rb.lookup(
502                table,
503                &dst_entries,
504                ColumnId::from_usize(schema_math.subsume_col()),
505            )?;
506            schema_math.write_table_row(
507                &mut dst_entries,
508                RowVals {
509                    timestamp: inner.next_ts(),
510                    subsume: Some(SUBSUMED.into()),
511                    ret_val: Some(inner.convert(&ret)),
512                },
513            );
514            rb.insert_if_eq(
515                table,
516                cur_subsume_val.into(),
517                NOT_SUBSUMED.into(),
518                &dst_entries,
519            )?;
520            Ok(())
521        });
522    }
523
524    pub(crate) fn lookup_with_subsumed(
525        &mut self,
526        func: FunctionId,
527        entries: &[QueryEntry],
528        subsumed: QueryEntry,
529        panic_msg: impl FnOnce() -> String + Send + 'static,
530    ) -> Variable {
531        let entries = entries.to_vec();
532        let info = &self.egraph.funcs[func];
533        let res = self
534            .query
535            .vars
536            .push(VarInfo {
537                ty: info.ret_ty(),
538                name: None,
539            })
540            .to_var();
541        let table = info.table;
542        let id_counter = self.query.id_counter;
543        let schema_math = SchemaMath {
544            subsume: info.can_subsume,
545            func_cols: info.schema.len(),
546        };
547        let cb: BuildRuleCallback = match info.default_val {
548            DefaultVal::Const(_) | DefaultVal::FreshId => {
549                let wv: WriteVal = match &info.default_val {
550                    DefaultVal::Const(c) => (*c).into(),
551                    DefaultVal::FreshId => WriteVal::IncCounter(id_counter),
552                    _ => unreachable!(),
553                };
554                let get_write_vals = move |inner: &mut Bindings| {
555                    let mut write_vals = SmallVec::<[WriteVal; 4]>::new();
556                    for i in schema_math.num_keys()..schema_math.table_columns() {
557                        if i == schema_math.ts_col() {
558                            write_vals.push(inner.next_ts().into());
559                        } else if i == schema_math.ret_val_col() {
560                            write_vals.push(wv);
561                        } else if schema_math.subsume && i == schema_math.subsume_col() {
562                            write_vals.push(inner.convert(&subsumed).into())
563                        } else {
564                            unreachable!()
565                        }
566                    }
567                    write_vals
568                };
569
570                Box::new(move |inner, rb| {
571                    let write_vals = get_write_vals(inner);
572                    let dst_vars = inner.convert_all(&entries);
573                    let var = rb.lookup_or_insert(
574                        table,
575                        &dst_vars,
576                        &write_vals,
577                        ColumnId::from_usize(schema_math.ret_val_col()),
578                    )?;
579                    inner.mapping.insert(res.id, var.into());
580                    Ok(())
581                })
582            }
583            DefaultVal::Fail => {
584                let panic_func = self.egraph.new_panic_lazy(panic_msg);
585                Box::new(move |inner, rb| {
586                    let dst_vars = inner.convert_all(&entries);
587                    let var = rb.lookup_with_fallback(
588                        table,
589                        &dst_vars,
590                        ColumnId::from_usize(schema_math.ret_val_col()),
591                        panic_func,
592                        &[],
593                    )?;
594                    inner.mapping.insert(res.id, var.into());
595                    Ok(())
596                })
597            }
598        };
599        self.query.add_rule.push(cb);
600        res
601    }
602
603    /// Look up the value of a function in the database. If the value is not
604    /// present, the configured default for the function is used.
605    ///
606    /// For functions configured with [`DefaultVal::Fail`], failing lookups will use `panic_msg` in
607    /// the panic output.
608    pub fn lookup(
609        &mut self,
610        func: FunctionId,
611        entries: &[QueryEntry],
612        panic_msg: impl FnOnce() -> String + Send + 'static,
613    ) -> Variable {
614        self.lookup_with_subsumed(
615            func,
616            entries,
617            QueryEntry::Const {
618                val: NOT_SUBSUMED,
619                ty: ColumnTy::Id,
620            },
621            panic_msg,
622        )
623    }
624
625    /// Merge the two values in the union-find.
626    pub fn union(&mut self, l: QueryEntry, r: QueryEntry) {
627        self.query.add_rule.push(Box::new(move |inner, rb| {
628            let l = inner.convert(&l);
629            let r = inner.convert(&r);
630            rb.insert(inner.uf_table, &[l, r, inner.next_ts()])
631                .context("union")
632        }));
633    }
634
635    /// This method is equivalent to `remove(table, before); set(table, after)`,
636    /// optionally propagating subsumption to the next row.
637    pub(crate) fn rebuild_row(
638        &mut self,
639        func: FunctionId,
640        before: &[QueryEntry],
641        after: &[QueryEntry],
642        // If subsumption is enabled for this function, we can optionally propagate it to the next
643        // row.
644        subsume_var: Option<Variable>,
645    ) {
646        assert_eq!(before.len(), after.len());
647        self.remove(func, &before[..before.len() - 1]);
648        if let Some(subsume_var) = subsume_var {
649            self.set_with_subsume(func, after, QueryEntry::Var(subsume_var));
650        } else {
651            self.set(func, after);
652        }
653    }
654
655    /// Set the value of a function in the database.
656    pub fn set(&mut self, func: FunctionId, entries: &[QueryEntry]) {
657        self.set_with_subsume(
658            func,
659            entries,
660            QueryEntry::Const {
661                val: NOT_SUBSUMED,
662                ty: ColumnTy::Id,
663            },
664        );
665    }
666
667    pub(crate) fn set_with_subsume(
668        &mut self,
669        func: FunctionId,
670        entries: &[QueryEntry],
671        subsume_entry: QueryEntry,
672    ) {
673        let info = &self.egraph.funcs[func];
674        let table = info.table;
675        let entries = entries.to_vec();
676        let schema_math = SchemaMath {
677            subsume: info.can_subsume,
678            func_cols: info.schema.len(),
679        };
680        self.query.add_rule.push(Box::new(move |inner, rb| {
681            let mut dst_vars = inner.convert_all(&entries);
682            schema_math.write_table_row(
683                &mut dst_vars,
684                RowVals {
685                    timestamp: inner.next_ts(),
686                    subsume: schema_math.subsume.then(|| inner.convert(&subsume_entry)),
687                    ret_val: None, // already filled in
688                },
689            );
690            rb.insert(table, &dst_vars).context("set")
691        }));
692    }
693
694    /// Remove the value of a function from the database.
695    pub fn remove(&mut self, table: FunctionId, entries: &[QueryEntry]) {
696        let table = self.egraph.funcs[table].table;
697        let entries = entries.to_vec();
698        let cb: BuildRuleCallback = Box::new(move |inner, rb| {
699            let dst_vars = inner.convert_all(&entries);
700            rb.remove(table, &dst_vars).context("remove")
701        });
702        self.query.add_rule.push(cb);
703    }
704
705    /// Panic with a given message.
706    pub fn panic(&mut self, message: String) {
707        let panic = self.egraph.new_panic(message.clone());
708        let ret_ty = ColumnTy::Id;
709        let res = self.new_var(ret_ty);
710        self.query.add_rule.push(Box::new(move |inner, rb| {
711            let var = rb.call_external(panic, &[])?;
712            inner.mapping.insert(res.id, var.into());
713            Ok(())
714        }));
715    }
716}
717
718impl Query {
719    fn query_state<'a, 'outer>(
720        &self,
721        rsb: &'a mut RuleSetBuilder<'outer>,
722    ) -> (QueryBuilder<'outer, 'a>, Bindings) {
723        let mut qb = rsb.new_rule();
724        qb.set_plan_strategy(self.plan_strategy);
725        qb.set_no_decomp(self.no_decomp);
726        let mut inner = Bindings {
727            uf_table: self.uf_table,
728            next_ts: None,
729            mapping: Default::default(),
730            grounded: Default::default(),
731        };
732        for (var, info) in self.vars.iter() {
733            let new_var = match info.name.as_ref() {
734                Some(name) => qb.new_var_named(name),
735                None => qb.new_var(),
736            };
737            inner.mapping.insert(var, DstVar::Var(new_var));
738        }
739        (qb, inner)
740    }
741
742    fn run_rules_and_build(
743        &self,
744        qb: QueryBuilder,
745        mut inner: Bindings,
746        desc: &str,
747    ) -> Result<core_relations::RuleId> {
748        let mut rb = qb.build();
749        inner.next_ts = Some(rb.read_counter(self.ts_counter).into());
750        self.add_rule
751            .iter()
752            .try_for_each(|f| f(&mut inner, &mut rb))?;
753        Ok(rb.build_with_description(desc))
754    }
755
756    pub(crate) fn build_cached_plan(
757        &self,
758        db: &mut core_relations::Database,
759        desc: &str,
760    ) -> Result<CachedPlanInfo> {
761        let mut rsb = RuleSetBuilder::new(db);
762        let (mut qb, mut inner) = self.query_state(&mut rsb);
763        let mut atom_mapping = Vec::with_capacity(self.atoms.len());
764        for (table, entries, _schema_info) in &self.atoms {
765            atom_mapping.push(add_atom(&mut qb, *table, entries, &[], &mut inner)?);
766        }
767        let rule_id = self.run_rules_and_build(qb, inner, desc)?;
768        let rs = rsb.build();
769        let plan = Arc::new(rs.build_cached_plan(rule_id));
770        Ok(CachedPlanInfo { plan, atom_mapping })
771    }
772
773    /// Add rules to the [`RuleSetBuilder`] for the query specified by the [`CachedPlanInfo`].
774    ///
775    /// A [`CachedPlanInfo`] is a compiled RHS and partial LHS for an egglog rules. In order to
776    /// implement seminaive evaluation, we run several variants of this cached plan with different
777    /// constraints on the timestamps for different atoms. This rule handles building these
778    /// variants of the base plan and adding them to `rsb`.
779    pub(crate) fn add_rules_from_cached(
780        &self,
781        rsb: &mut RuleSetBuilder,
782        mid_ts: Timestamp,
783        cached_plan: &CachedPlanInfo,
784    ) {
785        // For N atoms, we create N queries for seminaive evaluation. We can reuse the cached plan
786        // directly.
787        if !self.seminaive || (self.atoms.is_empty() && mid_ts == Timestamp::new(0)) {
788            let _ = rsb.add_rule_from_cached_plan(&cached_plan.plan, &[]);
789            return;
790        }
791        if let Some(focus_atom) = self.sole_focus {
792            // There is a single "focus" atom that we will constrain to look at new values.
793            let (_, _, schema_info) = &self.atoms[focus_atom];
794            let ts_col = ColumnId::from_usize(schema_info.ts_col());
795            let _ = rsb.add_rule_from_cached_plan(
796                &cached_plan.plan,
797                &[(
798                    cached_plan.atom_mapping[focus_atom],
799                    Constraint::GeConst {
800                        col: ts_col,
801                        val: mid_ts.to_value(),
802                    },
803                )],
804            );
805            return;
806        }
807        // Use the cached plan atoms.len() times with different constraints on each atom.
808        // The semi-naive set of queries will generate N variants of the rule, each with a different "focus" atom that looks at new values only.
809        // For a query A x B x C, we will generate the following rules:
810        //
811        //     A_new x B x C + A_old x B_new x C + A_old x B_old x C_new
812        //
813        let mut constraints: Vec<(core_relations::AtomId, Constraint)> =
814            Vec::with_capacity(self.atoms.len());
815        'outer: for focus_atom in 0..self.atoms.len() {
816            constraints.clear();
817            // start with the focus atom since `add_rule_from_cached_plan` will apply the
818            // constraints in order, and the focus atom may have an empty delta, which
819            // will let it bail early.
820            {
821                let (_, _, schema_info) = &self.atoms[focus_atom];
822                let ts_col = ColumnId::from_usize(schema_info.ts_col());
823                constraints.push((
824                    cached_plan.atom_mapping[focus_atom],
825                    Constraint::GeConst {
826                        col: ts_col,
827                        val: mid_ts.to_value(),
828                    },
829                ))
830            }
831            for (i, (_, _, schema_info)) in self.atoms[0..focus_atom].iter().enumerate() {
832                if mid_ts == Timestamp::new(0) {
833                    continue 'outer;
834                }
835                let ts_col = ColumnId::from_usize(schema_info.ts_col());
836                constraints.push((
837                    cached_plan.atom_mapping[i],
838                    Constraint::LtConst {
839                        col: ts_col,
840                        val: mid_ts.to_value(),
841                    },
842                ));
843            }
844            let _ = rsb.add_rule_from_cached_plan(&cached_plan.plan, &constraints);
845        }
846    }
847}
848
849/// State that is used during query execution to translate variabes in egglog
850/// rules into variables for core-relations rules.
851pub(crate) struct Bindings {
852    uf_table: TableId,
853    next_ts: Option<DstVar>,
854    pub(crate) mapping: DenseIdMap<VariableId, DstVar>,
855    grounded: HashSet<VariableId>,
856}
857
858impl Bindings {
859    pub(crate) fn next_ts(&self) -> DstVar {
860        self.next_ts
861            .expect("ts_var should only be used in RHS of the rule")
862    }
863    pub(crate) fn convert(&self, entry: &QueryEntry) -> DstVar {
864        match entry {
865            QueryEntry::Var(Variable { id: v, .. }) => self.mapping[*v],
866            QueryEntry::Const { val, .. } => DstVar::Const(*val),
867        }
868    }
869    pub(crate) fn convert_all(&self, entries: &[QueryEntry]) -> SmallVec<[DstVar; 4]> {
870        entries.iter().map(|e| self.convert(e)).collect()
871    }
872}
873
874fn add_atom(
875    qb: &mut QueryBuilder,
876    table: TableId,
877    entries: &[QueryEntry],
878    constraints: &[Constraint],
879    inner: &mut Bindings,
880) -> Result<core_relations::AtomId> {
881    for entry in entries {
882        if let QueryEntry::Var(Variable { id, .. }) = entry {
883            inner.grounded.insert(*id);
884        }
885    }
886    let vars = inner.convert_all(entries);
887    Ok(qb.add_atom(table, &vars, constraints)?)
888}