egglog_core_relations/free_join/
execute.rs

1//! Core free join execution.
2
3use std::{
4    cmp, iter, mem,
5    ops::Range,
6    sync::{Arc, OnceLock, RwLock, atomic::AtomicUsize},
7};
8
9use crate::{
10    common::{HashMap, IndexMap},
11    free_join::plan::{JoinStages, MatId, MatScanMode, MatSpec},
12    numeric_id::{DenseIdMap, IdVec, NumericId},
13    query::Atom,
14    row_buffer::{RowBuffer, SmallValueVec},
15};
16use crossbeam::utils::CachePadded;
17use dashmap::mapref::entry::Entry;
18use dashmap::mapref::one::RefMut;
19use egglog_reports::{ReportLevel, RuleReport, RuleSetReport};
20use smallvec::SmallVec;
21use web_time::Instant;
22
23use crate::{
24    Constraint, OffsetRange, Pool, SubsetRef,
25    action::{Bindings, ExecutionState},
26    common::{DashMap, Value},
27    free_join::{
28        frame_update::{FrameUpdates, UpdateInstr},
29        get_index_from_tableinfo,
30    },
31    hash_index::{ColumnIndex, IndexBase, TupleIndex},
32    offsets::{Offsets, RowId, SortedOffsetSlice, SortedOffsetVector, Subset},
33    parallel_heuristics::parallelize_db_level_op,
34    pool::Pooled,
35    query::RuleSet,
36    row_buffer::TaggedRowBuffer,
37    table_spec::{ColumnId, Offset, WrappedTableRef},
38};
39
40use super::{
41    ActionId, AtomId, Database, HashColumnIndex, HashIndex, TableInfo, Variable,
42    get_column_index_from_tableinfo,
43    plan::{JoinHeader, JoinStage, Plan},
44    with_pool_set,
45};
46
47const SMALL_RESIDUAL: usize = 8;
48
49struct SparseColumnIndex {
50    n_keys: usize,
51    n_subsets: usize,
52    keys: [Value; SMALL_RESIDUAL],
53    offsets: [usize; SMALL_RESIDUAL],
54    subset_ids: [RowId; SMALL_RESIDUAL],
55}
56
57/// Return a SubsetRef for the given range of rows in a SparseColumnIndex.
58/// Single-row ranges become Dense to skip pool allocation in to_owned.
59///
60/// # Safety
61/// `ids[range]` must be sorted in non-decreasing order. The wider `ids` slice
62/// need not be sorted as a whole; only the indicated sub-range. This is the
63/// invariant of `SortedOffsetSlice::new_unchecked`.
64#[inline]
65unsafe fn sparse_subset_ref(ids: &[RowId], range: Range<usize>) -> SubsetRef<'_> {
66    if range.len() == 1 {
67        let row = ids[range.start];
68        SubsetRef::Dense(OffsetRange::new(row, row.inc()))
69    } else {
70        // SAFETY: caller guarantees `ids[range]` is sorted.
71        SubsetRef::Sparse(unsafe { SortedOffsetSlice::new_unchecked(&ids[range]) })
72    }
73}
74
75impl SparseColumnIndex {
76    fn keys(&self) -> &[Value] {
77        &self.keys[..self.n_keys]
78    }
79
80    fn get_offset_for(&self, i: usize) -> Range<usize> {
81        let lo = self.offsets[i];
82        let hi = if i + 1 < self.n_keys {
83            self.offsets[i + 1]
84        } else {
85            self.n_subsets
86        };
87        lo..hi
88    }
89
90    fn new(table: WrappedTableRef<'_>, subset: SubsetRef<'_>, col: ColumnId) -> Self {
91        let mut rows = [(Value::new_const(0), RowId::new_const(0)); SMALL_RESIDUAL];
92        let mut pos = 0;
93        table.for_each_col(subset, col, &mut |row_id, val| {
94            rows[pos] = (val, row_id);
95            pos += 1;
96        });
97        let n_subsets = pos;
98
99        rows[..pos].sort_unstable();
100
101        let mut n_keys = 0;
102        let mut keys = [Value::new_const(0); SMALL_RESIDUAL];
103        let mut offsets = [0; SMALL_RESIDUAL];
104        let mut subset_ids = [RowId::new_const(0); SMALL_RESIDUAL];
105        offsets[0] = 0;
106
107        for (i, &(key, row_id)) in rows[..n_subsets].iter().enumerate() {
108            let is_new_key = n_keys == 0 || keys[n_keys - 1] != key;
109            if is_new_key {
110                offsets[n_keys] = i;
111                keys[n_keys] = key;
112                n_keys += 1;
113            }
114            subset_ids[i] = row_id;
115        }
116
117        SparseColumnIndex {
118            n_keys,
119            n_subsets,
120            keys,
121            offsets,
122            subset_ids,
123        }
124    }
125
126    fn get_subset(&self, key: Value) -> Option<SubsetRef<'_>> {
127        if self.n_keys == 0 {
128            return None;
129        }
130        let found = self.keys().binary_search(&key).ok()?;
131        let range = self.get_offset_for(found);
132        // SAFETY: `subset_ids` was populated from rows sorted by (Value, RowId),
133        // so RowIds within any single per-key range (as returned by
134        // `get_offset_for`) are in non-decreasing order.
135        Some(unsafe { sparse_subset_ref(&self.subset_ids, range) })
136    }
137
138    fn for_each(&self, mut f: impl FnMut(&[Value], SubsetRef)) {
139        if self.n_keys == 0 {
140            return;
141        }
142        for i in 0..self.n_keys {
143            let range = self.get_offset_for(i);
144            // SAFETY: see `get_subset` — each per-key range of `subset_ids` is sorted.
145            let subset = unsafe { sparse_subset_ref(&self.subset_ids, range) };
146            f(&self.keys[i..i + 1], subset);
147        }
148    }
149
150    fn len(&self) -> usize {
151        self.n_keys
152    }
153}
154
155enum DynamicIndex {
156    Cached {
157        /// When Some(range), intersect each subset from the index with this dense range.
158        /// The range is the Dense outer subset known at Prober construction time.
159        intersect_outer: Option<OffsetRange>,
160        table: HashIndex,
161    },
162    CachedColumn {
163        /// When Some(range), intersect each subset from the index with this dense range.
164        /// The range is the Dense outer subset known at Prober construction time.
165        intersect_outer: Option<OffsetRange>,
166        table: HashColumnIndex,
167    },
168    Dynamic(TupleIndex),
169    DynamicColumn(Arc<ColumnIndex>),
170    SparseColumn(SparseColumnIndex),
171}
172
173/// This struct is used to mark subsets that can contain non-stale entries.
174/// Whether a subset can be stale depends on the type of index it came from.
175/// Indices that come from a table may contain stale entries, while
176/// those that are built on the fly will not.
177struct PotentiallyStale<T> {
178    inner: T,
179    can_be_stale: bool,
180}
181
182impl<T> PotentiallyStale<T> {
183    fn maybe_stale(inner: T) -> Self {
184        Self {
185            inner,
186            can_be_stale: true,
187        }
188    }
189
190    fn not_stale(inner: T) -> Self {
191        Self {
192            inner,
193            can_be_stale: false,
194        }
195    }
196}
197
198impl PotentiallyStale<SubsetRef<'_>> {
199    fn size(&self) -> usize {
200        self.inner.size()
201    }
202
203    fn to_owned(&self, pool: &Pool<SortedOffsetVector>) -> PotentiallyStale<Subset> {
204        PotentiallyStale {
205            inner: self.inner.to_owned(pool),
206            can_be_stale: self.can_be_stale,
207        }
208    }
209}
210
211/// Intersect a `SubsetRef` with a dense `OffsetRange` and return the result as a
212/// borrowed `SubsetRef`, or `None` if the intersection is empty.
213///
214/// This function never allocates — it borrows into
215/// the source data via `subslice`. Use this in `for_each` paths where the result
216/// may be discarded (e.g., empty after refinement), to avoid pool allocations.
217#[inline]
218fn intersect_with_dense_ref<'a>(v: SubsetRef<'a>, range: OffsetRange) -> Option<SubsetRef<'a>> {
219    match v {
220        SubsetRef::Dense(r) => {
221            let resl = cmp::max(r.start, range.start);
222            let resr = cmp::min(r.end, range.end);
223            if resl >= resr {
224                None
225            } else {
226                Some(SubsetRef::Dense(OffsetRange::new(resl, resr)))
227            }
228        }
229        SubsetRef::Sparse(s) => {
230            let l = s.binary_search_by_id(range.start);
231            let r = s.binary_search_by_id(range.end);
232            if l >= r {
233                None
234            } else {
235                Some(SubsetRef::Sparse(s.subslice(l, r)))
236            }
237        }
238    }
239}
240
241struct Prober {
242    node: Arc<TrieNode>,
243    ix: DynamicIndex,
244}
245
246impl Prober {
247    fn get_subset<'a>(&'a self, key: &'a [Value]) -> Option<PotentiallyStale<SubsetRef<'a>>> {
248        match &self.ix {
249            DynamicIndex::Cached {
250                intersect_outer,
251                table,
252            } => {
253                let subset_ref = table.get().unwrap().get_subset(key)?;
254                let subset = if let Some(range) = intersect_outer {
255                    intersect_with_dense_ref(subset_ref, *range)?
256                } else {
257                    subset_ref
258                };
259                Some(PotentiallyStale::maybe_stale(subset))
260            }
261            DynamicIndex::CachedColumn {
262                intersect_outer,
263                table,
264            } => {
265                debug_assert_eq!(key.len(), 1);
266                let subset_ref = table.get().unwrap().get_subset(&key[0])?;
267                let subset = if let Some(range) = intersect_outer {
268                    intersect_with_dense_ref(subset_ref, *range)?
269                } else {
270                    subset_ref
271                };
272                Some(PotentiallyStale::maybe_stale(subset))
273            }
274            DynamicIndex::Dynamic(tab) => tab.get_subset(key).map(PotentiallyStale::not_stale),
275            DynamicIndex::DynamicColumn(tab) => {
276                tab.get_subset(&key[0]).map(PotentiallyStale::not_stale)
277            }
278            DynamicIndex::SparseColumn(tab) => {
279                debug_assert_eq!(key.len(), 1);
280                tab.get_subset(key[0]).map(PotentiallyStale::not_stale)
281            }
282        }
283    }
284    fn for_each(&self, mut f: impl FnMut(&[Value], PotentiallyStale<SubsetRef>)) {
285        match &self.ix {
286            DynamicIndex::Cached {
287                intersect_outer: Some(range),
288                table,
289            } => {
290                let range = *range;
291                table.get().unwrap().for_each(|k, v| {
292                    if let Some(res) = intersect_with_dense_ref(v, range) {
293                        f(k, PotentiallyStale::maybe_stale(res))
294                    }
295                });
296            }
297            DynamicIndex::Cached {
298                intersect_outer: None,
299                table,
300            } => table
301                .get()
302                .unwrap()
303                .for_each(|k, v| f(k, PotentiallyStale::maybe_stale(v))),
304            DynamicIndex::CachedColumn {
305                intersect_outer: Some(range),
306                table,
307            } => {
308                let range = *range;
309                table.get().unwrap().for_each(|k, v| {
310                    if let Some(res) = intersect_with_dense_ref(v, range) {
311                        f(&[*k], PotentiallyStale::maybe_stale(res))
312                    }
313                });
314            }
315            DynamicIndex::CachedColumn {
316                intersect_outer: None,
317                table,
318            } => {
319                table
320                    .get()
321                    .unwrap()
322                    .for_each(|k, v| f(&[*k], PotentiallyStale::maybe_stale(v)));
323            }
324            DynamicIndex::Dynamic(tab) => {
325                tab.for_each(|k, v| f(k, PotentiallyStale::not_stale(v)));
326            }
327            DynamicIndex::DynamicColumn(tab) => tab.for_each(|k, v| {
328                f(&[*k], PotentiallyStale::not_stale(v));
329            }),
330            DynamicIndex::SparseColumn(tab) => {
331                tab.for_each(|k, v| f(k, PotentiallyStale::not_stale(v)));
332            }
333        }
334    }
335
336    fn len(&self) -> usize {
337        match &self.ix {
338            DynamicIndex::Cached { table, .. } => table.get().unwrap().len(),
339            DynamicIndex::CachedColumn { table, .. } => table.get().unwrap().len(),
340            DynamicIndex::Dynamic(tab) => tab.len(),
341            DynamicIndex::DynamicColumn(tab) => tab.len(),
342            DynamicIndex::SparseColumn(tab) => tab.len(),
343        }
344    }
345}
346
347impl Database {
348    pub fn run_rule_set(&mut self, rule_set: &RuleSet, report_level: ReportLevel) -> RuleSetReport {
349        if rule_set.plans.is_empty() {
350            return RuleSetReport::default();
351        }
352        let match_counter = Arc::new(MatchCounter::new(rule_set.actions.n_ids()));
353
354        let search_and_apply_timer = Instant::now();
355        // let mut rule_reports: HashMap<String, Vec<RuleReport>>;
356        let mut rule_reports: HashMap<Arc<str>, Vec<RuleReport>>;
357        let exec_state = ExecutionState::new(self.read_only_view(), Default::default());
358        if parallelize_db_level_op(self.total_size_estimate) {
359            let dash_rule_reports: Arc<DashMap<Arc<str>, Vec<RuleReport>>> =
360                Arc::new(DashMap::default());
361            let db: &Database = self;
362            rayon::in_place_scope(|scope| {
363                for (plan, desc, symbol_map) in rule_set.plans.values() {
364                    // TODO: add stats
365                    let report_plan = match report_level {
366                        ReportLevel::TimeOnly => None,
367                        ReportLevel::WithPlan | ReportLevel::StageInfo => {
368                            Some(plan.to_report(symbol_map))
369                        }
370                    };
371
372                    let dash_rule_reports = dash_rule_reports.clone();
373                    let desc = desc.clone();
374                    let exec_state = exec_state.clone();
375                    let match_counter = match_counter.clone();
376                    scope.spawn(move |rule_scope| {
377                        let join_state = JoinState::new(db, exec_state.clone());
378                        let mut binding_info = BindingInfo::default();
379                        for (id, info) in plan.atoms().iter() {
380                            let table = join_state.db.get_table(info.table);
381                            binding_info.insert_subset(id, table.all());
382                        }
383                        let mut action_buf =
384                            ScopedActionBuffer::new(rule_scope, rule_set, match_counter.clone());
385                        let search_and_apply_timer = Instant::now();
386
387                        'eval: {
388                            for JoinHeader { atom, subset, .. } in plan.header() {
389                                if subset.is_empty() {
390                                    break 'eval;
391                                }
392                                let mut cur =
393                                    Arc::try_unwrap(binding_info.unwrap_val(*atom)).unwrap();
394                                debug_assert!(cur.cached_subsets.get().is_none());
395                                cur.subset.intersect(subset.as_ref(), &join_state.pool);
396                                if cur.subset.is_empty() {
397                                    break 'eval;
398                                }
399                                binding_info.move_back_node(*atom, Arc::new(cur));
400                            }
401
402                            match plan {
403                                Plan::SinglePlan(plan) => {
404                                    join_state.run_join_stages(
405                                        &plan.stages,
406                                        &plan.atoms,
407                                        plan.actions,
408                                        &mut binding_info,
409                                        &mut action_buf,
410                                    );
411                                }
412                                Plan::DecomposedPlan(plan) => {
413                                    let mut materializations: DenseIdMap<
414                                        MatId,
415                                        Arc<DashMap<Vec<Value>, RowBuffer>>,
416                                    > = DenseIdMap::with_capacity(plan.stages.blocks.len());
417                                    for i in 0..plan.stages.blocks.len() {
418                                        materializations.insert(
419                                            MatId::from_usize(i),
420                                            Arc::new(Default::default()),
421                                        );
422                                    }
423                                    let specs: Arc<DenseIdMap<MatId, MatSpec>> = Arc::new(
424                                        plan.stages
425                                            .blocks
426                                            .iter()
427                                            .enumerate()
428                                            .map(|(i, block)| {
429                                                (MatId::from_usize(i), block.1.clone())
430                                            })
431                                            .collect(),
432                                    );
433                                    let mut materializations = Arc::new(materializations);
434
435                                    for (mat_id, stage_block) in
436                                        plan.stages.blocks.iter().enumerate()
437                                    {
438                                        let mat_id = MatId::from_usize(mat_id);
439                                        rayon::in_place_scope(|stage_scope| {
440                                            let mut materializer = ScopedMaterializer {
441                                                scope: stage_scope,
442                                                specs: specs.clone(),
443                                                materializations: materializations.clone(),
444                                                scratch_key: Default::default(),
445                                                scratch_val: Default::default(),
446                                            };
447                                            join_state.run_join_stages(
448                                                &stage_block.0,
449                                                &plan.atoms,
450                                                mat_id,
451                                                &mut binding_info,
452                                                &mut materializer,
453                                            );
454                                        });
455                                        if materializations[mat_id].is_empty() {
456                                            break 'eval;
457                                        }
458                                        assert_eq!(Arc::strong_count(&materializations), 1);
459                                        let mut materializations_dearc =
460                                            Arc::unwrap_or_clone(materializations);
461                                        let materialization = mem::take(
462                                            Arc::get_mut(&mut materializations_dearc[mat_id])
463                                                .unwrap(),
464                                        )
465                                        .into_iter()
466                                        .collect::<IndexMap<_, _>>();
467                                        binding_info
468                                            .materializations
469                                            .insert(mat_id, Arc::new(materialization));
470                                        materializations = Arc::new(materializations_dearc);
471                                    }
472                                    join_state.run_join_stages(
473                                        &plan.result_block,
474                                        &plan.atoms,
475                                        plan.actions,
476                                        &mut binding_info,
477                                        &mut action_buf,
478                                    );
479                                }
480                            }
481                        }
482                        let search_and_apply_time = search_and_apply_timer.elapsed();
483                        if action_buf.needs_flush {
484                            action_buf.flush(&mut exec_state.clone());
485                        }
486                        let mut rule_report: RefMut<'_, Arc<str>, Vec<RuleReport>> =
487                            dash_rule_reports.entry(desc).or_default();
488                        rule_report.value_mut().push(RuleReport {
489                            plan: report_plan,
490                            search_and_apply_time,
491                            num_matches: usize::MAX,
492                        });
493                    });
494                }
495            });
496            rule_reports = dash_rule_reports
497                .iter()
498                .map(|entry| (entry.key().clone(), entry.value().clone()))
499                .collect();
500        } else {
501            rule_reports = HashMap::default();
502            let join_state = JoinState::new(self, exec_state.clone());
503            // Just run all of the plans in order with a single in-place action
504            // buffer.
505            let mut action_buf = InPlaceActionBuffer {
506                rule_set,
507                match_counter: match_counter.as_ref(),
508                batches: Default::default(),
509            };
510            for (plan, desc, symbol_map) in rule_set.plans.values() {
511                let report_plan = match report_level {
512                    ReportLevel::TimeOnly => None,
513                    ReportLevel::WithPlan | ReportLevel::StageInfo => {
514                        Some(plan.to_report(symbol_map))
515                    }
516                };
517                let mut binding_info = BindingInfo::default();
518
519                for (id, info) in plan.atoms().iter() {
520                    let table = join_state.db.get_table(info.table);
521                    binding_info.insert_subset(id, table.all());
522                }
523
524                let search_and_apply_timer = Instant::now();
525                'eval: {
526                    for JoinHeader { atom, subset, .. } in plan.header() {
527                        if subset.is_empty() {
528                            break 'eval;
529                        }
530                        // Before query execution, Arc<TrieNode> is owned exclusively by the binding info.
531                        // Trie nodes are only shared once we start running the query.
532                        let mut cur = Arc::try_unwrap(binding_info.unwrap_val(*atom)).unwrap();
533                        debug_assert!(cur.cached_subsets.get().is_none());
534                        cur.subset.intersect(subset.as_ref(), &join_state.pool);
535                        if cur.subset.is_empty() {
536                            break 'eval;
537                        }
538                        binding_info.move_back_node(*atom, Arc::new(cur));
539                    }
540                    match plan {
541                        Plan::SinglePlan(plan) => {
542                            join_state.run_join_stages(
543                                &plan.stages,
544                                &plan.atoms,
545                                plan.actions,
546                                &mut binding_info,
547                                &mut action_buf,
548                            );
549                        }
550                        Plan::DecomposedPlan(plan) => {
551                            let mut materializations =
552                                DenseIdMap::with_capacity(plan.stages.blocks.len());
553                            for i in 0..plan.stages.blocks.len() {
554                                materializations.insert(MatId::from_usize(i), Default::default());
555                            }
556                            let mut materializer = InPlaceMaterializer {
557                                specs: &plan
558                                    .stages
559                                    .blocks
560                                    .iter()
561                                    .enumerate()
562                                    .map(|(i, block)| (MatId::from_usize(i), block.1.clone()))
563                                    .collect(),
564                                materializations,
565                                scratch_key: Default::default(),
566                                scratch_val: Default::default(),
567                            };
568
569                            for (mat_id, stage_block) in plan.stages.blocks.iter().enumerate() {
570                                let mat_id = MatId::from_usize(mat_id);
571                                join_state.run_join_stages(
572                                    &stage_block.0,
573                                    &plan.atoms,
574                                    mat_id,
575                                    &mut binding_info,
576                                    &mut materializer,
577                                );
578                                if materializer.materializations[mat_id].is_empty() {
579                                    break 'eval;
580                                }
581                                binding_info.materializations.insert(
582                                    mat_id,
583                                    Arc::new(materializer.materializations.take(mat_id).unwrap()),
584                                );
585                            }
586                            join_state.run_join_stages(
587                                &plan.result_block,
588                                &plan.atoms,
589                                plan.actions,
590                                &mut binding_info,
591                                &mut action_buf,
592                            );
593                        }
594                    }
595                }
596                let search_and_apply_time = search_and_apply_timer.elapsed();
597
598                // TODO: unnecessary cloning in many cases
599                let rule_report = rule_reports.entry(desc.clone()).or_default();
600                rule_report.push(RuleReport {
601                    plan: report_plan,
602                    search_and_apply_time,
603                    num_matches: usize::MAX,
604                });
605            }
606            action_buf.flush(&mut exec_state.clone());
607        }
608
609        for (plan, desc, _symbol_map) in rule_set.plans.values() {
610            let reports = rule_reports.get_mut(desc).unwrap();
611            let i = reports
612                .iter()
613                // HACK: Since the order of visiting queries is fixed and # matches need to be obtained
614                // seperately from rule execution, we first set all # matches to be usize::MAX and then fill
615                // them in one by one.
616                .position(|r| r.num_matches == usize::MAX)
617                .unwrap();
618            // NB: This requires each action ID correspond to only one query.
619            // If an action is used by multiple queries, then we can't tell how many matches are
620            // caused by individual queries.
621            reports[i].num_matches = match_counter.read_matches(plan.actions());
622        }
623        let search_and_apply_time = search_and_apply_timer.elapsed();
624
625        let merge_timer = Instant::now();
626        let changed = self.merge_all();
627        let merge_time = merge_timer.elapsed();
628
629        RuleSetReport {
630            changed,
631            rule_reports,
632            search_and_apply_time,
633            merge_time,
634        }
635    }
636}
637
638struct ActionState {
639    n_runs: usize,
640    len: usize,
641    bindings: Bindings,
642}
643
644impl Default for ActionState {
645    fn default() -> Self {
646        Self {
647            n_runs: 0,
648            len: 0,
649            bindings: Bindings::new(VAR_BATCH_SIZE),
650        }
651    }
652}
653
654struct JoinState<'a> {
655    db: &'a Database,
656    exec_state: ExecutionState<'a>,
657    /// Cached thread-local pool for SortedOffsetVector allocations.
658    /// Stored here to avoid a per-call `with_pool_set` TLS access in `get_index`.
659    pool: Pool<SortedOffsetVector>,
660}
661
662/// Per-column indexes on a trie node's subset, lazily initialized on first access per column.
663type ColumnIndexes = IdVec<ColumnId, OnceLock<Arc<ColumnIndex>>>;
664// Each TrieNode is probed with exactly one column in practice, so we store a single
665// (ColumnId, map) pair instead of a per-column IdVec of Mutexes. Boxed to keep
666// TrieNode size small for the many short-lived TrieNodes that never need caching.
667type ChildrenMaps = IdVec<ColumnId, RwLock<HashMap<Value, Arc<TrieNode>>>>;
668
669/// Information about the current subset of an atom's relation that is being considered, along with
670/// lazily-initialized, cached indexes on that subset.
671///
672/// This is the standard trie-node used in lazy implementations of GJ as in the original egglog
673/// implementation and the FJ paper. It currently does not handle non-column indexes, but that
674/// should be a fairly straightforward extension if we start generating plans that need those.
675/// (Right now, most plans iterating over more than one column just do a scan anyway).
676pub(crate) struct TrieNode {
677    /// The actual subset of the corresponding atom.
678    subset: Subset,
679    /// Any cached indexes on this subset.
680    cached_subsets: OnceLock<Pooled<ColumnIndexes>>,
681    /// Cached child trie nodes, keyed by value. In practice each TrieNode is
682    /// only ever probed with a single column, so we store one (col, map) pair
683    /// instead of an IdVec across all columns.
684    cached_children: OnceLock<Pooled<ChildrenMaps>>,
685}
686
687impl std::fmt::Debug for TrieNode {
688    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
689        f.debug_struct("TrieNode")
690            .field("subset", &self.subset)
691            .finish()
692    }
693}
694
695impl TrieNode {
696    fn new(subset: Subset) -> Self {
697        Self {
698            subset,
699            cached_subsets: Default::default(),
700            cached_children: Default::default(),
701        }
702    }
703
704    fn size(&self) -> usize {
705        self.subset.size()
706    }
707    fn get_cached_index(&self, col: ColumnId, info: &TableInfo) -> Arc<ColumnIndex> {
708        self.cached_subsets.get_or_init(|| {
709            // Pre-size the vector so we do not need to borrow it mutably to initialize the index.
710            let mut vec: Pooled<ColumnIndexes> = with_pool_set(|ps| ps.get());
711            vec.resize_with(info.spec.arity(), OnceLock::new);
712            vec
713        })[col]
714            .get_or_init(|| {
715                let col_index = info.table.group_by_col(self.subset.as_ref(), col);
716                Arc::new(col_index)
717            })
718            .clone()
719    }
720
721    fn get_cached_trie_node(
722        &self,
723        col: ColumnId,
724        value: Value,
725        info: &TableInfo,
726        sub: impl FnOnce() -> Subset,
727    ) -> Arc<TrieNode> {
728        let map = &self.cached_children.get_or_init(|| {
729            let mut vec: Pooled<ChildrenMaps> = with_pool_set(|ps| ps.get());
730            vec.resize_with(info.spec.arity(), || RwLock::new(HashMap::default()));
731            vec
732        })[col];
733        // Optimistic read path: most calls are cache hits, so try shared lock first.
734        {
735            let guard = map.read().unwrap();
736            if let Some(node) = guard.get(&value) {
737                return node.clone();
738            }
739        }
740        // Cache miss: acquire write lock and insert.
741        let mut guard = map.write().unwrap();
742        // Double-check in case another thread inserted while we were waiting.
743        if let Some(node) = guard.get(&value) {
744            return node.clone();
745        }
746        let new_node = Arc::new(TrieNode::new(sub()));
747        guard.insert(value, new_node.clone());
748        new_node
749    }
750}
751
752impl FrameUpdates {
753    /// Refine `atom` to `subset`, using the dense fast path to avoid an
754    /// `Arc<TrieNode>` allocation when the subset is already a contiguous range.
755    fn refine_atom_subset(&mut self, atom: AtomId, subset: Subset) {
756        match subset {
757            Subset::Dense(range) => self.refine_atom_dense(atom, range),
758            sub => self.refine_atom(atom, Arc::new(TrieNode::new(sub))),
759        }
760    }
761}
762
763type BindingSet = Vec<(SmallVec<[Variable; 4]>, Arc<TaggedRowBuffer<SmallValueVec>>)>;
764
765#[derive(Default, Clone)]
766struct BindingInfo {
767    bindings: DenseIdMap<Variable, Value>,
768    binding_sets: BindingSet,
769    subsets: DenseIdMap<AtomId, Arc<TrieNode>>,
770    materializations: DenseIdMap<MatId, Arc<IndexMap<Vec<Value>, RowBuffer>>>,
771}
772
773impl BindingInfo {
774    /// Initializes the atom-related metadata in the [`BindingInfo`].    
775    fn insert_subset(&mut self, atom: AtomId, subset: Subset) {
776        if let Some(slot) = self.subsets.get_mut(atom)
777            && let Some(node) = Arc::get_mut(slot)
778        {
779            node.cached_subsets.take();
780            node.cached_children.take();
781            node.subset = subset;
782            return;
783        }
784        self.subsets.insert(atom, Arc::new(TrieNode::new(subset)));
785    }
786
787    fn insert_node(&mut self, atom: AtomId, node: Arc<TrieNode>) {
788        self.subsets.insert(atom, node);
789    }
790
791    /// Probers returned from [`JoinState::get_index`] will move atom-related state out of the
792    /// [`BindingInfo`]. Once the caller is done using a prober, this method moves it back.
793    fn move_back(&mut self, atom: AtomId, prober: Prober) {
794        self.subsets.insert(atom, prober.node);
795    }
796
797    fn move_back_node(&mut self, atom: AtomId, node: Arc<TrieNode>) {
798        self.subsets.insert(atom, node);
799    }
800
801    fn has_empty_subset(&self, atom: AtomId) -> bool {
802        self.subsets[atom].subset.is_empty()
803    }
804
805    fn unwrap_val(&mut self, atom: AtomId) -> Arc<TrieNode> {
806        self.subsets.unwrap_val(atom)
807    }
808}
809
810impl<'a> JoinState<'a> {
811    fn new(db: &'a Database, exec_state: ExecutionState<'a>) -> Self {
812        Self {
813            db,
814            exec_state,
815            pool: with_pool_set(|ps| ps.get_pool()),
816        }
817    }
818
819    fn get_index(
820        &self,
821        atoms: &Arc<DenseIdMap<AtomId, Atom>>,
822        atom: AtomId,
823        binding_info: &mut BindingInfo,
824        cols: impl Iterator<Item = ColumnId>,
825    ) -> Prober {
826        let cols = SmallVec::<[ColumnId; 4]>::from_iter(cols);
827        let trie_node = binding_info.subsets.unwrap_val(atom);
828        let subset = &trie_node.subset;
829
830        let table_id = atoms[atom].table;
831        let info = &self.db.tables[table_id];
832        let dyn_index = if subset.size() <= SMALL_RESIDUAL && cols.len() == 1 {
833            DynamicIndex::SparseColumn(SparseColumnIndex::new(
834                info.table.as_ref(),
835                subset.as_ref(),
836                cols[0],
837            ))
838        } else {
839            let all_cacheable = cols.iter().all(|col| {
840                !info
841                    .spec
842                    .uncacheable_columns
843                    .get(*col)
844                    .copied()
845                    .unwrap_or(false)
846            });
847            let whole_table = info.table.all();
848            if let Subset::Dense(range) = subset
849                && all_cacheable
850                && whole_table.size() / 2 < subset.size()
851            {
852                // Skip intersecting with the subset if we are just looking at the
853                // whole table.
854                let needs_intersect =
855                    !(whole_table.is_dense() && subset.bounds() == whole_table.bounds());
856                // When intersecting, store the Dense range directly so we can do a
857                // combined copy+filter without a runtime match on subset type later.
858                let intersect_outer = if needs_intersect { Some(*range) } else { None };
859                // heuristic: if the subset we are scanning is somewhat
860                // large _or_ it is most of the table, or we already have a cached
861                // index for it, then return it.
862                if cols.len() != 1 {
863                    DynamicIndex::Cached {
864                        intersect_outer,
865                        table: get_index_from_tableinfo(info, &cols),
866                    }
867                } else {
868                    DynamicIndex::CachedColumn {
869                        intersect_outer,
870                        table: get_column_index_from_tableinfo(info, cols[0]).clone(),
871                    }
872                }
873            } else if cols.len() != 1 {
874                // NB: we should have a caching strategy for non-column indexes.
875                DynamicIndex::Dynamic(info.table.group_by_key(subset.as_ref(), &cols))
876            } else {
877                DynamicIndex::DynamicColumn(trie_node.get_cached_index(cols[0], info))
878            }
879        };
880        Prober {
881            node: trie_node,
882            ix: dyn_index,
883        }
884    }
885    fn get_column_index(
886        &self,
887        atoms: &Arc<DenseIdMap<AtomId, Atom>>,
888        binding_info: &mut BindingInfo,
889        atom: AtomId,
890        col: ColumnId,
891    ) -> Prober {
892        self.get_index(atoms, atom, binding_info, iter::once(col))
893    }
894
895    /// Runs the free join plan, starting with the header.
896    ///
897    /// A bit about the `instr_order` parameter: This defines the order in which the [`JoinStage`]
898    /// instructions will run. We want to support cached [`SinglePlan`]s that may be based on stale
899    /// ordering information. `instr_order` allows us to specify a new ordering of the instructions
900    /// without mutating the plan itself: `run_plan` simply executes
901    /// `plan.stages.instrs[instr_order[i]]` at stage `i`.
902    ///
903    /// This is also a stepping stone towards supporting fully dynamic variable ordering.
904    fn run_join_stages<'buf, A: NumericId + 'buf, BUF: ActionBuffer<'buf, A>>(
905        &self,
906        stages: &'buf JoinStages,
907        atoms: &'buf Arc<DenseIdMap<AtomId, Atom>>,
908        action: A,
909        binding_info: &mut BindingInfo,
910        action_buf: &mut BUF,
911    ) where
912        'a: 'buf,
913    {
914        if log::log_enabled!(log::Level::Debug) {
915            log::debug!("Starting running query stages:\n{stages:#?}");
916        }
917        for (_, node) in binding_info.subsets.iter() {
918            if node.subset.is_empty() {
919                return;
920            }
921        }
922        let mut order = InstrOrder::from_iter(0..stages.instrs.len());
923        sort_plan_by_size(&mut order, 0, &stages.instrs, binding_info);
924        self.run_plan(
925            stages,
926            atoms,
927            action,
928            &mut order,
929            0,
930            binding_info,
931            action_buf,
932        );
933    }
934
935    /// The core method for executing a free join plan.
936    ///
937    /// This method takes the plan, mutable data-structures for variable binding and staging
938    /// actions, and two indexes: `cur` which is the current stage of the plan to run, and `level`
939    /// which is the current "fan-out" node we are in. The latter parameter is an experimental
940    /// index used to detect if we are at the "top" of a plan rather than the "bottom", and is
941    /// currently used as a heuristic to determine if we should increase parallelism more than the
942    /// default.
943    #[allow(clippy::too_many_arguments)]
944    fn run_plan<'buf, A: NumericId + 'buf, BUF: ActionBuffer<'buf, A>>(
945        &self,
946        stages: &'buf JoinStages,
947        atoms: &'buf Arc<DenseIdMap<AtomId, Atom>>,
948        action: A,
949        instr_order: &mut InstrOrder,
950        cur: usize,
951        binding_info: &mut BindingInfo,
952        action_buf: &mut BUF,
953    ) where
954        'a: 'buf,
955    {
956        if self.exec_state.should_stop() {
957            return;
958        }
959
960        if cur >= instr_order.len() {
961            action_buf.push_bindings_factorized(
962                action,
963                &mut binding_info.bindings,
964                &binding_info.binding_sets,
965                &self.exec_state,
966            );
967            return;
968        }
969        let chunk_size = action_buf.morsel_size(cur, instr_order.len());
970        let mut cur_size = estimate_size(&stages.instrs[instr_order.get(cur)], binding_info);
971        if cur_size > 32 && cur % 3 == 1 && cur < instr_order.len() - 1 {
972            // If we have a reasonable number of tuples to process, adjust the variable order every
973            // 3 rounds, but always make sure to readjust on the second roung.
974            sort_plan_by_size(instr_order, cur, &stages.instrs, binding_info);
975            cur_size = estimate_size(&stages.instrs[instr_order.get(cur)], binding_info);
976        }
977
978        // Helper macro (not its own method to appease the borrow checker).
979        macro_rules! drain_updates {
980            ($updates:expr) => {
981                if self.exec_state.should_stop() {
982                    return;
983                }
984                // TODO: `supports_parallel_drain`` is a hack because currently
985                // `drain_updates_parallel!`` is a bit slower because of the additional ExecutionState clone.
986                if (cur == 0 || cur == 1) && action_buf.supports_parallel_drain() {
987                    drain_updates_parallel!($updates)
988                } else {
989                    $updates.drain(|update| match update {
990                        UpdateInstr::PushBinding(var, val) => {
991                            binding_info.bindings.insert(var, val);
992                        }
993                        UpdateInstr::RefineAtom(atom, subset) => {
994                            binding_info.insert_node(atom, subset);
995                        }
996                        UpdateInstr::RefineAtomDense(atom, range) => {
997                            binding_info.insert_subset(atom, Subset::Dense(range));
998                        }
999                        UpdateInstr::EndFrame => {
1000                            // Inline leaf-level: if cur+1 is the leaf (no more
1001                            // join stages), call push_bindings directly without
1002                            // a recursive run_plan call, avoiding function call
1003                            // overhead + an extra should_stop() check.
1004                            if cur + 1 >= instr_order.len() {
1005                                action_buf.push_bindings_factorized(
1006                                    action,
1007                                    &mut binding_info.bindings,
1008                                    &binding_info.binding_sets,
1009                                    &self.exec_state,
1010                                );
1011                            } else {
1012                                self.run_plan(
1013                                    stages,
1014                                    atoms,
1015                                    action,
1016                                    instr_order,
1017                                    cur + 1,
1018                                    binding_info,
1019                                    action_buf,
1020                                );
1021                            }
1022                        }
1023                    })
1024                }
1025            };
1026        }
1027        macro_rules! drain_updates_parallel {
1028            ($updates:expr) => {{
1029                if self.exec_state.should_stop() {
1030                    return;
1031                }
1032                let db = self.db;
1033                let exec_state_for_factory = self.exec_state.clone();
1034                let exec_state_for_work = self.exec_state.clone();
1035                action_buf.recur(
1036                    BorrowedLocalState {
1037                        binding_info,
1038                        instr_order,
1039                        updates: &mut $updates,
1040                    },
1041                    move || exec_state_for_factory.clone(),
1042                    move |BorrowedLocalState {
1043                              binding_info,
1044                              instr_order,
1045                              updates,
1046                          },
1047                          buf| {
1048                        updates.drain(|update| match update {
1049                            UpdateInstr::PushBinding(var, val) => {
1050                                binding_info.bindings.insert(var, val);
1051                            }
1052                            UpdateInstr::RefineAtom(atom, subset) => {
1053                                binding_info.insert_node(atom, subset);
1054                            }
1055                            UpdateInstr::RefineAtomDense(atom, range) => {
1056                                binding_info.insert_subset(atom, Subset::Dense(range));
1057                            }
1058                            UpdateInstr::EndFrame => {
1059                                JoinState {
1060                                    db,
1061                                    exec_state: exec_state_for_work.clone(),
1062                                    // Each rayon task uses its own thread-local pool.
1063                                    // This makes drain_updates_parallel slightly more expensive
1064                                    // than drain_updates eevn when both are run in single thread
1065                                    pool: with_pool_set(|ps| ps.get_pool()),
1066                                }
1067                                .run_plan(
1068                                    stages,
1069                                    atoms,
1070                                    action,
1071                                    instr_order,
1072                                    cur + 1,
1073                                    binding_info,
1074                                    buf,
1075                                );
1076                            }
1077                        })
1078                    },
1079                );
1080                $updates.clear();
1081            }};
1082        }
1083
1084        fn refine_subset(
1085            sub: PotentiallyStale<Subset>,
1086            constraints: &[Constraint],
1087            table: &WrappedTableRef,
1088            has_stale: bool,
1089        ) -> Subset {
1090            let sub = if sub.can_be_stale && has_stale {
1091                table.refine_live(sub.inner)
1092            } else {
1093                sub.inner
1094            };
1095            if constraints.is_empty() {
1096                sub
1097            } else {
1098                table.refine(sub, constraints)
1099            }
1100        }
1101
1102        let pool = &self.pool;
1103        match &stages.instrs[instr_order.get(cur)] {
1104            JoinStage::Intersect { var, scans } => match scans.as_slice() {
1105                [] => {}
1106                [a] => {
1107                    if binding_info.has_empty_subset(a.atom) {
1108                        return;
1109                    }
1110                    let prober = self.get_column_index(atoms, binding_info, a.atom, a.column);
1111                    let info = &self.db.tables[atoms[a.atom].table];
1112                    let table = info.table.as_ref();
1113                    let has_stale = table.has_stale_rows();
1114                    let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
1115                    prober.for_each(|val, x| {
1116                        updates.push_binding(*var, val[0]);
1117                        if x.size() <= 16 {
1118                            let sub = refine_subset(x.to_owned(pool), &a.cs, &table, has_stale);
1119                            if sub.is_empty() {
1120                                updates.rollback();
1121                                return;
1122                            }
1123                            updates.refine_atom_subset(a.atom, sub);
1124                        } else {
1125                            let node =
1126                                prober
1127                                    .node
1128                                    .get_cached_trie_node(a.column, val[0], info, || {
1129                                        refine_subset(x.to_owned(pool), &a.cs, &table, has_stale)
1130                                    });
1131                            if node.subset.is_empty() {
1132                                updates.rollback();
1133                                return;
1134                            }
1135                            updates.refine_atom(a.atom, node);
1136                        }
1137                        updates.finish_frame();
1138                        if updates.frames() >= chunk_size {
1139                            drain_updates!(updates);
1140                        }
1141                    });
1142                    drain_updates!(updates);
1143                    binding_info.move_back(a.atom, prober);
1144                }
1145                [a, b] => {
1146                    let a_prober = self.get_column_index(atoms, binding_info, a.atom, a.column);
1147                    let b_prober = self.get_column_index(atoms, binding_info, b.atom, b.column);
1148
1149                    let ((smaller, smaller_scan), (larger, larger_scan)) =
1150                        if a_prober.len() < b_prober.len() {
1151                            ((&a_prober, a), (&b_prober, b))
1152                        } else {
1153                            ((&b_prober, b), (&a_prober, a))
1154                        };
1155
1156                    let smaller_atom = smaller_scan.atom;
1157                    let larger_atom = larger_scan.atom;
1158                    let large_info = &self.db.tables[atoms[larger_atom].table];
1159                    let large_table = large_info.table.as_ref();
1160                    let large_has_stale = large_table.has_stale_rows();
1161                    let small_info = &self.db.tables[atoms[smaller_atom].table];
1162                    let small_table = small_info.table.as_ref();
1163                    let small_has_stale = small_table.has_stale_rows();
1164                    let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
1165                    smaller.for_each(|val, small_sub| {
1166                        if let Some(large_sub) = larger.get_subset(val) {
1167                            updates.push_binding(*var, val[0]);
1168                            if small_sub.size() <= 16 {
1169                                let small_sub = refine_subset(
1170                                    small_sub.to_owned(pool),
1171                                    &smaller_scan.cs,
1172                                    &small_table,
1173                                    small_has_stale,
1174                                );
1175                                if small_sub.is_empty() {
1176                                    updates.rollback();
1177                                    return;
1178                                }
1179                                updates.refine_atom_subset(smaller_atom, small_sub);
1180                            } else {
1181                                let smaller_node = smaller.node.get_cached_trie_node(
1182                                    smaller_scan.column,
1183                                    val[0],
1184                                    small_info,
1185                                    || {
1186                                        refine_subset(
1187                                            small_sub.to_owned(pool),
1188                                            &smaller_scan.cs,
1189                                            &small_table,
1190                                            small_has_stale,
1191                                        )
1192                                    },
1193                                );
1194                                if smaller_node.subset.is_empty() {
1195                                    updates.rollback();
1196                                    return;
1197                                }
1198                                updates.refine_atom(smaller_atom, smaller_node);
1199                            }
1200                            if large_sub.size() <= 16 {
1201                                let large_sub = refine_subset(
1202                                    large_sub.to_owned(pool),
1203                                    &larger_scan.cs,
1204                                    &large_table,
1205                                    large_has_stale,
1206                                );
1207                                if large_sub.is_empty() {
1208                                    updates.rollback();
1209                                    return;
1210                                }
1211                                updates.refine_atom_subset(larger_atom, large_sub);
1212                            } else {
1213                                let larger_node = larger.node.get_cached_trie_node(
1214                                    larger_scan.column,
1215                                    val[0],
1216                                    large_info,
1217                                    || {
1218                                        refine_subset(
1219                                            large_sub.to_owned(pool),
1220                                            &larger_scan.cs,
1221                                            &large_table,
1222                                            large_has_stale,
1223                                        )
1224                                    },
1225                                );
1226                                if larger_node.subset.is_empty() {
1227                                    updates.rollback();
1228                                    return;
1229                                }
1230                                updates.refine_atom(larger_atom, larger_node);
1231                            }
1232                            updates.finish_frame();
1233                            if updates.frames() >= chunk_size {
1234                                drain_updates!(updates);
1235                            }
1236                        }
1237                    });
1238                    drain_updates!(updates);
1239
1240                    binding_info.move_back(a.atom, a_prober);
1241                    binding_info.move_back(b.atom, b_prober);
1242                }
1243                rest => {
1244                    let mut smallest = 0;
1245                    let mut smallest_size = usize::MAX;
1246                    let mut probers = Vec::with_capacity(rest.len());
1247                    for (i, scan) in rest.iter().enumerate() {
1248                        let prober =
1249                            self.get_column_index(atoms, binding_info, scan.atom, scan.column);
1250                        let size = prober.len();
1251                        if size < smallest_size {
1252                            smallest = i;
1253                            smallest_size = size;
1254                        }
1255                        probers.push(prober);
1256                    }
1257
1258                    let main_spec = &rest[smallest];
1259                    let main_spec_info = &self.db.tables[atoms[main_spec.atom].table];
1260                    let main_spec_table = main_spec_info.table.as_ref();
1261                    let main_spec_has_stale = main_spec_table.has_stale_rows();
1262                    // Pre-compute has_stale for each scan to avoid vtable calls in the hot loop.
1263                    let rest_has_stale: SmallVec<[bool; 3]> = rest
1264                        .iter()
1265                        .map(|scan| {
1266                            self.db.tables[atoms[scan.atom].table]
1267                                .table
1268                                .as_ref()
1269                                .has_stale_rows()
1270                        })
1271                        .collect();
1272
1273                    if smallest_size != 0 {
1274                        // Smallest leads the scan
1275                        let mut updates =
1276                            FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
1277                        probers[smallest].for_each(|key, sub| {
1278                            updates.push_binding(*var, key[0]);
1279                            for (i, scan) in rest.iter().enumerate() {
1280                                if i == smallest {
1281                                    continue;
1282                                }
1283                                if let Some(sub) = probers[i].get_subset(key) {
1284                                    let table =
1285                                        self.db.tables[atoms[rest[i].atom].table].table.as_ref();
1286                                    if sub.size() <= 16 {
1287                                        let sub = refine_subset(
1288                                            sub.to_owned(pool),
1289                                            &rest[i].cs,
1290                                            &table,
1291                                            rest_has_stale[i],
1292                                        );
1293                                        if sub.is_empty() {
1294                                            updates.rollback();
1295                                            return;
1296                                        }
1297                                        updates.refine_atom_subset(scan.atom, sub);
1298                                    } else {
1299                                        let node = probers[i].node.get_cached_trie_node(
1300                                            scan.column,
1301                                            key[0],
1302                                            &self.db.tables[atoms[scan.atom].table],
1303                                            || {
1304                                                refine_subset(
1305                                                    sub.to_owned(pool),
1306                                                    &rest[i].cs,
1307                                                    &table,
1308                                                    rest_has_stale[i],
1309                                                )
1310                                            },
1311                                        );
1312                                        if node.subset.is_empty() {
1313                                            updates.rollback();
1314                                            return;
1315                                        }
1316                                        updates.refine_atom(scan.atom, node);
1317                                    }
1318                                } else {
1319                                    updates.rollback();
1320                                    // Empty intersection.
1321                                    return;
1322                                }
1323                            }
1324                            if sub.size() <= 16 {
1325                                let main_sub = refine_subset(
1326                                    sub.to_owned(pool),
1327                                    &main_spec.cs,
1328                                    &main_spec_table,
1329                                    main_spec_has_stale,
1330                                );
1331                                if main_sub.is_empty() {
1332                                    updates.rollback();
1333                                    return;
1334                                }
1335                                updates.refine_atom_subset(main_spec.atom, main_sub);
1336                            } else {
1337                                let main_node = probers[smallest].node.get_cached_trie_node(
1338                                    main_spec.column,
1339                                    key[0],
1340                                    main_spec_info,
1341                                    || {
1342                                        let sub = sub.to_owned(pool);
1343                                        refine_subset(
1344                                            sub,
1345                                            &main_spec.cs,
1346                                            &main_spec_table,
1347                                            main_spec_has_stale,
1348                                        )
1349                                    },
1350                                );
1351                                if main_node.subset.is_empty() {
1352                                    updates.rollback();
1353                                    return;
1354                                }
1355                                updates.refine_atom(main_spec.atom, main_node);
1356                            }
1357                            updates.finish_frame();
1358                            if updates.frames() >= chunk_size {
1359                                drain_updates!(updates);
1360                            }
1361                        });
1362                        drain_updates!(updates);
1363                    }
1364                    for (spec, prober) in rest.iter().zip(probers.into_iter()) {
1365                        binding_info.move_back(spec.atom, prober);
1366                    }
1367                }
1368            },
1369            JoinStage::FusedIntersect {
1370                cover,
1371                bind,
1372                to_intersect,
1373                is_leaf_scan: true,
1374            } if to_intersect.is_empty() => {
1375                let cover_atom = cover.to_index.atom;
1376                if binding_info.has_empty_subset(cover_atom) {
1377                    return;
1378                }
1379                let table = self.db.tables[atoms[cover_atom].table].table.as_ref();
1380                let cover_node = binding_info.unwrap_val(cover_atom);
1381                let cover_subset = cover_node.subset.as_ref();
1382
1383                let proj = SmallVec::<[ColumnId; 4]>::from_iter(bind.iter().map(|(col, _)| *col));
1384                let vars = bind.iter().map(|(_, var)| *var).collect();
1385                let mut buf = TaggedRowBuffer::new_inline(bind.len());
1386                table.scan_project(
1387                    cover_subset,
1388                    &proj,
1389                    Offset::new(0),
1390                    usize::MAX,
1391                    &cover.constraints,
1392                    &mut buf,
1393                );
1394
1395                if buf.is_empty() {
1396                    return;
1397                }
1398
1399                binding_info.binding_sets.push((vars, Arc::new(buf)));
1400                let mut updates = FrameUpdates::with_capacity(1);
1401                updates.finish_frame();
1402                drain_updates!(updates);
1403                binding_info.binding_sets.pop();
1404                binding_info.move_back_node(cover_atom, cover_node);
1405            }
1406            JoinStage::FusedIntersect {
1407                cover,
1408                bind,
1409                to_intersect,
1410                is_leaf_scan: false,
1411            } if to_intersect.is_empty() => {
1412                let cover_atom = cover.to_index.atom;
1413                if binding_info.has_empty_subset(cover_atom) {
1414                    return;
1415                }
1416                let proj = SmallVec::<[ColumnId; 4]>::from_iter(bind.iter().map(|(col, _)| *col));
1417                let cover_node = binding_info.unwrap_val(cover_atom);
1418                let cover_subset = cover_node.subset.as_ref();
1419                let mut cur = Offset::new(0);
1420                let mut buffer = TaggedRowBuffer::new(bind.len());
1421                let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
1422                loop {
1423                    buffer.clear();
1424                    let table = &self.db.tables[atoms[cover_atom].table].table;
1425                    let next = table.scan_project(
1426                        cover_subset,
1427                        &proj,
1428                        cur,
1429                        chunk_size,
1430                        &cover.constraints,
1431                        &mut buffer,
1432                    );
1433                    for (row, key) in buffer.iter() {
1434                        updates.refine_atom_dense(cover_atom, OffsetRange::new(row, row.inc()));
1435                        // bind the values
1436                        for (i, (_, var)) in bind.iter().enumerate() {
1437                            updates.push_binding(*var, key[i]);
1438                        }
1439                        updates.finish_frame();
1440                        if updates.frames() >= chunk_size {
1441                            drain_updates!(updates);
1442                        }
1443                    }
1444                    if let Some(next) = next {
1445                        cur = next;
1446                        continue;
1447                    }
1448                    break;
1449                }
1450                drain_updates!(updates);
1451                // Restore the subsets we swapped out.
1452                binding_info.move_back_node(cover_atom, cover_node);
1453            }
1454            JoinStage::FusedIntersect {
1455                cover,
1456                bind,
1457                to_intersect,
1458                is_leaf_scan: _,
1459            } => {
1460                let cover_atom = cover.to_index.atom;
1461                if binding_info.has_empty_subset(cover_atom) {
1462                    return;
1463                }
1464                let index_probers = to_intersect
1465                    .iter()
1466                    .enumerate()
1467                    .map(|(i, (spec, _))| {
1468                        (
1469                            i,
1470                            spec.to_index.atom,
1471                            self.get_index(
1472                                atoms,
1473                                spec.to_index.atom,
1474                                binding_info,
1475                                spec.to_index.vars.iter().copied(),
1476                            ),
1477                        )
1478                    })
1479                    .collect::<SmallVec<[(usize, AtomId, Prober); 4]>>();
1480                // Pre-compute has_stale per prober to avoid vtable calls in the hot loop.
1481                let index_has_stale: SmallVec<[bool; 4]> = index_probers
1482                    .iter()
1483                    .map(|(_, atom, _)| {
1484                        self.db.tables[atoms[*atom].table]
1485                            .table
1486                            .as_ref()
1487                            .has_stale_rows()
1488                    })
1489                    .collect();
1490                let proj = SmallVec::<[ColumnId; 4]>::from_iter(bind.iter().map(|(col, _)| *col));
1491                let cover_node = binding_info.unwrap_val(cover_atom);
1492                let cover_subset = cover_node.subset.as_ref();
1493                let mut cur = Offset::new(0);
1494                let mut buffer = TaggedRowBuffer::new(bind.len());
1495                let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
1496                loop {
1497                    buffer.clear();
1498                    let table = &self.db.tables[atoms[cover_atom].table].table;
1499                    let next = table.scan_project(
1500                        cover_subset,
1501                        &proj,
1502                        cur,
1503                        chunk_size,
1504                        &cover.constraints,
1505                        &mut buffer,
1506                    );
1507                    'mid: for (row, key) in buffer.iter() {
1508                        updates.refine_atom_dense(cover_atom, OffsetRange::new(row, row.inc()));
1509                        // bind the values
1510                        for (i, (_, var)) in bind.iter().enumerate() {
1511                            updates.push_binding(*var, key[i]);
1512                        }
1513                        // now probe each remaining indexes
1514                        for (prober_idx, (i, atom, prober)) in index_probers.iter().enumerate() {
1515                            // create a key: to_intersect indexes into the key from the cover
1516                            let index_cols = &to_intersect[*i].1;
1517                            // Fast path for the common single-column case: avoid SmallVec collect.
1518                            let index_key_buf: SmallVec<[Value; 4]>;
1519                            let index_key: &[Value] = if let [col] = index_cols.as_slice() {
1520                                std::slice::from_ref(&key[col.index()])
1521                            } else {
1522                                index_key_buf =
1523                                    index_cols.iter().map(|col| key[col.index()]).collect();
1524                                &index_key_buf
1525                            };
1526                            let Some(subset) = prober.get_subset(index_key) else {
1527                                updates.rollback();
1528                                // There are no possible values for this subset
1529                                continue 'mid;
1530                            };
1531                            // apply any constraints needed in this scan.
1532                            let table_info = &self.db.tables[atoms[*atom].table];
1533                            let cs = &to_intersect[*i].0.constraints;
1534                            let subset = refine_subset(
1535                                subset.to_owned(pool),
1536                                cs,
1537                                &table_info.table.as_ref(),
1538                                index_has_stale[prober_idx],
1539                            );
1540                            if subset.is_empty() {
1541                                updates.rollback();
1542                                // There are no possible values for this subset
1543                                continue 'mid;
1544                            }
1545                            updates.refine_atom_subset(*atom, subset);
1546                        }
1547                        updates.finish_frame();
1548                        if updates.frames() >= chunk_size {
1549                            drain_updates!(updates);
1550                        }
1551                    }
1552                    if let Some(next) = next {
1553                        cur = next;
1554                        continue;
1555                    }
1556                    break;
1557                }
1558                // TODO: special-case the scenario when the cover doesn't need
1559                // deduping (and hence we can do a straight scan: e.g. when the
1560                // cover is binding a superset of the primary key for the
1561                // table).
1562                drain_updates!(updates);
1563                // Restore the subsets we swapped out.
1564                binding_info.move_back_node(cover_atom, cover_node);
1565                for (_, atom, prober) in index_probers {
1566                    binding_info.move_back(atom, prober);
1567                }
1568            }
1569            JoinStage::FusedIntersectMat {
1570                cover,
1571                mode,
1572                bind,
1573                to_intersect,
1574            } => {
1575                let cover_mat = binding_info.materializations[*cover].clone();
1576                let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
1577                let probers = to_intersect
1578                    .iter()
1579                    .map(|(spec, _)| {
1580                        self.get_index(
1581                            atoms,
1582                            spec.to_index.atom,
1583                            binding_info,
1584                            spec.to_index.vars.iter().copied(),
1585                        )
1586                    })
1587                    .collect::<SmallVec<[Prober; 4]>>();
1588                // Pre-compute has_stale per prober to avoid vtable calls in the hot loop.
1589                let probers_has_stale: SmallVec<[bool; 4]> = to_intersect
1590                    .iter()
1591                    .map(|(spec, _)| {
1592                        self.db.tables[atoms[spec.to_index.atom].table]
1593                            .table
1594                            .as_ref()
1595                            .has_stale_rows()
1596                    })
1597                    .collect();
1598
1599                let mut key = Vec::with_capacity(4);
1600                let mut prune_probers = |updates: &mut FrameUpdates,
1601                                         mat_key: Option<&[Value]>,
1602                                         mat_non_key: Option<&[Value]>|
1603                 -> bool {
1604                    for (j, ((spec, cols), prober)) in
1605                        to_intersect.iter().zip(probers.iter()).enumerate()
1606                    {
1607                        key.clear();
1608                        for col in cols.iter() {
1609                            let val = match mat_key {
1610                                Some(mat_key) => {
1611                                    if col.index() < mat_key.len() {
1612                                        mat_key[col.index()]
1613                                    } else {
1614                                        mat_non_key.unwrap()[col.index() - mat_key.len()]
1615                                    }
1616                                }
1617                                None => mat_non_key.unwrap()[col.index()],
1618                            };
1619                            key.push(val);
1620                        }
1621                        if let Some(subset) = prober.get_subset(&key) {
1622                            let subset = refine_subset(
1623                                subset.to_owned(pool),
1624                                &spec.constraints,
1625                                &self.db.tables[atoms[spec.to_index.atom].table]
1626                                    .table
1627                                    .as_ref(),
1628                                probers_has_stale[j],
1629                            );
1630                            if subset.is_empty() {
1631                                return false;
1632                            }
1633                            updates.refine_atom_subset(spec.to_index.atom, subset);
1634                        } else {
1635                            return false;
1636                        }
1637                    }
1638                    true
1639                };
1640
1641                match mode {
1642                    MatScanMode::Full | MatScanMode::KeyOnly => {
1643                        // enumerate keys
1644                        for group in cover_mat.iter() {
1645                            let group_key = group.0;
1646                            let group_val = group.1;
1647                            let group_key_len = group_key.len();
1648                            if mode == &MatScanMode::Full {
1649                                // enumerate non-keys
1650                                for non_keys in group_val.iter() {
1651                                    for (col, var) in bind.iter() {
1652                                        if col.index() < group_key_len {
1653                                            updates.push_binding(*var, group_key[col.index()]);
1654                                        }
1655                                    }
1656
1657                                    // TODO: optimization that guaratees all keys come before non-keys
1658                                    for (col, var) in bind.iter() {
1659                                        if col.index() >= group_key_len {
1660                                            updates.push_binding(
1661                                                *var,
1662                                                non_keys[col.index() - group_key_len],
1663                                            );
1664                                        }
1665                                    }
1666                                    if prune_probers(&mut updates, Some(group_key), Some(non_keys))
1667                                    {
1668                                        updates.finish_frame();
1669                                    } else {
1670                                        updates.rollback();
1671                                    }
1672                                }
1673                            } else if mode == &MatScanMode::KeyOnly {
1674                                for (col, var) in bind.iter() {
1675                                    debug_assert!(col.index() < group_key_len);
1676                                    updates.push_binding(*var, group_key[col.index()]);
1677                                }
1678                                if prune_probers(&mut updates, Some(group_key), None) {
1679                                    updates.finish_frame();
1680                                } else {
1681                                    updates.rollback();
1682                                }
1683                            }
1684                        }
1685                    }
1686                    MatScanMode::Value(index_vars) | MatScanMode::Lookup(index_vars) => {
1687                        let keys = index_vars
1688                            .iter()
1689                            .map(|var| binding_info.bindings[*var])
1690                            .collect::<Vec<Value>>();
1691                        // lookup keys
1692                        if let Some(group) = cover_mat.get(&keys) {
1693                            if matches!(mode, MatScanMode::Lookup(_)) {
1694                                debug_assert_eq!(to_intersect.len(), 0);
1695                                debug_assert_eq!(bind.len(), 0);
1696                                if group.len() > 0 {
1697                                    updates.finish_frame();
1698                                }
1699                                drain_updates!(updates);
1700                            } else {
1701                                // enumerate non-keys
1702                                // for vals in group.value().iter() {
1703                                for vals in group.iter() {
1704                                    debug_assert!(vals.len() == bind.len()); // TODO: not true for non-full query
1705                                    for (col, var) in bind.iter() {
1706                                        updates.push_binding(*var, vals[col.index()]);
1707                                    }
1708                                    if prune_probers(&mut updates, None, Some(vals)) {
1709                                        updates.finish_frame();
1710                                    } else {
1711                                        updates.rollback();
1712                                    }
1713                                    if updates.frames() >= chunk_size {
1714                                        drain_updates!(updates);
1715                                    }
1716                                }
1717                            }
1718                        }
1719                    }
1720                }
1721
1722                drain_updates!(updates);
1723                for (spec, prober) in to_intersect.iter().zip(probers) {
1724                    binding_info.move_back(spec.0.to_index.atom, prober);
1725                }
1726            }
1727        }
1728    }
1729}
1730
1731const VAR_BATCH_SIZE: usize = 128;
1732
1733/// A trait used to abstract over different ways of buffering actions together
1734/// before running them.
1735///
1736/// This trait exists as a fairly ad-hoc wrapper over its two implementations.
1737/// It allows us to avoid duplicating the (somewhat monstrous) `run_plan` method
1738/// for serial and parallel modes.
1739trait ActionBuffer<'state, A: NumericId>: Send {
1740    type AsLocal<'a>: ActionBuffer<'state, A>
1741    where
1742        'state: 'a;
1743
1744    /// Expand the binding sets to individual bindings and
1745    /// call push_bindings
1746    fn push_bindings_factorized(
1747        &mut self,
1748        action: A,
1749        bindings: &mut DenseIdMap<Variable, Value>,
1750        binding_sets: &BindingSet,
1751        exec_state: &ExecutionState<'state>,
1752    ) {
1753        expand_binding_sets(self, action, bindings, binding_sets, 0, exec_state);
1754    }
1755
1756    /// Push the given bindings to be executed for the specified action. If this
1757    /// buffer has built up a sufficient batch size, it may execute
1758    /// `to_exec_state` and then execute the action.
1759    ///
1760    /// NB: `push_bindings` makes module-specific assumptions on what values are passed to
1761    /// `bindings` for a common `action`. This is not a general-purpose trait for that reason and
1762    /// it should not, in general, be used outside of this module.
1763    fn push_bindings(
1764        &mut self,
1765        action: A,
1766        bindings: &DenseIdMap<Variable, Value>,
1767        to_exec_state: impl FnMut() -> ExecutionState<'state>,
1768    );
1769
1770    /// Execute any remaining actions associated with this buffer.
1771    fn flush(&mut self, exec_state: &mut ExecutionState);
1772
1773    /// Execute `work`, potentially asynchronously, with a mutable reference to
1774    /// an action buffer, potentially handed off to a different thread.
1775    ///
1776    /// Callers [`BorrowedLocalState`] values that may be modified by work, or
1777    /// cloned first and then have a separate copy modified by `work`. Callers
1778    /// should assume that `local` _is_ modified synchronously.
1779    // NB: Earlier versions of this method had BorrowedLocalState be a generic instead, but this
1780    // ran into difficulties when we needed to pass multiple mutable references.
1781    fn recur<'local>(
1782        &mut self,
1783        local: BorrowedLocalState<'local>,
1784        to_exec_state: impl FnMut() -> ExecutionState<'state> + Send + 'state,
1785        work: impl for<'a> FnOnce(BorrowedLocalState<'a>, &mut Self::AsLocal<'a>) + Send + 'state,
1786    );
1787
1788    /// The unit at which you should batch updates passed to calls to `recur`,
1789    /// potentially depending on the current level of recursion.
1790    ///
1791    /// As of right now this is just a hard-coded value. We may change it in the
1792    /// future to fan out more at higher levels though.
1793    fn morsel_size(&mut self, _level: usize, _total: usize) -> usize {
1794        256
1795    }
1796
1797    /// Whether this buffer supports parallel drain operations.
1798    ///
1799    /// When `false`, `drain_updates` will use the serial path even at `cur <= 1`,
1800    /// avoiding the per-frame `ExecutionState::clone()` overhead.
1801    fn supports_parallel_drain(&self) -> bool {
1802        true
1803    }
1804}
1805
1806/// The action buffer we use if we are executing in a single-threaded
1807/// environment. It builds up local batches and then flushes them inline.
1808struct InPlaceActionBuffer<'a> {
1809    rule_set: &'a RuleSet,
1810    match_counter: &'a MatchCounter,
1811    batches: DenseIdMap<ActionId, ActionState>,
1812}
1813
1814impl<'a, 'outer: 'a> ActionBuffer<'a, ActionId> for InPlaceActionBuffer<'outer> {
1815    type AsLocal<'b>
1816        = Self
1817    where
1818        'a: 'b;
1819
1820    fn push_bindings(
1821        &mut self,
1822        action: ActionId,
1823        bindings: &DenseIdMap<Variable, Value>,
1824        mut to_exec_state: impl FnMut() -> ExecutionState<'a>,
1825    ) {
1826        let action_state = self.batches.get_or_default(action);
1827        action_state.n_runs += 1;
1828        action_state.len += 1;
1829        let action_info = &self.rule_set.actions[action];
1830        // SAFETY: `used_vars` is a constant per-rule. This module only ever calls it with
1831        // `bindings` produced by the same join.
1832        unsafe {
1833            action_state.bindings.push(bindings, &action_info.used_vars);
1834        }
1835        if action_state.len >= VAR_BATCH_SIZE {
1836            let mut state = to_exec_state();
1837            let succeeded = state.run_instrs(&action_info.instrs, &mut action_state.bindings);
1838            action_state.bindings.clear();
1839            self.match_counter.inc_matches(action, succeeded);
1840            action_state.len = 0;
1841        }
1842    }
1843
1844    fn flush(&mut self, exec_state: &mut ExecutionState) {
1845        flush_action_states(
1846            exec_state,
1847            &mut self.batches,
1848            self.rule_set,
1849            self.match_counter,
1850        );
1851    }
1852
1853    fn recur<'local>(
1854        &mut self,
1855        local: BorrowedLocalState<'local>,
1856        _to_exec_state: impl FnMut() -> ExecutionState<'a> + Send + 'a,
1857        work: impl for<'b> FnOnce(BorrowedLocalState<'b>, &mut Self) + Send + 'a,
1858    ) {
1859        work(local, self)
1860    }
1861
1862    fn supports_parallel_drain(&self) -> bool {
1863        false
1864    }
1865}
1866
1867/// An Action buffer that hands off batches to of actions to rayon to execute.
1868struct ScopedActionBuffer<'inner, 'scope> {
1869    scope: &'inner rayon::Scope<'scope>,
1870    rule_set: &'scope RuleSet,
1871    match_counter: Arc<MatchCounter>,
1872    batches: DenseIdMap<ActionId, ActionState>,
1873    needs_flush: bool,
1874}
1875
1876impl<'inner, 'scope> ScopedActionBuffer<'inner, 'scope> {
1877    fn new(
1878        scope: &'inner rayon::Scope<'scope>,
1879        rule_set: &'scope RuleSet,
1880        match_counter: Arc<MatchCounter>,
1881    ) -> Self {
1882        Self {
1883            scope,
1884            rule_set,
1885            batches: Default::default(),
1886            match_counter,
1887            needs_flush: false,
1888        }
1889    }
1890}
1891
1892impl<'scope> ActionBuffer<'scope, ActionId> for ScopedActionBuffer<'_, 'scope> {
1893    type AsLocal<'a>
1894        = ScopedActionBuffer<'a, 'scope>
1895    where
1896        'scope: 'a;
1897    fn push_bindings(
1898        &mut self,
1899        action: ActionId,
1900        bindings: &DenseIdMap<Variable, Value>,
1901        mut to_exec_state: impl FnMut() -> ExecutionState<'scope>,
1902    ) {
1903        self.needs_flush = true;
1904        let action_state = self.batches.get_or_default(action);
1905        action_state.n_runs += 1;
1906        action_state.len += 1;
1907        let action_info = &self.rule_set.actions[action];
1908        // SAFETY: `used_vars` is a constant per-rule. This module only ever calls it with
1909        // `bindings` produced by the same join.
1910        unsafe {
1911            action_state.bindings.push(bindings, &action_info.used_vars);
1912        }
1913        if action_state.len >= VAR_BATCH_SIZE {
1914            let mut state = to_exec_state();
1915            let mut bindings =
1916                mem::replace(&mut action_state.bindings, Bindings::new(VAR_BATCH_SIZE));
1917            action_state.len = 0;
1918            let match_counter = self.match_counter.clone();
1919            self.scope.spawn(move |_| {
1920                let succeeded = state.run_instrs(&action_info.instrs, &mut bindings);
1921                match_counter.inc_matches(action, succeeded);
1922            });
1923        }
1924    }
1925
1926    fn flush(&mut self, exec_state: &mut ExecutionState) {
1927        flush_action_states(
1928            exec_state,
1929            &mut self.batches,
1930            self.rule_set,
1931            self.match_counter.as_ref(),
1932        );
1933        self.needs_flush = false;
1934    }
1935    fn recur<'local>(
1936        &mut self,
1937        mut local: BorrowedLocalState<'local>,
1938        mut to_exec_state: impl FnMut() -> ExecutionState<'scope> + Send + 'scope,
1939        work: impl for<'a> FnOnce(BorrowedLocalState<'a>, &mut ScopedActionBuffer<'a, 'scope>)
1940        + Send
1941        + 'scope,
1942    ) {
1943        let rule_set = self.rule_set;
1944        let match_counter = self.match_counter.clone();
1945        let mut inner = local.clone_state();
1946        self.scope.spawn(move |scope| {
1947            let mut buf: ScopedActionBuffer<'_, 'scope> = ScopedActionBuffer {
1948                scope,
1949                rule_set,
1950                match_counter,
1951                needs_flush: false,
1952                batches: Default::default(),
1953            };
1954            work(inner.borrow_mut(), &mut buf);
1955            if buf.needs_flush {
1956                flush_action_states(
1957                    &mut to_exec_state(),
1958                    &mut buf.batches,
1959                    buf.rule_set,
1960                    buf.match_counter.as_ref(),
1961                );
1962            }
1963        });
1964    }
1965
1966    fn morsel_size(&mut self, _level: usize, _total: usize) -> usize {
1967        // Lower morsel size to increase parallelism.
1968        match _level {
1969            0 if _total > 2 => 32,
1970            _ => 256,
1971        }
1972    }
1973}
1974
1975fn expand_binding_sets<'state, A: NumericId, BUF: ActionBuffer<'state, A> + ?Sized>(
1976    action_buf: &mut BUF,
1977    action: A,
1978    bindings: &mut DenseIdMap<Variable, Value>,
1979    binding_sets: &BindingSet,
1980    idx: usize,
1981    exec_state: &ExecutionState<'state>,
1982) {
1983    if exec_state.should_stop() {
1984        return;
1985    }
1986    if idx >= binding_sets.len() {
1987        action_buf.push_bindings(action, bindings, || exec_state.clone());
1988        return;
1989    }
1990    if idx + 1 == binding_sets.len() {
1991        let (vars, buf) = &binding_sets[idx];
1992        for (_, row) in buf.iter() {
1993            if exec_state.should_stop() {
1994                return;
1995            }
1996            for (var, val) in vars.iter().zip(row.iter()) {
1997                bindings.insert(*var, *val);
1998            }
1999            action_buf.push_bindings(action, bindings, || exec_state.clone());
2000        }
2001        return;
2002    }
2003    let (vars, buf) = &binding_sets[idx];
2004    for (_, row) in buf.iter() {
2005        for (var, val) in vars.iter().zip(row.iter()) {
2006            bindings.insert(*var, *val);
2007        }
2008        expand_binding_sets(
2009            action_buf,
2010            action,
2011            bindings,
2012            binding_sets,
2013            idx + 1,
2014            exec_state,
2015        );
2016    }
2017}
2018
2019fn flush_action_states(
2020    exec_state: &mut ExecutionState,
2021    actions: &mut DenseIdMap<ActionId, ActionState>,
2022    rule_set: &RuleSet,
2023    match_counter: &MatchCounter,
2024) {
2025    for (action, ActionState { bindings, len, .. }) in actions.iter_mut() {
2026        if *len > 0 {
2027            let succeeded = exec_state.run_instrs(&rule_set.actions[action].instrs, bindings);
2028            bindings.clear();
2029            match_counter.inc_matches(action, succeeded);
2030            *len = 0;
2031        }
2032    }
2033}
2034
2035struct InPlaceMaterializer<'a> {
2036    specs: &'a DenseIdMap<MatId, MatSpec>,
2037    materializations: DenseIdMap<MatId, IndexMap<Vec<Value>, RowBuffer>>,
2038    scratch_key: Vec<Value>,
2039    scratch_val: Vec<Value>,
2040}
2041
2042impl<'a> ActionBuffer<'a, MatId> for InPlaceMaterializer<'a> {
2043    type AsLocal<'b>
2044        = Self
2045    where
2046        'a: 'b;
2047
2048    fn push_bindings(
2049        &mut self,
2050        mat_id: MatId,
2051        bindings: &DenseIdMap<Variable, Value>,
2052        _to_exec_state: impl FnMut() -> ExecutionState<'a>,
2053    ) {
2054        let mat = self
2055            .materializations
2056            .get_mut(mat_id)
2057            .expect("invalid mat id");
2058        let spec = self.specs.get(mat_id).expect("invalid mat id");
2059        self.scratch_key.clear();
2060        for key in spec.msg_vars.iter().map(|var| bindings[*var]) {
2061            self.scratch_key.push(key);
2062        }
2063        self.scratch_val.clear();
2064        for val in spec.val_vars.iter().map(|var| bindings[*var]) {
2065            self.scratch_val.push(val);
2066        }
2067        if self.scratch_val.is_empty() {
2068            self.scratch_val.push(Value::stale());
2069        }
2070        if let Some(buffer) = mat.get_mut(&self.scratch_key) {
2071            buffer.add_row(&self.scratch_val);
2072        } else {
2073            let mut buffer = RowBuffer::new(usize::max(spec.val_vars.len(), 1));
2074            buffer.add_row(&self.scratch_val);
2075            mat.insert(self.scratch_key.clone(), buffer);
2076        }
2077    }
2078
2079    fn flush(&mut self, _exec_state: &mut ExecutionState) {
2080        // No-op for in-place materializer.
2081    }
2082
2083    fn recur<'local>(
2084        &mut self,
2085        local: BorrowedLocalState<'local>,
2086        _to_exec_state: impl FnMut() -> ExecutionState<'a> + Send + 'a,
2087        work: impl for<'b> FnOnce(BorrowedLocalState<'b>, &mut Self) + Send + 'a,
2088    ) {
2089        work(local, self)
2090    }
2091
2092    fn supports_parallel_drain(&self) -> bool {
2093        false
2094    }
2095}
2096
2097struct ScopedMaterializer<'inner, 'scope> {
2098    scope: &'inner rayon::Scope<'scope>,
2099    specs: Arc<DenseIdMap<MatId, MatSpec>>,
2100    materializations: Arc<DenseIdMap<MatId, Arc<DashMap<Vec<Value>, RowBuffer>>>>,
2101    scratch_key: Vec<Value>,
2102    scratch_val: Vec<Value>,
2103}
2104impl<'scope> ActionBuffer<'scope, MatId> for ScopedMaterializer<'_, 'scope> {
2105    type AsLocal<'a>
2106        = ScopedMaterializer<'a, 'scope>
2107    where
2108        'scope: 'a;
2109
2110    fn push_bindings(
2111        &mut self,
2112        mat_id: MatId,
2113        bindings: &DenseIdMap<Variable, Value>,
2114        _to_exec_state: impl FnMut() -> ExecutionState<'scope>,
2115    ) {
2116        let mat = self.materializations.get(mat_id).expect("invalid mat id");
2117        let spec = self.specs.get(mat_id).expect("invalid mat id");
2118        self.scratch_key.clear();
2119        for key in spec.msg_vars.iter().map(|var| bindings[*var]) {
2120            self.scratch_key.push(key);
2121        }
2122        self.scratch_val.clear();
2123        for val in spec.val_vars.iter().map(|var| bindings[*var]) {
2124            self.scratch_val.push(val);
2125        }
2126        if self.scratch_val.is_empty() {
2127            self.scratch_val.push(Value::stale());
2128        }
2129        let key = self.scratch_key.clone();
2130        match mat.entry(key) {
2131            Entry::Occupied(mut occ) => {
2132                occ.get_mut().add_row(&self.scratch_val);
2133            }
2134            Entry::Vacant(vac) => {
2135                let mut buffer = RowBuffer::new(usize::max(spec.val_vars.len(), 1));
2136                buffer.add_row(&self.scratch_val);
2137                vac.insert(buffer);
2138            }
2139        }
2140    }
2141
2142    fn flush(&mut self, _exec_state: &mut ExecutionState) {
2143        // No-op for scoped materializer since we always write to the materialization in-place.
2144    }
2145
2146    fn recur<'local>(
2147        &mut self,
2148        mut local: BorrowedLocalState<'local>,
2149        _to_exec_state: impl FnMut() -> ExecutionState<'scope> + Send + 'scope,
2150        work: impl for<'a> FnOnce(BorrowedLocalState<'a>, &mut ScopedMaterializer<'a, 'scope>)
2151        + Send
2152        + 'scope,
2153    ) {
2154        let scope = self.scope;
2155        let specs = self.specs.clone();
2156        let materializations = self.materializations.clone();
2157        let mut inner = local.clone_state();
2158        scope.spawn(move |scope| {
2159            let mut buf: ScopedMaterializer<'_, 'scope> = ScopedMaterializer {
2160                scope,
2161                specs,
2162                materializations: materializations.clone(),
2163                scratch_key: Vec::new(),
2164                scratch_val: Vec::new(),
2165            };
2166            work(inner.borrow_mut(), &mut buf);
2167        });
2168    }
2169}
2170
2171struct MatchCounter {
2172    matches: IdVec<ActionId, CachePadded<AtomicUsize>>,
2173}
2174
2175impl MatchCounter {
2176    fn new(n_ids: usize) -> Self {
2177        let mut matches = IdVec::with_capacity(n_ids);
2178        matches.resize_with(n_ids, || CachePadded::new(AtomicUsize::new(0)));
2179        Self { matches }
2180    }
2181
2182    fn inc_matches(&self, action: ActionId, by: usize) {
2183        self.matches[action].fetch_add(by, std::sync::atomic::Ordering::Relaxed);
2184    }
2185    fn read_matches(&self, action: ActionId) -> usize {
2186        self.matches[action].load(std::sync::atomic::Ordering::Acquire)
2187    }
2188}
2189
2190fn estimate_size(join_stage: &JoinStage, binding_info: &BindingInfo) -> usize {
2191    match join_stage {
2192        JoinStage::Intersect { scans, .. } => scans
2193            .iter()
2194            .map(|scan| binding_info.subsets[scan.atom].size())
2195            .min()
2196            .unwrap_or(0),
2197        JoinStage::FusedIntersect { cover, .. } => binding_info.subsets[cover.to_index.atom].size(),
2198        JoinStage::FusedIntersectMat { cover, .. } => binding_info.materializations[*cover].len(), // TODO: len() might be expensive.
2199    }
2200}
2201
2202fn num_intersected_rels(join_stage: &JoinStage) -> i32 {
2203    match join_stage {
2204        JoinStage::Intersect { scans, .. } => scans.len() as i32,
2205        JoinStage::FusedIntersect { to_intersect, .. } => to_intersect.len() as i32 + 1,
2206        JoinStage::FusedIntersectMat { to_intersect, .. } => to_intersect.len() as i32,
2207    }
2208}
2209
2210fn sort_plan_by_size(
2211    order: &mut InstrOrder,
2212    start: usize,
2213    instrs: &[JoinStage],
2214    binding_info: &mut BindingInfo,
2215) {
2216    let mut last_pos = start;
2217    for i in start..instrs.len() {
2218        if matches!(
2219            &instrs[i],
2220            // These nodes don't commute
2221            JoinStage::FusedIntersectMat {
2222                mode: MatScanMode::Lookup(_) | MatScanMode::Value(_) | MatScanMode::Full,
2223                ..
2224            }
2225        ) {
2226            sort_plan_by_size_inner(order, last_pos..i, instrs, binding_info);
2227            last_pos = i + 1;
2228        }
2229    }
2230    sort_plan_by_size_inner(order, last_pos..instrs.len(), instrs, binding_info);
2231}
2232
2233fn sort_plan_by_size_inner(
2234    order: &mut InstrOrder,
2235    range: Range<usize>,
2236    instrs: &[JoinStage],
2237    binding_info: &mut BindingInfo,
2238) {
2239    // Nothing to sort if there's 0 or 1 element.
2240    if range.len() <= 1 {
2241        return;
2242    }
2243    // How many times an atom has been intersected/joined
2244    let mut times_refined = with_pool_set(|ps| ps.get::<DenseIdMap<AtomId, i64>>());
2245
2246    // Count how many times each atom has been refined so far.
2247    for ins in instrs[..range.start].iter() {
2248        match ins {
2249            JoinStage::Intersect { scans, .. } => scans.iter().for_each(|scan| {
2250                *times_refined.get_or_default(scan.atom) += 1;
2251            }),
2252            JoinStage::FusedIntersect {
2253                cover,
2254                to_intersect,
2255                ..
2256            } => {
2257                *times_refined.get_or_default(cover.to_index.atom) +=
2258                    cover.to_index.vars.len() as i64;
2259                to_intersect.iter().for_each(|(spec, _)| {
2260                    *times_refined.get_or_default(spec.to_index.atom) +=
2261                        spec.to_index.vars.len() as i64;
2262                });
2263            }
2264            JoinStage::FusedIntersectMat { to_intersect, .. } => {
2265                to_intersect.iter().for_each(|(spec, _)| {
2266                    *times_refined.get_or_default(spec.to_index.atom) +=
2267                        spec.to_index.vars.len() as i64;
2268                });
2269            }
2270        }
2271    }
2272
2273    // We prioritize variables by
2274    //
2275    //   (1) how many times an atom with this variable has been refined,
2276    //   (2) then by the cardinality of the variable to be enumerated (smaller → earlier)
2277    //   (3) then by how many relations join on this variable (more → earlier)
2278    //
2279    // Estimate size is second so that stages with very small cardinality (e.g. FunDep
2280    // consequents with exactly 1 value) are run before multi-relation stages that happen
2281    // to have a larger current estimate.
2282    let key_fn = |join_stage: &JoinStage,
2283                  binding_info: &BindingInfo,
2284                  times_refined: &DenseIdMap<AtomId, i64>| {
2285        let refine = match join_stage {
2286            JoinStage::Intersect { scans, .. } => scans
2287                .iter()
2288                .map(|scan| times_refined.get(scan.atom).copied().unwrap_or_default())
2289                .max()
2290                .unwrap(),
2291            JoinStage::FusedIntersect { cover, .. } => times_refined
2292                .get(cover.to_index.atom)
2293                .copied()
2294                .unwrap_or_default(),
2295            JoinStage::FusedIntersectMat { bind, .. } => bind.len() as _,
2296        };
2297        (
2298            -refine,
2299            estimate_size(join_stage, binding_info),
2300            -num_intersected_rels(join_stage),
2301        )
2302    };
2303
2304    for i in range.clone() {
2305        let mut key_i = key_fn(&instrs[order.get(i)], binding_info, &times_refined);
2306        for j in (i + 1)..range.end {
2307            let key_j = key_fn(&instrs[order.get(j)], binding_info, &times_refined);
2308            if key_j < key_i {
2309                order.data.swap(i, j);
2310                key_i = key_j;
2311            }
2312        }
2313        // Update the counts after a new instruction is selected.
2314        match &instrs[order.get(i)] {
2315            JoinStage::Intersect { scans, .. } => scans.iter().for_each(|scan| {
2316                *times_refined.get_or_default(scan.atom) += 1;
2317            }),
2318            JoinStage::FusedIntersect {
2319                cover,
2320                to_intersect,
2321                ..
2322            } => {
2323                *times_refined.get_or_default(cover.to_index.atom) +=
2324                    cover.to_index.vars.len() as i64;
2325
2326                to_intersect.iter().for_each(|(spec, _)| {
2327                    *times_refined.get_or_default(spec.to_index.atom) +=
2328                        spec.to_index.vars.len() as i64;
2329                });
2330            }
2331            JoinStage::FusedIntersectMat { to_intersect, .. } => {
2332                to_intersect.iter().for_each(|(spec, _)| {
2333                    *times_refined.get_or_default(spec.to_index.atom) +=
2334                        spec.to_index.vars.len() as i64;
2335                });
2336            }
2337        }
2338    }
2339}
2340
2341#[derive(Debug, Clone, PartialEq, Eq)]
2342struct InstrOrder {
2343    data: SmallVec<[u16; 8]>,
2344}
2345
2346impl InstrOrder {
2347    fn new() -> Self {
2348        InstrOrder {
2349            data: SmallVec::new(),
2350        }
2351    }
2352
2353    fn from_iter(range: impl Iterator<Item = usize>) -> InstrOrder {
2354        let mut res = InstrOrder::new();
2355        res.data
2356            .extend(range.map(|x| u16::try_from(x).expect("too many instructions")));
2357        res
2358    }
2359
2360    fn get(&self, idx: usize) -> usize {
2361        self.data[idx] as usize
2362    }
2363    fn len(&self) -> usize {
2364        self.data.len()
2365    }
2366}
2367
2368struct BorrowedLocalState<'a> {
2369    instr_order: &'a mut InstrOrder,
2370    binding_info: &'a mut BindingInfo,
2371    updates: &'a mut FrameUpdates,
2372}
2373
2374impl BorrowedLocalState<'_> {
2375    fn clone_state(&mut self) -> LocalState {
2376        LocalState {
2377            instr_order: self.instr_order.clone(),
2378            binding_info: self.binding_info.clone(),
2379            updates: std::mem::take(self.updates),
2380        }
2381    }
2382}
2383
2384struct LocalState {
2385    instr_order: InstrOrder,
2386    binding_info: BindingInfo,
2387    updates: FrameUpdates,
2388}
2389
2390impl LocalState {
2391    fn borrow_mut<'a>(&'a mut self) -> BorrowedLocalState<'a> {
2392        BorrowedLocalState {
2393            instr_order: &mut self.instr_order,
2394            binding_info: &mut self.binding_info,
2395            updates: &mut self.updates,
2396        }
2397    }
2398}