egglog_core_relations/uf/
mod.rs

1//! A table implementation backed by a union-find.
2
3use std::{
4    any::Any,
5    mem::{self, ManuallyDrop},
6    sync::{Arc, Weak},
7};
8
9use crate::numeric_id::{DenseIdMap, NumericId};
10use crossbeam_queue::SegQueue;
11
12use crate::{
13    TableChange, TaggedRowBuffer,
14    action::ExecutionState,
15    common::{HashMap, Value},
16    offsets::{OffsetRange, RowId, Subset, SubsetRef},
17    pool::with_pool_set,
18    row_buffer::RowBuffer,
19    table_spec::{
20        ColumnId, Constraint, Generation, MutationBuffer, Offset, Rebuilder, Row, Table, TableSpec,
21        TableVersion, WrappedTableRef,
22    },
23};
24
25#[cfg(test)]
26mod tests;
27
28type UnionFind = crate::union_find::UnionFind<Value>;
29
30/// A special table backed by a union-find used to efficiently implement
31/// egglog-style canonicaliztion.
32///
33/// To canonicalize columns, we need to efficiently discover values that have
34/// ceased to be canonical. To do that we keep a table of _displaced_ values:
35///
36/// This table has three columns:
37/// 1. (the only key): a value that is _no longer canonical_ in the equivalence relation.
38/// 2. The canonical value of the equivalence class.
39/// 3. The timestamp at which the key stopped being canonical.
40///
41/// We do not store the second value explicitly: instead, we compute it
42/// on-the-fly using a union-find data-structure.
43///
44/// This is related to the 'Leader' encoding in some versions of egglog:
45/// Displaced is a version of Leader that _only_ stores ids when they cease to
46/// be canonical. Rows are also "automatically updated" with the current leader,
47/// rather than requiring the DB to replay history or canonicalize redundant
48/// values in the table.
49///
50/// To union new ids `l`, and `r`, stage an update `Displaced(l, r, ts)` where
51/// `ts` is the current timestamp. Note that all tie-breaks and other encoding
52/// decisions are made internally, so there may not literally be a row added
53/// with this value.
54pub struct DisplacedTable {
55    uf: UnionFind,
56    displaced: Vec<(Value, Value)>,
57    changed: bool,
58    lookup_table: HashMap<Value, RowId>,
59    buffered_writes: Arc<SegQueue<RowBuffer>>,
60}
61
62struct Canonicalizer<'a> {
63    cols: Vec<ColumnId>,
64    table: &'a DisplacedTable,
65}
66
67impl Rebuilder for Canonicalizer<'_> {
68    fn hint_col(&self) -> Option<ColumnId> {
69        Some(ColumnId::new(0))
70    }
71    fn rebuild_val(&self, val: Value) -> Value {
72        self.table.uf.find_naive(val)
73    }
74    fn rebuild_buf(
75        &self,
76        buf: &RowBuffer,
77        start: RowId,
78        end: RowId,
79        out: &mut TaggedRowBuffer,
80        _exec_state: &mut ExecutionState,
81    ) {
82        if start >= end {
83            return;
84        }
85        assert!(end.index() <= buf.len());
86        let mut cur = start;
87        let mut scratch = with_pool_set(|ps| ps.get::<Vec<Value>>());
88        // SAFETY: `cur` is always in-bounds, guaranteed by the above assertion.
89        // Special-case small columns: this gives us a modest speedup on rebuilding-heavy
90        // workloads.
91        match self.cols.as_slice() {
92            [c] => {
93                while cur < end {
94                    let row = unsafe { buf.get_row_unchecked(cur) };
95                    let to_canon = row[c.index()];
96                    let canon = self.table.uf.find_naive(to_canon);
97                    if canon != to_canon {
98                        scratch.extend_from_slice(row);
99                        scratch[c.index()] = canon;
100                        out.add_row(cur, &scratch);
101                        scratch.clear();
102                    }
103                    cur = cur.inc();
104                }
105            }
106            [c1, c2] => {
107                while cur < end {
108                    let row = unsafe { buf.get_row_unchecked(cur) };
109                    let v1 = row[c1.index()];
110                    let v2 = row[c2.index()];
111                    let ca1 = self.table.uf.find_naive(v1);
112                    let ca2 = self.table.uf.find_naive(v2);
113                    if ca1 != v1 || ca2 != v2 {
114                        scratch.extend_from_slice(row);
115                        scratch[c1.index()] = ca1;
116                        scratch[c2.index()] = ca2;
117                        out.add_row(cur, &scratch);
118                        scratch.clear();
119                    }
120                    cur = cur.inc();
121                }
122            }
123            [c1, c2, c3] => {
124                while cur < end {
125                    let row = unsafe { buf.get_row_unchecked(cur) };
126                    let v1 = row[c1.index()];
127                    let v2 = row[c2.index()];
128                    let v3 = row[c3.index()];
129                    let ca1 = self.table.uf.find_naive(v1);
130                    let ca2 = self.table.uf.find_naive(v2);
131                    let ca3 = self.table.uf.find_naive(v3);
132                    if ca1 != v1 || ca2 != v2 || ca3 != v3 {
133                        scratch.extend_from_slice(row);
134                        scratch[c1.index()] = ca1;
135                        scratch[c2.index()] = ca2;
136                        scratch[c3.index()] = ca3;
137                        out.add_row(cur, &scratch);
138                        scratch.clear();
139                    }
140                    cur = cur.inc();
141                }
142            }
143            cs => {
144                while cur < end {
145                    scratch.extend_from_slice(unsafe { buf.get_row_unchecked(cur) });
146                    let mut changed = false;
147                    for c in cs {
148                        let to_canon = scratch[c.index()];
149                        let canon = self.table.uf.find_naive(to_canon);
150                        scratch[c.index()] = canon;
151                        changed |= canon != to_canon;
152                    }
153                    if changed {
154                        out.add_row(cur, &scratch);
155                    }
156                    scratch.clear();
157                    cur = cur.inc();
158                }
159            }
160        }
161    }
162    fn rebuild_subset(
163        &self,
164        other: WrappedTableRef,
165        subset: SubsetRef,
166        out: &mut TaggedRowBuffer,
167        _exec_state: &mut ExecutionState,
168    ) {
169        let old_len = u32::try_from(out.len()).expect("row buffer sizes should fit in a u32");
170        let _next = other.scan_bounded(subset, Offset::new(0), usize::MAX, out);
171        debug_assert!(_next.is_none());
172        for i in old_len..u32::try_from(out.len()).expect("row buffer sizes should fit in a u32") {
173            let i = RowId::new(i);
174            let (_id, row) = out.get_row_mut(i);
175            let mut changed = false;
176            for col in &self.cols {
177                let to_canon = row[col.index()];
178                let canon = self.table.uf.find_naive(to_canon);
179                changed |= canon != to_canon;
180                row[col.index()] = canon;
181            }
182            if !changed {
183                out.set_stale(i);
184            }
185        }
186    }
187    fn rebuild_slice(&self, vals: &mut [Value]) -> bool {
188        let mut changed = false;
189        for val in vals {
190            let canon = self.table.uf.find_naive(*val);
191            changed |= canon != *val;
192            *val = canon;
193        }
194        changed
195    }
196}
197
198impl Default for DisplacedTable {
199    fn default() -> Self {
200        Self {
201            uf: UnionFind::default(),
202            displaced: Vec::new(),
203            changed: false,
204            lookup_table: HashMap::default(),
205            buffered_writes: Arc::new(SegQueue::new()),
206        }
207    }
208}
209
210impl Clone for DisplacedTable {
211    fn clone(&self) -> Self {
212        DisplacedTable {
213            uf: self.uf.clone(),
214            displaced: self.displaced.clone(),
215            changed: self.changed,
216            lookup_table: self.lookup_table.clone(),
217            buffered_writes: Default::default(),
218        }
219    }
220}
221
222struct UfBuffer {
223    to_insert: ManuallyDrop<RowBuffer>,
224    buffered_writes: Weak<SegQueue<RowBuffer>>,
225}
226
227impl Drop for UfBuffer {
228    fn drop(&mut self) {
229        let Some(buffered_writes) = self.buffered_writes.upgrade() else {
230            // SAFETY: If we can't write updates, manually drop to_insert
231            unsafe {
232                ManuallyDrop::drop(&mut self.to_insert);
233            }
234            return;
235        };
236        // SAFETY: self.to_insert will not be used again after this point.
237        //
238        // This avoids creating a fresh row buffer via `mem::take` or `mem::swap` and
239        // dropping it immediately.
240        let to_insert = unsafe { ManuallyDrop::take(&mut self.to_insert) };
241        buffered_writes.push(to_insert);
242    }
243}
244
245impl MutationBuffer for UfBuffer {
246    fn stage_insert(&mut self, row: &[Value]) {
247        self.to_insert.add_row(row);
248    }
249    fn stage_remove(&mut self, _: &[Value]) {
250        panic!("attempting to remove data from a DisplacedTable")
251    }
252    fn fresh_handle(&self) -> Box<dyn MutationBuffer> {
253        Box::new(UfBuffer {
254            to_insert: ManuallyDrop::new(RowBuffer::new(self.to_insert.arity())),
255            buffered_writes: self.buffered_writes.clone(),
256        })
257    }
258}
259
260impl Table for DisplacedTable {
261    fn dyn_clone(&self) -> Box<dyn Table> {
262        Box::new(self.clone())
263    }
264    fn as_any(&self) -> &dyn Any {
265        self
266    }
267    fn spec(&self) -> TableSpec {
268        let mut uncacheable_columns = DenseIdMap::default();
269        // The second column of this table is determined dynamically by the union-find.
270        uncacheable_columns.insert(ColumnId::new(1), true);
271        TableSpec {
272            n_keys: 1,
273            n_vals: 2,
274            uncacheable_columns,
275            allows_delete: false,
276        }
277    }
278
279    fn rebuilder<'a>(&'a self, cols: &[ColumnId]) -> Option<Box<dyn Rebuilder + 'a>> {
280        Some(Box::new(Canonicalizer {
281            cols: cols.to_vec(),
282            table: self,
283        }))
284    }
285
286    fn clear(&mut self) {
287        self.uf.reset();
288        self.displaced.clear();
289    }
290
291    fn all(&self) -> Subset {
292        Subset::Dense(OffsetRange::new(
293            RowId::new(0),
294            RowId::from_usize(self.displaced.len()),
295        ))
296    }
297
298    fn len(&self) -> usize {
299        self.displaced.len()
300    }
301
302    fn version(&self) -> TableVersion {
303        TableVersion {
304            major: Generation::new(0),
305            minor: Offset::from_usize(self.displaced.len()),
306        }
307    }
308
309    fn updates_since(&self, offset: Offset) -> Subset {
310        Subset::Dense(OffsetRange::new(
311            RowId::from_usize(offset.index()),
312            RowId::from_usize(self.displaced.len()),
313        ))
314    }
315
316    fn scan_generic_bounded(
317        &self,
318        subset: SubsetRef,
319        start: Offset,
320        n: usize,
321        cs: &[Constraint],
322        mut f: impl FnMut(RowId, &[Value]),
323    ) -> Option<Offset>
324    where
325        Self: Sized,
326    {
327        if cs.is_empty() {
328            let start = start.index();
329            subset
330                .iter_bounded(start, start + n, |row| {
331                    f(row, self.expand(row).as_slice());
332                })
333                .map(Offset::from_usize)
334        } else {
335            let start = start.index();
336            subset
337                .iter_bounded(start, start + n, |row| {
338                    if cs.iter().all(|c| self.eval(c, row)) {
339                        f(row, self.expand(row).as_slice());
340                    }
341                })
342                .map(Offset::from_usize)
343        }
344    }
345
346    fn refine_one(&self, mut subset: Subset, c: &Constraint) -> Subset {
347        subset.retain(|row| self.eval(c, row));
348        subset
349    }
350
351    fn fast_subset(&self, constraint: &Constraint) -> Option<Subset> {
352        let ts = ColumnId::new(2);
353        match constraint {
354            Constraint::Eq { .. } => None,
355            Constraint::EqConst { col, val } => {
356                if *col == ColumnId::new(1) {
357                    return None;
358                }
359                if *col == ColumnId::new(0) {
360                    return Some(match self.lookup_table.get(val) {
361                        Some(row) => Subset::Dense(OffsetRange::new(
362                            *row,
363                            RowId::from_usize(row.index() + 1),
364                        )),
365                        None => Subset::empty(),
366                    });
367                }
368                match self.timestamp_bounds(*val) {
369                    Ok((start, end)) => Some(Subset::Dense(OffsetRange::new(start, end))),
370                    Err(_) => None,
371                }
372            }
373            Constraint::LtConst { col, val } => {
374                if *col != ts {
375                    return None;
376                }
377                match self.timestamp_bounds(*val) {
378                    Err(bound) | Ok((bound, _)) => {
379                        Some(Subset::Dense(OffsetRange::new(RowId::new(0), bound)))
380                    }
381                }
382            }
383            Constraint::GtConst { col, val } => {
384                if *col != ts {
385                    return None;
386                }
387
388                match self.timestamp_bounds(*val) {
389                    Err(bound) | Ok((_, bound)) => Some(Subset::Dense(OffsetRange::new(
390                        bound,
391                        RowId::from_usize(self.displaced.len()),
392                    ))),
393                }
394            }
395            Constraint::LeConst { col, val } => {
396                if *col != ts {
397                    return None;
398                }
399
400                match self.timestamp_bounds(*val) {
401                    Err(bound) | Ok((_, bound)) => {
402                        Some(Subset::Dense(OffsetRange::new(RowId::new(0), bound)))
403                    }
404                }
405            }
406            Constraint::GeConst { col, val } => {
407                if *col != ts {
408                    return None;
409                }
410
411                match self.timestamp_bounds(*val) {
412                    Err(bound) | Ok((bound, _)) => Some(Subset::Dense(OffsetRange::new(
413                        bound,
414                        RowId::from_usize(self.displaced.len()),
415                    ))),
416                }
417            }
418        }
419    }
420
421    fn get_row(&self, key: &[Value]) -> Option<Row> {
422        assert_eq!(key.len(), 1, "attempt to lookup a row with the wrong key");
423        let row_id = *self.lookup_table.get(&key[0])?;
424        let mut vals = with_pool_set(|ps| ps.get::<Vec<Value>>());
425        vals.extend_from_slice(self.expand(row_id).as_slice());
426        Some(Row { id: row_id, vals })
427    }
428
429    fn get_row_column(&self, key: &[Value], col: ColumnId) -> Option<Value> {
430        assert_eq!(key.len(), 1, "attempt to lookup a row with the wrong key");
431        if col == ColumnId::new(1) {
432            Some(self.uf.find_naive(key[0]))
433        } else {
434            let row_id = *self.lookup_table.get(&key[0])?;
435            Some(self.expand(row_id)[col.index()])
436        }
437    }
438
439    fn new_buffer(&self) -> Box<dyn MutationBuffer> {
440        Box::new(UfBuffer {
441            to_insert: ManuallyDrop::new(RowBuffer::new(3)),
442            buffered_writes: Arc::downgrade(&self.buffered_writes),
443        })
444    }
445
446    fn merge(&mut self, _: &mut ExecutionState) -> TableChange {
447        while let Some(rowbuf) = self.buffered_writes.pop() {
448            for row in rowbuf.iter() {
449                self.changed |= self.insert_impl(row).is_some();
450            }
451        }
452        let changed = mem::take(&mut self.changed);
453        // UF table rows can be updated "in place", we count both added and removed as changed in
454        // this case.
455        TableChange {
456            added: changed,
457            removed: changed,
458        }
459    }
460}
461
462impl DisplacedTable {
463    pub fn underlying_uf(&self) -> &UnionFind {
464        &self.uf
465    }
466    fn expand(&self, row: RowId) -> [Value; 3] {
467        let (child, ts) = self.displaced[row.index()];
468        [child, self.uf.find_naive(child), ts]
469    }
470    fn timestamp_bounds(&self, val: Value) -> Result<(RowId, RowId), RowId> {
471        match self.displaced.binary_search_by_key(&val, |(_, ts)| *ts) {
472            Ok(mut off) => {
473                let mut next = off;
474                while off > 0 && self.displaced[off - 1].1 == val {
475                    off -= 1;
476                }
477                while next < self.displaced.len() && self.displaced[next].1 == val {
478                    next += 1;
479                }
480                Ok((RowId::from_usize(off), RowId::from_usize(next)))
481            }
482            Err(off) => Err(RowId::from_usize(off)),
483        }
484    }
485    fn eval(&self, constraint: &Constraint, row: RowId) -> bool {
486        let vals = self.expand(row);
487        eval_constraint(&vals, constraint)
488    }
489    fn insert_impl(&mut self, row: &[Value]) -> Option<(Value, Value)> {
490        assert_eq!(row.len(), 3, "attempt to insert a row with the wrong arity");
491        if self.uf.find(row[0]) == self.uf.find(row[1]) {
492            return None;
493        }
494        let (parent, child) = self.uf.union(row[0], row[1]);
495
496        // Compress paths somewhat, given that we perform naive finds everywhere else.
497        let _ = self.uf.find(parent);
498        let _ = self.uf.find(child);
499        let ts = row[2];
500        if let Some((_, highest)) = self.displaced.last() {
501            assert!(
502                *highest <= ts,
503                "must insert rows with increasing timestamps"
504            );
505        }
506        let next = RowId::from_usize(self.displaced.len());
507        self.displaced.push((child, ts));
508        self.lookup_table.insert(child, next);
509        Some((parent, child))
510    }
511}
512
513fn eval_constraint<const N: usize>(vals: &[Value; N], constraint: &Constraint) -> bool {
514    match constraint {
515        Constraint::Eq { l_col, r_col } => vals[l_col.index()] == vals[r_col.index()],
516        Constraint::EqConst { col, val } => vals[col.index()] == *val,
517        Constraint::LtConst { col, val } => vals[col.index()] < *val,
518        Constraint::GtConst { col, val } => vals[col.index()] > *val,
519        Constraint::LeConst { col, val } => vals[col.index()] <= *val,
520        Constraint::GeConst { col, val } => vals[col.index()] >= *val,
521    }
522}