egglog_core_relations/uf/
mod.rs

1//! A table implementation backed by a union-find.
2
3use std::{
4    any::Any,
5    mem::{self, ManuallyDrop},
6    sync::{Arc, Weak},
7};
8
9use crate::numeric_id::{DenseIdMap, NumericId};
10use crossbeam_queue::SegQueue;
11use indexmap::IndexMap;
12use petgraph::{Direction, Graph, algo::dijkstra, graph::NodeIndex, visit::EdgeRef};
13
14use crate::{
15    TableChange, TaggedRowBuffer,
16    action::ExecutionState,
17    common::{HashMap, IndexSet, Value},
18    offsets::{OffsetRange, RowId, Subset, SubsetRef},
19    pool::with_pool_set,
20    row_buffer::RowBuffer,
21    table_spec::{
22        ColumnId, Constraint, Generation, MutationBuffer, Offset, Rebuilder, Row, Table, TableSpec,
23        TableVersion, WrappedTableRef,
24    },
25};
26
27#[cfg(test)]
28mod tests;
29
30type UnionFind = crate::union_find::UnionFind<Value>;
31
32/// A special table backed by a union-find used to efficiently implement
33/// egglog-style canonicaliztion.
34///
35/// To canonicalize columns, we need to efficiently discover values that have
36/// ceased to be canonical. To do that we keep a table of _displaced_ values:
37///
38/// This table has three columns:
39/// 1. (the only key): a value that is _no longer canonical_ in the equivalence relation.
40/// 2. The canonical value of the equivalence class.
41/// 3. The timestamp at which the key stopped being canonical.
42///
43/// We do not store the second value explicitly: instead, we compute it
44/// on-the-fly using a union-find data-structure.
45///
46/// This is related to the 'Leader' encoding in some versions of egglog:
47/// Displaced is a version of Leader that _only_ stores ids when they cease to
48/// be canonical. Rows are also "automatically updated" with the current leader,
49/// rather than requiring the DB to replay history or canonicalize redundant
50/// values in the table.
51///
52/// To union new ids `l`, and `r`, stage an update `Displaced(l, r, ts)` where
53/// `ts` is the current timestamp. Note that all tie-breaks and other encoding
54/// decisions are made internally, so there may not literally be a row added
55/// with this value.
56pub struct DisplacedTable {
57    uf: UnionFind,
58    displaced: Vec<(Value, Value)>,
59    changed: bool,
60    lookup_table: HashMap<Value, RowId>,
61    buffered_writes: Arc<SegQueue<RowBuffer>>,
62}
63
64struct Canonicalizer<'a> {
65    cols: Vec<ColumnId>,
66    table: &'a DisplacedTable,
67}
68
69impl Rebuilder for Canonicalizer<'_> {
70    fn hint_col(&self) -> Option<ColumnId> {
71        Some(ColumnId::new(0))
72    }
73    fn rebuild_val(&self, val: Value) -> Value {
74        self.table.uf.find_naive(val)
75    }
76    fn rebuild_buf(
77        &self,
78        buf: &RowBuffer,
79        start: RowId,
80        end: RowId,
81        out: &mut TaggedRowBuffer,
82        _exec_state: &mut ExecutionState,
83    ) {
84        if start >= end {
85            return;
86        }
87        assert!(end.index() <= buf.len());
88        let mut cur = start;
89        let mut scratch = with_pool_set(|ps| ps.get::<Vec<Value>>());
90        // SAFETY: `cur` is always in-bounds, guaranteed by the above assertion.
91        // Special-case small columns: this gives us a modest speedup on rebuilding-heavy
92        // workloads.
93        match self.cols.as_slice() {
94            [c] => {
95                while cur < end {
96                    let row = unsafe { buf.get_row_unchecked(cur) };
97                    let to_canon = row[c.index()];
98                    let canon = self.table.uf.find_naive(to_canon);
99                    if canon != to_canon {
100                        scratch.extend_from_slice(row);
101                        scratch[c.index()] = canon;
102                        out.add_row(cur, &scratch);
103                        scratch.clear();
104                    }
105                    cur = cur.inc();
106                }
107            }
108            [c1, c2] => {
109                while cur < end {
110                    let row = unsafe { buf.get_row_unchecked(cur) };
111                    let v1 = row[c1.index()];
112                    let v2 = row[c2.index()];
113                    let ca1 = self.table.uf.find_naive(v1);
114                    let ca2 = self.table.uf.find_naive(v2);
115                    if ca1 != v1 || ca2 != v2 {
116                        scratch.extend_from_slice(row);
117                        scratch[c1.index()] = ca1;
118                        scratch[c2.index()] = ca2;
119                        out.add_row(cur, &scratch);
120                        scratch.clear();
121                    }
122                    cur = cur.inc();
123                }
124            }
125            [c1, c2, c3] => {
126                while cur < end {
127                    let row = unsafe { buf.get_row_unchecked(cur) };
128                    let v1 = row[c1.index()];
129                    let v2 = row[c2.index()];
130                    let v3 = row[c3.index()];
131                    let ca1 = self.table.uf.find_naive(v1);
132                    let ca2 = self.table.uf.find_naive(v2);
133                    let ca3 = self.table.uf.find_naive(v3);
134                    if ca1 != v1 || ca2 != v2 || ca3 != v3 {
135                        scratch.extend_from_slice(row);
136                        scratch[c1.index()] = ca1;
137                        scratch[c2.index()] = ca2;
138                        scratch[c3.index()] = ca3;
139                        out.add_row(cur, &scratch);
140                        scratch.clear();
141                    }
142                    cur = cur.inc();
143                }
144            }
145            cs => {
146                while cur < end {
147                    scratch.extend_from_slice(unsafe { buf.get_row_unchecked(cur) });
148                    let mut changed = false;
149                    for c in cs {
150                        let to_canon = scratch[c.index()];
151                        let canon = self.table.uf.find_naive(to_canon);
152                        scratch[c.index()] = canon;
153                        changed |= canon != to_canon;
154                    }
155                    if changed {
156                        out.add_row(cur, &scratch);
157                    }
158                    scratch.clear();
159                    cur = cur.inc();
160                }
161            }
162        }
163    }
164    fn rebuild_subset(
165        &self,
166        other: WrappedTableRef,
167        subset: SubsetRef,
168        out: &mut TaggedRowBuffer,
169        _exec_state: &mut ExecutionState,
170    ) {
171        let old_len = u32::try_from(out.len()).expect("row buffer sizes should fit in a u32");
172        let _next = other.scan_bounded(subset, Offset::new(0), usize::MAX, out);
173        debug_assert!(_next.is_none());
174        for i in old_len..u32::try_from(out.len()).expect("row buffer sizes should fit in a u32") {
175            let i = RowId::new(i);
176            let (_id, row) = out.get_row_mut(i);
177            let mut changed = false;
178            for col in &self.cols {
179                let to_canon = row[col.index()];
180                let canon = self.table.uf.find_naive(to_canon);
181                changed |= canon != to_canon;
182                row[col.index()] = canon;
183            }
184            if !changed {
185                out.set_stale(i);
186            }
187        }
188    }
189    fn rebuild_slice(&self, vals: &mut [Value]) -> bool {
190        let mut changed = false;
191        for val in vals {
192            let canon = self.table.uf.find_naive(*val);
193            changed |= canon != *val;
194            *val = canon;
195        }
196        changed
197    }
198}
199
200impl Default for DisplacedTable {
201    fn default() -> Self {
202        Self {
203            uf: UnionFind::default(),
204            displaced: Vec::new(),
205            changed: false,
206            lookup_table: HashMap::default(),
207            buffered_writes: Arc::new(SegQueue::new()),
208        }
209    }
210}
211
212impl Clone for DisplacedTable {
213    fn clone(&self) -> Self {
214        DisplacedTable {
215            uf: self.uf.clone(),
216            displaced: self.displaced.clone(),
217            changed: self.changed,
218            lookup_table: self.lookup_table.clone(),
219            buffered_writes: Default::default(),
220        }
221    }
222}
223
224struct UfBuffer {
225    to_insert: ManuallyDrop<RowBuffer>,
226    buffered_writes: Weak<SegQueue<RowBuffer>>,
227}
228
229impl Drop for UfBuffer {
230    fn drop(&mut self) {
231        let Some(buffered_writes) = self.buffered_writes.upgrade() else {
232            // SAFETY: If we can't write updates, manually drop to_insert
233            unsafe {
234                ManuallyDrop::drop(&mut self.to_insert);
235            }
236            return;
237        };
238        // SAFETY: self.to_insert will not be used again after this point.
239        //
240        // This avoids creating a fresh row buffer via `mem::take` or `mem::swap` and
241        // dropping it immediately.
242        let to_insert = unsafe { ManuallyDrop::take(&mut self.to_insert) };
243        buffered_writes.push(to_insert);
244    }
245}
246
247impl MutationBuffer for UfBuffer {
248    fn stage_insert(&mut self, row: &[Value]) {
249        self.to_insert.add_row(row);
250    }
251    fn stage_remove(&mut self, _: &[Value]) {
252        panic!("attempting to remove data from a DisplacedTable")
253    }
254    fn fresh_handle(&self) -> Box<dyn MutationBuffer> {
255        Box::new(UfBuffer {
256            to_insert: ManuallyDrop::new(RowBuffer::new(self.to_insert.arity())),
257            buffered_writes: self.buffered_writes.clone(),
258        })
259    }
260}
261
262impl Table for DisplacedTable {
263    fn dyn_clone(&self) -> Box<dyn Table> {
264        Box::new(self.clone())
265    }
266    fn as_any(&self) -> &dyn Any {
267        self
268    }
269    fn spec(&self) -> TableSpec {
270        let mut uncacheable_columns = DenseIdMap::default();
271        // The second column of this table is determined dynamically by the union-find.
272        uncacheable_columns.insert(ColumnId::new(1), true);
273        TableSpec {
274            n_keys: 1,
275            n_vals: 2,
276            uncacheable_columns,
277            allows_delete: false,
278        }
279    }
280
281    fn rebuilder<'a>(&'a self, cols: &[ColumnId]) -> Option<Box<dyn Rebuilder + 'a>> {
282        Some(Box::new(Canonicalizer {
283            cols: cols.to_vec(),
284            table: self,
285        }))
286    }
287
288    fn clear(&mut self) {
289        self.uf.reset();
290        self.displaced.clear();
291    }
292
293    fn all(&self) -> Subset {
294        Subset::Dense(OffsetRange::new(
295            RowId::new(0),
296            RowId::from_usize(self.displaced.len()),
297        ))
298    }
299
300    fn len(&self) -> usize {
301        self.displaced.len()
302    }
303
304    fn version(&self) -> TableVersion {
305        TableVersion {
306            major: Generation::new(0),
307            minor: Offset::from_usize(self.displaced.len()),
308        }
309    }
310
311    fn updates_since(&self, offset: Offset) -> Subset {
312        Subset::Dense(OffsetRange::new(
313            RowId::from_usize(offset.index()),
314            RowId::from_usize(self.displaced.len()),
315        ))
316    }
317
318    fn scan_generic_bounded(
319        &self,
320        subset: SubsetRef,
321        start: Offset,
322        n: usize,
323        cs: &[Constraint],
324        mut f: impl FnMut(RowId, &[Value]),
325    ) -> Option<Offset>
326    where
327        Self: Sized,
328    {
329        if cs.is_empty() {
330            let start = start.index();
331            subset
332                .iter_bounded(start, start + n, |row| {
333                    f(row, self.expand(row).as_slice());
334                })
335                .map(Offset::from_usize)
336        } else {
337            let start = start.index();
338            subset
339                .iter_bounded(start, start + n, |row| {
340                    if cs.iter().all(|c| self.eval(c, row)) {
341                        f(row, self.expand(row).as_slice());
342                    }
343                })
344                .map(Offset::from_usize)
345        }
346    }
347
348    fn refine_one(&self, mut subset: Subset, c: &Constraint) -> Subset {
349        subset.retain(|row| self.eval(c, row));
350        subset
351    }
352
353    fn fast_subset(&self, constraint: &Constraint) -> Option<Subset> {
354        let ts = ColumnId::new(2);
355        match constraint {
356            Constraint::Eq { .. } => None,
357            Constraint::EqConst { col, val } => {
358                if *col == ColumnId::new(1) {
359                    return None;
360                }
361                if *col == ColumnId::new(0) {
362                    return Some(match self.lookup_table.get(val) {
363                        Some(row) => Subset::Dense(OffsetRange::new(
364                            *row,
365                            RowId::from_usize(row.index() + 1),
366                        )),
367                        None => Subset::empty(),
368                    });
369                }
370                match self.timestamp_bounds(*val) {
371                    Ok((start, end)) => Some(Subset::Dense(OffsetRange::new(start, end))),
372                    Err(_) => None,
373                }
374            }
375            Constraint::LtConst { col, val } => {
376                if *col != ts {
377                    return None;
378                }
379                match self.timestamp_bounds(*val) {
380                    Err(bound) | Ok((bound, _)) => {
381                        Some(Subset::Dense(OffsetRange::new(RowId::new(0), bound)))
382                    }
383                }
384            }
385            Constraint::GtConst { col, val } => {
386                if *col != ts {
387                    return None;
388                }
389
390                match self.timestamp_bounds(*val) {
391                    Err(bound) | Ok((_, bound)) => Some(Subset::Dense(OffsetRange::new(
392                        bound,
393                        RowId::from_usize(self.displaced.len()),
394                    ))),
395                }
396            }
397            Constraint::LeConst { col, val } => {
398                if *col != ts {
399                    return None;
400                }
401
402                match self.timestamp_bounds(*val) {
403                    Err(bound) | Ok((_, bound)) => {
404                        Some(Subset::Dense(OffsetRange::new(RowId::new(0), bound)))
405                    }
406                }
407            }
408            Constraint::GeConst { col, val } => {
409                if *col != ts {
410                    return None;
411                }
412
413                match self.timestamp_bounds(*val) {
414                    Err(bound) | Ok((bound, _)) => Some(Subset::Dense(OffsetRange::new(
415                        bound,
416                        RowId::from_usize(self.displaced.len()),
417                    ))),
418                }
419            }
420        }
421    }
422
423    fn get_row(&self, key: &[Value]) -> Option<Row> {
424        assert_eq!(key.len(), 1, "attempt to lookup a row with the wrong key");
425        let row_id = *self.lookup_table.get(&key[0])?;
426        let mut vals = with_pool_set(|ps| ps.get::<Vec<Value>>());
427        vals.extend_from_slice(self.expand(row_id).as_slice());
428        Some(Row { id: row_id, vals })
429    }
430
431    fn get_row_column(&self, key: &[Value], col: ColumnId) -> Option<Value> {
432        assert_eq!(key.len(), 1, "attempt to lookup a row with the wrong key");
433        if col == ColumnId::new(1) {
434            Some(self.uf.find_naive(key[0]))
435        } else {
436            let row_id = *self.lookup_table.get(&key[0])?;
437            Some(self.expand(row_id)[col.index()])
438        }
439    }
440
441    fn new_buffer(&self) -> Box<dyn MutationBuffer> {
442        Box::new(UfBuffer {
443            to_insert: ManuallyDrop::new(RowBuffer::new(3)),
444            buffered_writes: Arc::downgrade(&self.buffered_writes),
445        })
446    }
447
448    fn merge(&mut self, _: &mut ExecutionState) -> TableChange {
449        while let Some(rowbuf) = self.buffered_writes.pop() {
450            for row in rowbuf.iter() {
451                self.changed |= self.insert_impl(row).is_some();
452            }
453        }
454        let changed = mem::take(&mut self.changed);
455        // UF table rows can be updated "in place", we count both added and removed as changed in
456        // this case.
457        TableChange {
458            added: changed,
459            removed: changed,
460        }
461    }
462}
463
464impl DisplacedTable {
465    pub fn underlying_uf(&self) -> &UnionFind {
466        &self.uf
467    }
468    fn expand(&self, row: RowId) -> [Value; 3] {
469        let (child, ts) = self.displaced[row.index()];
470        [child, self.uf.find_naive(child), ts]
471    }
472    fn timestamp_bounds(&self, val: Value) -> Result<(RowId, RowId), RowId> {
473        match self.displaced.binary_search_by_key(&val, |(_, ts)| *ts) {
474            Ok(mut off) => {
475                let mut next = off;
476                while off > 0 && self.displaced[off - 1].1 == val {
477                    off -= 1;
478                }
479                while next < self.displaced.len() && self.displaced[next].1 == val {
480                    next += 1;
481                }
482                Ok((RowId::from_usize(off), RowId::from_usize(next)))
483            }
484            Err(off) => Err(RowId::from_usize(off)),
485        }
486    }
487    fn eval(&self, constraint: &Constraint, row: RowId) -> bool {
488        let vals = self.expand(row);
489        eval_constraint(&vals, constraint)
490    }
491    fn insert_impl(&mut self, row: &[Value]) -> Option<(Value, Value)> {
492        assert_eq!(row.len(), 3, "attempt to insert a row with the wrong arity");
493        if self.uf.find(row[0]) == self.uf.find(row[1]) {
494            return None;
495        }
496        let (parent, child) = self.uf.union(row[0], row[1]);
497
498        // Compress paths somewhat, given that we perform naive finds everywhere else.
499        let _ = self.uf.find(parent);
500        let _ = self.uf.find(child);
501        let ts = row[2];
502        if let Some((_, highest)) = self.displaced.last() {
503            assert!(
504                *highest <= ts,
505                "must insert rows with increasing timestamps"
506            );
507        }
508        let next = RowId::from_usize(self.displaced.len());
509        self.displaced.push((child, ts));
510        self.lookup_table.insert(child, next);
511        Some((parent, child))
512    }
513}
514
515/// A variant of `DisplacedTable` that also stores "provenance" information that
516/// can be used to generate proofs of equality.
517///
518/// This table expects a fourth "proof" column, though the values it hands back
519/// _are not_ the proofs that come in and generally should not be used directly.
520/// To generate a proof that two values are equal, this table exports a separate
521/// `get_proof` method.
522#[derive(Clone, Default)]
523pub struct DisplacedTableWithProvenance {
524    base: DisplacedTable,
525    /// Added context for a given "displaced" row. We use this to store "proofs
526    /// that x = y".
527    ///
528    /// N.B. We currently only use the first proof that we find. The remaining
529    /// proofs are used for debugging. With some further refactoring we should
530    /// be able to remove this field entirely, as complete proof information is
531    /// now available through `proof_graph`.
532    context: HashMap<(Value, Value), IndexSet<Value>>,
533    proof_graph: Graph<Value, ProofEdge>,
534    node_map: HashMap<Value, NodeIndex>,
535    /// The value that was displaced, the value _immediately_ displacing it.
536    /// NB: this is different from the 'displaced' table in 'base', which holds
537    /// a timestamp.
538    displaced: Vec<(Value, Value)>,
539    buffered_writes: Arc<SegQueue<RowBuffer>>,
540}
541
542#[derive(Copy, Clone, Eq, PartialEq)]
543struct ProofEdge {
544    reason: ProofReason,
545    ts: Value,
546}
547
548#[derive(Clone, Debug, PartialEq, Eq)]
549pub struct ProofStep {
550    pub lhs: Value,
551    pub rhs: Value,
552    pub reason: ProofReason,
553}
554
555#[derive(Debug, PartialEq, Eq, Clone, Copy)]
556pub enum ProofReason {
557    Forward(Value),
558    Backward(Value),
559}
560
561impl DisplacedTableWithProvenance {
562    fn expand(&self, row: RowId) -> [Value; 4] {
563        let [v1, v2, v3] = self.base.expand(row);
564        let (child, parent) = self.displaced[row.index()];
565        debug_assert_eq!(child, v1);
566        let proof = *self.context[&(child, parent)].get_index(0).unwrap();
567        [v1, v2, v3, proof]
568    }
569
570    fn eval(&self, constraint: &Constraint, row: RowId) -> bool {
571        eval_constraint(&self.expand(row), constraint)
572    }
573
574    /// Return the timestamp when `l` and `r` became equal.
575    ///
576    /// This is used to filter possible paths in the proof graph. The algorithm
577    /// we use here is a variant of the classic algorithm in "Proof-Producing
578    /// Congruence Closure" by Nieuwenhuis and Oliveras for reconstructing a
579    /// proof.
580    fn timestamp_when_equal(&self, l: Value, r: Value) -> Option<u32> {
581        if l == r {
582            return Some(0);
583        }
584        let mut l_proofs = IndexMap::new();
585        let mut r_proofs = IndexMap::new();
586        if self.base.uf.find_naive(l) != self.base.uf.find_naive(r) {
587            // The two values aren't equal.
588            return None;
589        }
590        let canon = self.base.uf.find_naive(l);
591
592        // General case: collect individual equality proofs that point from `l`
593        // (sim. `r`) and move towards canon. We stop early and don't always go
594        // to `canon`. To see why consider the following sequences of unions.
595        // For simplicity, we'll assume that the "leader" (or new canonical id)
596        // is always the second argument to `union`.
597        // * left:  A: union(0,2), B: union(2,4), C: union(4,6)
598        // * right: D: union(1,3), E: union(3,5), F: union(5,4), C: union(4,6)
599        // Where `l` `r` are 0 and 1, and their canonical value is `6`.
600        // A simple approach here would be to simply glue the proofs that `l=6`
601        // and `r=6` together, something like:
602        //
603        //    [A;B;C;rev(C);rev(F);rev(E);rev(D)]
604        //
605        // The code below avoids the redundant common suffix (i.e. `C;rev(C)`)
606        // and just uses A,B,D,E, and F.
607        //
608        // In addition to allowing us to generate smaller proofs, this sort of
609        // algorithm also ensures that we are returning the first proof of `l =
610        // r` that we learned about, which is important for avoiding cycles when
611        // reconstructing a proof.
612
613        // General case: create a proof  that l = canon, then compose it with
614        // the proof that r = canon, reversed.
615        for (mut cur, steps) in [(l, &mut l_proofs), (r, &mut r_proofs)] {
616            while cur != canon {
617                // Find where cur became non-canonical.
618                let row = *self.base.lookup_table.get(&cur).unwrap();
619                let (_, ts) = self.base.displaced[row.index()];
620                let (child, parent) = self.displaced[row.index()];
621                debug_assert_eq!(child, cur);
622                steps.insert(parent, ts);
623                cur = parent;
624            }
625        }
626
627        let mut l_end = None;
628        let mut r_start = None;
629
630        if let Some(i) = r_proofs.get_index_of(&l) {
631            r_start = Some(i);
632        } else {
633            for (i, (next_id, _)) in l_proofs.iter().enumerate() {
634                if *next_id == r {
635                    l_end = Some(i);
636                    break;
637                }
638                if let Some(j) = r_proofs.get_index_of(next_id) {
639                    l_end = Some(i);
640                    r_start = Some(j);
641                    break;
642                }
643            }
644        }
645        match (l_end, r_start) {
646            (None, Some(start)) => r_proofs.as_slice()[..=start]
647                .iter()
648                .map(|(_, ts)| ts.rep())
649                .max(),
650            (Some(end), None) => l_proofs.as_slice()[..=end]
651                .iter()
652                .map(|(_, ts)| ts.rep())
653                .max(),
654            (Some(end), Some(start)) => l_proofs.as_slice()[..=end]
655                .iter()
656                .map(|(_, ts)| ts.rep())
657                .chain(r_proofs.as_slice()[..=start].iter().map(|(_, ts)| ts.rep()))
658                .max(),
659            (None, None) => {
660                panic!(
661                    "did not find common id, despite the values being equivalent {l:?} / {r:?}, l_proofs={l_proofs:?}, r_proofs={r_proofs:?}"
662                )
663            }
664        }
665    }
666
667    /// A simple proof generation algorithm that searches for the shortest path
668    /// in the proof graph between `l` and `r`.
669    ///
670    /// The path in the graph is restricted to the timestamps at or before `l`
671    /// and `r` first became equal. This is to avoid cycles during proof
672    /// reconstruction.
673    pub fn get_proof(&self, l: Value, r: Value) -> Option<Vec<ProofStep>> {
674        let ts = self.timestamp_when_equal(l, r)?;
675        let start = self.node_map[&l];
676        let goal = self.node_map[&r];
677        let costs = dijkstra(&self.proof_graph, self.node_map[&l], Some(goal), |edge| {
678            if edge.weight().ts.rep() > ts {
679                // avoid edges added after the two became equal.
680                f64::INFINITY
681            } else {
682                1.0f64
683            }
684        });
685        // Reconstruct the proof steps from the cost map returned from petgraph.
686        // Start at the end and then work backwards along the shortest path.
687        let mut path = Vec::new();
688        let mut cur = goal;
689        while cur != start {
690            let (_, step, next) = self
691                .proof_graph
692                .edges_directed(cur, Direction::Incoming)
693                .filter_map(|edge| {
694                    let source = edge.source();
695                    let cost = costs.get(&source)?;
696                    let step = ProofStep {
697                        lhs: *self.proof_graph.node_weight(source).unwrap(),
698                        rhs: *self.proof_graph.node_weight(edge.target()).unwrap(),
699                        reason: edge.weight().reason,
700                    };
701                    Some((cost, step, source))
702                })
703                .fold(None, |acc, cur| {
704                    // Manually implement 'min' because we are using f64 for costs.
705                    // We should probably switch these edge costs over to NotNan
706                    // or a custom type.
707                    let Some(acc) = acc else {
708                        return Some(cur);
709                    };
710                    Some(if acc.0 > cur.0 { cur } else { acc })
711                })
712                .unwrap();
713            path.push(step);
714            cur = next;
715        }
716        path.reverse();
717        Some(path)
718    }
719    fn get_or_create_node(&mut self, val: Value) -> NodeIndex {
720        *self
721            .node_map
722            .entry(val)
723            .or_insert_with(|| self.proof_graph.add_node(val))
724    }
725
726    fn insert_impl(&mut self, row: &[Value]) {
727        let [a, b, ts, reason] = row else {
728            panic!("attempt to insert a row with the wrong arity ({row:?})");
729        };
730        match self.base.insert_impl(&[*a, *b, *ts]) {
731            Some((parent, child)) => {
732                self.displaced.push((child, parent));
733                self.context
734                    .entry((child, parent))
735                    .or_default()
736                    .insert(*reason);
737                self.base.changed = true;
738
739                let a_node = self.get_or_create_node(*a);
740                let b_node = self.get_or_create_node(*b);
741                self.proof_graph.add_edge(
742                    a_node,
743                    b_node,
744                    ProofEdge {
745                        reason: ProofReason::Forward(*reason),
746                        ts: *ts,
747                    },
748                );
749                self.proof_graph.add_edge(
750                    b_node,
751                    a_node,
752                    ProofEdge {
753                        reason: ProofReason::Backward(*reason),
754                        ts: *ts,
755                    },
756                );
757            }
758            None => {
759                self.context.entry((*a, *b)).or_default().insert(*reason);
760                // We don't register a change, even if we learned a new proof.
761                // We may want to change this behavior in order to search for
762                // smaller proofs.
763            }
764        }
765    }
766}
767
768impl Table for DisplacedTableWithProvenance {
769    fn refine_one(&self, mut subset: Subset, c: &Constraint) -> Subset {
770        subset.retain(|row| self.eval(c, row));
771        subset
772    }
773    fn scan_generic_bounded(
774        &self,
775        subset: SubsetRef,
776        start: Offset,
777        n: usize,
778        cs: &[Constraint],
779        mut f: impl FnMut(RowId, &[Value]),
780    ) -> Option<Offset>
781    where
782        Self: Sized,
783    {
784        if cs.is_empty() {
785            let start = start.index();
786            subset
787                .iter_bounded(start, start + n, |row| {
788                    f(row, self.expand(row).as_slice());
789                })
790                .map(Offset::from_usize)
791        } else {
792            let start = start.index();
793            subset
794                .iter_bounded(start, start + n, |row| {
795                    if cs.iter().all(|c| self.eval(c, row)) {
796                        f(row, self.expand(row).as_slice());
797                    }
798                })
799                .map(Offset::from_usize)
800        }
801    }
802
803    fn spec(&self) -> TableSpec {
804        TableSpec {
805            n_vals: 3,
806            ..self.base.spec()
807        }
808    }
809
810    fn merge(&mut self, exec_state: &mut ExecutionState) -> TableChange {
811        while let Some(rowbuf) = self.buffered_writes.pop() {
812            for row in rowbuf.iter() {
813                self.insert_impl(row);
814            }
815        }
816
817        self.base.merge(exec_state)
818    }
819
820    fn get_row(&self, key: &[Value]) -> Option<Row> {
821        let mut inner = self.base.get_row(key)?;
822        let (child, parent) = self.displaced[inner.id.index()];
823        debug_assert_eq!(child, inner.vals[0]);
824        let proof = *self.context[&(child, parent)].get_index(0).unwrap();
825        inner.vals.push(proof);
826        Some(inner)
827    }
828
829    fn get_row_column(&self, key: &[Value], col: ColumnId) -> Option<Value> {
830        if col == ColumnId::new(3) {
831            let row = *self.base.lookup_table.get(&key[0])?;
832            Some(self.expand(row)[3])
833        } else {
834            self.base.get_row_column(key, col)
835        }
836    }
837
838    fn new_buffer(&self) -> Box<dyn MutationBuffer> {
839        Box::new(UfBuffer {
840            to_insert: ManuallyDrop::new(RowBuffer::new(4)),
841            buffered_writes: Arc::downgrade(&self.buffered_writes),
842        })
843    }
844
845    // Many of these methods just delgate to `base`:
846
847    fn dyn_clone(&self) -> Box<dyn Table> {
848        Box::new(self.clone())
849    }
850    fn as_any(&self) -> &dyn Any {
851        self
852    }
853    fn clear(&mut self) {
854        self.base.clear()
855    }
856    fn all(&self) -> Subset {
857        self.base.all()
858    }
859    fn len(&self) -> usize {
860        self.base.len()
861    }
862    fn updates_since(&self, offset: Offset) -> Subset {
863        self.base.updates_since(offset)
864    }
865    fn version(&self) -> TableVersion {
866        self.base.version()
867    }
868    fn fast_subset(&self, c: &Constraint) -> Option<Subset> {
869        self.base.fast_subset(c)
870    }
871}
872
873fn eval_constraint<const N: usize>(vals: &[Value; N], constraint: &Constraint) -> bool {
874    match constraint {
875        Constraint::Eq { l_col, r_col } => vals[l_col.index()] == vals[r_col.index()],
876        Constraint::EqConst { col, val } => vals[col.index()] == *val,
877        Constraint::LtConst { col, val } => vals[col.index()] < *val,
878        Constraint::GtConst { col, val } => vals[col.index()] > *val,
879        Constraint::LeConst { col, val } => vals[col.index()] <= *val,
880        Constraint::GeConst { col, val } => vals[col.index()] >= *val,
881    }
882}