egglog_core_relations/free_join/
execute.rs

1//! Core free join execution.
2
3use std::{
4    cmp, iter, mem,
5    ops::Range,
6    sync::{Arc, OnceLock, RwLock, atomic::AtomicUsize},
7};
8
9use crate::{
10    common::{HashMap, IndexMap},
11    free_join::plan::{JoinStages, MatId, MatScanMode, MatSpec},
12    numeric_id::{DenseIdMap, IdVec, NumericId},
13    query::Atom,
14    row_buffer::{RowBuffer, SmallValueVec},
15};
16use crossbeam::utils::CachePadded;
17use dashmap::mapref::entry::Entry;
18use dashmap::mapref::one::RefMut;
19use egglog_reports::{ReportLevel, RuleReport, RuleSetReport};
20use smallvec::SmallVec;
21use web_time::Instant;
22
23use crate::{
24    Constraint, OffsetRange, Pool, SubsetRef,
25    action::{Bindings, ExecutionState},
26    common::{DashMap, Value},
27    free_join::{
28        frame_update::{FrameUpdates, UpdateInstr},
29        get_index_from_tableinfo,
30    },
31    hash_index::{ColumnIndex, IndexBase, TupleIndex},
32    offsets::{Offsets, RowId, SortedOffsetSlice, SortedOffsetVector, Subset},
33    parallel_heuristics::parallelize_db_level_op,
34    pool::Pooled,
35    query::RuleSet,
36    row_buffer::TaggedRowBuffer,
37    table_spec::{ColumnId, Offset, WrappedTableRef},
38};
39
40use super::{
41    ActionId, AtomId, Database, HashColumnIndex, HashIndex, TableInfo, Variable,
42    get_column_index_from_tableinfo,
43    plan::{JoinHeader, JoinStage, Plan},
44    with_pool_set,
45};
46
47const SMALL_RESIDUAL: usize = 8;
48
49struct SparseColumnIndex {
50    n_keys: usize,
51    n_subsets: usize,
52    keys: [Value; SMALL_RESIDUAL],
53    offsets: [usize; SMALL_RESIDUAL],
54    subset_ids: [RowId; SMALL_RESIDUAL],
55}
56
57/// Return a SubsetRef for the given range of rows in a SparseColumnIndex.
58/// Single-row ranges become Dense to skip pool allocation in to_owned.
59///
60/// # Safety
61/// `ids[range]` must be sorted in non-decreasing order. The wider `ids` slice
62/// need not be sorted as a whole; only the indicated sub-range. This is the
63/// invariant of `SortedOffsetSlice::new_unchecked`.
64#[inline]
65unsafe fn sparse_subset_ref(ids: &[RowId], range: Range<usize>) -> SubsetRef<'_> {
66    if range.len() == 1 {
67        let row = ids[range.start];
68        SubsetRef::Dense(OffsetRange::new(row, row.inc()))
69    } else {
70        // SAFETY: caller guarantees `ids[range]` is sorted.
71        SubsetRef::Sparse(unsafe { SortedOffsetSlice::new_unchecked(&ids[range]) })
72    }
73}
74
75impl SparseColumnIndex {
76    fn keys(&self) -> &[Value] {
77        &self.keys[..self.n_keys]
78    }
79
80    fn get_offset_for(&self, i: usize) -> Range<usize> {
81        let lo = self.offsets[i];
82        let hi = if i + 1 < self.n_keys {
83            self.offsets[i + 1]
84        } else {
85            self.n_subsets
86        };
87        lo..hi
88    }
89
90    fn new(table: WrappedTableRef<'_>, subset: SubsetRef<'_>, col: ColumnId) -> Self {
91        let mut rows = [(Value::new_const(0), RowId::new_const(0)); SMALL_RESIDUAL];
92        let mut pos = 0;
93        table.for_each_col(subset, col, &mut |row_id, val| {
94            rows[pos] = (val, row_id);
95            pos += 1;
96        });
97        let n_subsets = pos;
98
99        rows[..pos].sort_unstable();
100
101        let mut n_keys = 0;
102        let mut keys = [Value::new_const(0); SMALL_RESIDUAL];
103        let mut offsets = [0; SMALL_RESIDUAL];
104        let mut subset_ids = [RowId::new_const(0); SMALL_RESIDUAL];
105        offsets[0] = 0;
106
107        for (i, &(key, row_id)) in rows[..n_subsets].iter().enumerate() {
108            let is_new_key = n_keys == 0 || keys[n_keys - 1] != key;
109            if is_new_key {
110                offsets[n_keys] = i;
111                keys[n_keys] = key;
112                n_keys += 1;
113            }
114            subset_ids[i] = row_id;
115        }
116
117        SparseColumnIndex {
118            n_keys,
119            n_subsets,
120            keys,
121            offsets,
122            subset_ids,
123        }
124    }
125
126    fn get_subset(&self, key: Value) -> Option<SubsetRef<'_>> {
127        if self.n_keys == 0 {
128            return None;
129        }
130        let found = self.keys().binary_search(&key).ok()?;
131        let range = self.get_offset_for(found);
132        // SAFETY: `subset_ids` was populated from rows sorted by (Value, RowId),
133        // so RowIds within any single per-key range (as returned by
134        // `get_offset_for`) are in non-decreasing order.
135        Some(unsafe { sparse_subset_ref(&self.subset_ids, range) })
136    }
137
138    fn for_each(&self, mut f: impl FnMut(&[Value], SubsetRef)) {
139        if self.n_keys == 0 {
140            return;
141        }
142        for i in 0..self.n_keys {
143            let range = self.get_offset_for(i);
144            // SAFETY: see `get_subset` — each per-key range of `subset_ids` is sorted.
145            let subset = unsafe { sparse_subset_ref(&self.subset_ids, range) };
146            f(&self.keys[i..i + 1], subset);
147        }
148    }
149
150    fn len(&self) -> usize {
151        self.n_keys
152    }
153}
154
155enum DynamicIndex {
156    Cached {
157        /// When Some(range), intersect each subset from the index with this dense range.
158        /// The range is the Dense outer subset known at Prober construction time.
159        intersect_outer: Option<OffsetRange>,
160        table: HashIndex,
161    },
162    CachedColumn {
163        /// When Some(range), intersect each subset from the index with this dense range.
164        /// The range is the Dense outer subset known at Prober construction time.
165        intersect_outer: Option<OffsetRange>,
166        table: HashColumnIndex,
167    },
168    Dynamic(TupleIndex),
169    DynamicColumn(Arc<ColumnIndex>),
170    SparseColumn(SparseColumnIndex),
171}
172
173/// This struct is used to mark subsets that can contain non-stale entries.
174/// Whether a subset can be stale depends on the type of index it came from.
175/// Indices that come from a table may contain stale entries, while
176/// those that are built on the fly will not.
177struct PotentiallyStale<T> {
178    inner: T,
179    can_be_stale: bool,
180}
181
182impl<T> PotentiallyStale<T> {
183    fn maybe_stale(inner: T) -> Self {
184        Self {
185            inner,
186            can_be_stale: true,
187        }
188    }
189
190    fn not_stale(inner: T) -> Self {
191        Self {
192            inner,
193            can_be_stale: false,
194        }
195    }
196}
197
198impl PotentiallyStale<SubsetRef<'_>> {
199    fn size(&self) -> usize {
200        self.inner.size()
201    }
202
203    fn to_owned(&self, pool: &Pool<SortedOffsetVector>) -> PotentiallyStale<Subset> {
204        PotentiallyStale {
205            inner: self.inner.to_owned(pool),
206            can_be_stale: self.can_be_stale,
207        }
208    }
209}
210
211/// Intersect a `SubsetRef` with a dense `OffsetRange` and return the result as a
212/// borrowed `SubsetRef`, or `None` if the intersection is empty.
213///
214/// This function never allocates — it borrows into
215/// the source data via `subslice`. Use this in `for_each` paths where the result
216/// may be discarded (e.g., empty after refinement), to avoid pool allocations.
217#[inline]
218fn intersect_with_dense_ref<'a>(v: SubsetRef<'a>, range: OffsetRange) -> Option<SubsetRef<'a>> {
219    match v {
220        SubsetRef::Dense(r) => {
221            let resl = cmp::max(r.start, range.start);
222            let resr = cmp::min(r.end, range.end);
223            if resl >= resr {
224                None
225            } else {
226                Some(SubsetRef::Dense(OffsetRange::new(resl, resr)))
227            }
228        }
229        SubsetRef::Sparse(s) => {
230            let l = s.binary_search_by_id(range.start);
231            let r = s.binary_search_by_id(range.end);
232            if l >= r {
233                None
234            } else {
235                Some(SubsetRef::Sparse(s.subslice(l, r)))
236            }
237        }
238    }
239}
240
241struct Prober {
242    node: Arc<TrieNode>,
243    ix: DynamicIndex,
244}
245
246impl Prober {
247    fn get_subset<'a>(&'a self, key: &'a [Value]) -> Option<PotentiallyStale<SubsetRef<'a>>> {
248        match &self.ix {
249            DynamicIndex::Cached {
250                intersect_outer,
251                table,
252            } => {
253                let subset_ref = table.get().unwrap().get_subset(key)?;
254                let subset = if let Some(range) = intersect_outer {
255                    intersect_with_dense_ref(subset_ref, *range)?
256                } else {
257                    subset_ref
258                };
259                Some(PotentiallyStale::maybe_stale(subset))
260            }
261            DynamicIndex::CachedColumn {
262                intersect_outer,
263                table,
264            } => {
265                debug_assert_eq!(key.len(), 1);
266                let subset_ref = table.get().unwrap().get_subset(&key[0])?;
267                let subset = if let Some(range) = intersect_outer {
268                    intersect_with_dense_ref(subset_ref, *range)?
269                } else {
270                    subset_ref
271                };
272                Some(PotentiallyStale::maybe_stale(subset))
273            }
274            DynamicIndex::Dynamic(tab) => tab.get_subset(key).map(PotentiallyStale::not_stale),
275            DynamicIndex::DynamicColumn(tab) => {
276                tab.get_subset(&key[0]).map(PotentiallyStale::not_stale)
277            }
278            DynamicIndex::SparseColumn(tab) => {
279                debug_assert_eq!(key.len(), 1);
280                tab.get_subset(key[0]).map(PotentiallyStale::not_stale)
281            }
282        }
283    }
284    fn for_each(&self, mut f: impl FnMut(&[Value], PotentiallyStale<SubsetRef>)) {
285        match &self.ix {
286            DynamicIndex::Cached {
287                intersect_outer: Some(range),
288                table,
289            } => {
290                let range = *range;
291                table.get().unwrap().for_each(|k, v| {
292                    if let Some(res) = intersect_with_dense_ref(v, range) {
293                        f(k, PotentiallyStale::maybe_stale(res))
294                    }
295                });
296            }
297            DynamicIndex::Cached {
298                intersect_outer: None,
299                table,
300            } => table
301                .get()
302                .unwrap()
303                .for_each(|k, v| f(k, PotentiallyStale::maybe_stale(v))),
304            DynamicIndex::CachedColumn {
305                intersect_outer: Some(range),
306                table,
307            } => {
308                let range = *range;
309                table.get().unwrap().for_each(|k, v| {
310                    if let Some(res) = intersect_with_dense_ref(v, range) {
311                        f(&[*k], PotentiallyStale::maybe_stale(res))
312                    }
313                });
314            }
315            DynamicIndex::CachedColumn {
316                intersect_outer: None,
317                table,
318            } => {
319                table
320                    .get()
321                    .unwrap()
322                    .for_each(|k, v| f(&[*k], PotentiallyStale::maybe_stale(v)));
323            }
324            DynamicIndex::Dynamic(tab) => {
325                tab.for_each(|k, v| f(k, PotentiallyStale::not_stale(v)));
326            }
327            DynamicIndex::DynamicColumn(tab) => tab.for_each(|k, v| {
328                f(&[*k], PotentiallyStale::not_stale(v));
329            }),
330            DynamicIndex::SparseColumn(tab) => {
331                tab.for_each(|k, v| f(k, PotentiallyStale::not_stale(v)));
332            }
333        }
334    }
335
336    fn len(&self) -> usize {
337        match &self.ix {
338            DynamicIndex::Cached { table, .. } => table.get().unwrap().len(),
339            DynamicIndex::CachedColumn { table, .. } => table.get().unwrap().len(),
340            DynamicIndex::Dynamic(tab) => tab.len(),
341            DynamicIndex::DynamicColumn(tab) => tab.len(),
342            DynamicIndex::SparseColumn(tab) => tab.len(),
343        }
344    }
345}
346
347impl Database {
348    pub fn run_rule_set(&mut self, rule_set: &RuleSet, report_level: ReportLevel) -> RuleSetReport {
349        if rule_set.plans.is_empty() {
350            return RuleSetReport::default();
351        }
352        let match_counter = Arc::new(MatchCounter::new(rule_set.actions.n_ids()));
353
354        let search_and_apply_timer = Instant::now();
355        // let mut rule_reports: HashMap<String, Vec<RuleReport>>;
356        let mut rule_reports: HashMap<Arc<str>, Vec<RuleReport>>;
357        let exec_state = ExecutionState::new(self.read_only_view(), Default::default());
358        if parallelize_db_level_op(self.total_size_estimate) {
359            let dash_rule_reports: Arc<DashMap<Arc<str>, Vec<RuleReport>>> =
360                Arc::new(DashMap::default());
361            let db: &Database = self;
362            rayon::in_place_scope(|scope| {
363                for (plan, desc, symbol_map) in rule_set.plans.values() {
364                    // TODO: add stats
365                    let report_plan = match report_level {
366                        ReportLevel::TimeOnly => None,
367                        ReportLevel::WithPlan | ReportLevel::StageInfo => {
368                            Some(plan.to_report(symbol_map))
369                        }
370                    };
371
372                    let dash_rule_reports = dash_rule_reports.clone();
373                    let desc = desc.clone();
374                    let exec_state = exec_state.clone();
375                    let match_counter = match_counter.clone();
376                    scope.spawn(move |rule_scope| {
377                        let join_state = JoinState::new(db, exec_state.clone());
378                        let mut binding_info = BindingInfo::default();
379                        for (id, info) in plan.atoms().iter() {
380                            let table = join_state.db.get_table(info.table);
381                            binding_info.insert_subset(id, table.all());
382                        }
383                        let mut action_buf =
384                            ScopedActionBuffer::new(rule_scope, rule_set, match_counter.clone());
385                        let search_and_apply_timer = Instant::now();
386
387                        'eval: {
388                            for JoinHeader { atom, subset, .. } in plan.header() {
389                                if subset.is_empty() {
390                                    break 'eval;
391                                }
392                                let mut cur =
393                                    Arc::try_unwrap(binding_info.unwrap_val(*atom)).unwrap();
394                                debug_assert!(cur.cached_subsets.get().is_none());
395                                cur.subset.intersect(subset.as_ref(), &join_state.pool);
396                                if cur.subset.is_empty() {
397                                    break 'eval;
398                                }
399                                binding_info.move_back_node(*atom, Arc::new(cur));
400                            }
401
402                            match plan {
403                                Plan::SinglePlan(plan) => {
404                                    join_state.run_join_stages(
405                                        &plan.stages,
406                                        &plan.atoms,
407                                        plan.actions,
408                                        &mut binding_info,
409                                        &mut action_buf,
410                                    );
411                                }
412                                Plan::DecomposedPlan(plan) => {
413                                    let mut materializations: DenseIdMap<
414                                        MatId,
415                                        Arc<DashMap<Vec<Value>, RowBuffer>>,
416                                    > = DenseIdMap::with_capacity(plan.stages.blocks.len());
417                                    for i in 0..plan.stages.blocks.len() {
418                                        materializations.insert(
419                                            MatId::from_usize(i),
420                                            Arc::new(Default::default()),
421                                        );
422                                    }
423                                    let specs: Arc<DenseIdMap<MatId, MatSpec>> = Arc::new(
424                                        plan.stages
425                                            .blocks
426                                            .iter()
427                                            .enumerate()
428                                            .map(|(i, block)| {
429                                                (MatId::from_usize(i), block.1.clone())
430                                            })
431                                            .collect(),
432                                    );
433                                    let mut materializations = Arc::new(materializations);
434
435                                    for (mat_id, stage_block) in
436                                        plan.stages.blocks.iter().enumerate()
437                                    {
438                                        let mat_id = MatId::from_usize(mat_id);
439                                        rayon::in_place_scope(|stage_scope| {
440                                            let mut materializer = ScopedMaterializer {
441                                                scope: stage_scope,
442                                                specs: specs.clone(),
443                                                materializations: materializations.clone(),
444                                                scratch_key: Default::default(),
445                                                scratch_val: Default::default(),
446                                            };
447                                            join_state.run_join_stages(
448                                                &stage_block.0,
449                                                &plan.atoms,
450                                                mat_id,
451                                                &mut binding_info,
452                                                &mut materializer,
453                                            );
454                                        });
455                                        if materializations[mat_id].is_empty() {
456                                            break 'eval;
457                                        }
458                                        assert_eq!(Arc::strong_count(&materializations), 1);
459                                        let mut materializations_dearc =
460                                            Arc::unwrap_or_clone(materializations);
461                                        let materialization = mem::take(
462                                            Arc::get_mut(&mut materializations_dearc[mat_id])
463                                                .unwrap(),
464                                        )
465                                        .into_iter()
466                                        .collect::<IndexMap<_, _>>();
467                                        binding_info
468                                            .materializations
469                                            .insert(mat_id, Arc::new(materialization));
470                                        materializations = Arc::new(materializations_dearc);
471                                    }
472                                    join_state.run_join_stages(
473                                        &plan.result_block,
474                                        &plan.atoms,
475                                        plan.actions,
476                                        &mut binding_info,
477                                        &mut action_buf,
478                                    );
479                                }
480                            }
481                        }
482                        let search_and_apply_time = search_and_apply_timer.elapsed();
483                        if action_buf.needs_flush {
484                            action_buf.flush(&mut exec_state.clone());
485                        }
486                        let mut rule_report: RefMut<'_, Arc<str>, Vec<RuleReport>> =
487                            dash_rule_reports.entry(desc).or_default();
488                        rule_report.value_mut().push(RuleReport {
489                            plan: report_plan,
490                            search_and_apply_time,
491                            num_matches: usize::MAX,
492                        });
493                    });
494                }
495            });
496            rule_reports = dash_rule_reports
497                .iter()
498                .map(|entry| (entry.key().clone(), entry.value().clone()))
499                .collect();
500        } else {
501            rule_reports = HashMap::default();
502            let join_state = JoinState::new(self, exec_state.clone());
503            // Just run all of the plans in order with a single in-place action
504            // buffer.
505            let mut action_buf = InPlaceActionBuffer {
506                rule_set,
507                match_counter: match_counter.as_ref(),
508                batches: Default::default(),
509            };
510            for (plan, desc, symbol_map) in rule_set.plans.values() {
511                let report_plan = match report_level {
512                    ReportLevel::TimeOnly => None,
513                    ReportLevel::WithPlan | ReportLevel::StageInfo => {
514                        Some(plan.to_report(symbol_map))
515                    }
516                };
517                let mut binding_info = BindingInfo::default();
518
519                for (id, info) in plan.atoms().iter() {
520                    let table = join_state.db.get_table(info.table);
521                    binding_info.insert_subset(id, table.all());
522                }
523
524                let search_and_apply_timer = Instant::now();
525                'eval: {
526                    for JoinHeader { atom, subset, .. } in plan.header() {
527                        if subset.is_empty() {
528                            break 'eval;
529                        }
530                        // Before query execution, Arc<TrieNode> is owned exclusively by the binding info.
531                        // Trie nodes are only shared once we start running the query.
532                        let mut cur = Arc::try_unwrap(binding_info.unwrap_val(*atom)).unwrap();
533                        debug_assert!(cur.cached_subsets.get().is_none());
534                        cur.subset.intersect(subset.as_ref(), &join_state.pool);
535                        if cur.subset.is_empty() {
536                            break 'eval;
537                        }
538                        binding_info.move_back_node(*atom, Arc::new(cur));
539                    }
540                    match plan {
541                        Plan::SinglePlan(plan) => {
542                            join_state.run_join_stages(
543                                &plan.stages,
544                                &plan.atoms,
545                                plan.actions,
546                                &mut binding_info,
547                                &mut action_buf,
548                            );
549                        }
550                        Plan::DecomposedPlan(plan) => {
551                            let mut materializations =
552                                DenseIdMap::with_capacity(plan.stages.blocks.len());
553                            for i in 0..plan.stages.blocks.len() {
554                                materializations.insert(MatId::from_usize(i), Default::default());
555                            }
556                            let mut materializer = InPlaceMaterializer {
557                                specs: &plan
558                                    .stages
559                                    .blocks
560                                    .iter()
561                                    .enumerate()
562                                    .map(|(i, block)| (MatId::from_usize(i), block.1.clone()))
563                                    .collect(),
564                                materializations,
565                                scratch_key: Default::default(),
566                                scratch_val: Default::default(),
567                            };
568
569                            for (mat_id, stage_block) in plan.stages.blocks.iter().enumerate() {
570                                let mat_id = MatId::from_usize(mat_id);
571                                join_state.run_join_stages(
572                                    &stage_block.0,
573                                    &plan.atoms,
574                                    mat_id,
575                                    &mut binding_info,
576                                    &mut materializer,
577                                );
578                                if materializer.materializations[mat_id].is_empty() {
579                                    break 'eval;
580                                }
581                                binding_info.materializations.insert(
582                                    mat_id,
583                                    Arc::new(materializer.materializations.take(mat_id).unwrap()),
584                                );
585                            }
586                            join_state.run_join_stages(
587                                &plan.result_block,
588                                &plan.atoms,
589                                plan.actions,
590                                &mut binding_info,
591                                &mut action_buf,
592                            );
593                        }
594                    }
595                }
596                let search_and_apply_time = search_and_apply_timer.elapsed();
597
598                // TODO: unnecessary cloning in many cases
599                let rule_report = rule_reports.entry(desc.clone()).or_default();
600                rule_report.push(RuleReport {
601                    plan: report_plan,
602                    search_and_apply_time,
603                    num_matches: usize::MAX,
604                });
605            }
606            action_buf.flush(&mut exec_state.clone());
607        }
608
609        for (plan, desc, _symbol_map) in rule_set.plans.values() {
610            let reports = rule_reports.get_mut(desc).unwrap();
611            let i = reports
612                .iter()
613                // HACK: Since the order of visiting queries is fixed and # matches need to be obtained
614                // seperately from rule execution, we first set all # matches to be usize::MAX and then fill
615                // them in one by one.
616                .position(|r| r.num_matches == usize::MAX)
617                .unwrap();
618            // NB: This requires each action ID correspond to only one query.
619            // If an action is used by multiple queries, then we can't tell how many matches are
620            // caused by individual queries.
621            reports[i].num_matches = match_counter.read_matches(plan.actions());
622        }
623        let search_and_apply_time = search_and_apply_timer.elapsed();
624
625        let merge_timer = Instant::now();
626        let changed = self.merge_all();
627        let merge_time = merge_timer.elapsed();
628
629        RuleSetReport {
630            changed,
631            rule_reports,
632            search_and_apply_time,
633            merge_time,
634        }
635    }
636}
637
638struct ActionState {
639    n_runs: usize,
640    len: usize,
641    bindings: Bindings,
642}
643
644impl Default for ActionState {
645    fn default() -> Self {
646        Self {
647            n_runs: 0,
648            len: 0,
649            bindings: Bindings::new(VAR_BATCH_SIZE),
650        }
651    }
652}
653
654struct JoinState<'a> {
655    db: &'a Database,
656    exec_state: ExecutionState<'a>,
657    /// Cached thread-local pool for SortedOffsetVector allocations.
658    /// Stored here to avoid a per-call `with_pool_set` TLS access in `get_index`.
659    pool: Pool<SortedOffsetVector>,
660}
661
662/// Per-column indexes on a trie node's subset, lazily initialized on first access per column.
663type ColumnIndexes = IdVec<ColumnId, OnceLock<Arc<ColumnIndex>>>;
664// Each TrieNode is probed with exactly one column in practice, so we store a single
665// (ColumnId, map) pair instead of a per-column IdVec of Mutexes. Boxed to keep
666// TrieNode size small for the many short-lived TrieNodes that never need caching.
667type ChildrenMaps = IdVec<ColumnId, RwLock<HashMap<Value, Arc<TrieNode>>>>;
668
669/// Information about the current subset of an atom's relation that is being considered, along with
670/// lazily-initialized, cached indexes on that subset.
671///
672/// This is the standard trie-node used in lazy implementations of GJ as in the original egglog
673/// implementation and the FJ paper. It currently does not handle non-column indexes, but that
674/// should be a fairly straightforward extension if we start generating plans that need those.
675/// (Right now, most plans iterating over more than one column just do a scan anyway).
676pub(crate) struct TrieNode {
677    /// The actual subset of the corresponding atom.
678    subset: Subset,
679    /// Any cached indexes on this subset.
680    cached_subsets: OnceLock<Pooled<ColumnIndexes>>,
681    /// Cached child trie nodes, keyed by value. In practice each TrieNode is
682    /// only ever probed with a single column, so we store one (col, map) pair
683    /// instead of an IdVec across all columns.
684    cached_children: OnceLock<Pooled<ChildrenMaps>>,
685}
686
687impl std::fmt::Debug for TrieNode {
688    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
689        f.debug_struct("TrieNode")
690            .field("subset", &self.subset)
691            .finish()
692    }
693}
694
695impl TrieNode {
696    fn new(subset: Subset) -> Self {
697        Self {
698            subset,
699            cached_subsets: Default::default(),
700            cached_children: Default::default(),
701        }
702    }
703
704    fn size(&self) -> usize {
705        self.subset.size()
706    }
707    fn get_cached_index(&self, col: ColumnId, info: &TableInfo) -> Arc<ColumnIndex> {
708        self.cached_subsets.get_or_init(|| {
709            // Pre-size the vector so we do not need to borrow it mutably to initialize the index.
710            let mut vec: Pooled<ColumnIndexes> = with_pool_set(|ps| ps.get());
711            vec.resize_with(info.spec.arity(), OnceLock::new);
712            vec
713        })[col]
714            .get_or_init(|| {
715                let col_index = info.table.group_by_col(self.subset.as_ref(), col);
716                Arc::new(col_index)
717            })
718            .clone()
719    }
720
721    fn get_cached_trie_node(
722        &self,
723        col: ColumnId,
724        value: Value,
725        info: &TableInfo,
726        sub: impl FnOnce() -> Subset,
727    ) -> Arc<TrieNode> {
728        let map = &self.cached_children.get_or_init(|| {
729            let mut vec: Pooled<ChildrenMaps> = with_pool_set(|ps| ps.get());
730            vec.resize_with(info.spec.arity(), || RwLock::new(HashMap::default()));
731            vec
732        })[col];
733        // Optimistic read path: most calls are cache hits, so try shared lock first.
734        {
735            let guard = map.read().unwrap();
736            if let Some(node) = guard.get(&value) {
737                return node.clone();
738            }
739        }
740        // Cache miss: acquire write lock and insert.
741        let mut guard = map.write().unwrap();
742        // Double-check in case another thread inserted while we were waiting.
743        if let Some(node) = guard.get(&value) {
744            return node.clone();
745        }
746        let new_node = Arc::new(TrieNode::new(sub()));
747        guard.insert(value, new_node.clone());
748        new_node
749    }
750}
751
752impl FrameUpdates {
753    /// Refine `atom` to `subset`, using the dense fast path to avoid an
754    /// `Arc<TrieNode>` allocation when the subset is already a contiguous range.
755    fn refine_atom_subset(&mut self, atom: AtomId, subset: Subset) {
756        match subset {
757            Subset::Dense(range) => self.refine_atom_dense(atom, range),
758            sub => self.refine_atom(atom, Arc::new(TrieNode::new(sub))),
759        }
760    }
761}
762
763type BindingSet = Vec<(SmallVec<[Variable; 4]>, Arc<TaggedRowBuffer<SmallValueVec>>)>;
764
765#[derive(Default, Clone)]
766struct BindingInfo {
767    bindings: DenseIdMap<Variable, Value>,
768    binding_sets: BindingSet,
769    subsets: DenseIdMap<AtomId, Arc<TrieNode>>,
770    materializations: DenseIdMap<MatId, Arc<IndexMap<Vec<Value>, RowBuffer>>>,
771}
772
773impl BindingInfo {
774    /// Initializes the atom-related metadata in the [`BindingInfo`].    
775    fn insert_subset(&mut self, atom: AtomId, subset: Subset) {
776        if let Some(slot) = self.subsets.get_mut(atom)
777            && let Some(node) = Arc::get_mut(slot)
778        {
779            node.cached_subsets.take();
780            node.cached_children.take();
781            node.subset = subset;
782            return;
783        }
784        self.subsets.insert(atom, Arc::new(TrieNode::new(subset)));
785    }
786
787    fn insert_node(&mut self, atom: AtomId, node: Arc<TrieNode>) {
788        self.subsets.insert(atom, node);
789    }
790
791    /// Probers returned from [`JoinState::get_index`] will move atom-related state out of the
792    /// [`BindingInfo`]. Once the caller is done using a prober, this method moves it back.
793    fn move_back(&mut self, atom: AtomId, prober: Prober) {
794        self.subsets.insert(atom, prober.node);
795    }
796
797    fn move_back_node(&mut self, atom: AtomId, node: Arc<TrieNode>) {
798        self.subsets.insert(atom, node);
799    }
800
801    fn has_empty_subset(&self, atom: AtomId) -> bool {
802        self.subsets[atom].subset.is_empty()
803    }
804
805    fn unwrap_val(&mut self, atom: AtomId) -> Arc<TrieNode> {
806        self.subsets.unwrap_val(atom)
807    }
808}
809
810impl<'a> JoinState<'a> {
811    fn new(db: &'a Database, exec_state: ExecutionState<'a>) -> Self {
812        Self {
813            db,
814            exec_state,
815            pool: with_pool_set(|ps| ps.get_pool()),
816        }
817    }
818
819    fn get_index(
820        &self,
821        atoms: &Arc<DenseIdMap<AtomId, Atom>>,
822        atom: AtomId,
823        binding_info: &mut BindingInfo,
824        cols: impl Iterator<Item = ColumnId>,
825    ) -> Prober {
826        let cols = SmallVec::<[ColumnId; 4]>::from_iter(cols);
827        let trie_node = binding_info.subsets.unwrap_val(atom);
828        let subset = &trie_node.subset;
829
830        let table_id = atoms[atom].table;
831        let info = &self.db.tables[table_id];
832        let dyn_index = if subset.size() <= SMALL_RESIDUAL && cols.len() == 1 {
833            DynamicIndex::SparseColumn(SparseColumnIndex::new(
834                info.table.as_ref(),
835                subset.as_ref(),
836                cols[0],
837            ))
838        } else {
839            let all_cacheable = cols.iter().all(|col| {
840                !info
841                    .spec
842                    .uncacheable_columns
843                    .get(*col)
844                    .copied()
845                    .unwrap_or(false)
846            });
847            let whole_table = info.table.all();
848            if let Subset::Dense(range) = subset
849                && all_cacheable
850                && whole_table.size() / 2 < subset.size()
851            {
852                // Skip intersecting with the subset if we are just looking at the
853                // whole table.
854                let needs_intersect =
855                    !(whole_table.is_dense() && subset.bounds() == whole_table.bounds());
856                // When intersecting, store the Dense range directly so we can do a
857                // combined copy+filter without a runtime match on subset type later.
858                let intersect_outer = if needs_intersect { Some(*range) } else { None };
859                // heuristic: if the subset we are scanning is somewhat
860                // large _or_ it is most of the table, or we already have a cached
861                // index for it, then return it.
862                if cols.len() != 1 {
863                    DynamicIndex::Cached {
864                        intersect_outer,
865                        table: get_index_from_tableinfo(info, &cols),
866                    }
867                } else {
868                    DynamicIndex::CachedColumn {
869                        intersect_outer,
870                        table: get_column_index_from_tableinfo(info, cols[0]).clone(),
871                    }
872                }
873            } else if cols.len() != 1 {
874                // NB: we should have a caching strategy for non-column indexes.
875                DynamicIndex::Dynamic(info.table.group_by_key(subset.as_ref(), &cols))
876            } else {
877                DynamicIndex::DynamicColumn(trie_node.get_cached_index(cols[0], info))
878            }
879        };
880        Prober {
881            node: trie_node,
882            ix: dyn_index,
883        }
884    }
885    fn get_column_index(
886        &self,
887        atoms: &Arc<DenseIdMap<AtomId, Atom>>,
888        binding_info: &mut BindingInfo,
889        atom: AtomId,
890        col: ColumnId,
891    ) -> Prober {
892        self.get_index(atoms, atom, binding_info, iter::once(col))
893    }
894
895    /// Runs the free join plan, starting with the header.
896    ///
897    /// A bit about the `instr_order` parameter: This defines the order in which the [`JoinStage`]
898    /// instructions will run. We want to support cached [`SinglePlan`]s that may be based on stale
899    /// ordering information. `instr_order` allows us to specify a new ordering of the instructions
900    /// without mutating the plan itself: `run_plan` simply executes
901    /// `plan.stages.instrs[instr_order[i]]` at stage `i`.
902    ///
903    /// This is also a stepping stone towards supporting fully dynamic variable ordering.
904    fn run_join_stages<'buf, A: NumericId + 'buf, BUF: ActionBuffer<'buf, A>>(
905        &self,
906        stages: &'buf JoinStages,
907        atoms: &'buf Arc<DenseIdMap<AtomId, Atom>>,
908        action: A,
909        binding_info: &mut BindingInfo,
910        action_buf: &mut BUF,
911    ) where
912        'a: 'buf,
913    {
914        if log::log_enabled!(log::Level::Trace) {
915            log::trace!("Starting running query stages:\n{stages:#?}");
916        }
917        for (_, node) in binding_info.subsets.iter() {
918            if node.subset.is_empty() {
919                return;
920            }
921        }
922        let mut order = InstrOrder::from_iter(0..stages.instrs.len());
923        let mut leaf_scans: LeafScans = smallvec::smallvec![false; stages.instrs.len()];
924        sort_plan_by_size(&mut order, &mut leaf_scans, 0, &stages.instrs, binding_info);
925        self.run_plan(
926            stages,
927            atoms,
928            action,
929            &mut order,
930            &mut leaf_scans,
931            0,
932            binding_info,
933            action_buf,
934        );
935    }
936
937    /// The core method for executing a free join plan.
938    ///
939    /// This method takes the plan, mutable data-structures for variable binding and staging
940    /// actions, and two indexes: `cur` which is the current stage of the plan to run, and `level`
941    /// which is the current "fan-out" node we are in. The latter parameter is an experimental
942    /// index used to detect if we are at the "top" of a plan rather than the "bottom", and is
943    /// currently used as a heuristic to determine if we should increase parallelism more than the
944    /// default.
945    #[allow(clippy::too_many_arguments)]
946    fn run_plan<'buf, A: NumericId + 'buf, BUF: ActionBuffer<'buf, A>>(
947        &self,
948        stages: &'buf JoinStages,
949        atoms: &'buf Arc<DenseIdMap<AtomId, Atom>>,
950        action: A,
951        instr_order: &mut InstrOrder,
952        leaf_scans: &mut LeafScans,
953        cur: usize,
954        binding_info: &mut BindingInfo,
955        action_buf: &mut BUF,
956    ) where
957        'a: 'buf,
958    {
959        if self.exec_state.should_stop() {
960            return;
961        }
962
963        if cur >= instr_order.len() {
964            action_buf.push_bindings_factorized(
965                action,
966                &mut binding_info.bindings,
967                &binding_info.binding_sets,
968                &self.exec_state,
969            );
970            return;
971        }
972        let chunk_size = action_buf.morsel_size(cur, instr_order.len());
973        let mut cur_size = estimate_size(&stages.instrs[instr_order.get(cur)], binding_info);
974        if cur_size > 32 && cur % 3 == 1 && cur < instr_order.len() - 1 {
975            // If we have a reasonable number of tuples to process, adjust the variable order every
976            // 3 rounds, but always make sure to readjust on the second roung.
977            sort_plan_by_size(instr_order, leaf_scans, cur, &stages.instrs, binding_info);
978            cur_size = estimate_size(&stages.instrs[instr_order.get(cur)], binding_info);
979        }
980
981        // Helper macro (not its own method to appease the borrow checker).
982        macro_rules! drain_updates {
983            ($updates:expr) => {
984                if self.exec_state.should_stop() {
985                    return;
986                }
987                // TODO: `supports_parallel_drain`` is a hack because currently
988                // `drain_updates_parallel!`` is a bit slower because of the additional ExecutionState clone.
989                if (cur == 0 || cur == 1) && action_buf.supports_parallel_drain() {
990                    drain_updates_parallel!($updates)
991                } else {
992                    $updates.drain(|update| match update {
993                        UpdateInstr::PushBinding(var, val) => {
994                            binding_info.bindings.insert(var, val);
995                        }
996                        UpdateInstr::RefineAtom(atom, subset) => {
997                            binding_info.insert_node(atom, subset);
998                        }
999                        UpdateInstr::RefineAtomDense(atom, range) => {
1000                            binding_info.insert_subset(atom, Subset::Dense(range));
1001                        }
1002                        UpdateInstr::EndFrame => {
1003                            // Inline leaf-level: if cur+1 is the leaf (no more
1004                            // join stages), call push_bindings directly without
1005                            // a recursive run_plan call, avoiding function call
1006                            // overhead + an extra should_stop() check.
1007                            if cur + 1 >= instr_order.len() {
1008                                action_buf.push_bindings_factorized(
1009                                    action,
1010                                    &mut binding_info.bindings,
1011                                    &binding_info.binding_sets,
1012                                    &self.exec_state,
1013                                );
1014                            } else {
1015                                self.run_plan(
1016                                    stages,
1017                                    atoms,
1018                                    action,
1019                                    instr_order,
1020                                    leaf_scans,
1021                                    cur + 1,
1022                                    binding_info,
1023                                    action_buf,
1024                                );
1025                            }
1026                        }
1027                    })
1028                }
1029            };
1030        }
1031        macro_rules! drain_updates_parallel {
1032            ($updates:expr) => {{
1033                if self.exec_state.should_stop() {
1034                    return;
1035                }
1036                let db = self.db;
1037                let exec_state_for_factory = self.exec_state.clone();
1038                let exec_state_for_work = self.exec_state.clone();
1039                action_buf.recur(
1040                    BorrowedLocalState {
1041                        binding_info,
1042                        instr_order,
1043                        leaf_scans,
1044                        updates: &mut $updates,
1045                    },
1046                    move || exec_state_for_factory.clone(),
1047                    move |BorrowedLocalState {
1048                              binding_info,
1049                              instr_order,
1050                              leaf_scans,
1051                              updates,
1052                          },
1053                          buf| {
1054                        updates.drain(|update| match update {
1055                            UpdateInstr::PushBinding(var, val) => {
1056                                binding_info.bindings.insert(var, val);
1057                            }
1058                            UpdateInstr::RefineAtom(atom, subset) => {
1059                                binding_info.insert_node(atom, subset);
1060                            }
1061                            UpdateInstr::RefineAtomDense(atom, range) => {
1062                                binding_info.insert_subset(atom, Subset::Dense(range));
1063                            }
1064                            UpdateInstr::EndFrame => {
1065                                JoinState {
1066                                    db,
1067                                    exec_state: exec_state_for_work.clone(),
1068                                    // Each rayon task uses its own thread-local pool.
1069                                    // This makes drain_updates_parallel slightly more expensive
1070                                    // than drain_updates eevn when both are run in single thread
1071                                    pool: with_pool_set(|ps| ps.get_pool()),
1072                                }
1073                                .run_plan(
1074                                    stages,
1075                                    atoms,
1076                                    action,
1077                                    instr_order,
1078                                    leaf_scans,
1079                                    cur + 1,
1080                                    binding_info,
1081                                    buf,
1082                                );
1083                            }
1084                        })
1085                    },
1086                );
1087                $updates.clear();
1088            }};
1089        }
1090
1091        fn refine_subset(
1092            sub: PotentiallyStale<Subset>,
1093            constraints: &[Constraint],
1094            table: &WrappedTableRef,
1095            has_stale: bool,
1096        ) -> Subset {
1097            let sub = if sub.can_be_stale && has_stale {
1098                table.refine_live(sub.inner)
1099            } else {
1100                sub.inner
1101            };
1102            if constraints.is_empty() {
1103                sub
1104            } else {
1105                table.refine(sub, constraints)
1106            }
1107        }
1108
1109        let pool = &self.pool;
1110        match &stages.instrs[instr_order.get(cur)] {
1111            JoinStage::Intersect { var, scans } => match scans.as_slice() {
1112                [] => {}
1113                [a] => {
1114                    if binding_info.has_empty_subset(a.atom) {
1115                        return;
1116                    }
1117                    let prober = self.get_column_index(atoms, binding_info, a.atom, a.column);
1118                    let info = &self.db.tables[atoms[a.atom].table];
1119                    let table = info.table.as_ref();
1120                    let has_stale = table.has_stale_rows();
1121                    let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
1122                    prober.for_each(|val, x| {
1123                        updates.push_binding(*var, val[0]);
1124                        if x.size() <= 16 {
1125                            let sub = refine_subset(x.to_owned(pool), &a.cs, &table, has_stale);
1126                            if sub.is_empty() {
1127                                updates.rollback();
1128                                return;
1129                            }
1130                            updates.refine_atom_subset(a.atom, sub);
1131                        } else {
1132                            let node =
1133                                prober
1134                                    .node
1135                                    .get_cached_trie_node(a.column, val[0], info, || {
1136                                        refine_subset(x.to_owned(pool), &a.cs, &table, has_stale)
1137                                    });
1138                            if node.subset.is_empty() {
1139                                updates.rollback();
1140                                return;
1141                            }
1142                            updates.refine_atom(a.atom, node);
1143                        }
1144                        updates.finish_frame();
1145                        if updates.frames() >= chunk_size {
1146                            drain_updates!(updates);
1147                        }
1148                    });
1149                    drain_updates!(updates);
1150                    binding_info.move_back(a.atom, prober);
1151                }
1152                [a, b] => {
1153                    let a_prober = self.get_column_index(atoms, binding_info, a.atom, a.column);
1154                    let b_prober = self.get_column_index(atoms, binding_info, b.atom, b.column);
1155
1156                    let ((smaller, smaller_scan), (larger, larger_scan)) =
1157                        if a_prober.len() < b_prober.len() {
1158                            ((&a_prober, a), (&b_prober, b))
1159                        } else {
1160                            ((&b_prober, b), (&a_prober, a))
1161                        };
1162
1163                    let smaller_atom = smaller_scan.atom;
1164                    let larger_atom = larger_scan.atom;
1165                    let large_info = &self.db.tables[atoms[larger_atom].table];
1166                    let large_table = large_info.table.as_ref();
1167                    let large_has_stale = large_table.has_stale_rows();
1168                    let small_info = &self.db.tables[atoms[smaller_atom].table];
1169                    let small_table = small_info.table.as_ref();
1170                    let small_has_stale = small_table.has_stale_rows();
1171                    let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
1172                    smaller.for_each(|val, small_sub| {
1173                        if let Some(large_sub) = larger.get_subset(val) {
1174                            updates.push_binding(*var, val[0]);
1175                            if small_sub.size() <= 16 {
1176                                let small_sub = refine_subset(
1177                                    small_sub.to_owned(pool),
1178                                    &smaller_scan.cs,
1179                                    &small_table,
1180                                    small_has_stale,
1181                                );
1182                                if small_sub.is_empty() {
1183                                    updates.rollback();
1184                                    return;
1185                                }
1186                                updates.refine_atom_subset(smaller_atom, small_sub);
1187                            } else {
1188                                let smaller_node = smaller.node.get_cached_trie_node(
1189                                    smaller_scan.column,
1190                                    val[0],
1191                                    small_info,
1192                                    || {
1193                                        refine_subset(
1194                                            small_sub.to_owned(pool),
1195                                            &smaller_scan.cs,
1196                                            &small_table,
1197                                            small_has_stale,
1198                                        )
1199                                    },
1200                                );
1201                                if smaller_node.subset.is_empty() {
1202                                    updates.rollback();
1203                                    return;
1204                                }
1205                                updates.refine_atom(smaller_atom, smaller_node);
1206                            }
1207                            if large_sub.size() <= 16 {
1208                                let large_sub = refine_subset(
1209                                    large_sub.to_owned(pool),
1210                                    &larger_scan.cs,
1211                                    &large_table,
1212                                    large_has_stale,
1213                                );
1214                                if large_sub.is_empty() {
1215                                    updates.rollback();
1216                                    return;
1217                                }
1218                                updates.refine_atom_subset(larger_atom, large_sub);
1219                            } else {
1220                                let larger_node = larger.node.get_cached_trie_node(
1221                                    larger_scan.column,
1222                                    val[0],
1223                                    large_info,
1224                                    || {
1225                                        refine_subset(
1226                                            large_sub.to_owned(pool),
1227                                            &larger_scan.cs,
1228                                            &large_table,
1229                                            large_has_stale,
1230                                        )
1231                                    },
1232                                );
1233                                if larger_node.subset.is_empty() {
1234                                    updates.rollback();
1235                                    return;
1236                                }
1237                                updates.refine_atom(larger_atom, larger_node);
1238                            }
1239                            updates.finish_frame();
1240                            if updates.frames() >= chunk_size {
1241                                drain_updates!(updates);
1242                            }
1243                        }
1244                    });
1245                    drain_updates!(updates);
1246
1247                    binding_info.move_back(a.atom, a_prober);
1248                    binding_info.move_back(b.atom, b_prober);
1249                }
1250                rest => {
1251                    let mut smallest = 0;
1252                    let mut smallest_size = usize::MAX;
1253                    let mut probers = Vec::with_capacity(rest.len());
1254                    for (i, scan) in rest.iter().enumerate() {
1255                        let prober =
1256                            self.get_column_index(atoms, binding_info, scan.atom, scan.column);
1257                        let size = prober.len();
1258                        if size < smallest_size {
1259                            smallest = i;
1260                            smallest_size = size;
1261                        }
1262                        probers.push(prober);
1263                    }
1264
1265                    let main_spec = &rest[smallest];
1266                    let main_spec_info = &self.db.tables[atoms[main_spec.atom].table];
1267                    let main_spec_table = main_spec_info.table.as_ref();
1268                    let main_spec_has_stale = main_spec_table.has_stale_rows();
1269                    // Pre-compute has_stale for each scan to avoid vtable calls in the hot loop.
1270                    let rest_has_stale: SmallVec<[bool; 3]> = rest
1271                        .iter()
1272                        .map(|scan| {
1273                            self.db.tables[atoms[scan.atom].table]
1274                                .table
1275                                .as_ref()
1276                                .has_stale_rows()
1277                        })
1278                        .collect();
1279
1280                    if smallest_size != 0 {
1281                        // Smallest leads the scan
1282                        let mut updates =
1283                            FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
1284                        probers[smallest].for_each(|key, sub| {
1285                            updates.push_binding(*var, key[0]);
1286                            for (i, scan) in rest.iter().enumerate() {
1287                                if i == smallest {
1288                                    continue;
1289                                }
1290                                if let Some(sub) = probers[i].get_subset(key) {
1291                                    let table =
1292                                        self.db.tables[atoms[rest[i].atom].table].table.as_ref();
1293                                    if sub.size() <= 16 {
1294                                        let sub = refine_subset(
1295                                            sub.to_owned(pool),
1296                                            &rest[i].cs,
1297                                            &table,
1298                                            rest_has_stale[i],
1299                                        );
1300                                        if sub.is_empty() {
1301                                            updates.rollback();
1302                                            return;
1303                                        }
1304                                        updates.refine_atom_subset(scan.atom, sub);
1305                                    } else {
1306                                        let node = probers[i].node.get_cached_trie_node(
1307                                            scan.column,
1308                                            key[0],
1309                                            &self.db.tables[atoms[scan.atom].table],
1310                                            || {
1311                                                refine_subset(
1312                                                    sub.to_owned(pool),
1313                                                    &rest[i].cs,
1314                                                    &table,
1315                                                    rest_has_stale[i],
1316                                                )
1317                                            },
1318                                        );
1319                                        if node.subset.is_empty() {
1320                                            updates.rollback();
1321                                            return;
1322                                        }
1323                                        updates.refine_atom(scan.atom, node);
1324                                    }
1325                                } else {
1326                                    updates.rollback();
1327                                    // Empty intersection.
1328                                    return;
1329                                }
1330                            }
1331                            if sub.size() <= 16 {
1332                                let main_sub = refine_subset(
1333                                    sub.to_owned(pool),
1334                                    &main_spec.cs,
1335                                    &main_spec_table,
1336                                    main_spec_has_stale,
1337                                );
1338                                if main_sub.is_empty() {
1339                                    updates.rollback();
1340                                    return;
1341                                }
1342                                updates.refine_atom_subset(main_spec.atom, main_sub);
1343                            } else {
1344                                let main_node = probers[smallest].node.get_cached_trie_node(
1345                                    main_spec.column,
1346                                    key[0],
1347                                    main_spec_info,
1348                                    || {
1349                                        let sub = sub.to_owned(pool);
1350                                        refine_subset(
1351                                            sub,
1352                                            &main_spec.cs,
1353                                            &main_spec_table,
1354                                            main_spec_has_stale,
1355                                        )
1356                                    },
1357                                );
1358                                if main_node.subset.is_empty() {
1359                                    updates.rollback();
1360                                    return;
1361                                }
1362                                updates.refine_atom(main_spec.atom, main_node);
1363                            }
1364                            updates.finish_frame();
1365                            if updates.frames() >= chunk_size {
1366                                drain_updates!(updates);
1367                            }
1368                        });
1369                        drain_updates!(updates);
1370                    }
1371                    for (spec, prober) in rest.iter().zip(probers.into_iter()) {
1372                        binding_info.move_back(spec.atom, prober);
1373                    }
1374                }
1375            },
1376            JoinStage::FusedIntersect {
1377                cover,
1378                bind,
1379                to_intersect,
1380            } if to_intersect.is_empty() => {
1381                let is_leaf_scan = leaf_scans[cur];
1382                let cover_atom = cover.to_index.atom;
1383                if binding_info.has_empty_subset(cover_atom) {
1384                    return;
1385                }
1386                if is_leaf_scan {
1387                    let table = self.db.tables[atoms[cover_atom].table].table.as_ref();
1388                    let cover_node = binding_info.unwrap_val(cover_atom);
1389                    let cover_subset = cover_node.subset.as_ref();
1390
1391                    let proj =
1392                        SmallVec::<[ColumnId; 4]>::from_iter(bind.iter().map(|(col, _)| *col));
1393                    let vars = bind.iter().map(|(_, var)| *var).collect();
1394                    let mut buf = TaggedRowBuffer::new_inline(bind.len());
1395                    table.scan_project(
1396                        cover_subset,
1397                        &proj,
1398                        Offset::new(0),
1399                        usize::MAX,
1400                        &cover.constraints,
1401                        &mut buf,
1402                    );
1403
1404                    if buf.is_empty() {
1405                        binding_info.move_back_node(cover_atom, cover_node);
1406                        return;
1407                    }
1408
1409                    binding_info.binding_sets.push((vars, Arc::new(buf)));
1410                    let mut updates = FrameUpdates::with_capacity(1);
1411                    updates.finish_frame();
1412                    drain_updates!(updates);
1413                    binding_info.binding_sets.pop();
1414                    binding_info.move_back_node(cover_atom, cover_node);
1415                } else {
1416                    let proj =
1417                        SmallVec::<[ColumnId; 4]>::from_iter(bind.iter().map(|(col, _)| *col));
1418                    let cover_node = binding_info.unwrap_val(cover_atom);
1419                    let cover_subset = cover_node.subset.as_ref();
1420                    let mut offset = Offset::new(0);
1421                    let mut buffer = TaggedRowBuffer::new(bind.len());
1422                    let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
1423                    loop {
1424                        buffer.clear();
1425                        let table = &self.db.tables[atoms[cover_atom].table].table;
1426                        let next = table.scan_project(
1427                            cover_subset,
1428                            &proj,
1429                            offset,
1430                            chunk_size,
1431                            &cover.constraints,
1432                            &mut buffer,
1433                        );
1434                        for (row, key) in buffer.iter() {
1435                            updates.refine_atom_dense(cover_atom, OffsetRange::new(row, row.inc()));
1436                            // bind the values
1437                            for (i, (_, var)) in bind.iter().enumerate() {
1438                                updates.push_binding(*var, key[i]);
1439                            }
1440                            updates.finish_frame();
1441                            if updates.frames() >= chunk_size {
1442                                drain_updates!(updates);
1443                            }
1444                        }
1445                        if let Some(next) = next {
1446                            offset = next;
1447                            continue;
1448                        }
1449                        break;
1450                    }
1451                    drain_updates!(updates);
1452                    // Restore the subsets we swapped out.
1453                    binding_info.move_back_node(cover_atom, cover_node);
1454                }
1455            }
1456            JoinStage::FusedIntersect {
1457                cover,
1458                bind,
1459                to_intersect,
1460            } => {
1461                let cover_atom = cover.to_index.atom;
1462                if binding_info.has_empty_subset(cover_atom) {
1463                    return;
1464                }
1465                let index_probers = to_intersect
1466                    .iter()
1467                    .enumerate()
1468                    .map(|(i, (spec, _))| {
1469                        (
1470                            i,
1471                            spec.to_index.atom,
1472                            self.get_index(
1473                                atoms,
1474                                spec.to_index.atom,
1475                                binding_info,
1476                                spec.to_index.vars.iter().copied(),
1477                            ),
1478                        )
1479                    })
1480                    .collect::<SmallVec<[(usize, AtomId, Prober); 4]>>();
1481                // Pre-compute has_stale per prober to avoid vtable calls in the hot loop.
1482                let index_has_stale: SmallVec<[bool; 4]> = index_probers
1483                    .iter()
1484                    .map(|(_, atom, _)| {
1485                        self.db.tables[atoms[*atom].table]
1486                            .table
1487                            .as_ref()
1488                            .has_stale_rows()
1489                    })
1490                    .collect();
1491                let proj = SmallVec::<[ColumnId; 4]>::from_iter(bind.iter().map(|(col, _)| *col));
1492                let cover_node = binding_info.unwrap_val(cover_atom);
1493                let cover_subset = cover_node.subset.as_ref();
1494                let mut cur = Offset::new(0);
1495                let mut buffer = TaggedRowBuffer::new(bind.len());
1496                let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
1497                loop {
1498                    buffer.clear();
1499                    let table = &self.db.tables[atoms[cover_atom].table].table;
1500                    let next = table.scan_project(
1501                        cover_subset,
1502                        &proj,
1503                        cur,
1504                        chunk_size,
1505                        &cover.constraints,
1506                        &mut buffer,
1507                    );
1508                    'mid: for (row, key) in buffer.iter() {
1509                        updates.refine_atom_dense(cover_atom, OffsetRange::new(row, row.inc()));
1510                        // bind the values
1511                        for (i, (_, var)) in bind.iter().enumerate() {
1512                            updates.push_binding(*var, key[i]);
1513                        }
1514                        // now probe each remaining indexes
1515                        for (prober_idx, (i, atom, prober)) in index_probers.iter().enumerate() {
1516                            // create a key: to_intersect indexes into the key from the cover
1517                            let index_cols = &to_intersect[*i].1;
1518                            // Fast path for the common single-column case: avoid SmallVec collect.
1519                            let index_key_buf: SmallVec<[Value; 4]>;
1520                            let index_key: &[Value] = if let [col] = index_cols.as_slice() {
1521                                std::slice::from_ref(&key[col.index()])
1522                            } else {
1523                                index_key_buf =
1524                                    index_cols.iter().map(|col| key[col.index()]).collect();
1525                                &index_key_buf
1526                            };
1527                            let Some(subset) = prober.get_subset(index_key) else {
1528                                updates.rollback();
1529                                // There are no possible values for this subset
1530                                continue 'mid;
1531                            };
1532                            // apply any constraints needed in this scan.
1533                            let table_info = &self.db.tables[atoms[*atom].table];
1534                            let cs = &to_intersect[*i].0.constraints;
1535                            let subset = refine_subset(
1536                                subset.to_owned(pool),
1537                                cs,
1538                                &table_info.table.as_ref(),
1539                                index_has_stale[prober_idx],
1540                            );
1541                            if subset.is_empty() {
1542                                updates.rollback();
1543                                // There are no possible values for this subset
1544                                continue 'mid;
1545                            }
1546                            updates.refine_atom_subset(*atom, subset);
1547                        }
1548                        updates.finish_frame();
1549                        if updates.frames() >= chunk_size {
1550                            drain_updates!(updates);
1551                        }
1552                    }
1553                    if let Some(next) = next {
1554                        cur = next;
1555                        continue;
1556                    }
1557                    break;
1558                }
1559                // TODO: special-case the scenario when the cover doesn't need
1560                // deduping (and hence we can do a straight scan: e.g. when the
1561                // cover is binding a superset of the primary key for the
1562                // table).
1563                drain_updates!(updates);
1564                // Restore the subsets we swapped out.
1565                binding_info.move_back_node(cover_atom, cover_node);
1566                for (_, atom, prober) in index_probers {
1567                    binding_info.move_back(atom, prober);
1568                }
1569            }
1570            JoinStage::FusedIntersectMat {
1571                cover,
1572                mode,
1573                bind,
1574                to_intersect,
1575            } if leaf_scans[cur]
1576                && to_intersect.is_empty()
1577                && matches!(
1578                    mode,
1579                    MatScanMode::Full | MatScanMode::KeyOnly | MatScanMode::Value(_)
1580                ) =>
1581            {
1582                // Leaf-scan factorization for FusedIntersectMat: flatten the materialization into
1583                // one `TaggedRowBuffer`, push it onto `binding_sets`, and recurse to the leaf once.
1584                let cover_mat = binding_info.materializations[*cover].clone();
1585                let vars: SmallVec<[Variable; 4]> = bind.iter().map(|(_, v)| *v).collect();
1586                let mut buf = TaggedRowBuffer::new_inline(bind.len());
1587                let mut row_scratch: SmallVec<[Value; 8]> = SmallVec::new();
1588                match mode {
1589                    MatScanMode::Full => {
1590                        for group in cover_mat.iter() {
1591                            let group_key = group.0;
1592                            let group_key_len = group_key.len();
1593                            for non_keys in group.1.iter() {
1594                                row_scratch.clear();
1595                                for (col, _) in bind.iter() {
1596                                    let val = if col.index() < group_key_len {
1597                                        group_key[col.index()]
1598                                    } else {
1599                                        non_keys[col.index() - group_key_len]
1600                                    };
1601                                    row_scratch.push(val);
1602                                }
1603                                buf.add_row(RowId::new(0), &row_scratch);
1604                            }
1605                        }
1606                    }
1607                    MatScanMode::KeyOnly => {
1608                        for group in cover_mat.iter() {
1609                            let group_key = group.0;
1610                            row_scratch.clear();
1611                            for (col, _) in bind.iter() {
1612                                debug_assert!(col.index() < group_key.len());
1613                                row_scratch.push(group_key[col.index()]);
1614                            }
1615                            buf.add_row(RowId::new(0), &row_scratch);
1616                        }
1617                    }
1618                    MatScanMode::Value(index_vars) => {
1619                        let keys: Vec<Value> = index_vars
1620                            .iter()
1621                            .map(|var| binding_info.bindings[*var])
1622                            .collect();
1623                        if let Some(group) = cover_mat.get(&keys) {
1624                            for vals in group.iter() {
1625                                debug_assert!(vals.len() == bind.len());
1626                                row_scratch.clear();
1627                                for (col, _) in bind.iter() {
1628                                    row_scratch.push(vals[col.index()]);
1629                                }
1630                                buf.add_row(RowId::new(0), &row_scratch);
1631                            }
1632                        }
1633                    }
1634                    MatScanMode::Lookup(_) => unreachable!("guarded above"),
1635                }
1636                if buf.is_empty() {
1637                    return;
1638                }
1639                binding_info.binding_sets.push((vars, Arc::new(buf)));
1640                let mut updates = FrameUpdates::with_capacity(1);
1641                updates.finish_frame();
1642                drain_updates!(updates);
1643                binding_info.binding_sets.pop();
1644            }
1645            JoinStage::FusedIntersectMat {
1646                cover,
1647                mode,
1648                bind,
1649                to_intersect,
1650            } => {
1651                let cover_mat = binding_info.materializations[*cover].clone();
1652                let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
1653                let probers = to_intersect
1654                    .iter()
1655                    .map(|(spec, _)| {
1656                        self.get_index(
1657                            atoms,
1658                            spec.to_index.atom,
1659                            binding_info,
1660                            spec.to_index.vars.iter().copied(),
1661                        )
1662                    })
1663                    .collect::<SmallVec<[Prober; 4]>>();
1664                // Pre-compute has_stale per prober to avoid vtable calls in the hot loop.
1665                let probers_has_stale: SmallVec<[bool; 4]> = to_intersect
1666                    .iter()
1667                    .map(|(spec, _)| {
1668                        self.db.tables[atoms[spec.to_index.atom].table]
1669                            .table
1670                            .as_ref()
1671                            .has_stale_rows()
1672                    })
1673                    .collect();
1674
1675                let mut key = Vec::with_capacity(4);
1676                let mut prune_probers = |updates: &mut FrameUpdates,
1677                                         mat_key: Option<&[Value]>,
1678                                         mat_non_key: Option<&[Value]>|
1679                 -> bool {
1680                    for (j, ((spec, cols), prober)) in
1681                        to_intersect.iter().zip(probers.iter()).enumerate()
1682                    {
1683                        key.clear();
1684                        for col in cols.iter() {
1685                            let val = match mat_key {
1686                                Some(mat_key) => {
1687                                    if col.index() < mat_key.len() {
1688                                        mat_key[col.index()]
1689                                    } else {
1690                                        mat_non_key.unwrap()[col.index() - mat_key.len()]
1691                                    }
1692                                }
1693                                None => mat_non_key.unwrap()[col.index()],
1694                            };
1695                            key.push(val);
1696                        }
1697                        if let Some(subset) = prober.get_subset(&key) {
1698                            let subset = refine_subset(
1699                                subset.to_owned(pool),
1700                                &spec.constraints,
1701                                &self.db.tables[atoms[spec.to_index.atom].table]
1702                                    .table
1703                                    .as_ref(),
1704                                probers_has_stale[j],
1705                            );
1706                            if subset.is_empty() {
1707                                return false;
1708                            }
1709                            updates.refine_atom_subset(spec.to_index.atom, subset);
1710                        } else {
1711                            return false;
1712                        }
1713                    }
1714                    true
1715                };
1716
1717                match mode {
1718                    MatScanMode::Full | MatScanMode::KeyOnly => {
1719                        // enumerate keys
1720                        for group in cover_mat.iter() {
1721                            let group_key = group.0;
1722                            let group_val = group.1;
1723                            let group_key_len = group_key.len();
1724                            if mode == &MatScanMode::Full {
1725                                // enumerate non-keys
1726                                for non_keys in group_val.iter() {
1727                                    for (col, var) in bind.iter() {
1728                                        if col.index() < group_key_len {
1729                                            updates.push_binding(*var, group_key[col.index()]);
1730                                        }
1731                                    }
1732
1733                                    // TODO: optimization that guaratees all keys come before non-keys
1734                                    for (col, var) in bind.iter() {
1735                                        if col.index() >= group_key_len {
1736                                            updates.push_binding(
1737                                                *var,
1738                                                non_keys[col.index() - group_key_len],
1739                                            );
1740                                        }
1741                                    }
1742                                    if prune_probers(&mut updates, Some(group_key), Some(non_keys))
1743                                    {
1744                                        updates.finish_frame();
1745                                    } else {
1746                                        updates.rollback();
1747                                    }
1748                                }
1749                            } else if mode == &MatScanMode::KeyOnly {
1750                                for (col, var) in bind.iter() {
1751                                    debug_assert!(col.index() < group_key_len);
1752                                    updates.push_binding(*var, group_key[col.index()]);
1753                                }
1754                                if prune_probers(&mut updates, Some(group_key), None) {
1755                                    updates.finish_frame();
1756                                } else {
1757                                    updates.rollback();
1758                                }
1759                            }
1760                        }
1761                    }
1762                    MatScanMode::Value(index_vars) | MatScanMode::Lookup(index_vars) => {
1763                        let keys = index_vars
1764                            .iter()
1765                            .map(|var| binding_info.bindings[*var])
1766                            .collect::<Vec<Value>>();
1767                        // lookup keys
1768                        if let Some(group) = cover_mat.get(&keys) {
1769                            if matches!(mode, MatScanMode::Lookup(_)) {
1770                                debug_assert_eq!(to_intersect.len(), 0);
1771                                debug_assert_eq!(bind.len(), 0);
1772                                if group.len() > 0 {
1773                                    updates.finish_frame();
1774                                }
1775                                drain_updates!(updates);
1776                            } else {
1777                                // enumerate non-keys
1778                                // for vals in group.value().iter() {
1779                                for vals in group.iter() {
1780                                    debug_assert!(vals.len() == bind.len()); // TODO: not true for non-full query
1781                                    for (col, var) in bind.iter() {
1782                                        updates.push_binding(*var, vals[col.index()]);
1783                                    }
1784                                    if prune_probers(&mut updates, None, Some(vals)) {
1785                                        updates.finish_frame();
1786                                    } else {
1787                                        updates.rollback();
1788                                    }
1789                                    if updates.frames() >= chunk_size {
1790                                        drain_updates!(updates);
1791                                    }
1792                                }
1793                            }
1794                        }
1795                    }
1796                }
1797
1798                drain_updates!(updates);
1799                for (spec, prober) in to_intersect.iter().zip(probers) {
1800                    binding_info.move_back(spec.0.to_index.atom, prober);
1801                }
1802            }
1803        }
1804    }
1805}
1806
1807const VAR_BATCH_SIZE: usize = 128;
1808
1809/// A trait used to abstract over different ways of buffering actions together
1810/// before running them.
1811///
1812/// This trait exists as a fairly ad-hoc wrapper over its two implementations.
1813/// It allows us to avoid duplicating the (somewhat monstrous) `run_plan` method
1814/// for serial and parallel modes.
1815trait ActionBuffer<'state, A: NumericId>: Send {
1816    type AsLocal<'a>: ActionBuffer<'state, A>
1817    where
1818        'state: 'a;
1819
1820    /// Expand the binding sets to individual bindings and
1821    /// call push_bindings
1822    fn push_bindings_factorized(
1823        &mut self,
1824        action: A,
1825        bindings: &mut DenseIdMap<Variable, Value>,
1826        binding_sets: &BindingSet,
1827        exec_state: &ExecutionState<'state>,
1828    ) {
1829        expand_binding_sets(self, action, bindings, binding_sets, 0, exec_state);
1830    }
1831
1832    /// Push the given bindings to be executed for the specified action. If this
1833    /// buffer has built up a sufficient batch size, it may execute
1834    /// `to_exec_state` and then execute the action.
1835    ///
1836    /// NB: `push_bindings` makes module-specific assumptions on what values are passed to
1837    /// `bindings` for a common `action`. This is not a general-purpose trait for that reason and
1838    /// it should not, in general, be used outside of this module.
1839    fn push_bindings(
1840        &mut self,
1841        action: A,
1842        bindings: &DenseIdMap<Variable, Value>,
1843        to_exec_state: impl FnMut() -> ExecutionState<'state>,
1844    );
1845
1846    /// Execute any remaining actions associated with this buffer.
1847    fn flush(&mut self, exec_state: &mut ExecutionState);
1848
1849    /// Execute `work`, potentially asynchronously, with a mutable reference to
1850    /// an action buffer, potentially handed off to a different thread.
1851    ///
1852    /// Callers [`BorrowedLocalState`] values that may be modified by work, or
1853    /// cloned first and then have a separate copy modified by `work`. Callers
1854    /// should assume that `local` _is_ modified synchronously.
1855    // NB: Earlier versions of this method had BorrowedLocalState be a generic instead, but this
1856    // ran into difficulties when we needed to pass multiple mutable references.
1857    fn recur<'local>(
1858        &mut self,
1859        local: BorrowedLocalState<'local>,
1860        to_exec_state: impl FnMut() -> ExecutionState<'state> + Send + 'state,
1861        work: impl for<'a> FnOnce(BorrowedLocalState<'a>, &mut Self::AsLocal<'a>) + Send + 'state,
1862    );
1863
1864    /// The unit at which you should batch updates passed to calls to `recur`,
1865    /// potentially depending on the current level of recursion.
1866    ///
1867    /// As of right now this is just a hard-coded value. We may change it in the
1868    /// future to fan out more at higher levels though.
1869    fn morsel_size(&mut self, _level: usize, _total: usize) -> usize {
1870        256
1871    }
1872
1873    /// Whether this buffer supports parallel drain operations.
1874    ///
1875    /// When `false`, `drain_updates` will use the serial path even at `cur <= 1`,
1876    /// avoiding the per-frame `ExecutionState::clone()` overhead.
1877    fn supports_parallel_drain(&self) -> bool {
1878        true
1879    }
1880}
1881
1882/// The action buffer we use if we are executing in a single-threaded
1883/// environment. It builds up local batches and then flushes them inline.
1884struct InPlaceActionBuffer<'a> {
1885    rule_set: &'a RuleSet,
1886    match_counter: &'a MatchCounter,
1887    batches: DenseIdMap<ActionId, ActionState>,
1888}
1889
1890impl<'a, 'outer: 'a> ActionBuffer<'a, ActionId> for InPlaceActionBuffer<'outer> {
1891    type AsLocal<'b>
1892        = Self
1893    where
1894        'a: 'b;
1895
1896    fn push_bindings(
1897        &mut self,
1898        action: ActionId,
1899        bindings: &DenseIdMap<Variable, Value>,
1900        mut to_exec_state: impl FnMut() -> ExecutionState<'a>,
1901    ) {
1902        let action_state = self.batches.get_or_default(action);
1903        action_state.n_runs += 1;
1904        action_state.len += 1;
1905        let action_info = &self.rule_set.actions[action];
1906        // SAFETY: `used_vars` is a constant per-rule. This module only ever calls it with
1907        // `bindings` produced by the same join.
1908        unsafe {
1909            action_state.bindings.push(bindings, &action_info.used_vars);
1910        }
1911        if action_state.len >= VAR_BATCH_SIZE {
1912            let mut state = to_exec_state();
1913            let succeeded = state.run_instrs(&action_info.instrs, &mut action_state.bindings);
1914            action_state.bindings.clear();
1915            self.match_counter.inc_matches(action, succeeded);
1916            action_state.len = 0;
1917        }
1918    }
1919
1920    fn flush(&mut self, exec_state: &mut ExecutionState) {
1921        flush_action_states(
1922            exec_state,
1923            &mut self.batches,
1924            self.rule_set,
1925            self.match_counter,
1926        );
1927    }
1928
1929    fn recur<'local>(
1930        &mut self,
1931        local: BorrowedLocalState<'local>,
1932        _to_exec_state: impl FnMut() -> ExecutionState<'a> + Send + 'a,
1933        work: impl for<'b> FnOnce(BorrowedLocalState<'b>, &mut Self) + Send + 'a,
1934    ) {
1935        work(local, self)
1936    }
1937
1938    fn supports_parallel_drain(&self) -> bool {
1939        false
1940    }
1941}
1942
1943/// An Action buffer that hands off batches to of actions to rayon to execute.
1944struct ScopedActionBuffer<'inner, 'scope> {
1945    scope: &'inner rayon::Scope<'scope>,
1946    rule_set: &'scope RuleSet,
1947    match_counter: Arc<MatchCounter>,
1948    batches: DenseIdMap<ActionId, ActionState>,
1949    needs_flush: bool,
1950}
1951
1952impl<'inner, 'scope> ScopedActionBuffer<'inner, 'scope> {
1953    fn new(
1954        scope: &'inner rayon::Scope<'scope>,
1955        rule_set: &'scope RuleSet,
1956        match_counter: Arc<MatchCounter>,
1957    ) -> Self {
1958        Self {
1959            scope,
1960            rule_set,
1961            batches: Default::default(),
1962            match_counter,
1963            needs_flush: false,
1964        }
1965    }
1966}
1967
1968impl<'scope> ActionBuffer<'scope, ActionId> for ScopedActionBuffer<'_, 'scope> {
1969    type AsLocal<'a>
1970        = ScopedActionBuffer<'a, 'scope>
1971    where
1972        'scope: 'a;
1973    fn push_bindings(
1974        &mut self,
1975        action: ActionId,
1976        bindings: &DenseIdMap<Variable, Value>,
1977        mut to_exec_state: impl FnMut() -> ExecutionState<'scope>,
1978    ) {
1979        self.needs_flush = true;
1980        let action_state = self.batches.get_or_default(action);
1981        action_state.n_runs += 1;
1982        action_state.len += 1;
1983        let action_info = &self.rule_set.actions[action];
1984        // SAFETY: `used_vars` is a constant per-rule. This module only ever calls it with
1985        // `bindings` produced by the same join.
1986        unsafe {
1987            action_state.bindings.push(bindings, &action_info.used_vars);
1988        }
1989        if action_state.len >= VAR_BATCH_SIZE {
1990            let mut state = to_exec_state();
1991            let mut bindings =
1992                mem::replace(&mut action_state.bindings, Bindings::new(VAR_BATCH_SIZE));
1993            action_state.len = 0;
1994            let match_counter = self.match_counter.clone();
1995            self.scope.spawn(move |_| {
1996                let succeeded = state.run_instrs(&action_info.instrs, &mut bindings);
1997                match_counter.inc_matches(action, succeeded);
1998            });
1999        }
2000    }
2001
2002    fn flush(&mut self, exec_state: &mut ExecutionState) {
2003        flush_action_states(
2004            exec_state,
2005            &mut self.batches,
2006            self.rule_set,
2007            self.match_counter.as_ref(),
2008        );
2009        self.needs_flush = false;
2010    }
2011    fn recur<'local>(
2012        &mut self,
2013        mut local: BorrowedLocalState<'local>,
2014        mut to_exec_state: impl FnMut() -> ExecutionState<'scope> + Send + 'scope,
2015        work: impl for<'a> FnOnce(BorrowedLocalState<'a>, &mut ScopedActionBuffer<'a, 'scope>)
2016        + Send
2017        + 'scope,
2018    ) {
2019        let rule_set = self.rule_set;
2020        let match_counter = self.match_counter.clone();
2021        let mut inner = local.clone_state();
2022        self.scope.spawn(move |scope| {
2023            let mut buf: ScopedActionBuffer<'_, 'scope> = ScopedActionBuffer {
2024                scope,
2025                rule_set,
2026                match_counter,
2027                needs_flush: false,
2028                batches: Default::default(),
2029            };
2030            work(inner.borrow_mut(), &mut buf);
2031            if buf.needs_flush {
2032                flush_action_states(
2033                    &mut to_exec_state(),
2034                    &mut buf.batches,
2035                    buf.rule_set,
2036                    buf.match_counter.as_ref(),
2037                );
2038            }
2039        });
2040    }
2041
2042    fn morsel_size(&mut self, _level: usize, _total: usize) -> usize {
2043        // Lower morsel size to increase parallelism.
2044        match _level {
2045            0 if _total > 2 => 32,
2046            _ => 256,
2047        }
2048    }
2049}
2050
2051fn expand_binding_sets<'state, A: NumericId, BUF: ActionBuffer<'state, A> + ?Sized>(
2052    action_buf: &mut BUF,
2053    action: A,
2054    bindings: &mut DenseIdMap<Variable, Value>,
2055    binding_sets: &BindingSet,
2056    idx: usize,
2057    exec_state: &ExecutionState<'state>,
2058) {
2059    if exec_state.should_stop() {
2060        return;
2061    }
2062    if idx >= binding_sets.len() {
2063        action_buf.push_bindings(action, bindings, || exec_state.clone());
2064        return;
2065    }
2066    if idx + 1 == binding_sets.len() {
2067        let (vars, buf) = &binding_sets[idx];
2068        for (_, row) in buf.iter() {
2069            if exec_state.should_stop() {
2070                return;
2071            }
2072            for (var, val) in vars.iter().zip(row.iter()) {
2073                bindings.insert(*var, *val);
2074            }
2075            action_buf.push_bindings(action, bindings, || exec_state.clone());
2076        }
2077        return;
2078    }
2079    let (vars, buf) = &binding_sets[idx];
2080    for (_, row) in buf.iter() {
2081        for (var, val) in vars.iter().zip(row.iter()) {
2082            bindings.insert(*var, *val);
2083        }
2084        expand_binding_sets(
2085            action_buf,
2086            action,
2087            bindings,
2088            binding_sets,
2089            idx + 1,
2090            exec_state,
2091        );
2092    }
2093}
2094
2095fn flush_action_states(
2096    exec_state: &mut ExecutionState,
2097    actions: &mut DenseIdMap<ActionId, ActionState>,
2098    rule_set: &RuleSet,
2099    match_counter: &MatchCounter,
2100) {
2101    for (action, ActionState { bindings, len, .. }) in actions.iter_mut() {
2102        if *len > 0 {
2103            let succeeded = exec_state.run_instrs(&rule_set.actions[action].instrs, bindings);
2104            bindings.clear();
2105            match_counter.inc_matches(action, succeeded);
2106            *len = 0;
2107        }
2108    }
2109}
2110
2111struct InPlaceMaterializer<'a> {
2112    specs: &'a DenseIdMap<MatId, MatSpec>,
2113    materializations: DenseIdMap<MatId, IndexMap<Vec<Value>, RowBuffer>>,
2114    scratch_key: Vec<Value>,
2115    scratch_val: Vec<Value>,
2116}
2117
2118impl<'a> ActionBuffer<'a, MatId> for InPlaceMaterializer<'a> {
2119    type AsLocal<'b>
2120        = Self
2121    where
2122        'a: 'b;
2123
2124    fn push_bindings(
2125        &mut self,
2126        mat_id: MatId,
2127        bindings: &DenseIdMap<Variable, Value>,
2128        _to_exec_state: impl FnMut() -> ExecutionState<'a>,
2129    ) {
2130        let mat = self
2131            .materializations
2132            .get_mut(mat_id)
2133            .expect("invalid mat id");
2134        let spec = self.specs.get(mat_id).expect("invalid mat id");
2135        self.scratch_key.clear();
2136        for key in spec.msg_vars.iter().map(|var| bindings[*var]) {
2137            self.scratch_key.push(key);
2138        }
2139        self.scratch_val.clear();
2140        for val in spec.val_vars.iter().map(|var| bindings[*var]) {
2141            self.scratch_val.push(val);
2142        }
2143        if self.scratch_val.is_empty() {
2144            self.scratch_val.push(Value::stale());
2145        }
2146        if let Some(buffer) = mat.get_mut(&self.scratch_key) {
2147            buffer.add_row(&self.scratch_val);
2148        } else {
2149            let mut buffer = RowBuffer::new(usize::max(spec.val_vars.len(), 1));
2150            buffer.add_row(&self.scratch_val);
2151            mat.insert(self.scratch_key.clone(), buffer);
2152        }
2153    }
2154
2155    fn flush(&mut self, _exec_state: &mut ExecutionState) {
2156        // No-op for in-place materializer.
2157    }
2158
2159    fn recur<'local>(
2160        &mut self,
2161        local: BorrowedLocalState<'local>,
2162        _to_exec_state: impl FnMut() -> ExecutionState<'a> + Send + 'a,
2163        work: impl for<'b> FnOnce(BorrowedLocalState<'b>, &mut Self) + Send + 'a,
2164    ) {
2165        work(local, self)
2166    }
2167
2168    fn supports_parallel_drain(&self) -> bool {
2169        false
2170    }
2171}
2172
2173struct ScopedMaterializer<'inner, 'scope> {
2174    scope: &'inner rayon::Scope<'scope>,
2175    specs: Arc<DenseIdMap<MatId, MatSpec>>,
2176    materializations: Arc<DenseIdMap<MatId, Arc<DashMap<Vec<Value>, RowBuffer>>>>,
2177    scratch_key: Vec<Value>,
2178    scratch_val: Vec<Value>,
2179}
2180impl<'scope> ActionBuffer<'scope, MatId> for ScopedMaterializer<'_, 'scope> {
2181    type AsLocal<'a>
2182        = ScopedMaterializer<'a, 'scope>
2183    where
2184        'scope: 'a;
2185
2186    fn push_bindings(
2187        &mut self,
2188        mat_id: MatId,
2189        bindings: &DenseIdMap<Variable, Value>,
2190        _to_exec_state: impl FnMut() -> ExecutionState<'scope>,
2191    ) {
2192        let mat = self.materializations.get(mat_id).expect("invalid mat id");
2193        let spec = self.specs.get(mat_id).expect("invalid mat id");
2194        self.scratch_key.clear();
2195        for key in spec.msg_vars.iter().map(|var| bindings[*var]) {
2196            self.scratch_key.push(key);
2197        }
2198        self.scratch_val.clear();
2199        for val in spec.val_vars.iter().map(|var| bindings[*var]) {
2200            self.scratch_val.push(val);
2201        }
2202        if self.scratch_val.is_empty() {
2203            self.scratch_val.push(Value::stale());
2204        }
2205        let key = self.scratch_key.clone();
2206        match mat.entry(key) {
2207            Entry::Occupied(mut occ) => {
2208                occ.get_mut().add_row(&self.scratch_val);
2209            }
2210            Entry::Vacant(vac) => {
2211                let mut buffer = RowBuffer::new(usize::max(spec.val_vars.len(), 1));
2212                buffer.add_row(&self.scratch_val);
2213                vac.insert(buffer);
2214            }
2215        }
2216    }
2217
2218    fn flush(&mut self, _exec_state: &mut ExecutionState) {
2219        // No-op for scoped materializer since we always write to the materialization in-place.
2220    }
2221
2222    fn recur<'local>(
2223        &mut self,
2224        mut local: BorrowedLocalState<'local>,
2225        _to_exec_state: impl FnMut() -> ExecutionState<'scope> + Send + 'scope,
2226        work: impl for<'a> FnOnce(BorrowedLocalState<'a>, &mut ScopedMaterializer<'a, 'scope>)
2227        + Send
2228        + 'scope,
2229    ) {
2230        let scope = self.scope;
2231        let specs = self.specs.clone();
2232        let materializations = self.materializations.clone();
2233        let mut inner = local.clone_state();
2234        scope.spawn(move |scope| {
2235            let mut buf: ScopedMaterializer<'_, 'scope> = ScopedMaterializer {
2236                scope,
2237                specs,
2238                materializations: materializations.clone(),
2239                scratch_key: Vec::new(),
2240                scratch_val: Vec::new(),
2241            };
2242            work(inner.borrow_mut(), &mut buf);
2243        });
2244    }
2245}
2246
2247struct MatchCounter {
2248    matches: IdVec<ActionId, CachePadded<AtomicUsize>>,
2249}
2250
2251impl MatchCounter {
2252    fn new(n_ids: usize) -> Self {
2253        let mut matches = IdVec::with_capacity(n_ids);
2254        matches.resize_with(n_ids, || CachePadded::new(AtomicUsize::new(0)));
2255        Self { matches }
2256    }
2257
2258    fn inc_matches(&self, action: ActionId, by: usize) {
2259        self.matches[action].fetch_add(by, std::sync::atomic::Ordering::Relaxed);
2260    }
2261    fn read_matches(&self, action: ActionId) -> usize {
2262        self.matches[action].load(std::sync::atomic::Ordering::Acquire)
2263    }
2264}
2265
2266fn estimate_size(join_stage: &JoinStage, binding_info: &BindingInfo) -> usize {
2267    match join_stage {
2268        JoinStage::Intersect { scans, .. } => scans
2269            .iter()
2270            .map(|scan| binding_info.subsets[scan.atom].size())
2271            .min()
2272            .unwrap_or(0),
2273        JoinStage::FusedIntersect { cover, .. } => binding_info.subsets[cover.to_index.atom].size(),
2274        JoinStage::FusedIntersectMat { cover, .. } => binding_info.materializations[*cover].len(), // TODO: len() might be expensive.
2275    }
2276}
2277
2278fn num_intersected_rels(join_stage: &JoinStage) -> i32 {
2279    match join_stage {
2280        JoinStage::Intersect { scans, .. } => scans.len() as i32,
2281        JoinStage::FusedIntersect { to_intersect, .. } => to_intersect.len() as i32 + 1,
2282        JoinStage::FusedIntersectMat { to_intersect, .. } => to_intersect.len() as i32,
2283    }
2284}
2285
2286fn sort_plan_by_size(
2287    order: &mut InstrOrder,
2288    leaf_scans: &mut LeafScans,
2289    start: usize,
2290    instrs: &[JoinStage],
2291    binding_info: &mut BindingInfo,
2292) {
2293    let mut last_pos = start;
2294    for i in start..instrs.len() {
2295        if matches!(
2296            &instrs[i],
2297            // These nodes don't commute
2298            JoinStage::FusedIntersectMat {
2299                mode: MatScanMode::Lookup(_) | MatScanMode::Value(_) | MatScanMode::Full,
2300                ..
2301            }
2302        ) {
2303            sort_plan_by_size_inner(order, last_pos..i, instrs, binding_info);
2304            last_pos = i + 1;
2305        }
2306    }
2307    sort_plan_by_size_inner(order, last_pos..instrs.len(), instrs, binding_info);
2308    recompute_leaf_scans(order, leaf_scans, instrs, start);
2309}
2310
2311/// Recompute `leaf_scans[i]` for every position `i` in `[start, order.len())` against the
2312/// current order. A position is a leaf scan iff its stage is either a `FusedIntersect` or a
2313/// `FusedIntersectMat { mode: Full | KeyOnly | Value }`, both with empty `to_intersect`, AND no
2314/// later stage either (a) for `FusedIntersect`, references the same cover atom, or (b) reads
2315/// any of the bound variables as a scalar via `FusedIntersectMat { mode: Value | Lookup }`.
2316/// `FusedIntersectMat::Lookup` itself binds nothing, so it is never marked a leaf scan.
2317fn recompute_leaf_scans(
2318    order: &InstrOrder,
2319    leaf_scans: &mut LeafScans,
2320    instrs: &[JoinStage],
2321    start: usize,
2322) {
2323    for i in start..order.len() {
2324        let stage_idx = order.get(i);
2325        let (cover_atom, bind_vars) = match &instrs[stage_idx] {
2326            JoinStage::FusedIntersect {
2327                cover,
2328                bind,
2329                to_intersect,
2330            } if to_intersect.is_empty() => {
2331                let vars: SmallVec<[Variable; 4]> = bind.iter().map(|(_, v)| *v).collect();
2332                (Some(cover.to_index.atom), vars)
2333            }
2334            JoinStage::FusedIntersectMat {
2335                mode,
2336                bind,
2337                to_intersect,
2338                ..
2339            } if to_intersect.is_empty()
2340                && matches!(
2341                    mode,
2342                    MatScanMode::Full | MatScanMode::KeyOnly | MatScanMode::Value(_)
2343                ) =>
2344            {
2345                let vars: SmallVec<[Variable; 4]> = bind.iter().map(|(_, v)| *v).collect();
2346                (None, vars)
2347            }
2348            _ => {
2349                leaf_scans[i] = false;
2350                continue;
2351            }
2352        };
2353        let mut blocked = false;
2354        for j in (i + 1)..order.len() {
2355            match &instrs[order.get(j)] {
2356                JoinStage::Intersect { scans, .. } => {
2357                    if let Some(ca) = cover_atom
2358                        && scans.iter().any(|scan| scan.atom == ca)
2359                    {
2360                        blocked = true;
2361                        break;
2362                    }
2363                }
2364                JoinStage::FusedIntersect {
2365                    cover,
2366                    to_intersect,
2367                    ..
2368                } => {
2369                    if let Some(ca) = cover_atom
2370                        && (cover.to_index.atom == ca
2371                            || to_intersect.iter().any(|(s, _)| s.to_index.atom == ca))
2372                    {
2373                        blocked = true;
2374                        break;
2375                    }
2376                }
2377                JoinStage::FusedIntersectMat {
2378                    mode, to_intersect, ..
2379                } => {
2380                    if let Some(ca) = cover_atom
2381                        && to_intersect.iter().any(|(s, _)| s.to_index.atom == ca)
2382                    {
2383                        blocked = true;
2384                        break;
2385                    }
2386                    if let MatScanMode::Value(vars) | MatScanMode::Lookup(vars) = mode
2387                        && vars.iter().any(|v| bind_vars.contains(v))
2388                    {
2389                        blocked = true;
2390                        break;
2391                    }
2392                }
2393            }
2394        }
2395        leaf_scans[i] = !blocked;
2396    }
2397}
2398
2399fn sort_plan_by_size_inner(
2400    order: &mut InstrOrder,
2401    range: Range<usize>,
2402    instrs: &[JoinStage],
2403    binding_info: &mut BindingInfo,
2404) {
2405    // Nothing to sort if there's 0 or 1 element.
2406    if range.len() <= 1 {
2407        return;
2408    }
2409    // How many times an atom has been intersected/joined
2410    let mut times_refined = with_pool_set(|ps| ps.get::<DenseIdMap<AtomId, i64>>());
2411
2412    // Count how many times each atom has been refined so far.
2413    for ins in instrs[..range.start].iter() {
2414        match ins {
2415            JoinStage::Intersect { scans, .. } => scans.iter().for_each(|scan| {
2416                *times_refined.get_or_default(scan.atom) += 1;
2417            }),
2418            JoinStage::FusedIntersect {
2419                cover,
2420                to_intersect,
2421                ..
2422            } => {
2423                *times_refined.get_or_default(cover.to_index.atom) +=
2424                    cover.to_index.vars.len() as i64;
2425                to_intersect.iter().for_each(|(spec, _)| {
2426                    *times_refined.get_or_default(spec.to_index.atom) +=
2427                        spec.to_index.vars.len() as i64;
2428                });
2429            }
2430            JoinStage::FusedIntersectMat { to_intersect, .. } => {
2431                to_intersect.iter().for_each(|(spec, _)| {
2432                    *times_refined.get_or_default(spec.to_index.atom) +=
2433                        spec.to_index.vars.len() as i64;
2434                });
2435            }
2436        }
2437    }
2438
2439    // We prioritize variables by
2440    //
2441    //   (1) how many times an atom with this variable has been refined,
2442    //   (2) then by the cardinality of the variable to be enumerated (smaller → earlier)
2443    //   (3) then by how many relations join on this variable (more → earlier)
2444    //
2445    // Estimate size is second so that stages with very small cardinality (e.g. FunDep
2446    // consequents with exactly 1 value) are run before multi-relation stages that happen
2447    // to have a larger current estimate.
2448    let key_fn = |join_stage: &JoinStage,
2449                  binding_info: &BindingInfo,
2450                  times_refined: &DenseIdMap<AtomId, i64>| {
2451        let refine = match join_stage {
2452            JoinStage::Intersect { scans, .. } => scans
2453                .iter()
2454                .map(|scan| times_refined.get(scan.atom).copied().unwrap_or_default())
2455                .max()
2456                .unwrap(),
2457            JoinStage::FusedIntersect { cover, .. } => times_refined
2458                .get(cover.to_index.atom)
2459                .copied()
2460                .unwrap_or_default(),
2461            JoinStage::FusedIntersectMat { bind, .. } => bind.len() as _,
2462        };
2463        (
2464            -refine,
2465            estimate_size(join_stage, binding_info),
2466            -num_intersected_rels(join_stage),
2467        )
2468    };
2469
2470    for i in range.clone() {
2471        let mut key_i = key_fn(&instrs[order.get(i)], binding_info, &times_refined);
2472        for j in (i + 1)..range.end {
2473            let key_j = key_fn(&instrs[order.get(j)], binding_info, &times_refined);
2474            if key_j < key_i {
2475                order.data.swap(i, j);
2476                key_i = key_j;
2477            }
2478        }
2479        // Update the counts after a new instruction is selected.
2480        match &instrs[order.get(i)] {
2481            JoinStage::Intersect { scans, .. } => scans.iter().for_each(|scan| {
2482                *times_refined.get_or_default(scan.atom) += 1;
2483            }),
2484            JoinStage::FusedIntersect {
2485                cover,
2486                to_intersect,
2487                ..
2488            } => {
2489                *times_refined.get_or_default(cover.to_index.atom) +=
2490                    cover.to_index.vars.len() as i64;
2491
2492                to_intersect.iter().for_each(|(spec, _)| {
2493                    *times_refined.get_or_default(spec.to_index.atom) +=
2494                        spec.to_index.vars.len() as i64;
2495                });
2496            }
2497            JoinStage::FusedIntersectMat { to_intersect, .. } => {
2498                to_intersect.iter().for_each(|(spec, _)| {
2499                    *times_refined.get_or_default(spec.to_index.atom) +=
2500                        spec.to_index.vars.len() as i64;
2501                });
2502            }
2503        }
2504    }
2505}
2506
2507#[derive(Debug, Clone, PartialEq, Eq)]
2508struct InstrOrder {
2509    data: SmallVec<[u16; 8]>,
2510}
2511
2512impl InstrOrder {
2513    fn new() -> Self {
2514        InstrOrder {
2515            data: SmallVec::new(),
2516        }
2517    }
2518
2519    fn from_iter(range: impl Iterator<Item = usize>) -> InstrOrder {
2520        let mut res = InstrOrder::new();
2521        res.data
2522            .extend(range.map(|x| u16::try_from(x).expect("too many instructions")));
2523        res
2524    }
2525
2526    fn get(&self, idx: usize) -> usize {
2527        self.data[idx] as usize
2528    }
2529    fn len(&self) -> usize {
2530        self.data.len()
2531    }
2532}
2533
2534/// Per-position leaf-scan flags. `leaf_scans[i] == true` means the stage currently scheduled at
2535/// position `i` (i.e. `instrs[instr_order.get(i)]`) can take the factorized-binding fast path.
2536/// Recomputed by [`sort_plan_by_size`] whenever the order changes.
2537type LeafScans = SmallVec<[bool; 8]>;
2538
2539struct BorrowedLocalState<'a> {
2540    instr_order: &'a mut InstrOrder,
2541    leaf_scans: &'a mut LeafScans,
2542    binding_info: &'a mut BindingInfo,
2543    updates: &'a mut FrameUpdates,
2544}
2545
2546impl BorrowedLocalState<'_> {
2547    fn clone_state(&mut self) -> LocalState {
2548        LocalState {
2549            instr_order: self.instr_order.clone(),
2550            leaf_scans: self.leaf_scans.clone(),
2551            binding_info: self.binding_info.clone(),
2552            updates: std::mem::take(self.updates),
2553        }
2554    }
2555}
2556
2557struct LocalState {
2558    instr_order: InstrOrder,
2559    leaf_scans: LeafScans,
2560    binding_info: BindingInfo,
2561    updates: FrameUpdates,
2562}
2563
2564impl LocalState {
2565    fn borrow_mut<'a>(&'a mut self) -> BorrowedLocalState<'a> {
2566        BorrowedLocalState {
2567            instr_order: &mut self.instr_order,
2568            leaf_scans: &mut self.leaf_scans,
2569            binding_info: &mut self.binding_info,
2570            updates: &mut self.updates,
2571        }
2572    }
2573}