egglog_core_relations/free_join/
execute.rs

1//! Core free join execution.
2
3use std::{
4    cmp, iter, mem,
5    sync::{Arc, OnceLock, atomic::AtomicUsize},
6};
7
8use crate::{
9    common::HashMap,
10    numeric_id::{DenseIdMap, IdVec, NumericId},
11};
12use crossbeam::utils::CachePadded;
13use dashmap::mapref::one::RefMut;
14use egglog_reports::{ReportLevel, RuleReport, RuleSetReport};
15use smallvec::SmallVec;
16use web_time::Instant;
17
18use crate::{
19    Constraint, OffsetRange, Pool, SubsetRef,
20    action::{Bindings, ExecutionState},
21    common::{DashMap, Value},
22    free_join::{
23        frame_update::{FrameUpdates, UpdateInstr},
24        get_index_from_tableinfo,
25    },
26    hash_index::{ColumnIndex, IndexBase, TupleIndex},
27    offsets::{Offsets, SortedOffsetVector, Subset},
28    parallel_heuristics::parallelize_db_level_op,
29    pool::Pooled,
30    query::RuleSet,
31    row_buffer::TaggedRowBuffer,
32    table_spec::{ColumnId, Offset, WrappedTableRef},
33};
34
35use super::{
36    ActionId, AtomId, Database, HashColumnIndex, HashIndex, TableInfo, Variable,
37    get_column_index_from_tableinfo,
38    plan::{JoinHeader, JoinStage, Plan},
39    with_pool_set,
40};
41
42enum DynamicIndex {
43    Cached {
44        intersect_outer: bool,
45        table: HashIndex,
46    },
47    CachedColumn {
48        intersect_outer: bool,
49        table: HashColumnIndex,
50    },
51    Dynamic(TupleIndex),
52    DynamicColumn(Arc<ColumnIndex>),
53}
54
55struct Prober {
56    node: TrieNode,
57    pool: Pool<SortedOffsetVector>,
58    ix: DynamicIndex,
59}
60
61impl Prober {
62    fn get_subset(&self, key: &[Value]) -> Option<Subset> {
63        match &self.ix {
64            DynamicIndex::Cached {
65                intersect_outer,
66                table,
67            } => {
68                let mut sub = table.get().unwrap().get_subset(key)?.to_owned(&self.pool);
69                if *intersect_outer {
70                    sub.intersect(self.node.subset.as_ref(), &self.pool);
71                    if sub.is_empty() {
72                        return None;
73                    }
74                }
75                Some(sub)
76            }
77            DynamicIndex::CachedColumn {
78                intersect_outer,
79                table,
80            } => {
81                debug_assert_eq!(key.len(), 1);
82                let mut sub = table
83                    .get()
84                    .unwrap()
85                    .get_subset(&key[0])?
86                    .to_owned(&self.pool);
87                if *intersect_outer {
88                    sub.intersect(self.node.subset.as_ref(), &self.pool);
89                    if sub.is_empty() {
90                        return None;
91                    }
92                }
93                Some(sub)
94            }
95            DynamicIndex::Dynamic(tab) => tab.get_subset(key).map(|x| x.to_owned(&self.pool)),
96            DynamicIndex::DynamicColumn(tab) => {
97                tab.get_subset(&key[0]).map(|x| x.to_owned(&self.pool))
98            }
99        }
100    }
101    fn for_each(&self, mut f: impl FnMut(&[Value], SubsetRef)) {
102        match &self.ix {
103            DynamicIndex::Cached {
104                intersect_outer: true,
105                table,
106            } => table.get().unwrap().for_each(|k, v| {
107                let mut res = v.to_owned(&self.pool);
108                res.intersect(self.node.subset.as_ref(), &self.pool);
109                if !res.is_empty() {
110                    f(k, res.as_ref())
111                }
112            }),
113            DynamicIndex::Cached {
114                intersect_outer: false,
115                table,
116            } => table.get().unwrap().for_each(|k, v| f(k, v)),
117            DynamicIndex::CachedColumn {
118                intersect_outer: true,
119                table,
120            } => {
121                table.get().unwrap().for_each(|k, v| {
122                    let mut res = v.to_owned(&self.pool);
123                    res.intersect(self.node.subset.as_ref(), &self.pool);
124                    if !res.is_empty() {
125                        f(&[*k], res.as_ref())
126                    }
127                });
128            }
129            DynamicIndex::CachedColumn {
130                intersect_outer: false,
131                table,
132            } => {
133                table.get().unwrap().for_each(|k, v| f(&[*k], v));
134            }
135            DynamicIndex::Dynamic(tab) => {
136                tab.for_each(f);
137            }
138            DynamicIndex::DynamicColumn(tab) => tab.for_each(|k, v| {
139                f(&[*k], v);
140            }),
141        }
142    }
143
144    fn len(&self) -> usize {
145        match &self.ix {
146            DynamicIndex::Cached { table, .. } => table.get().unwrap().len(),
147            DynamicIndex::CachedColumn { table, .. } => table.get().unwrap().len(),
148            DynamicIndex::Dynamic(tab) => tab.len(),
149            DynamicIndex::DynamicColumn(tab) => tab.len(),
150        }
151    }
152}
153
154impl Database {
155    pub fn run_rule_set(&mut self, rule_set: &RuleSet, report_level: ReportLevel) -> RuleSetReport {
156        if rule_set.plans.is_empty() {
157            return RuleSetReport::default();
158        }
159        let match_counter = MatchCounter::new(rule_set.actions.n_ids());
160
161        let search_and_apply_timer = Instant::now();
162        // let mut rule_reports: HashMap<String, Vec<RuleReport>>;
163        let mut rule_reports: HashMap<Arc<str>, Vec<RuleReport>>;
164        let exec_state = ExecutionState::new(self.read_only_view(), Default::default());
165        if parallelize_db_level_op(self.total_size_estimate) {
166            // let dash_rule_reports: DashMap<String, Vec<RuleReport>> = DashMap::default();
167            let dash_rule_reports: DashMap<Arc<str>, Vec<RuleReport>> = DashMap::default();
168            rayon::in_place_scope(|scope| {
169                for (plan, desc, symbol_map, _action) in rule_set.plans.values() {
170                    // TODO: add stats
171                    let report_plan = match report_level {
172                        ReportLevel::TimeOnly => None,
173                        ReportLevel::WithPlan | ReportLevel::StageInfo => {
174                            Some(plan.to_report(symbol_map))
175                        }
176                    };
177                    scope.spawn(|scope| {
178                        let join_state = JoinState::new(self, exec_state.clone());
179                        let mut action_buf =
180                            ScopedActionBuffer::new(scope, rule_set, &match_counter);
181                        let mut binding_info = BindingInfo::default();
182                        for (id, info) in plan.atoms.iter() {
183                            let table = join_state.db.get_table(info.table);
184                            binding_info.insert_subset(id, table.all());
185                        }
186
187                        let search_and_apply_timer = Instant::now();
188                        join_state.run_header_and_plan(plan, &mut binding_info, &mut action_buf);
189                        let search_and_apply_time = search_and_apply_timer.elapsed();
190
191                        if action_buf.needs_flush {
192                            action_buf.flush(&mut exec_state.clone());
193                        }
194                        let mut rule_report: RefMut<'_, Arc<str>, Vec<RuleReport>> =
195                            dash_rule_reports.entry(desc.clone()).or_default();
196                        rule_report.value_mut().push(RuleReport {
197                            plan: report_plan,
198                            search_and_apply_time,
199                            num_matches: usize::MAX,
200                        });
201                    });
202                }
203            });
204            rule_reports = dash_rule_reports.into_iter().collect();
205        } else {
206            rule_reports = HashMap::default();
207            let join_state = JoinState::new(self, exec_state.clone());
208            // Just run all of the plans in order with a single in-place action
209            // buffer.
210            let mut action_buf = InPlaceActionBuffer {
211                rule_set,
212                match_counter: &match_counter,
213                batches: Default::default(),
214            };
215            for (plan, desc, symbol_map, _action) in rule_set.plans.values() {
216                let report_plan = match report_level {
217                    ReportLevel::TimeOnly => None,
218                    ReportLevel::WithPlan | ReportLevel::StageInfo => {
219                        Some(plan.to_report(symbol_map))
220                    }
221                };
222                let mut binding_info = BindingInfo::default();
223                for (id, info) in plan.atoms.iter() {
224                    let table = join_state.db.get_table(info.table);
225                    binding_info.insert_subset(id, table.all());
226                }
227
228                let search_and_apply_timer = Instant::now();
229                join_state.run_header_and_plan(plan, &mut binding_info, &mut action_buf);
230                let search_and_apply_time = search_and_apply_timer.elapsed();
231
232                // TODO: unnecessary cloning in many cases
233                let rule_report = rule_reports.entry(desc.clone()).or_default();
234                rule_report.push(RuleReport {
235                    plan: report_plan,
236                    search_and_apply_time,
237                    num_matches: usize::MAX,
238                });
239            }
240            action_buf.flush(&mut exec_state.clone());
241        }
242        for (_plan, desc, _symbol_map, action) in rule_set.plans.values() {
243            let reports = rule_reports.get_mut(desc).unwrap();
244            let i = reports
245                .iter()
246                // HACK: Since the order of visiting queries is fixed and # matches need to be obtained
247                // seperately from rule execution, we first set all # matches to be usize::MAX and then fill
248                // them in one by one.
249                .position(|r| r.num_matches == usize::MAX)
250                .unwrap();
251            // NB: This requires each action ID correspond to only one query.
252            // If an action is used by multiple queries, then we can't tell how many matches are
253            // caused by individual queries.
254            reports[i].num_matches = match_counter.read_matches(*action);
255        }
256        let search_and_apply_time = search_and_apply_timer.elapsed();
257
258        let merge_timer = Instant::now();
259        let changed = self.merge_all();
260        let merge_time = merge_timer.elapsed();
261
262        RuleSetReport {
263            changed,
264            rule_reports,
265            search_and_apply_time,
266            merge_time,
267        }
268    }
269}
270
271struct ActionState {
272    n_runs: usize,
273    len: usize,
274    bindings: Bindings,
275}
276
277impl Default for ActionState {
278    fn default() -> Self {
279        Self {
280            n_runs: 0,
281            len: 0,
282            bindings: Bindings::new(VAR_BATCH_SIZE),
283        }
284    }
285}
286
287struct JoinState<'a> {
288    db: &'a Database,
289    exec_state: ExecutionState<'a>,
290}
291
292type ColumnIndexes = IdVec<ColumnId, OnceLock<Arc<ColumnIndex>>>;
293
294/// Information about the current subset of an atom's relation that is being considered, along with
295/// lazily-initialized, cached indexes on that subset.
296///
297/// This is the standard trie-node used in lazy implementations of GJ as in the original egglog
298/// implementation and the FJ paper. It currently does not handle non-column indexes, but that
299/// should be a fairly straightforward extension if we start generating plans that need those.
300/// (Right now, most plans iterating over more than one column just do a scan anyway).
301struct TrieNode {
302    /// The actual subset of the corresponding atom.
303    subset: Subset,
304    /// Any cached indexes on this subset.
305    cached_subsets: OnceLock<Arc<Pooled<ColumnIndexes>>>,
306}
307
308impl TrieNode {
309    fn size(&self) -> usize {
310        self.subset.size()
311    }
312    fn get_cached_index(&self, col: ColumnId, info: &TableInfo) -> Arc<ColumnIndex> {
313        self.cached_subsets.get_or_init(|| {
314            // Pre-size the vector so we do not need to borrow it mutably to initialize the index.
315            let mut vec: Pooled<ColumnIndexes> = with_pool_set(|ps| ps.get());
316            vec.resize_with(info.spec.arity(), OnceLock::new);
317            Arc::new(vec)
318        })[col]
319            .get_or_init(|| {
320                let col_index = info.table.group_by_col(self.subset.as_ref(), col);
321                Arc::new(col_index)
322            })
323            .clone()
324    }
325}
326
327impl Clone for TrieNode {
328    fn clone(&self) -> Self {
329        let cached_subsets = OnceLock::new();
330        if let Some(cached) = self.cached_subsets.get() {
331            cached_subsets.set(cached.clone()).ok().unwrap();
332        }
333        Self {
334            subset: self.subset.clone(),
335            cached_subsets,
336        }
337    }
338}
339
340#[derive(Default, Clone)]
341struct BindingInfo {
342    bindings: DenseIdMap<Variable, Value>,
343    subsets: DenseIdMap<AtomId, TrieNode>,
344}
345
346impl BindingInfo {
347    /// Initializes the atom-related metadata in the [`BindingInfo`].
348    fn insert_subset(&mut self, atom: AtomId, subset: Subset) {
349        let node = TrieNode {
350            subset,
351            cached_subsets: Default::default(),
352        };
353        self.subsets.insert(atom, node);
354    }
355
356    /// Probers returned from [`JoinState::get_index`] will move atom-related state out of the
357    /// [`BindingInfo`]. Once the caller is done using a prober, this method moves it back.
358    fn move_back(&mut self, atom: AtomId, prober: Prober) {
359        self.subsets.insert(atom, prober.node);
360    }
361
362    fn move_back_node(&mut self, atom: AtomId, node: TrieNode) {
363        self.subsets.insert(atom, node);
364    }
365
366    fn has_empty_subset(&self, atom: AtomId) -> bool {
367        self.subsets[atom].subset.is_empty()
368    }
369
370    fn unwrap_val(&mut self, atom: AtomId) -> TrieNode {
371        self.subsets.unwrap_val(atom)
372    }
373}
374
375impl<'a> JoinState<'a> {
376    fn new(db: &'a Database, exec_state: ExecutionState<'a>) -> Self {
377        Self { db, exec_state }
378    }
379
380    fn get_index(
381        &self,
382        plan: &Plan,
383        atom: AtomId,
384        binding_info: &mut BindingInfo,
385        cols: impl Iterator<Item = ColumnId>,
386    ) -> Prober {
387        let cols = SmallVec::<[ColumnId; 4]>::from_iter(cols);
388        let trie_node = binding_info.subsets.unwrap_val(atom);
389        let subset = &trie_node.subset;
390
391        let table_id = plan.atoms[atom].table;
392        let info = &self.db.tables[table_id];
393        let all_cacheable = cols.iter().all(|col| {
394            !info
395                .spec
396                .uncacheable_columns
397                .get(*col)
398                .copied()
399                .unwrap_or(false)
400        });
401        let whole_table = info.table.all();
402        let dyn_index =
403            if all_cacheable && subset.is_dense() && whole_table.size() / 2 < subset.size() {
404                // Skip intersecting with the subset if we are just looking at the
405                // whole table.
406                let intersect_outer =
407                    !(whole_table.is_dense() && subset.bounds() == whole_table.bounds());
408                // heuristic: if the subset we are scanning is somewhat
409                // large _or_ it is most of the table, or we already have a cached
410                // index for it, then return it.
411                if cols.len() != 1 {
412                    DynamicIndex::Cached {
413                        intersect_outer,
414                        table: get_index_from_tableinfo(info, &cols).clone(),
415                    }
416                } else {
417                    DynamicIndex::CachedColumn {
418                        intersect_outer,
419                        table: get_column_index_from_tableinfo(info, cols[0]).clone(),
420                    }
421                }
422            } else if cols.len() != 1 {
423                // NB: we should have a caching strategy for non-column indexes.
424                DynamicIndex::Dynamic(info.table.group_by_key(subset.as_ref(), &cols))
425            } else {
426                DynamicIndex::DynamicColumn(trie_node.get_cached_index(cols[0], info))
427            };
428        Prober {
429            node: trie_node,
430            pool: with_pool_set(|ps| ps.get_pool().clone()),
431            ix: dyn_index,
432        }
433    }
434    fn get_column_index(
435        &self,
436        plan: &Plan,
437        binding_info: &mut BindingInfo,
438        atom: AtomId,
439        col: ColumnId,
440    ) -> Prober {
441        self.get_index(plan, atom, binding_info, iter::once(col))
442    }
443
444    /// Runs the free join plan, starting with the header.
445    ///
446    /// A bit about the `instr_order` parameter: This defines the order in which the [`JoinStage`]
447    /// instructions will run. We want to support cached [`Plan`]s that may be based on stale
448    /// ordering information. `instr_order` allows us to specify a new ordering of the instructions
449    /// without mutating the plan itself: `run_plan` simply executes
450    /// `plan.stages.instrs[instr_order[i]]` at stage `i`.
451    ///
452    /// This is also a stepping stone towards supporting fully dynamic variable ordering.
453    fn run_header_and_plan<'buf, BUF: ActionBuffer<'buf>>(
454        &self,
455        plan: &'a Plan,
456        binding_info: &mut BindingInfo,
457        action_buf: &mut BUF,
458    ) where
459        'a: 'buf,
460    {
461        for JoinHeader { atom, subset, .. } in &plan.stages.header {
462            if subset.is_empty() {
463                return;
464            }
465            let mut cur = binding_info.unwrap_val(*atom);
466            debug_assert!(cur.cached_subsets.get().is_none());
467            cur.subset
468                .intersect(subset.as_ref(), &with_pool_set(|ps| ps.get_pool()));
469            binding_info.move_back_node(*atom, cur);
470        }
471        for (_, node) in binding_info.subsets.iter() {
472            if node.subset.is_empty() {
473                return;
474            }
475        }
476        let mut order = InstrOrder::from_iter(0..plan.stages.instrs.len());
477        sort_plan_by_size(&mut order, 0, &plan.stages.instrs, binding_info);
478        self.run_plan(plan, &mut order, 0, binding_info, action_buf);
479    }
480
481    /// The core method for executing a free join plan.
482    ///
483    /// This method takes the plan, mutable data-structures for variable binding and staging
484    /// actions, and two indexes: `cur` which is the current stage of the plan to run, and `level`
485    /// which is the current "fan-out" node we are in. The latter parameter is an experimental
486    /// index used to detect if we are at the "top" of a plan rather than the "bottom", and is
487    /// currently used as a heuristic to determine if we should increase parallelism more than the
488    /// default.
489    fn run_plan<'buf, BUF: ActionBuffer<'buf>>(
490        &self,
491        plan: &'a Plan,
492        instr_order: &mut InstrOrder,
493        cur: usize,
494        binding_info: &mut BindingInfo,
495        action_buf: &mut BUF,
496    ) where
497        'a: 'buf,
498    {
499        if self.exec_state.should_stop() {
500            return;
501        }
502
503        if cur >= instr_order.len() {
504            action_buf.push_bindings(plan.stages.actions, &binding_info.bindings, || {
505                self.exec_state.clone()
506            });
507            return;
508        }
509        let chunk_size = action_buf.morsel_size(cur, instr_order.len());
510        let mut cur_size = estimate_size(&plan.stages.instrs[instr_order.get(cur)], binding_info);
511        if cur_size > 32 && cur % 3 == 1 && cur < instr_order.len() - 1 {
512            // If we have a reasonable number of tuples to process, adjust the variable order every
513            // 3 rounds, but always make sure to readjust on the second roung.
514            sort_plan_by_size(instr_order, cur, &plan.stages.instrs, binding_info);
515            cur_size = estimate_size(&plan.stages.instrs[instr_order.get(cur)], binding_info);
516        }
517
518        // Helper macro (not its own method to appease the borrow checker).
519        macro_rules! drain_updates {
520            ($updates:expr) => {
521                if self.exec_state.should_stop() {
522                    return;
523                }
524                if cur == 0 || cur == 1 {
525                    drain_updates_parallel!($updates)
526                } else {
527                    $updates.drain(|update| match update {
528                        UpdateInstr::PushBinding(var, val) => {
529                            binding_info.bindings.insert(var, val);
530                        }
531                        UpdateInstr::RefineAtom(atom, subset) => {
532                            binding_info.insert_subset(atom, subset);
533                        }
534                        UpdateInstr::EndFrame => {
535                            self.run_plan(plan, instr_order, cur + 1, binding_info, action_buf);
536                        }
537                    })
538                }
539            };
540        }
541        macro_rules! drain_updates_parallel {
542            ($updates:expr) => {{
543                if self.exec_state.should_stop() {
544                    return;
545                }
546                let db = self.db;
547                let exec_state_for_factory = self.exec_state.clone();
548                let exec_state_for_work = self.exec_state.clone();
549                action_buf.recur(
550                    BorrowedLocalState {
551                        binding_info,
552                        instr_order,
553                        updates: &mut $updates,
554                    },
555                    move || exec_state_for_factory.clone(),
556                    move |BorrowedLocalState {
557                              binding_info,
558                              instr_order,
559                              updates,
560                          },
561                          buf| {
562                        updates.drain(|update| match update {
563                            UpdateInstr::PushBinding(var, val) => {
564                                binding_info.bindings.insert(var, val);
565                            }
566                            UpdateInstr::RefineAtom(atom, subset) => {
567                                binding_info.insert_subset(atom, subset);
568                            }
569                            UpdateInstr::EndFrame => {
570                                JoinState {
571                                    db,
572                                    exec_state: exec_state_for_work.clone(),
573                                }
574                                .run_plan(
575                                    plan,
576                                    instr_order,
577                                    cur + 1,
578                                    binding_info,
579                                    buf,
580                                );
581                            }
582                        })
583                    },
584                );
585                $updates.clear();
586            }};
587        }
588
589        fn refine_subset(
590            sub: Subset,
591            constraints: &[Constraint],
592            table: &WrappedTableRef,
593        ) -> Subset {
594            let sub = table.refine_live(sub);
595            table.refine(sub, constraints)
596        }
597
598        match &plan.stages.instrs[instr_order.get(cur)] {
599            JoinStage::Intersect { var, scans } => match scans.as_slice() {
600                [] => {}
601                [a] if a.cs.is_empty() => {
602                    if binding_info.has_empty_subset(a.atom) {
603                        return;
604                    }
605                    let prober = self.get_column_index(plan, binding_info, a.atom, a.column);
606                    let table = self.db.tables[plan.atoms[a.atom].table].table.as_ref();
607                    let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
608                    with_pool_set(|ps| {
609                        prober.for_each(|val, x| {
610                            updates.push_binding(*var, val[0]);
611                            let sub = refine_subset(x.to_owned(&ps.get_pool()), &[], &table);
612                            if sub.is_empty() {
613                                updates.rollback();
614                                return;
615                            }
616                            updates.refine_atom(a.atom, sub);
617                            updates.finish_frame();
618                            if updates.frames() >= chunk_size {
619                                drain_updates!(updates);
620                            }
621                        })
622                    });
623                    drain_updates!(updates);
624                    binding_info.move_back(a.atom, prober);
625                }
626                [a] => {
627                    if binding_info.has_empty_subset(a.atom) {
628                        return;
629                    }
630                    let prober = self.get_column_index(plan, binding_info, a.atom, a.column);
631                    let table = self.db.tables[plan.atoms[a.atom].table].table.as_ref();
632                    let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
633                    with_pool_set(|ps| {
634                        prober.for_each(|val, x| {
635                            updates.push_binding(*var, val[0]);
636                            let sub = refine_subset(x.to_owned(&ps.get_pool()), &a.cs, &table);
637                            if sub.is_empty() {
638                                updates.rollback();
639                                return;
640                            }
641                            updates.refine_atom(a.atom, sub);
642                            updates.finish_frame();
643                            if updates.frames() >= chunk_size {
644                                drain_updates!(updates);
645                            }
646                        })
647                    });
648                    drain_updates!(updates);
649                    binding_info.move_back(a.atom, prober);
650                }
651                [a, b] => {
652                    let a_prober = self.get_column_index(plan, binding_info, a.atom, a.column);
653                    let b_prober = self.get_column_index(plan, binding_info, b.atom, b.column);
654
655                    let ((smaller, smaller_scan), (larger, larger_scan)) =
656                        if a_prober.len() < b_prober.len() {
657                            ((&a_prober, a), (&b_prober, b))
658                        } else {
659                            ((&b_prober, b), (&a_prober, a))
660                        };
661
662                    let smaller_atom = smaller_scan.atom;
663                    let larger_atom = larger_scan.atom;
664                    let large_table = self.db.tables[plan.atoms[larger_atom].table].table.as_ref();
665                    let small_table = self.db.tables[plan.atoms[smaller_atom].table]
666                        .table
667                        .as_ref();
668                    let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
669                    with_pool_set(|ps| {
670                        smaller.for_each(|val, small_sub| {
671                            if let Some(mut large_sub) = larger.get_subset(val) {
672                                large_sub = refine_subset(large_sub, &larger_scan.cs, &large_table);
673                                if large_sub.is_empty() {
674                                    updates.rollback();
675                                    return;
676                                }
677                                let small_sub = refine_subset(
678                                    small_sub.to_owned(&ps.get_pool()),
679                                    &smaller_scan.cs,
680                                    &small_table,
681                                );
682                                if small_sub.is_empty() {
683                                    updates.rollback();
684                                    return;
685                                }
686                                updates.push_binding(*var, val[0]);
687                                updates.refine_atom(smaller_atom, small_sub);
688                                updates.refine_atom(larger_atom, large_sub);
689                                updates.finish_frame();
690                                if updates.frames() >= chunk_size {
691                                    drain_updates_parallel!(updates);
692                                }
693                            }
694                        });
695                    });
696                    drain_updates!(updates);
697
698                    binding_info.move_back(a.atom, a_prober);
699                    binding_info.move_back(b.atom, b_prober);
700                }
701                rest => {
702                    let mut smallest = 0;
703                    let mut smallest_size = usize::MAX;
704                    let mut probers = Vec::with_capacity(rest.len());
705                    for (i, scan) in rest.iter().enumerate() {
706                        let prober =
707                            self.get_column_index(plan, binding_info, scan.atom, scan.column);
708                        let size = prober.len();
709                        if size < smallest_size {
710                            smallest = i;
711                            smallest_size = size;
712                        }
713                        probers.push(prober);
714                    }
715
716                    let main_spec = &rest[smallest];
717                    let main_spec_table = self.db.tables[plan.atoms[main_spec.atom].table]
718                        .table
719                        .as_ref();
720
721                    if smallest_size != 0 {
722                        // Smallest leads the scan
723                        let mut updates =
724                            FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
725                        probers[smallest].for_each(|key, sub| {
726                            with_pool_set(|ps| {
727                                updates.push_binding(*var, key[0]);
728                                for (i, scan) in rest.iter().enumerate() {
729                                    if i == smallest {
730                                        continue;
731                                    }
732                                    if let Some(mut sub) = probers[i].get_subset(key) {
733                                        let table = self.db.tables[plan.atoms[rest[i].atom].table]
734                                            .table
735                                            .as_ref();
736                                        sub = refine_subset(sub, &rest[i].cs, &table);
737                                        if sub.is_empty() {
738                                            updates.rollback();
739                                            return;
740                                        }
741                                        updates.refine_atom(scan.atom, sub)
742                                    } else {
743                                        updates.rollback();
744                                        // Empty intersection.
745                                        return;
746                                    }
747                                }
748                                let sub = sub.to_owned(&ps.get_pool());
749                                let sub = refine_subset(sub, &main_spec.cs, &main_spec_table);
750                                if sub.is_empty() {
751                                    updates.rollback();
752                                    return;
753                                }
754                                updates.refine_atom(main_spec.atom, sub);
755                                updates.finish_frame();
756                                if updates.frames() >= chunk_size {
757                                    drain_updates_parallel!(updates);
758                                }
759                            })
760                        });
761                        drain_updates!(updates);
762                    }
763                    for (spec, prober) in rest.iter().zip(probers.into_iter()) {
764                        binding_info.move_back(spec.atom, prober);
765                    }
766                }
767            },
768            JoinStage::FusedIntersect {
769                cover,
770                bind,
771                to_intersect,
772            } if to_intersect.is_empty() => {
773                let cover_atom = cover.to_index.atom;
774                if binding_info.has_empty_subset(cover_atom) {
775                    return;
776                }
777                let proj = SmallVec::<[ColumnId; 4]>::from_iter(bind.iter().map(|(col, _)| *col));
778                let cover_node = binding_info.unwrap_val(cover_atom);
779                let cover_subset = cover_node.subset.as_ref();
780                let mut cur = Offset::new(0);
781                let mut buffer = TaggedRowBuffer::new(bind.len());
782                let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
783                loop {
784                    buffer.clear();
785                    let table = &self.db.tables[plan.atoms[cover_atom].table].table;
786                    let next = table.scan_project(
787                        cover_subset,
788                        &proj,
789                        cur,
790                        chunk_size,
791                        &cover.constraints,
792                        &mut buffer,
793                    );
794                    for (row, key) in buffer.non_stale() {
795                        updates.refine_atom(
796                            cover_atom,
797                            Subset::Dense(OffsetRange::new(row, row.inc())),
798                        );
799                        // bind the values
800                        for (i, (_, var)) in bind.iter().enumerate() {
801                            updates.push_binding(*var, key[i]);
802                        }
803                        updates.finish_frame();
804                        if updates.frames() >= chunk_size {
805                            drain_updates_parallel!(updates);
806                        }
807                    }
808                    if let Some(next) = next {
809                        cur = next;
810                        continue;
811                    }
812                    break;
813                }
814                drain_updates!(updates);
815                // Restore the subsets we swapped out.
816                binding_info.move_back_node(cover_atom, cover_node);
817            }
818            JoinStage::FusedIntersect {
819                cover,
820                bind,
821                to_intersect,
822            } => {
823                let cover_atom = cover.to_index.atom;
824                if binding_info.has_empty_subset(cover_atom) {
825                    return;
826                }
827                let index_probers = to_intersect
828                    .iter()
829                    .enumerate()
830                    .map(|(i, (spec, _))| {
831                        (
832                            i,
833                            spec.to_index.atom,
834                            self.get_index(
835                                plan,
836                                spec.to_index.atom,
837                                binding_info,
838                                spec.to_index.vars.iter().copied(),
839                            ),
840                        )
841                    })
842                    .collect::<SmallVec<[(usize, AtomId, Prober); 4]>>();
843                let proj = SmallVec::<[ColumnId; 4]>::from_iter(bind.iter().map(|(col, _)| *col));
844                let cover_node = binding_info.unwrap_val(cover_atom);
845                let cover_subset = cover_node.subset.as_ref();
846                let mut cur = Offset::new(0);
847                let mut buffer = TaggedRowBuffer::new(bind.len());
848                let mut updates = FrameUpdates::with_capacity(cmp::min(chunk_size, cur_size));
849                loop {
850                    buffer.clear();
851                    let table = &self.db.tables[plan.atoms[cover_atom].table].table;
852                    let next = table.scan_project(
853                        cover_subset,
854                        &proj,
855                        cur,
856                        chunk_size,
857                        &cover.constraints,
858                        &mut buffer,
859                    );
860                    'mid: for (row, key) in buffer.non_stale() {
861                        updates.refine_atom(
862                            cover_atom,
863                            Subset::Dense(OffsetRange::new(row, row.inc())),
864                        );
865                        // bind the values
866                        for (i, (_, var)) in bind.iter().enumerate() {
867                            updates.push_binding(*var, key[i]);
868                        }
869                        // now probe each remaining indexes
870                        for (i, atom, prober) in &index_probers {
871                            // create a key: to_intersect indexes into the key from the cover
872                            let index_cols = &to_intersect[*i].1;
873                            let index_key = index_cols
874                                .iter()
875                                .map(|col| key[col.index()])
876                                .collect::<SmallVec<[Value; 4]>>();
877                            let Some(mut subset) = prober.get_subset(&index_key) else {
878                                updates.rollback();
879                                // There are no possible values for this subset
880                                continue 'mid;
881                            };
882                            // apply any constraints needed in this scan.
883                            let table_info = &self.db.tables[plan.atoms[*atom].table];
884                            let cs = &to_intersect[*i].0.constraints;
885                            subset = refine_subset(subset, cs, &table_info.table.as_ref());
886                            if subset.is_empty() {
887                                updates.rollback();
888                                // There are no possible values for this subset
889                                continue 'mid;
890                            }
891                            updates.refine_atom(*atom, subset);
892                        }
893                        updates.finish_frame();
894                        if updates.frames() >= chunk_size {
895                            drain_updates_parallel!(updates);
896                        }
897                    }
898                    if let Some(next) = next {
899                        cur = next;
900                        continue;
901                    }
902                    break;
903                }
904                // TODO: special-case the scenario when the cover doesn't need
905                // deduping (and hence we can do a straight scan: e.g. when the
906                // cover is binding a superset of the primary key for the
907                // table).
908                drain_updates!(updates);
909                // Restore the subsets we swapped out.
910                binding_info.move_back_node(cover_atom, cover_node);
911                for (_, atom, prober) in index_probers {
912                    binding_info.move_back(atom, prober);
913                }
914            }
915        }
916    }
917}
918
919const VAR_BATCH_SIZE: usize = 128;
920
921/// A trait used to abstract over different ways of buffering actions together
922/// before running them.
923///
924/// This trait exists as a fairly ad-hoc wrapper over its two implementations.
925/// It allows us to avoid duplicating the (somewhat monstrous) `run_plan` method
926/// for serial and parallel modes.
927trait ActionBuffer<'state>: Send {
928    type AsLocal<'a>: ActionBuffer<'state>
929    where
930        'state: 'a;
931    /// Push the given bindings to be executed for the specified action. If this
932    /// buffer has built up a sufficient batch size, it may execute
933    /// `to_exec_state` and then execute the action.
934    ///
935    /// NB: `push_bindings` makes module-specific assumptions on what values are passed to
936    /// `bindings` for a common `action`. This is not a general-purpose trait for that reason and
937    /// it should not, in general, be used outside of this module.
938    fn push_bindings(
939        &mut self,
940        action: ActionId,
941        bindings: &DenseIdMap<Variable, Value>,
942        to_exec_state: impl FnMut() -> ExecutionState<'state>,
943    );
944
945    /// Execute any remaining actions associated with this buffer.
946    fn flush(&mut self, exec_state: &mut ExecutionState);
947
948    /// Execute `work`, potentially asynchronously, with a mutable reference to
949    /// an action buffer, potentially handed off to a different thread.
950    ///
951    /// Callers [`BorrowedLocalState`] values that may be modified by work, or
952    /// cloned first and then have a separate copy modified by `work`. Callers
953    /// should assume that `local` _is_ modified synchronously.
954    // NB: Earlier versions of this method had BorrowedLocalState be a generic instead, but this
955    // ran into difficulties when we needed to pass multiple mutable references.
956    fn recur<'local>(
957        &mut self,
958        local: BorrowedLocalState<'local>,
959        to_exec_state: impl FnMut() -> ExecutionState<'state> + Send + 'state,
960        work: impl for<'a> FnOnce(BorrowedLocalState<'a>, &mut Self::AsLocal<'a>) + Send + 'state,
961    );
962
963    /// The unit at which you should batch updates passed to calls to `recur`,
964    /// potentially depending on the current level of recursion.
965    ///
966    /// As of right now this is just a hard-coded value. We may change it in the
967    /// future to fan out more at higher levels though.
968    fn morsel_size(&mut self, _level: usize, _total: usize) -> usize {
969        256
970    }
971}
972
973/// The action buffer we use if we are executing in a single-threaded
974/// environment. It builds up local batches and then flushes them inline.
975struct InPlaceActionBuffer<'a> {
976    rule_set: &'a RuleSet,
977    match_counter: &'a MatchCounter,
978    batches: DenseIdMap<ActionId, ActionState>,
979}
980
981impl<'a, 'outer: 'a> ActionBuffer<'a> for InPlaceActionBuffer<'outer> {
982    type AsLocal<'b>
983        = Self
984    where
985        'a: 'b;
986
987    fn push_bindings(
988        &mut self,
989        action: ActionId,
990        bindings: &DenseIdMap<Variable, Value>,
991        mut to_exec_state: impl FnMut() -> ExecutionState<'a>,
992    ) {
993        let action_state = self.batches.get_or_default(action);
994        action_state.n_runs += 1;
995        action_state.len += 1;
996        let action_info = &self.rule_set.actions[action];
997        // SAFETY: `used_vars` is a constant per-rule. This module only ever calls it with
998        // `bindings` produced by the same join.
999        unsafe {
1000            action_state.bindings.push(bindings, &action_info.used_vars);
1001        }
1002        if action_state.len >= VAR_BATCH_SIZE {
1003            let mut state = to_exec_state();
1004            let succeeded = state.run_instrs(&action_info.instrs, &mut action_state.bindings);
1005            action_state.bindings.clear();
1006            self.match_counter.inc_matches(action, succeeded);
1007            action_state.len = 0;
1008        }
1009    }
1010
1011    fn flush(&mut self, exec_state: &mut ExecutionState) {
1012        flush_action_states(
1013            exec_state,
1014            &mut self.batches,
1015            self.rule_set,
1016            self.match_counter,
1017        );
1018    }
1019
1020    fn recur<'local>(
1021        &mut self,
1022        local: BorrowedLocalState<'local>,
1023        _to_exec_state: impl FnMut() -> ExecutionState<'a> + Send + 'a,
1024        work: impl for<'b> FnOnce(BorrowedLocalState<'b>, &mut Self) + Send + 'a,
1025    ) {
1026        work(local, self)
1027    }
1028}
1029
1030/// An Action buffer that hands off batches to of actions to rayon to execute.
1031struct ScopedActionBuffer<'inner, 'scope> {
1032    scope: &'inner rayon::Scope<'scope>,
1033    rule_set: &'scope RuleSet,
1034    match_counter: &'scope MatchCounter,
1035    batches: DenseIdMap<ActionId, ActionState>,
1036    needs_flush: bool,
1037}
1038
1039impl<'inner, 'scope> ScopedActionBuffer<'inner, 'scope> {
1040    fn new(
1041        scope: &'inner rayon::Scope<'scope>,
1042        rule_set: &'scope RuleSet,
1043        match_counter: &'scope MatchCounter,
1044    ) -> Self {
1045        Self {
1046            scope,
1047            rule_set,
1048            batches: Default::default(),
1049            match_counter,
1050            needs_flush: false,
1051        }
1052    }
1053}
1054
1055impl<'scope> ActionBuffer<'scope> for ScopedActionBuffer<'_, 'scope> {
1056    type AsLocal<'a>
1057        = ScopedActionBuffer<'a, 'scope>
1058    where
1059        'scope: 'a;
1060    fn push_bindings(
1061        &mut self,
1062        action: ActionId,
1063        bindings: &DenseIdMap<Variable, Value>,
1064        mut to_exec_state: impl FnMut() -> ExecutionState<'scope>,
1065    ) {
1066        self.needs_flush = true;
1067        let action_state = self.batches.get_or_default(action);
1068        action_state.n_runs += 1;
1069        action_state.len += 1;
1070        let action_info = &self.rule_set.actions[action];
1071        // SAFETY: `used_vars` is a constant per-rule. This module only ever calls it with
1072        // `bindings` produced by the same join.
1073        unsafe {
1074            action_state.bindings.push(bindings, &action_info.used_vars);
1075        }
1076        if action_state.len >= VAR_BATCH_SIZE {
1077            let mut state = to_exec_state();
1078            let mut bindings =
1079                mem::replace(&mut action_state.bindings, Bindings::new(VAR_BATCH_SIZE));
1080            action_state.len = 0;
1081            let match_counter = self.match_counter;
1082            self.scope.spawn(move |_| {
1083                let succeeded = state.run_instrs(&action_info.instrs, &mut bindings);
1084                match_counter.inc_matches(action, succeeded);
1085            });
1086        }
1087    }
1088
1089    fn flush(&mut self, exec_state: &mut ExecutionState) {
1090        flush_action_states(
1091            exec_state,
1092            &mut self.batches,
1093            self.rule_set,
1094            self.match_counter,
1095        );
1096        self.needs_flush = false;
1097    }
1098    fn recur<'local>(
1099        &mut self,
1100        mut local: BorrowedLocalState<'local>,
1101        mut to_exec_state: impl FnMut() -> ExecutionState<'scope> + Send + 'scope,
1102        work: impl for<'a> FnOnce(BorrowedLocalState<'a>, &mut ScopedActionBuffer<'a, 'scope>)
1103        + Send
1104        + 'scope,
1105    ) {
1106        let rule_set = self.rule_set;
1107        let match_counter = self.match_counter;
1108        let mut inner = local.clone_state();
1109        self.scope.spawn(move |scope| {
1110            let mut buf: ScopedActionBuffer<'_, 'scope> = ScopedActionBuffer {
1111                scope,
1112                rule_set,
1113                match_counter,
1114                needs_flush: false,
1115                batches: Default::default(),
1116            };
1117            work(inner.borrow_mut(), &mut buf);
1118            if buf.needs_flush {
1119                flush_action_states(
1120                    &mut to_exec_state(),
1121                    &mut buf.batches,
1122                    buf.rule_set,
1123                    buf.match_counter,
1124                );
1125            }
1126        });
1127    }
1128
1129    fn morsel_size(&mut self, _level: usize, _total: usize) -> usize {
1130        // Lower morsel size to increase parallelism.
1131        match _level {
1132            0 if _total > 2 => 32,
1133            _ => 256,
1134        }
1135    }
1136}
1137
1138fn flush_action_states(
1139    exec_state: &mut ExecutionState,
1140    actions: &mut DenseIdMap<ActionId, ActionState>,
1141    rule_set: &RuleSet,
1142    match_counter: &MatchCounter,
1143) {
1144    for (action, ActionState { bindings, len, .. }) in actions.iter_mut() {
1145        if *len > 0 {
1146            let succeeded = exec_state.run_instrs(&rule_set.actions[action].instrs, bindings);
1147            bindings.clear();
1148            match_counter.inc_matches(action, succeeded);
1149            *len = 0;
1150        }
1151    }
1152}
1153struct MatchCounter {
1154    matches: IdVec<ActionId, CachePadded<AtomicUsize>>,
1155}
1156
1157impl MatchCounter {
1158    fn new(n_ids: usize) -> Self {
1159        let mut matches = IdVec::with_capacity(n_ids);
1160        matches.resize_with(n_ids, || CachePadded::new(AtomicUsize::new(0)));
1161        Self { matches }
1162    }
1163
1164    fn inc_matches(&self, action: ActionId, by: usize) {
1165        self.matches[action].fetch_add(by, std::sync::atomic::Ordering::Relaxed);
1166    }
1167    fn read_matches(&self, action: ActionId) -> usize {
1168        self.matches[action].load(std::sync::atomic::Ordering::Acquire)
1169    }
1170}
1171
1172fn estimate_size(join_stage: &JoinStage, binding_info: &BindingInfo) -> usize {
1173    match join_stage {
1174        JoinStage::Intersect { scans, .. } => scans
1175            .iter()
1176            .map(|scan| binding_info.subsets[scan.atom].size())
1177            .min()
1178            .unwrap_or(0),
1179        JoinStage::FusedIntersect { cover, .. } => binding_info.subsets[cover.to_index.atom].size(),
1180    }
1181}
1182
1183fn num_intersected_rels(join_stage: &JoinStage) -> i32 {
1184    match join_stage {
1185        JoinStage::Intersect { scans, .. } => scans.len() as i32,
1186        JoinStage::FusedIntersect { to_intersect, .. } => to_intersect.len() as i32 + 1,
1187    }
1188}
1189
1190fn sort_plan_by_size(
1191    order: &mut InstrOrder,
1192    start: usize,
1193    instrs: &[JoinStage],
1194    binding_info: &mut BindingInfo,
1195) {
1196    // How many times an atom has been intersected/joined
1197    let mut times_refined = with_pool_set(|ps| ps.get::<DenseIdMap<AtomId, i64>>());
1198
1199    // Count how many times each atom has been refined so far.
1200    for ins in instrs[..start].iter() {
1201        match ins {
1202            JoinStage::Intersect { scans, .. } => scans.iter().for_each(|scan| {
1203                *times_refined.get_or_default(scan.atom) += 1;
1204            }),
1205            JoinStage::FusedIntersect { cover, .. } => {
1206                *times_refined.get_or_default(cover.to_index.atom) +=
1207                    cover.to_index.vars.len() as i64;
1208            }
1209        }
1210    }
1211
1212    // We prioritize variables by
1213    //
1214    //   (1) how many times an atom with this variable has been refined,
1215    //   (2) then by how many relations joins on this variable
1216    //   (3) then by the cardinality of the variable to be enumerated
1217    let key_fn = |join_stage: &JoinStage,
1218                  binding_info: &BindingInfo,
1219                  times_refined: &DenseIdMap<AtomId, i64>| {
1220        let refine = match join_stage {
1221            JoinStage::Intersect { scans, .. } => scans
1222                .iter()
1223                .map(|scan| times_refined.get(scan.atom).copied().unwrap_or_default())
1224                .sum::<i64>(),
1225            JoinStage::FusedIntersect { cover, .. } => times_refined
1226                .get(cover.to_index.atom)
1227                .copied()
1228                .unwrap_or_default(),
1229        };
1230        (
1231            -refine,
1232            -num_intersected_rels(join_stage),
1233            estimate_size(join_stage, binding_info),
1234        )
1235    };
1236
1237    for i in start..order.len() {
1238        for j in i + 1..order.len() {
1239            let key_i = key_fn(&instrs[order.get(i)], binding_info, &times_refined);
1240            let key_j = key_fn(&instrs[order.get(j)], binding_info, &times_refined);
1241            if key_j < key_i {
1242                order.data.swap(i, j);
1243            }
1244        }
1245        // Update the counts after a new instruction is selected.
1246        match &instrs[order.get(i)] {
1247            JoinStage::Intersect { scans, .. } => scans.iter().for_each(|scan| {
1248                *times_refined.get_or_default(scan.atom) += 1;
1249            }),
1250            JoinStage::FusedIntersect { cover, .. } => {
1251                *times_refined.get_or_default(cover.to_index.atom) +=
1252                    cover.to_index.vars.len() as i64;
1253            }
1254        }
1255    }
1256}
1257
1258#[derive(Debug, Clone, PartialEq, Eq)]
1259struct InstrOrder {
1260    data: SmallVec<[u16; 8]>,
1261}
1262
1263impl InstrOrder {
1264    fn new() -> Self {
1265        InstrOrder {
1266            data: SmallVec::new(),
1267        }
1268    }
1269
1270    fn from_iter(range: impl Iterator<Item = usize>) -> InstrOrder {
1271        let mut res = InstrOrder::new();
1272        res.data
1273            .extend(range.map(|x| u16::try_from(x).expect("too many instructions")));
1274        res
1275    }
1276
1277    fn get(&self, idx: usize) -> usize {
1278        self.data[idx] as usize
1279    }
1280    fn len(&self) -> usize {
1281        self.data.len()
1282    }
1283}
1284
1285struct BorrowedLocalState<'a> {
1286    instr_order: &'a mut InstrOrder,
1287    binding_info: &'a mut BindingInfo,
1288    updates: &'a mut FrameUpdates,
1289}
1290
1291impl BorrowedLocalState<'_> {
1292    fn clone_state(&mut self) -> LocalState {
1293        LocalState {
1294            instr_order: self.instr_order.clone(),
1295            binding_info: self.binding_info.clone(),
1296            updates: std::mem::take(self.updates),
1297        }
1298    }
1299}
1300
1301struct LocalState {
1302    instr_order: InstrOrder,
1303    binding_info: BindingInfo,
1304    updates: FrameUpdates,
1305}
1306
1307impl LocalState {
1308    fn borrow_mut<'a>(&'a mut self) -> BorrowedLocalState<'a> {
1309        BorrowedLocalState {
1310            instr_order: &mut self.instr_order,
1311            binding_info: &mut self.binding_info,
1312            updates: &mut self.updates,
1313        }
1314    }
1315}