egglog_core_relations/table/
mod.rs

1//! A generic table implementation supporting sorted writes.
2//!
3//! The primary difference between this table and the `Function` implementation
4//! in egglog is that high level concepts like "timestamp" and "merge function"
5//! are abstracted away from the core functionality of the table.
6
7use std::{
8    any::Any,
9    cmp,
10    hash::Hasher,
11    mem,
12    sync::{
13        Arc, Weak,
14        atomic::{AtomicUsize, Ordering},
15    },
16};
17
18use crate::numeric_id::{DenseIdMap, NumericId};
19use crossbeam_queue::SegQueue;
20use hashbrown::HashTable;
21use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
22use rustc_hash::FxHasher;
23use sharded_hash_table::ShardedHashTable;
24
25use crate::{
26    Pooled, TableChange, TableId,
27    action::ExecutionState,
28    common::{HashMap, ShardData, ShardId, SubsetTracker, Value},
29    hash_index::{ColumnIndex, Index},
30    offsets::{OffsetRange, Offsets, RowId, Subset, SubsetRef},
31    parallel_heuristics::parallelize_table_op,
32    pool::with_pool_set,
33    row_buffer::{ParallelRowBufWriter, RowBuffer},
34    table_spec::{
35        ColumnId, Constraint, Generation, MutationBuffer, Offset, Row, Table, TableSpec,
36        TableVersion,
37    },
38};
39
40mod rebuild;
41mod sharded_hash_table;
42#[cfg(test)]
43mod tests;
44
45// NB: Having this type def lets us switch between 64 and 32 bits of hashcode.
46//
47// We should consider just using u64 everywhere though. Hashbrown doesn't play nicely with 32-bit
48// hashcodes because it uses both the high and low bits of a 64-bit code.
49
50type HashCode = u64;
51
52/// A pointer to a row in the table.
53#[derive(Clone, Debug)]
54pub(crate) struct TableEntry {
55    hashcode: HashCode,
56    row: RowId,
57}
58
59impl TableEntry {
60    fn hashcode(&self) -> u64 {
61        // We keep the cast here to make it easy to switch to HashCode=u32.
62        #[allow(clippy::unnecessary_cast)]
63        {
64            self.hashcode as u64
65        }
66    }
67}
68
69/// The core data for a table.
70///
71/// This type is a thin wrapper around `RowBuffer`. The big difference is that
72/// it keeps track of how many stale rows are present.
73#[derive(Clone)]
74struct Rows {
75    data: RowBuffer,
76    scratch: RowBuffer,
77    stale_rows: usize,
78}
79
80impl Rows {
81    fn new(data: RowBuffer) -> Rows {
82        let arity = data.arity();
83        Rows {
84            data,
85            scratch: RowBuffer::new(arity),
86            stale_rows: 0,
87        }
88    }
89    fn clear(&mut self) {
90        self.data.clear();
91        self.stale_rows = 0;
92    }
93    fn next_row(&self) -> RowId {
94        RowId::from_usize(self.data.len())
95    }
96    fn set_stale(&mut self, row: RowId) {
97        if !self.data.set_stale(row) {
98            self.stale_rows += 1;
99        }
100    }
101
102    fn get_row(&self, row: RowId) -> Option<&[Value]> {
103        let row = self.data.get_row(row);
104        if row[0].is_stale() { None } else { Some(row) }
105    }
106
107    /// A variant of `get_row` without bounds-checking on `row`.
108    unsafe fn get_row_unchecked(&self, row: RowId) -> Option<&[Value]> {
109        let row = unsafe { self.data.get_row_unchecked(row) };
110        if row[0].is_stale() { None } else { Some(row) }
111    }
112
113    fn add_row(&mut self, row: &[Value]) -> RowId {
114        if row[0].is_stale() {
115            self.stale_rows += 1;
116        }
117        self.data.add_row(row)
118    }
119
120    fn remove_stale(&mut self, remap: impl FnMut(&[Value], RowId, RowId)) {
121        self.data.remove_stale(remap);
122        self.stale_rows = 0;
123    }
124}
125
126/// The type of closures that are used to merge values in a [`SortedWritesTable`].
127///
128/// The first argument grants access to database using an [`ExecutionState`], the second argument
129/// is the current value of the tuple. The third argument is the new, or "incoming" value of the
130/// tuple. The fourth argument is a mutable reference to a vector that will be used to store the
131/// output of the merge function _if_ it changes the value of the tuple. If it does not, then the
132/// merge function should return `false`.
133pub type MergeFn =
134    dyn Fn(&mut ExecutionState, &[Value], &[Value], &mut Vec<Value>) -> bool + Send + Sync;
135
136pub struct SortedWritesTable {
137    generation: Generation,
138    data: Rows,
139    hash: ShardedHashTable<TableEntry>,
140
141    n_keys: usize,
142    n_columns: usize,
143    sort_by: Option<ColumnId>,
144    offsets: Vec<(Value, RowId)>,
145
146    pending_state: Arc<PendingState>,
147    merge: Arc<MergeFn>,
148    to_rebuild: Vec<ColumnId>,
149    rebuild_index: Index<ColumnIndex>,
150    // Used to manage incremental rebuilds.
151    subset_tracker: SubsetTracker,
152}
153
154impl Clone for SortedWritesTable {
155    fn clone(&self) -> SortedWritesTable {
156        SortedWritesTable {
157            generation: self.generation,
158            data: self.data.clone(),
159            hash: self.hash.clone(),
160            n_keys: self.n_keys,
161            n_columns: self.n_columns,
162            sort_by: self.sort_by,
163            offsets: self.offsets.clone(),
164            pending_state: Arc::new(self.pending_state.deep_copy()),
165            merge: self.merge.clone(),
166            to_rebuild: self.to_rebuild.clone(),
167            rebuild_index: Index::new(self.to_rebuild.clone(), ColumnIndex::new()),
168            subset_tracker: Default::default(),
169        }
170    }
171}
172
173/// A variant of [`RowBuffer`] that can handle arity 0.
174///
175/// We use this to handle empty keys, where the deletion API needs to handle "row buffers of empty
176/// rows". The goal here is to keep most of the API RowBuffer-centric and avoid complicating the
177/// code too much: actual code that was optimized to handle arity 0 would look a bit different.
178#[derive(Clone)]
179enum ArbitraryRowBuffer {
180    NonEmpty(RowBuffer),
181    Empty { rows: usize },
182}
183
184impl ArbitraryRowBuffer {
185    fn new(arity: usize) -> ArbitraryRowBuffer {
186        if arity == 0 {
187            ArbitraryRowBuffer::Empty { rows: 0 }
188        } else {
189            ArbitraryRowBuffer::NonEmpty(RowBuffer::new(arity))
190        }
191    }
192
193    fn add_row(&mut self, row: &[Value]) {
194        match self {
195            ArbitraryRowBuffer::NonEmpty(buf) => {
196                buf.add_row(row);
197            }
198            ArbitraryRowBuffer::Empty { rows } => {
199                *rows += 1;
200            }
201        }
202    }
203
204    fn len(&self) -> usize {
205        match self {
206            ArbitraryRowBuffer::NonEmpty(buf) => buf.len(),
207            ArbitraryRowBuffer::Empty { rows } => *rows,
208        }
209    }
210
211    fn for_each(&self, mut f: impl FnMut(&[Value])) {
212        match self {
213            ArbitraryRowBuffer::NonEmpty(buf) => {
214                for row in buf.iter() {
215                    f(row);
216                }
217            }
218            ArbitraryRowBuffer::Empty { rows } => {
219                for _ in 0..*rows {
220                    f(&[]);
221                }
222            }
223        }
224    }
225}
226
227struct Buffer {
228    pending_rows: DenseIdMap<ShardId, RowBuffer>,
229    pending_removals: DenseIdMap<ShardId, ArbitraryRowBuffer>,
230    state: Weak<PendingState>,
231    n_cols: u32,
232    n_keys: u32,
233    shard_data: ShardData,
234}
235
236impl MutationBuffer for Buffer {
237    fn stage_insert(&mut self, row: &[Value]) {
238        let (shard, _) = hash_code(self.shard_data, row, self.n_keys as _);
239        self.pending_rows
240            .get_or_insert(shard, || RowBuffer::new(self.n_cols as _))
241            .add_row(row);
242    }
243    fn stage_remove(&mut self, key: &[Value]) {
244        let (shard, _) = hash_code(self.shard_data, key, self.n_keys as _);
245        self.pending_removals
246            .get_or_insert(shard, || ArbitraryRowBuffer::new(self.n_keys as _))
247            .add_row(key);
248    }
249    fn fresh_handle(&self) -> Box<dyn MutationBuffer> {
250        Box::new(Buffer {
251            pending_rows: Default::default(),
252            pending_removals: Default::default(),
253            state: self.state.clone(),
254            n_cols: self.n_cols,
255            n_keys: self.n_keys,
256            shard_data: self.shard_data,
257        })
258    }
259}
260
261impl Drop for Buffer {
262    fn drop(&mut self) {
263        if let Some(state) = self.state.upgrade() {
264            let mut rows = 0;
265            for shard_id in 0..self.pending_rows.n_ids() {
266                let shard = ShardId::from_usize(shard_id);
267                let Some(buf) = self.pending_rows.take(shard) else {
268                    continue;
269                };
270                rows += buf.len();
271                state.pending_rows[shard].push(buf);
272            }
273            state.total_rows.fetch_add(rows, Ordering::Relaxed);
274
275            let mut rows = 0;
276            for shard_id in 0..self.pending_removals.n_ids() {
277                let shard = ShardId::from_usize(shard_id);
278                let Some(buf) = self.pending_removals.take(shard) else {
279                    continue;
280                };
281                rows += buf.len();
282                state.pending_removals[shard].push(buf);
283            }
284            state.total_removals.fetch_add(rows, Ordering::Relaxed);
285        }
286    }
287}
288
289impl Table for SortedWritesTable {
290    fn dyn_clone(&self) -> Box<dyn Table> {
291        Box::new(self.clone())
292    }
293    fn as_any(&self) -> &dyn Any {
294        self
295    }
296    fn clear(&mut self) {
297        self.pending_state.clear();
298        if self.data.data.len() == 0 {
299            return;
300        }
301        self.offsets.clear();
302        self.data.clear();
303        self.hash.clear();
304        self.generation = Generation::from_usize(self.version().major.index() + 1);
305    }
306
307    fn spec(&self) -> TableSpec {
308        TableSpec {
309            n_keys: self.n_keys,
310            n_vals: self.n_columns - self.n_keys,
311            uncacheable_columns: Default::default(),
312            allows_delete: true,
313        }
314    }
315
316    fn apply_rebuild(
317        &mut self,
318        table_id: TableId,
319        table: &crate::WrappedTable,
320        next_ts: Value,
321        exec_state: &mut ExecutionState,
322    ) -> bool {
323        self.do_rebuild(table_id, table, next_ts, exec_state)
324    }
325
326    fn refresh_rows_for_values(&mut self, dirty_ids: &[Value], next_ts: Value) -> bool {
327        SortedWritesTable::refresh_rows_for_values(self, dirty_ids, next_ts)
328    }
329
330    fn version(&self) -> TableVersion {
331        TableVersion {
332            major: self.generation,
333            minor: Offset::from_usize(self.data.next_row().index()),
334        }
335    }
336
337    fn updates_since(&self, offset: Offset) -> Subset {
338        Subset::Dense(OffsetRange::new(
339            RowId::from_usize(offset.index()),
340            self.data.next_row(),
341        ))
342    }
343
344    fn all(&self) -> Subset {
345        Subset::Dense(OffsetRange::new(RowId::new(0), self.data.next_row()))
346    }
347
348    fn has_stale_rows(&self) -> bool {
349        self.data.stale_rows > 0
350    }
351
352    fn len(&self) -> usize {
353        self.data.data.len() - self.data.stale_rows
354    }
355
356    fn scan_generic(&self, subset: SubsetRef, mut f: impl FnMut(RowId, &[Value]))
357    where
358        Self: Sized,
359    {
360        let Some((_low, hi)) = subset.bounds() else {
361            // Empty subset
362            return;
363        };
364        assert!(
365            hi.index() <= self.data.data.len(),
366            "{} vs. {}",
367            hi.index(),
368            self.data.data.len()
369        );
370        if self.data.stale_rows == 0 {
371            // Fast path: no stale rows, skip is_stale check per row.
372            // SAFETY: subsets are sorted, low must be at most hi, and hi is less
373            // than the length of the table.
374            // TODO: provide a safe API for this `get_row_unchecked` usage since we have
375            // checked the full bounds above.
376            subset.offsets(|row| unsafe { f(row, self.data.data.get_row_unchecked(row)) })
377        } else {
378            // SAFETY: same as above.
379            subset.offsets(|row| unsafe {
380                if let Some(vals) = self.data.get_row_unchecked(row) {
381                    f(row, vals)
382                }
383            })
384        }
385    }
386
387    fn scan_generic_bounded(
388        &self,
389        subset: SubsetRef,
390        start: Offset,
391        n: usize,
392        cs: &[Constraint],
393        mut f: impl FnMut(RowId, &[Value]),
394    ) -> Option<Offset>
395    where
396        Self: Sized,
397    {
398        let Some((_low, hi)) = subset.bounds() else {
399            // Empty subset
400            return None;
401        };
402        assert!(
403            hi.index() <= self.data.data.len(),
404            "{} vs. {}",
405            hi.index(),
406            self.data.data.len()
407        );
408        if cs.is_empty() {
409            if self.data.stale_rows == 0 {
410                // Fast path: no stale rows, skip bounds check and is_stale check.
411                // SAFETY: all row IDs are in-bounds.
412                subset
413                    .iter_bounded(start.index(), start.index() + n, |row| {
414                        let entry = unsafe { self.data.data.get_row_unchecked(row) };
415                        f(row, entry);
416                    })
417                    .map(Offset::from_usize)
418            } else {
419                subset
420                    .iter_bounded(start.index(), start.index() + n, |row| {
421                        // SAFETY: all row IDs are in-bounds.
422                        let Some(entry) = (unsafe { self.data.get_row_unchecked(row) }) else {
423                            return;
424                        };
425                        f(row, entry);
426                    })
427                    .map(Offset::from_usize)
428            }
429        } else {
430            subset
431                .iter_bounded(start.index(), start.index() + n, |row| {
432                    // SAFETY: all row IDs are in-bounds.
433                    let Some(entry) = (unsafe { self.get_if_unchecked(cs, row) }) else {
434                        return;
435                    };
436                    f(row, entry);
437                })
438                .map(Offset::from_usize)
439        }
440    }
441
442    fn fast_subset(&self, constraint: &Constraint) -> Option<Subset> {
443        let sort_by = self.sort_by?;
444        match constraint {
445            Constraint::Eq { .. } => None,
446            Constraint::EqConst { col, val } => {
447                if col == &sort_by {
448                    match self.binary_search_sort_val(*val) {
449                        Ok((found, bound)) => Some(Subset::Dense(OffsetRange::new(found, bound))),
450                        Err(_) => Some(Subset::empty()),
451                    }
452                } else {
453                    None
454                }
455            }
456            Constraint::LtConst { col, val } => {
457                if col == &sort_by {
458                    match self.binary_search_sort_val(*val) {
459                        Ok((found, _)) => {
460                            Some(Subset::Dense(OffsetRange::new(RowId::new(0), found)))
461                        }
462                        Err(next) => Some(Subset::Dense(OffsetRange::new(RowId::new(0), next))),
463                    }
464                } else {
465                    None
466                }
467            }
468            Constraint::GtConst { col, val } => {
469                if col == &sort_by {
470                    match self.binary_search_sort_val(*val) {
471                        Ok((_, bound)) => {
472                            Some(Subset::Dense(OffsetRange::new(bound, self.data.next_row())))
473                        }
474                        Err(next) => {
475                            Some(Subset::Dense(OffsetRange::new(next, self.data.next_row())))
476                        }
477                    }
478                } else {
479                    None
480                }
481            }
482            Constraint::LeConst { col, val } => {
483                if col == &sort_by {
484                    match self.binary_search_sort_val(*val) {
485                        Ok((_, bound)) => {
486                            Some(Subset::Dense(OffsetRange::new(RowId::new(0), bound)))
487                        }
488                        Err(next) => Some(Subset::Dense(OffsetRange::new(RowId::new(0), next))),
489                    }
490                } else {
491                    None
492                }
493            }
494            Constraint::GeConst { col, val } => {
495                if col == &sort_by {
496                    match self.binary_search_sort_val(*val) {
497                        Ok((found, _)) => {
498                            Some(Subset::Dense(OffsetRange::new(found, self.data.next_row())))
499                        }
500                        Err(next) => {
501                            Some(Subset::Dense(OffsetRange::new(next, self.data.next_row())))
502                        }
503                    }
504                } else {
505                    None
506                }
507            }
508        }
509    }
510
511    fn refine_one(&self, mut subset: Subset, c: &Constraint) -> Subset {
512        // NB: we aren't using any of the `fast_subset` tricks here. We may want
513        // to if the higher-level implementations end up using it directly.
514        subset.retain(|row| self.eval(std::slice::from_ref(c), row));
515        subset
516    }
517
518    fn new_buffer(&self) -> Box<dyn MutationBuffer> {
519        let n_shards = self.hash.shard_data().n_shards();
520        Box::new(Buffer {
521            pending_rows: DenseIdMap::with_capacity(n_shards),
522            pending_removals: DenseIdMap::with_capacity(n_shards),
523            state: Arc::downgrade(&self.pending_state),
524            n_keys: u32::try_from(self.n_keys).expect("n_keys should fit in u32"),
525            n_cols: u32::try_from(self.n_columns).expect("n_columns should fit in u32"),
526            shard_data: self.hash.shard_data(),
527        })
528    }
529
530    fn merge(&mut self, exec_state: &mut ExecutionState) -> TableChange {
531        let removed = self.do_delete();
532        let added = self.do_insert(exec_state);
533        self.maybe_rehash();
534        TableChange { removed, added }
535    }
536
537    fn get_row(&self, key: &[Value]) -> Option<Row> {
538        let id = get_entry(key, self.n_keys, &self.hash, |row| {
539            &self.data.get_row(row).unwrap()[0..self.n_keys] == key
540        })?;
541        let mut vals = with_pool_set(|ps| ps.get::<Vec<Value>>());
542        vals.extend_from_slice(self.data.get_row(id).unwrap());
543        Some(Row { id, vals })
544    }
545
546    fn get_row_column(&self, key: &[Value], col: ColumnId) -> Option<Value> {
547        let id = get_entry(key, self.n_keys, &self.hash, |row| {
548            &self.data.get_row(row).unwrap()[0..self.n_keys] == key
549        })?;
550        Some(self.data.get_row(id).unwrap()[col.index()])
551    }
552}
553
554impl SortedWritesTable {
555    /// Create a new [`SortedWritesTable`] with the given number of keys,
556    /// columns, and an optional sort column.
557    ///
558    /// The `merge_fn` is used to evaluate conflicts when more than one row is
559    /// inserted with the same primary key. The old and new proposed values are
560    /// passed as the second and third arguments, respectively, with the
561    /// function filling the final argument with the contents of the new row.
562    /// The return value indicates whether or not the contents of the vector
563    /// should be used.
564    ///
565    /// Merge functions can access the database via [`ExecutionState`].
566    pub fn new(
567        n_keys: usize,
568        n_columns: usize,
569        sort_by: Option<ColumnId>,
570        to_rebuild: Vec<ColumnId>,
571        merge_fn: Box<MergeFn>,
572    ) -> Self {
573        let hash = ShardedHashTable::<TableEntry>::default();
574        let shard_data = hash.shard_data();
575        let rebuild_index = Index::new(to_rebuild.clone(), ColumnIndex::new());
576        SortedWritesTable {
577            generation: Generation::new(0),
578            data: Rows::new(RowBuffer::new(n_columns)),
579            hash,
580            n_keys,
581            n_columns,
582            sort_by,
583            offsets: Default::default(),
584            pending_state: Arc::new(PendingState::new(shard_data)),
585            merge: merge_fn.into(),
586            to_rebuild,
587            rebuild_index,
588            subset_tracker: Default::default(),
589        }
590    }
591
592    /// Flush all pending removals, in parallel.
593    fn parallel_delete(&mut self) -> bool {
594        let shard_data = self.hash.shard_data();
595        let stale_delta: usize = self
596            .hash
597            .mut_shards()
598            .par_iter_mut()
599            .enumerate()
600            .filter_map(|(shard_id, shard)| {
601                let shard_id = ShardId::from_usize(shard_id);
602                if self.pending_state.pending_removals[shard_id].is_empty() {
603                    return None;
604                }
605                Some((shard_id, shard))
606            })
607            .map(|(shard_id, shard)| {
608                let queue = &self.pending_state.pending_removals[shard_id];
609                let mut marked_stale = 0;
610                while let Some(buf) = queue.pop() {
611                    buf.for_each(|to_remove| {
612                        let (actual_shard, hc) = hash_code(shard_data, to_remove, self.n_keys);
613                        assert_eq!(actual_shard, shard_id);
614                        if let Ok(entry) = shard.find_entry(hc, |entry| {
615                            entry.hashcode == (hc as _)
616                                && &self.data.get_row(entry.row).unwrap()[0..self.n_keys]
617                                    == to_remove
618                        }) {
619                            let (ent, _) = entry.remove();
620                            // SAFETY: The safety requirements of
621                            // `set_stale_shared` are that there are no
622                            // concurrent accesses to `row`. No other threads
623                            // can access this row within this method because
624                            // different `shards` partition the space
625                            // (guaranteed by the assertion above), and we
626                            // launch at most one thread per shard.
627                            marked_stale +=
628                                unsafe { !self.data.data.set_stale_shared(ent.row) } as usize;
629                        }
630                    });
631                }
632                marked_stale
633            })
634            .sum();
635        // Update the stale count with the total marked stale.
636        self.data.stale_rows += stale_delta;
637        stale_delta > 0
638    }
639    fn serial_delete(&mut self) -> bool {
640        let shard_data = self.hash.shard_data();
641        let mut changed = false;
642        self.hash
643            .mut_shards()
644            .iter_mut()
645            .enumerate()
646            .for_each(|(shard_id, shard)| {
647                let shard_id = ShardId::from_usize(shard_id);
648                let queue = &self.pending_state.pending_removals[shard_id];
649                while let Some(buf) = queue.pop() {
650                    buf.for_each(|to_remove| {
651                        let (actual_shard, hc) = hash_code(shard_data, to_remove, self.n_keys);
652                        assert_eq!(actual_shard, shard_id);
653                        if let Ok(entry) = shard.find_entry(hc, |entry| {
654                            entry.hashcode == (hc as _)
655                                && &self.data.get_row(entry.row).unwrap()[0..self.n_keys]
656                                    == to_remove
657                        }) {
658                            let (ent, _) = entry.remove();
659                            self.data.set_stale(ent.row);
660                            changed = true;
661                        }
662                    })
663                }
664            });
665        changed
666    }
667
668    fn do_delete(&mut self) -> bool {
669        let total = self.pending_state.total_removals.swap(0, Ordering::Relaxed);
670
671        if parallelize_table_op(total) {
672            self.parallel_delete()
673        } else {
674            self.serial_delete()
675        }
676    }
677
678    fn do_insert(&mut self, exec_state: &mut ExecutionState) -> bool {
679        let total = self.pending_state.total_rows.swap(0, Ordering::Relaxed);
680        self.data.data.reserve(total);
681        if parallelize_table_op(total) {
682            if let Some(col) = self.sort_by {
683                self.parallel_insert(
684                    exec_state,
685                    SortChecker {
686                        col,
687                        current: None,
688                        baseline: self.offsets.last().map(|(v, _)| *v),
689                    },
690                )
691            } else {
692                self.parallel_insert(exec_state, ())
693            }
694        } else {
695            self.serial_insert(exec_state)
696        }
697    }
698
699    fn serial_insert(&mut self, exec_state: &mut ExecutionState) -> bool {
700        let mut changed = false;
701        let n_keys = self.n_keys;
702        let mut scratch = with_pool_set(|ps| ps.get::<Vec<Value>>());
703        for (_outer_shard, queue) in self.pending_state.pending_rows.iter() {
704            if let Some(sort_by) = self.sort_by {
705                while let Some(buf) = queue.pop() {
706                    for query in buf.non_stale() {
707                        let key = &query[0..n_keys];
708                        let entry = get_entry_mut(query, n_keys, &mut self.hash, |row| {
709                            let Some(row) = self.data.get_row(row) else {
710                                return false;
711                            };
712                            &row[0..n_keys] == key
713                        });
714
715                        if let Some(row) = entry {
716                            // First case: overwriting an existing value. Apply merge
717                            // function. Insert new row and update hash table if merge
718                            // changes anything.
719                            let cur = self
720                                .data
721                                .get_row(*row)
722                                .expect("table should not point to stale entry");
723                            if (self.merge)(exec_state, cur, query, &mut scratch) {
724                                let sort_val = query[sort_by.index()];
725                                let new = self.data.add_row(&scratch);
726                                if let Some(largest) = self.offsets.last().map(|(v, _)| *v) {
727                                    assert!(
728                                        sort_val >= largest,
729                                        "inserting row that violates sort order ({sort_val:?} vs. {largest:?})"
730                                    );
731                                    if sort_val > largest {
732                                        self.offsets.push((sort_val, new));
733                                    }
734                                } else {
735                                    self.offsets.push((sort_val, new));
736                                }
737                                self.data.set_stale(*row);
738                                *row = new;
739                                changed = true;
740                            }
741                            scratch.clear();
742                        } else {
743                            let sort_val = query[sort_by.index()];
744                            // New value: update invariants.
745                            let new = self.data.add_row(query);
746                            if let Some(largest) = self.offsets.last().map(|(v, _)| *v) {
747                                assert!(
748                                    sort_val >= largest,
749                                    "inserting row that violates sort order {sort_val:?} vs. {largest:?}"
750                                );
751                                if sort_val > largest {
752                                    self.offsets.push((sort_val, new));
753                                }
754                            } else {
755                                self.offsets.push((sort_val, new));
756                            }
757                            let (shard, hc) = hash_code(self.hash.shard_data(), query, self.n_keys);
758                            debug_assert_eq!(shard, _outer_shard);
759                            self.hash.mut_shards()[shard.index()].insert_unique(
760                                hc as _,
761                                TableEntry {
762                                    hashcode: hc as _,
763                                    row: new,
764                                },
765                                TableEntry::hashcode,
766                            );
767                            changed = true;
768                        }
769                    }
770                }
771            } else {
772                // Simplified variant without the sorting constraint.
773                while let Some(buf) = queue.pop() {
774                    for query in buf.non_stale() {
775                        let key = &query[0..n_keys];
776                        let entry = get_entry_mut(query, n_keys, &mut self.hash, |row| {
777                            let Some(row) = self.data.get_row(row) else {
778                                return false;
779                            };
780                            &row[0..n_keys] == key
781                        });
782
783                        if let Some(row) = entry {
784                            let cur = self
785                                .data
786                                .get_row(*row)
787                                .expect("table should not point to stale entry");
788                            if (self.merge)(exec_state, cur, query, &mut scratch) {
789                                let new = self.data.add_row(&scratch);
790                                self.data.set_stale(*row);
791                                *row = new;
792                                changed = true;
793                            }
794                            scratch.clear();
795                        } else {
796                            // New value: update invariants.
797                            let new = self.data.add_row(query);
798                            let (shard, hc) = hash_code(self.hash.shard_data(), query, self.n_keys);
799                            debug_assert_eq!(shard, _outer_shard);
800                            self.hash.mut_shards()[shard.index()].insert_unique(
801                                hc as _,
802                                TableEntry {
803                                    hashcode: hc as _,
804                                    row: new,
805                                },
806                                TableEntry::hashcode,
807                            );
808                            changed = true;
809                        }
810                    }
811                }
812            };
813        }
814        changed
815    }
816
817    fn parallel_insert<C: OrderingChecker>(
818        &mut self,
819        exec_state: &ExecutionState,
820        checker: C,
821    ) -> bool {
822        const BATCH_SIZE: usize = 1 << 18;
823        // Parallel insert uses one giant parallel foreach. We have updates
824        // pre-sharded, and one logical thread can process updates for each
825        // shard independently. Updates happen in three phases, which comments
826        // describe below.
827        let shard_data = self.hash.shard_data();
828        let n_keys = self.n_keys;
829        let n_cols = self.n_columns;
830        let next_offset = RowId::from_usize(self.data.data.len());
831        let row_writer = self.data.data.parallel_writer();
832        let pending_adds = self
833            .hash
834            .mut_shards()
835            .par_iter_mut()
836            .enumerate()
837            .map(|(shard_id, shard)| {
838                let shard_id = ShardId::from_usize(shard_id);
839                let mut checker = checker.clone();
840                let mut exec_state = exec_state.clone();
841                let mut scratch = with_pool_set(|ps| ps.get::<Vec<Value>>());
842                let queue = &self.pending_state.pending_rows[shard_id];
843                let mut marked_stale = 0usize;
844                let mut staged = StagedOutputs::new(n_keys, n_cols, BATCH_SIZE);
845                let mut changed = false;
846                // The core flush loop: We call once `staged` reaches `BATCH_SIZE` or
847                // when we're done.
848                macro_rules! flush_staged_outputs {
849                    () => {{
850                        // Phase 2: Write the staged rows to the row writer. This only
851                        // works due to the `ParallelRowBufWriter` machinery.
852                        let (start_row, stale) = staged.write_output(&row_writer);
853                        marked_stale += stale;
854                        // Phase 3: With the values buffered in the row buffer, we can
855                        // write them back to the shard, pointed to the correct rows.
856
857                        // In the serial implementation, we do phases 2 and 3 inline with
858                        // processing the incoming mutation, but separating them out
859                        // this way allows us to do a single write to the shared row
860                        // buffer, rather than one per row, which would cause
861                        // contention.
862                        let mut cur_row = start_row;
863                        let read_handle = row_writer.read_handle();
864                        for row in staged.rows() {
865                            if row.first().map(Value::is_stale).unwrap_or(false) {
866                                cur_row = cur_row.inc();
867                                continue;
868                            }
869                            use hashbrown::hash_table::Entry;
870                            checker.check_local(row);
871                            changed = true;
872                            let key = &row[0..n_keys];
873                            let (_actual_shard, hc) = hash_code(shard_data, row, n_keys);
874                            #[cfg(any(debug_assertions, test))]
875                            {
876                                unsafe {
877                                    // read the value we wrote at this row and
878                                    // check that it matches.
879                                    assert_eq!(read_handle.get_row_unchecked(cur_row), row);
880                                }
881                            }
882                            debug_assert_eq!(_actual_shard, shard_id);
883                            match shard.entry(
884                                hc,
885                                // SAFETY: `ent` must point to a valid row
886                                |ent| unsafe {
887                                    ent.hashcode == hc as HashCode
888                                        && &read_handle.get_row_unchecked(ent.row)[0..n_keys] == key
889                                },
890                                TableEntry::hashcode,
891                            ) {
892                                Entry::Occupied(mut occ) => {
893                                    // SAFETY: `occ` must point to a valid row: we only insert valid rows
894                                    // into the map.
895                                    let cur = unsafe { read_handle.get_row_unchecked(occ.get().row) };
896
897                                    // SAFETY: The safety requirements of
898                                    // `set_stale_shared` are that there are no
899                                    // concurrent accesses to `row`. We have
900                                    // exclusive access to any row whose hash matches this
901                                    // shard.
902                                    if (self.merge)(&mut exec_state, cur, row, &mut scratch) {
903                                        unsafe {
904                                            let _was_stale = read_handle.set_stale_shared(occ.get().row);
905                                            debug_assert!(!_was_stale);
906                                        }
907                                        occ.get_mut().row = cur_row;
908                                        changed = true;
909                                    } else {
910                                        // Mark the new row as stale: we didn't end up needing it.
911                                        unsafe {
912                                            let _was_stale = read_handle.set_stale_shared(cur_row);
913                                            debug_assert!(!_was_stale);
914                                        }
915                                    }
916                                    marked_stale += 1;
917                                    scratch.clear();
918                                }
919                                Entry::Vacant(v) => {
920                                    changed = true;
921                                    v.insert(TableEntry {
922                                        hashcode: hc as HashCode,
923                                        row: cur_row,
924                                    });
925                                }
926                            }
927
928                            cur_row = cur_row.inc();
929                        }
930                        staged.clear();
931                    }};
932                }
933                // Phase 1: process all incoming updates:
934                // * Add new values to `staged`
935                // * Removing entries in `shard` and mark them as stale in
936                // `data` if they will be overwritten.
937                while let Some(buf) = queue.pop() {
938                    // We create a read_handle once per batch to avoid blocking
939                    // too many threads if someone needs to resize the row
940                    // writer.
941                    for row in buf.non_stale() {
942                        staged.insert(row, |cur, new, out| {
943                            (self.merge)(&mut exec_state, cur, new, out)
944                        });
945                        if staged.len() >= BATCH_SIZE {
946                            flush_staged_outputs!();
947                        }
948                    }
949                }
950                flush_staged_outputs!();
951                (checker, marked_stale, changed)
952            })
953            .collect_vec_list();
954        self.data.data = row_writer.finish();
955        // Now we just need to reset our invariants.
956
957        // Confirm none of the writes violated sort order and update the
958        // `offsets` vector.
959        let checker = C::check_global(pending_adds.iter().flatten().map(|(checker, _, _)| checker));
960        checker.update_offsets(next_offset, &mut self.offsets);
961
962        // Update the staleness counters.
963        self.data.stale_rows += pending_adds
964            .iter()
965            .flatten()
966            .map(|(_, stale, _)| *stale)
967            .sum::<usize>();
968
969        // Register any changes.
970        pending_adds
971            .iter()
972            .flatten()
973            .any(|(_, _, changed)| *changed)
974    }
975
976    fn binary_search_sort_val(&self, val: Value) -> Result<(RowId, RowId), RowId> {
977        debug_assert!(
978            self.offsets.windows(2).all(|x| x[0].1 < x[1].1),
979            "{:?}",
980            self.offsets
981        );
982
983        debug_assert!(
984            self.offsets.windows(2).all(|x| x[0].0 < x[1].0),
985            "{:?}",
986            self.offsets
987        );
988        match self.offsets.binary_search_by_key(&val, |(v, _)| *v) {
989            Ok(got) => Ok((
990                self.offsets[got].1,
991                self.offsets
992                    .get(got + 1)
993                    .map(|(_, r)| *r)
994                    .unwrap_or(self.data.next_row()),
995            )),
996            Err(next) => Err(self
997                .offsets
998                .get(next)
999                .map(|(_, id)| *id)
1000                .unwrap_or(self.data.next_row())),
1001        }
1002    }
1003    fn eval(&self, cs: &[Constraint], row: RowId) -> bool {
1004        self.get_if(cs, row).is_some()
1005    }
1006
1007    fn eval_constraints(cs: &[Constraint], row: &[Value]) -> bool {
1008        cs.iter().all(|constraint| match constraint {
1009            Constraint::Eq { l_col, r_col } => row[l_col.index()] == row[r_col.index()],
1010            Constraint::EqConst { col, val } => row[col.index()] == *val,
1011            Constraint::LtConst { col, val } => row[col.index()] < *val,
1012            Constraint::GtConst { col, val } => row[col.index()] > *val,
1013            Constraint::LeConst { col, val } => row[col.index()] <= *val,
1014            Constraint::GeConst { col, val } => row[col.index()] >= *val,
1015        })
1016    }
1017
1018    unsafe fn get_if_unchecked(&self, cs: &[Constraint], row: RowId) -> Option<&[Value]> {
1019        let row = unsafe { self.data.data.get_row_unchecked(row) };
1020        if Self::eval_constraints(cs, row) {
1021            Some(row)
1022        } else {
1023            None
1024        }
1025    }
1026
1027    fn get_if(&self, cs: &[Constraint], row: RowId) -> Option<&[Value]> {
1028        let row = self.data.get_row(row)?;
1029        if Self::eval_constraints(cs, row) {
1030            Some(row)
1031        } else {
1032            None
1033        }
1034    }
1035
1036    fn maybe_rehash(&mut self) {
1037        if self.data.stale_rows <= cmp::max(16, self.data.data.len() / 2) {
1038            return;
1039        }
1040
1041        if parallelize_table_op(self.data.data.len()) {
1042            self.parallel_rehash();
1043        } else {
1044            self.rehash();
1045        }
1046    }
1047    fn parallel_rehash(&mut self) {
1048        use rayon::prelude::*;
1049        // Parallel rehashes go "hash-first" rather than "rows-first".
1050        //
1051        // We iterate over each shard and then write out new contents to a fresh row, in parallel.
1052        let Some(sort_by) = self.sort_by else {
1053            // Just do a serial rehash for now. We currently do not have a use-case for parallel
1054            // compaction of unsorted tables.
1055            //
1056            // Implementing parallel compaction for an unsorted table is much easier: each shard
1057            // can write to a contiguous chunk of the `scratch` buffer, with the offsets being
1058            // pre-chunked based on the size of each shard.
1059            self.rehash();
1060            return;
1061        };
1062        self.generation = self.generation.inc();
1063        assert!(!self.offsets.is_empty());
1064        struct TimestampStats {
1065            value: Value,
1066            count: usize,
1067            histogram: Pooled<DenseIdMap<ShardId, usize>>,
1068        }
1069        impl Default for TimestampStats {
1070            fn default() -> TimestampStats {
1071                TimestampStats {
1072                    value: Value::stale(),
1073                    count: 0,
1074                    histogram: with_pool_set(|ps| ps.get()),
1075                }
1076            }
1077        }
1078        let mut results = Vec::<TimestampStats>::with_capacity(self.offsets.len());
1079        results.resize_with(self.offsets.len() - 1, Default::default);
1080        // Use a macro rather than a lambda to avoid borrow issues.
1081        macro_rules! compute_hist {
1082            ($start_val: expr, $start_row: expr, $end_row: expr) => {{
1083                let mut histogram: Pooled<DenseIdMap<ShardId, usize>> =
1084                    with_pool_set(|ps| ps.get());
1085                let mut cur_row = $start_row;
1086                let mut count = 0;
1087                while cur_row < $end_row {
1088                    if let Some(row) = self.data.get_row(cur_row) {
1089                        count += 1;
1090                        let (shard, _) = hash_code(self.hash.shard_data(), row, self.n_keys);
1091                        *histogram.get_or_default(shard) += 1;
1092                    }
1093                    cur_row = cur_row.inc();
1094                }
1095                TimestampStats {
1096                    value: $start_val,
1097                    count,
1098                    histogram,
1099                }
1100            }};
1101        }
1102        let mut last: TimestampStats = Default::default();
1103        rayon::join(
1104            || {
1105                // This closure handles computing all timestamps but the last one.
1106                self.offsets
1107                    .windows(2)
1108                    .zip(results.iter_mut())
1109                    .par_bridge()
1110                    .for_each(|(xs, res)| {
1111                        let [(start_val, start_row), (_, end_row)] = xs else {
1112                            unreachable!()
1113                        };
1114                        *res = compute_hist!(*start_val, *start_row, *end_row);
1115                    })
1116            },
1117            || {
1118                // And here we handle the final one.
1119                let (start_val, start_row) = self.offsets.last().unwrap();
1120                let end_row = self.data.next_row();
1121                last = compute_hist!(*start_val, *start_row, end_row);
1122            },
1123        );
1124        results.push(last);
1125        // Now we need to compute cumulative statistics on the row layouts here.
1126        // We do this serially a we currently don't have a ton of use for cases with thousands
1127        // of timestamps or more. There are well-known parallel algorithms for computing these
1128        // cumulative statistics in parallel, but they aren't currently all that well-suited
1129        // for rayon at the moment.
1130        let mut prev_count = 0;
1131        self.offsets.clear();
1132        for stats in results.iter_mut() {
1133            if stats.count == 0 {
1134                continue;
1135            }
1136            self.offsets
1137                .push((stats.value, RowId::from_usize(prev_count)));
1138            let mut inner = prev_count;
1139            for (_, count) in stats.histogram.iter_mut() {
1140                // Each entry in the histogram now points to the start row for that shard's
1141                // rows for a given timestamp.
1142                let tmp = *count;
1143                *count = inner;
1144                inner += tmp;
1145            }
1146            prev_count += stats.count;
1147            debug_assert_eq!(inner, prev_count)
1148        }
1149
1150        // Now the part with some unsafe code.
1151        // We will iterate over each shard and use the statistics in `results` to guide where
1152        // each row will go.
1153        //
1154        // This involves doing unsynchronized writes to the table (ptr::copy_nonoverlapping)
1155        // followed by a set_len. The safety of these operations relies on the fact that:
1156        // * No one grabs a reference to the interior of `scratch` until these operations have
1157        //   finished.
1158        // * `scratch` does not overlap `data`.
1159        // * The sharding function completely partitions the set of objects in the table: one
1160        //   shard's writes will never stomp on those of another.
1161
1162        self.data.scratch.clear();
1163        self.data.scratch.reserve(prev_count);
1164        self.hash
1165            .mut_shards()
1166            .par_iter_mut()
1167            .with_max_len(1)
1168            .enumerate()
1169            .for_each(|(shard_id, shard)| {
1170                let shard_id = ShardId::from_usize(shard_id);
1171                let scratch_ptr = self.data.scratch.raw_rows();
1172                let mut progress =
1173                    HashMap::<Value /* timestamp */, RowId /* next row */>::default();
1174                progress.reserve(results.len());
1175                for stats in &results {
1176                    let Some(start) = stats.histogram.get(shard_id) else {
1177                        continue;
1178                    };
1179                    progress.insert(stats.value, RowId::from_usize(*start));
1180                }
1181                for TableEntry { row: row_id, .. } in shard.iter_mut() {
1182                    let row = self
1183                        .data
1184                        .get_row(*row_id)
1185                        .expect("shard should not map to a stale value");
1186                    let val = row[sort_by.index()];
1187                    let next = progress[&val];
1188                    // SAFETY: see above longer comment.
1189                    unsafe {
1190                        std::ptr::copy_nonoverlapping(
1191                            row.as_ptr(),
1192                            scratch_ptr.add(next.index() * self.n_columns) as *mut Value,
1193                            self.n_columns,
1194                        )
1195                    }
1196                    *row_id = next;
1197                    progress.insert(val, next.inc());
1198                }
1199            });
1200        // SAFETY: see above longer comment.
1201        unsafe { self.data.scratch.set_len(prev_count) };
1202        mem::swap(&mut self.data.data, &mut self.data.scratch);
1203        self.data.stale_rows = 0;
1204    }
1205    fn rehash_impl(
1206        sort_by: Option<ColumnId>,
1207        n_keys: usize,
1208        rows: &mut Rows,
1209        offsets: &mut Vec<(Value, RowId)>,
1210        hash: &mut ShardedHashTable<TableEntry>,
1211    ) {
1212        if let Some(sort_by) = sort_by {
1213            offsets.clear();
1214            rows.remove_stale(|row, old, new| {
1215                let stale_entry = get_entry_mut(row, n_keys, hash, |x| x == old)
1216                    .expect("non-stale entry not mapped in hash");
1217                *stale_entry = new;
1218                let sort_col = row[sort_by.index()];
1219                if let Some((max, _)) = offsets.last() {
1220                    if sort_col > *max {
1221                        offsets.push((sort_col, new));
1222                    }
1223                } else {
1224                    offsets.push((sort_col, new));
1225                }
1226            })
1227        } else {
1228            rows.remove_stale(|row, old, new| {
1229                let stale_entry = get_entry_mut(row, n_keys, hash, |x| x == old)
1230                    .expect("non-stale entry not mapped in hash");
1231                *stale_entry = new;
1232            })
1233        }
1234    }
1235
1236    fn rehash(&mut self) {
1237        self.generation = self.generation.inc();
1238        Self::rehash_impl(
1239            self.sort_by,
1240            self.n_keys,
1241            &mut self.data,
1242            &mut self.offsets,
1243            &mut self.hash,
1244        )
1245    }
1246}
1247
1248fn get_entry(
1249    row: &[Value],
1250    n_keys: usize,
1251    table: &ShardedHashTable<TableEntry>,
1252    test: impl Fn(RowId) -> bool,
1253) -> Option<RowId> {
1254    let (shard, hash) = hash_code(table.shard_data(), row, n_keys);
1255    table
1256        .get_shard(shard)
1257        .find(hash, |ent| {
1258            ent.hashcode == hash as HashCode && test(ent.row)
1259        })
1260        .map(|ent| ent.row)
1261}
1262
1263fn get_entry_mut<'a>(
1264    row: &[Value],
1265    n_keys: usize,
1266    table: &'a mut ShardedHashTable<TableEntry>,
1267    test: impl Fn(RowId) -> bool,
1268) -> Option<&'a mut RowId> {
1269    let (shard, hash) = hash_code(table.shard_data(), row, n_keys);
1270    table.mut_shards()[shard.index()]
1271        .find_mut(hash, |ent| {
1272            ent.hashcode == hash as HashCode && test(ent.row)
1273        })
1274        .map(|ent| &mut ent.row)
1275}
1276
1277fn hash_code(shard_data: ShardData, row: &[Value], n_keys: usize) -> (ShardId, u64) {
1278    let mut hasher = FxHasher::default();
1279    for val in &row[0..n_keys] {
1280        hasher.write_usize(val.index());
1281    }
1282    let full_code = hasher.finish();
1283    // We keep this cast here to allow for experimenting with HashCode=u32.
1284    #[allow(clippy::unnecessary_cast)]
1285    (shard_data.shard_id(full_code), full_code as HashCode as u64)
1286}
1287
1288/// A simple struct for packaging up pending mutations to a `SortedWritesTable`.
1289struct PendingState {
1290    pending_rows: DenseIdMap<ShardId, SegQueue<RowBuffer>>,
1291    pending_removals: DenseIdMap<ShardId, SegQueue<ArbitraryRowBuffer>>,
1292    total_removals: AtomicUsize,
1293    total_rows: AtomicUsize,
1294}
1295
1296impl PendingState {
1297    fn new(shard_data: ShardData) -> PendingState {
1298        let n_shards = shard_data.n_shards();
1299        let mut pending_rows = DenseIdMap::with_capacity(n_shards);
1300        let mut pending_removals = DenseIdMap::with_capacity(n_shards);
1301        for i in 0..n_shards {
1302            pending_rows.insert(ShardId::from_usize(i), SegQueue::default());
1303            pending_removals.insert(ShardId::from_usize(i), SegQueue::default());
1304        }
1305
1306        PendingState {
1307            pending_rows,
1308            pending_removals,
1309            total_removals: AtomicUsize::new(0),
1310            total_rows: AtomicUsize::new(0),
1311        }
1312    }
1313    fn clear(&self) {
1314        for (_, queue) in self.pending_rows.iter() {
1315            while queue.pop().is_some() {}
1316        }
1317
1318        for (_, queue) in self.pending_removals.iter() {
1319            while queue.pop().is_some() {}
1320        }
1321    }
1322
1323    /// This is only really used in debugging, but it's annoying enough to write
1324    /// that it may help to have around.
1325    ///
1326    /// We also, however, use it in the clone impl (which should only be called when pending state
1327    /// is empty).
1328    fn deep_copy(&self) -> PendingState {
1329        let mut pending_rows = DenseIdMap::new();
1330        let mut pending_removals = DenseIdMap::new();
1331        fn drain_queue<T>(queue: &SegQueue<T>) -> Vec<T> {
1332            let mut res = Vec::new();
1333            while let Some(x) = queue.pop() {
1334                res.push(x);
1335            }
1336            res
1337        }
1338        for (shard, queue) in self.pending_rows.iter() {
1339            let contents = drain_queue(queue);
1340            let new_queue = SegQueue::default();
1341            for x in contents {
1342                new_queue.push(x.clone());
1343                queue.push(x);
1344            }
1345            pending_rows.insert(shard, new_queue);
1346        }
1347
1348        for (shard, queue) in self.pending_removals.iter() {
1349            let contents = drain_queue(queue);
1350            let new_queue = SegQueue::default();
1351            for x in contents {
1352                new_queue.push(x.clone());
1353                queue.push(x);
1354            }
1355            pending_removals.insert(shard, new_queue);
1356        }
1357
1358        PendingState {
1359            pending_rows,
1360            pending_removals,
1361            total_removals: AtomicUsize::new(self.total_removals.load(Ordering::Acquire)),
1362            total_rows: AtomicUsize::new(self.total_rows.load(Ordering::Acquire)),
1363        }
1364    }
1365}
1366
1367/// A trait that encapsulates the logic of potentially checking that written
1368/// columns appear in sorted order.
1369///
1370/// For rows that are sorted by a column, an OrderingChecker asserts that all
1371/// new rows have the same value in that column, and that the column is greater
1372/// than or equal to the column value coming in. For rows not sorted, these
1373/// checks become no-ops.
1374trait OrderingChecker: Clone + Send + Sync {
1375    /// Check any invariants locally, updating the state of the checker when
1376    /// doing so.
1377    fn check_local(&mut self, row: &[Value]);
1378    /// Combine the states of multiple checkers, returning a new checker with
1379    /// all information assimilated. This is the checker that is suitable for
1380    /// calling `update_offsets` with.
1381    fn check_global<'a>(checkers: impl Iterator<Item = &'a Self>) -> Self
1382    where
1383        Self: 'a;
1384    /// Update the sorted offset vector with the current state of the checker.
1385    fn update_offsets(&self, start: RowId, offsets: &mut Vec<(Value, RowId)>);
1386}
1387
1388impl OrderingChecker for () {
1389    fn check_local(&mut self, _: &[Value]) {}
1390    fn check_global<'a>(_: impl Iterator<Item = &'a ()>) {}
1391    fn update_offsets(&self, _: RowId, _: &mut Vec<(Value, RowId)>) {}
1392}
1393
1394#[derive(Copy, Clone)]
1395struct SortChecker {
1396    col: ColumnId,
1397    baseline: Option<Value>,
1398    current: Option<Value>,
1399}
1400
1401impl OrderingChecker for SortChecker {
1402    fn check_local(&mut self, row: &[Value]) {
1403        let val = row[self.col.index()];
1404        if let Some(cur) = self.current {
1405            assert_eq!(
1406                cur, val,
1407                "concurrently inserting rows with different sort keys"
1408            );
1409        } else {
1410            self.current = Some(val);
1411            if let Some(baseline) = self.baseline {
1412                assert!(val >= baseline, "inserted row violates sort order");
1413            }
1414        }
1415    }
1416
1417    fn check_global<'a>(mut checkers: impl Iterator<Item = &'a Self>) -> Self {
1418        let Some(start) = checkers.next() else {
1419            return SortChecker {
1420                col: ColumnId::new(!0),
1421                baseline: None,
1422                current: None,
1423            };
1424        };
1425        let mut expected = start.current;
1426        for checker in checkers {
1427            assert_eq!(checker.baseline, start.baseline);
1428            match (&mut expected, checker.current) {
1429                (None, None) => {}
1430                (cur @ None, Some(x)) => {
1431                    *cur = Some(x);
1432                }
1433                (Some(_), None) => {}
1434                (Some(x), Some(y)) => {
1435                    assert_eq!(
1436                        *x, y,
1437                        "concurrently inserting rows with different sort keys"
1438                    );
1439                }
1440            }
1441        }
1442        SortChecker {
1443            col: start.col,
1444            baseline: start.baseline,
1445            current: expected,
1446        }
1447    }
1448
1449    fn update_offsets(&self, start: RowId, offsets: &mut Vec<(Value, RowId)>) {
1450        if let Some(cur) = self.current {
1451            if let Some((max, _)) = offsets.last() {
1452                if cur > *max {
1453                    offsets.push((cur, start));
1454                }
1455            } else {
1456                offsets.push((cur, start));
1457            }
1458        }
1459    }
1460}
1461
1462/// A type similar to a SortedWritesTable used to buffer outputs. The main thing
1463/// that StagedOutputs handles is running the merge function for a table on
1464/// multiple updates to the same key that show up in the same round of
1465/// insertions.
1466struct StagedOutputs {
1467    shard_data: ShardData,
1468    n_keys: usize,
1469    hash: Pooled<HashTable<TableEntry>>,
1470    rows: RowBuffer,
1471    n_stale: usize,
1472    scratch: Pooled<Vec<Value>>,
1473}
1474
1475impl StagedOutputs {
1476    fn rows(&self) -> impl Iterator<Item = &[Value]> {
1477        self.rows.iter()
1478    }
1479    fn new(n_keys: usize, n_cols: usize, capacity: usize) -> Self {
1480        let mut res = with_pool_set(|ps| StagedOutputs {
1481            shard_data: ShardData::new(1),
1482            n_keys,
1483            n_stale: 0,
1484            hash: ps.get(),
1485            rows: RowBuffer::new(n_cols),
1486            scratch: ps.get(),
1487        });
1488        res.hash.reserve(capacity, TableEntry::hashcode);
1489        res.rows.reserve(capacity);
1490        res
1491    }
1492    fn clear(&mut self) {
1493        self.hash.clear();
1494        self.rows.clear();
1495        self.n_stale = 0;
1496    }
1497    fn len(&self) -> usize {
1498        self.rows.len() - self.n_stale
1499    }
1500
1501    fn insert(
1502        &mut self,
1503        row: &[Value],
1504        mut merge_fn: impl FnMut(&[Value], &[Value], &mut Vec<Value>) -> bool,
1505    ) {
1506        if row[0].is_stale() {
1507            return;
1508        }
1509        use hashbrown::hash_table::Entry;
1510        let (_, hc) = hash_code(self.shard_data, row, self.n_keys);
1511        let entry = self.hash.entry(
1512            hc,
1513            |te| {
1514                te.hashcode() == hc
1515                    && self.rows.get_row(te.row)[0..self.n_keys] == row[0..self.n_keys]
1516            },
1517            TableEntry::hashcode,
1518        );
1519        match entry {
1520            Entry::Occupied(mut occupied_entry) => {
1521                let cur = self.rows.get_row(occupied_entry.get().row);
1522                if merge_fn(cur, row, &mut self.scratch) {
1523                    let new = self.rows.add_row(&self.scratch);
1524                    self.rows.set_stale(occupied_entry.get().row);
1525                    self.n_stale += 1;
1526                    occupied_entry.get_mut().row = new;
1527                }
1528                self.scratch.clear();
1529            }
1530            Entry::Vacant(vacant_entry) => {
1531                let next = self.rows.add_row(row);
1532                vacant_entry.insert(TableEntry {
1533                    hashcode: hc as _,
1534                    row: next,
1535                });
1536            }
1537        }
1538    }
1539
1540    /// Write the contents of the staged outputs to the given writer, returning the initial RowId
1541    /// of the new output. Returns the number of stale values in the buffer that was appended.
1542    fn write_output(&self, output: &ParallelRowBufWriter) -> (RowId, usize) {
1543        (output.append_contents(&self.rows), self.n_stale)
1544    }
1545}