egglog_core_relations/free_join/
plan.rs

1//! This module defines query optimization for egglog. The main entry point is `plan_query`, which takes a `Query` and produces a `Plan`.
2//!
3//! At a high level, the query planner has two phases: **(hyper)tree decomposition** and **join planning for each bag**.
4//! Both phases are very subtle, and heuristics are heavily used for good performance.
5//!
6//! # (Hyper)tree Decomposition
7//!
8//! A conjunctive query can be viewed as a hypergraph where variables are vertices and atoms (relations) are hyperedges.
9//! The idea of tree decomposition is to break this hypergraph into a tree of overlapping subqueries called *bags*,
10//! each of which is cheaper to evaluate independently. This is the classical idea behind tree decomposition and the
11//! Yannakakis algorithm.
12//!
13//! The decomposition proceeds via *variable elimination*: we iteratively pick a variable `v` and eliminate the neighborhood
14//! `N(v)` (which also includes `v`) from the hypergraph, and add back a hyperedge consisting of `N(v) - {v}`, until
15//! there are no variables left. Each elimination step gives us a bag. A min-fill heuristic
16//! (`next_var_to_eliminate`) guides the order of elimination to keep bags small. After all variables are eliminated,
17//! redundant bags are pruned: bags subsumed by another (all their variables are covered) are merged, and "ears"
18//! are merged into their parent.
19//!
20//! We then topologically sort the bags and decide which variables are "message variables" and which are private to the bag.
21//! The materialized result of each bag has its output keyed on the *message variables* it shares with
22//! its parent, and the parent uses that materialization to prune its own search space.
23//!
24//! When the query hypergraph is a single connected component with no beneficial decomposition, the planner falls back to
25//! a `SinglePlan` with no materialization steps.
26//!
27//! # Join Planning for a Single Bag
28//!
29//! Once each bag (subquery) is isolated, the planner generates a sequence of `JoinStage` instructions that enumerate
30//! all satisfying tuples for that bag. Two heuristics are supported:
31//!
32//! - **Generic Join** (`PlanStrategy::Gj`): The classic worst-case optimal join algorithm. Each stage picks one variable
33//!   and intersects the columns of atoms that correspond to this variable (`JoinStage::Intersect`).
34//!
35//! - **Free Join** (`PlanStrategy::PureSize` / `PlanStrategy::MinCover`): From Remy's paper. The planning algorithm
36//!   does the following: Each stage it selects a *cover* — a (sub)atom whose columns span the variables being bound in that step — and
37//!   uses it to probe all other atoms that share those variables (`JoinStage::FusedIntersect`). When the cover is an
38//!   entire atom and there is only one relation to probe, this degenerates to a hash join; when covers are single-column
39//!   scans it ~ recovers generic join*.
40//!
41//!   *: although this is not worst-case optimal because it does not necessarily picks the smallest side to scan.
42//!
43//! Both strategies produce a flat list of `JoinStage` instructions that are fused where possible (`JoinStage::fuse`) to
44//! reduce the number of passes over the data. A `JoinHeader` is prepended to each plan to apply constant constraints and
45//! pre-filter the driving relation before the main join loop begins.
46//!
47use std::{collections::BTreeMap, iter, mem, sync::Arc};
48
49use crate::{
50    TableId,
51    free_join::ProcessedConstraints,
52    numeric_id::{DenseIdMap, NumericId},
53    query::{FunDeps, SymbolMap},
54};
55use egglog_numeric_id::define_id;
56use fixedbitset::FixedBitSet;
57use smallvec::{SmallVec, smallvec};
58
59use crate::{
60    common::{HashMap, HashSet, IndexSet},
61    offsets::Subset,
62    pool::Pooled,
63    query::{Atom, Query, VarColumnMap},
64    table_spec::Constraint,
65};
66
67use super::{ActionId, AtomId, ColumnId, SubAtom, VarInfo, Variable};
68
69#[derive(Clone, Debug, PartialEq, Eq)]
70pub(crate) struct ScanSpec {
71    pub to_index: SubAtom,
72    // Only yield rows where the given constraints match.
73    pub constraints: Vec<Constraint>,
74}
75
76#[derive(Clone, Debug, PartialEq, Eq)]
77pub(crate) struct SingleScanSpec {
78    pub atom: AtomId,
79    pub column: ColumnId,
80    pub cs: Vec<Constraint>,
81}
82
83define_id!(pub(crate) MatId, u32, "An identifier for materialization within a decomposed plan.");
84
85#[derive(Clone, Debug, PartialEq, Eq)]
86pub(crate) enum MatScanMode {
87    Full,
88    KeyOnly,
89    Value(SmallVec<[Variable; 16]>),
90    Lookup(SmallVec<[Variable; 16]>),
91}
92
93/// Join headers evaluate constraints on a single atom; they prune the search space before the rest
94/// of the join plan is executed.
95pub(crate) struct JoinHeader {
96    pub atom: AtomId,
97    /// We currently aren't using these at all. The plan is to use this to
98    /// dedup plan stages later (it also helps for debugging).
99    #[allow(unused)]
100    pub constraints: Pooled<Vec<Constraint>>,
101    /// A pre-computed table subset that we can use to filter the table,
102    /// given these constaints.
103    ///
104    /// Why use the constraints at all? Because we want to use them to
105    /// discover common plan nodes from different queries (subsets can be
106    /// large).
107    pub subset: Subset,
108}
109
110impl std::fmt::Debug for JoinHeader {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        f.debug_struct("JoinHeader")
113            .field("atom", &self.atom)
114            .field("constraints", &self.constraints)
115            .field(
116                "subset",
117                &format_args!("Subset(size={})", self.subset.size()),
118            )
119            .finish()
120    }
121}
122
123impl Clone for JoinHeader {
124    fn clone(&self) -> Self {
125        JoinHeader {
126            atom: self.atom,
127            constraints: Pooled::cloned(&self.constraints),
128            subset: self.subset.clone(),
129        }
130    }
131}
132
133#[derive(Debug, Clone)]
134pub(crate) enum JoinStage {
135    /// `Intersect` takes a variable and intersects a set of atoms
136    /// on that variable.
137    /// This corresponds to the classic generic join algorithm.
138    Intersect {
139        var: Variable,
140        scans: SmallVec<[SingleScanSpec; 3]>,
141    },
142    /// `FusedIntersect` takes a "cover" (sub)atom and use it to probe other (sub)atoms.
143    /// This corresponds to the free join algorithm, or when to_intersect.len() == 1 and cover is
144    /// the entire atom, a hash join.
145    FusedIntersect {
146        cover: ScanSpec,
147        bind: SmallVec<[(ColumnId, Variable); 2]>,
148        // to_intersect.1 is the index into the cover atom.
149        to_intersect: Vec<(ScanSpec, SmallVec<[ColumnId; 2]>)>,
150        is_leaf_scan: bool,
151    },
152    FusedIntersectMat {
153        cover: MatId,
154        mode: MatScanMode,
155        bind: SmallVec<[(ColumnId, Variable); 2]>,
156        to_intersect: Vec<(ScanSpec, SmallVec<[ColumnId; 2]>)>,
157    },
158}
159
160impl JoinStage {
161    /// Attempt to fuse two stages into one.
162    ///
163    /// This operation is very conservative right now, it only fuses multiple
164    /// scans that do no filtering whatsoever.
165    fn fuse(&mut self, other: &JoinStage) -> bool {
166        use JoinStage::*;
167        match (&*self, other) {
168            (
169                FusedIntersect {
170                    cover: cover1,
171                    bind: bind1,
172                    to_intersect: to_intersect1,
173                    is_leaf_scan: is_leaf_scan1,
174                },
175                FusedIntersect {
176                    cover: cover2,
177                    bind: bind2,
178                    to_intersect: to_intersect2,
179                    is_leaf_scan: is_leaf_scan2,
180                },
181            ) if cover1.to_index.atom == cover2.to_index.atom
182                && to_intersect1.is_empty()
183                && to_intersect2.is_empty()
184                && (cover1.constraints.is_empty() || cover2.constraints.is_empty()) =>
185            {
186                assert!(!*is_leaf_scan1 && !is_leaf_scan2);
187                let to_index = SubAtom {
188                    atom: cover1.to_index.atom,
189                    vars: cover1
190                        .to_index
191                        .vars
192                        .iter()
193                        .chain(cover2.to_index.vars.iter())
194                        .copied()
195                        .collect(),
196                };
197                let bind = bind1.iter().chain(bind2.iter()).copied().collect();
198                *self = FusedIntersect {
199                    cover: ScanSpec {
200                        to_index,
201                        constraints: cover1
202                            .constraints
203                            .iter()
204                            .chain(cover2.constraints.iter())
205                            .cloned()
206                            .collect(),
207                    },
208                    bind,
209                    to_intersect: Default::default(),
210                    is_leaf_scan: false,
211                };
212                true
213            }
214            _ => false,
215        }
216    }
217}
218
219#[derive(Debug, Clone)]
220pub(crate) enum Plan {
221    SinglePlan(SinglePlan),
222    DecomposedPlan(DecomposedPlan),
223}
224impl Plan {
225    pub fn actions(&self) -> ActionId {
226        match self {
227            Plan::SinglePlan(p) => p.actions,
228            Plan::DecomposedPlan(p) => p.actions,
229        }
230    }
231
232    pub fn atoms(&self) -> Arc<DenseIdMap<AtomId, Atom>> {
233        match self {
234            Plan::SinglePlan(p) => p.atoms.clone(),
235            Plan::DecomposedPlan(p) => p.atoms.clone(),
236        }
237    }
238
239    pub(crate) fn to_report(&self, _symbol_map: &SymbolMap) -> egglog_reports::Plan {
240        match self {
241            Plan::SinglePlan(p) => p.to_report(_symbol_map),
242            Plan::DecomposedPlan(_) => {
243                todo!()
244            }
245        }
246    }
247
248    pub(crate) fn header(&self) -> &[JoinHeader] {
249        match self {
250            Plan::SinglePlan(p) => &p.header,
251            Plan::DecomposedPlan(p) => &p.header,
252        }
253    }
254}
255
256#[derive(Debug, Clone)]
257pub(crate) struct SinglePlan {
258    pub atoms: Arc<DenseIdMap<AtomId, Atom>>,
259    pub header: Vec<JoinHeader>,
260    pub stages: JoinStages,
261    pub actions: ActionId,
262}
263
264#[derive(Debug, Clone)]
265pub(crate) struct JoinStages {
266    pub instrs: Arc<Vec<JoinStage>>,
267}
268
269/// Specification of the materialization of the intermediate results, as required by tree decomposition.
270/// A materialization has two parts. The message variables are variables that are passed to and joined with later stages,
271/// and the value/private variables are variables that only occur in the current (and maybe previous) bags.
272///
273/// A materialization thus looks like a map from values of the message variables to sets of values of the private variables,
274/// and when we evaluate other bags, only the keys (message variables) are looked up or enumerated. This is because
275/// the private variables are not relevant to the evaluation of other bags. A key idea of tree decomposition is to separate
276/// independent parts of a query and make sure they are evaluated independently.
277#[derive(Debug, Clone)]
278pub(crate) struct MatSpec {
279    // Variables that are used by later stages
280    pub msg_vars: SmallVec<[Variable; 16]>,
281    // Variables that are not used by later stages.
282    pub val_vars: SmallVec<[Variable; 16]>,
283}
284
285#[derive(Debug, Clone)]
286pub(crate) struct JoinStageBlocks {
287    // each block is a list of instructions and how to yield
288    pub blocks: Vec<(JoinStages, MatSpec)>,
289}
290
291#[derive(Debug, Clone)]
292pub(crate) struct DecomposedPlan {
293    pub atoms: Arc<DenseIdMap<AtomId, Atom>>,
294    pub header: Vec<JoinHeader>,
295    pub stages: JoinStageBlocks,
296    pub result_block: JoinStages,
297    pub actions: ActionId,
298}
299
300impl SinglePlan {
301    pub(crate) fn to_report(&self, symbol_map: &SymbolMap) -> egglog_reports::Plan {
302        use egglog_reports::{
303            Plan as ReportPlan, Scan as ReportScan, SingleScan as ReportSingleScan,
304            Stage as ReportStage,
305        };
306        const INTERNAL_PREFIX: &str = "@";
307        let get_var = |var: Variable| {
308            symbol_map
309                .vars
310                .get(&var)
311                .map(|s| s.to_string())
312                .unwrap_or_else(|| format!("{INTERNAL_PREFIX}x{var:?}"))
313        };
314        let get_atom = |atom: AtomId| {
315            symbol_map
316                .atoms
317                .get(&atom)
318                .map(|s| s.to_string())
319                .unwrap_or_else(|| format!("{INTERNAL_PREFIX}R{atom:?}"))
320        };
321        let mut stages = Vec::new();
322        for (i, stage) in self.stages.instrs.iter().enumerate() {
323            let report_stage = match stage {
324                JoinStage::Intersect { var, scans } => {
325                    let var_name = get_var(*var);
326                    let report_scans = scans
327                        .iter()
328                        .map(|scan| {
329                            let atom_name = get_atom(scan.atom);
330                            ReportSingleScan(
331                                atom_name,
332                                (var_name.clone(), scan.column.index() as i64),
333                            )
334                        })
335                        .collect();
336                    ReportStage::Intersect {
337                        scans: report_scans,
338                    }
339                }
340                JoinStage::FusedIntersect {
341                    cover,
342                    bind: _,
343                    to_intersect,
344                    is_leaf_scan: _,
345                } => {
346                    let cover_atom_name = get_atom(cover.to_index.atom);
347                    let cover_cols: Vec<(String, i64)> = cover
348                        .to_index
349                        .vars
350                        .iter()
351                        .map(|col| {
352                            let var_name =
353                                get_var(self.atoms[cover.to_index.atom].get_var(*col).unwrap());
354                            (var_name, col.index() as i64)
355                        })
356                        .collect();
357                    let report_cover = ReportScan(cover_atom_name, cover_cols);
358                    let report_to_intersect = to_intersect
359                        .iter()
360                        .map(|(scan, key_spec)| {
361                            let atom_name = get_atom(scan.to_index.atom);
362                            let cols: Vec<(String, i64)> = key_spec
363                                .iter()
364                                .map(|col| {
365                                    let var_name = get_var(
366                                        self.atoms[scan.to_index.atom].get_var(*col).unwrap(),
367                                    );
368                                    (var_name, col.index() as i64)
369                                })
370                                .collect();
371                            ReportScan(atom_name, cols)
372                        })
373                        .collect();
374                    ReportStage::FusedIntersect {
375                        cover: report_cover,
376                        to_intersect: report_to_intersect,
377                    }
378                }
379                JoinStage::FusedIntersectMat {
380                    cover: _,
381                    mode: _,
382                    bind: _,
383                    to_intersect: _,
384                } => {
385                    todo!("materialization")
386                }
387            };
388            let next = if i == self.stages.instrs.len() - 1 {
389                vec![]
390            } else {
391                vec![i + 1]
392            };
393            stages.push((report_stage, None, next));
394        }
395        ReportPlan { stages }
396    }
397}
398
399/// The algorithm used to produce a join plan.
400#[derive(Default, Copy, Clone)]
401pub enum PlanStrategy {
402    /// Free Join: Iteratively pick the smallest atom as the cover for the next
403    /// stage, until all subatoms have been visited.
404    PureSize,
405
406    /// Free Join: Pick an approximate minimal set of covers, then order those
407    /// covers in increasing order of size.
408    ///
409    /// This is similar to PureSize but we first limit the potential atoms that
410    /// can act as covers so as to minimize the total number of stages in the
411    /// plan. This is only an approximate minimum: the problem of finding the
412    /// exact minimum ("set cover") is NP-hard.
413    MinCover,
414
415    /// Generate a plan for the classic Generic Join algorithm, constraining a
416    /// single variable per stage.
417    #[default]
418    Gj,
419}
420
421/// Pick the next variable to eliminate and computes its neighborhood.
422///
423/// Each time, we pick a variable that has the least number of occurrences and find its neighborhood* (i.e.,
424/// the set of variables that share an atom with it). We pick the neighborhood based on the "min-fill" heuristic,
425/// which tries to eliminate neighborhood that would introduce the least number of new hyperedges.
426/// A hyperedge is introduced during variable elimination if two variables that don't share an atom before are in the same neighborhood.
427///
428/// *: We find the closure of the neighborhood under functional dependencies, since these variables are "for free".
429fn next_var_to_eliminate(
430    vars: &DenseIdMap<Variable, VarInfo>,
431    atoms: &DenseIdMap<AtomId, Atom>,
432    fun_deps: &FunDeps,
433) -> Option<IndexSet<Variable>> {
434    let (_var, subquery_vars) = vars
435        .iter()
436        .map(|(var, _vinfo)| {
437            let subquery_vars = atoms
438                .iter()
439                // every atom that contains this variable
440                .filter(|(_, atom)| atom.get_col(var).is_some())
441                // every variable of those atoms
442                .flat_map(|(_, atom)| atom.vars());
443
444            // Optimization: use functional dependencies to find all variables inferred by the
445            // current neightborhood.
446            let subquery_vars = fun_deps.closure(subquery_vars);
447
448            let occ = atoms
449                .iter()
450                .filter(|(_, atom)| atom.vars().any(|v| subquery_vars.contains_key(v)))
451                .count();
452            (occ, var, subquery_vars)
453        })
454        .min_by_key(|a| a.0)
455        .map(|a| (a.1, a.2))?;
456    Some(IndexSet::from_iter(
457        subquery_vars.into_iter().map(|(var, _)| var),
458    ))
459}
460
461/// It updates the hypergraph with the given bag of variables by:
462/// 1. Remove atoms that only contain variables in the bag and remove those atoms from variable's occurrences,
463/// 2. Add a covering hyperedge that contains every non-private variable.
464fn update_hypergraph(
465    subquery_vars: &IndexSet<Variable>,
466    vars: &mut DenseIdMap<Variable, VarInfo>,
467    atoms: &mut DenseIdMap<AtomId, Atom>,
468) {
469    // Build the covering hyperedge before we remove from the hypergraph
470
471    // Find variables that occur not just in the subquery
472    let covering_vars: Vec<_> = subquery_vars
473        .iter()
474        .copied()
475        .filter(|&var| {
476            vars.contains_key(var)
477                && vars[var].occurrences.iter().any(|occ| {
478                    atoms[occ.atom]
479                        .vars()
480                        .any(|ov| !subquery_vars.contains(&ov))
481                })
482        })
483        .collect();
484
485    // Remove atoms from the hypergraph
486    let mut removed = Vec::new();
487    atoms.retain(|atom_id, atom| {
488        if atom.vars().all(|var| subquery_vars.contains(&var)) {
489            removed.push(atom_id);
490            false
491        } else {
492            true
493        }
494    });
495
496    // Update occurrences to reflect removed atoms
497    for &subq_var in subquery_vars.iter() {
498        if vars.contains_key(subq_var) {
499            vars[subq_var]
500                .occurrences
501                .retain(|occ| !removed.contains(&occ.atom));
502
503            if vars[subq_var].occurrences.is_empty() {
504                vars.unwrap_val(subq_var);
505            }
506        }
507    }
508
509    // Add the covering atom to the hypergraph
510    let mut var_columns = VarColumnMap::default();
511    for (ix, var) in covering_vars.iter().enumerate() {
512        var_columns.insert(*var, ColumnId::from_usize(ix));
513    }
514    let fake_atom_id = atoms.push(Atom {
515        var_columns,
516        constraints: ProcessedConstraints::dummy(),
517        table: TableId::dummy(),
518    });
519
520    // Update variable occurrences to include the covering atom
521    for (i, &covering_var) in covering_vars.iter().enumerate() {
522        vars[covering_var].occurrences.push(SubAtom {
523            atom: fake_atom_id,
524            vars: smallvec![ColumnId::from_usize(i)],
525        });
526    }
527}
528
529/// This function does tree decomposition. At a high level, it takes a bag (equivalently, a `PlanningContext`, a subquery, a hypergraph,
530/// or a set of variables + atoms), and returns a list of bags that forms a tree decomposition.
531///
532/// Recall that a bag is equivalent to a hypergraph, where vertices = variables and hyperedges = atoms.
533///
534/// The algorithm is based on the classical variable elimination, where it iteratively removes neighborhoods until no variables are left.
535/// More specifically, it iteratively
536///
537/// 1. Select a variable `v` and its neighborhood `N(v)`, based on the "min-fill" heuristic. (`next_var_to_eliminate`)
538/// 2. Remove the neighborhood from the working hypergraph. (`update_hypergraph`)
539/// 3. Add a covering atom that contains variables `N(v) - {v}` to the working hypergraph. (`update_hypergraph`)
540/// 4. Step 1-3 gives us a set of variables `N(v)`. We need to construct a subquery from it. This step is a bit subtle.
541///
542///    For example, consider the rectangle query `R(x, y), S(y, z), T(z, w), U(w, x)`. Let's say we pick variable `x` to eliminate.
543///    The neighborhood `N(x)` of `x` is {x, y, z}. A naive approach is to subquery would be `R(x, y), S(y, z)`, but this query can have size quadratic,
544///    even when the final output size is small. The issue here is `x` and `z` are not fully constrained in this subquery.
545///    Another approach is to include every atom that contains variables in `N(x)`, but this gives us the entire query as the subquery for that rectangle query,
546///    which is also not ideal because the rectangle query should be broken into two bags.
547///
548///    The solution is to include every atom that contains variables in `N(x)`, but only keep the variables in `N(x)` in those atoms. For the rectangle example,
549///    this would be `R(x, y), S(y, z), T(z, -), U(-, x)`, where `-` means we don't expand this variable during evaluation. As a result, the produced PlanningContext
550///    may have atoms whose variables are not in `PlanningContext::vars`. The query planner for a single bag handles this correctly.
551///
552/// Now we have collected a list of bags, but they are very redundant. (Remember the variable elimination loop is run |vars| steps, because each iteration eliminates
553/// only one variable.) We need to prune these bags. See the comments in the code for details.
554///
555/// Another invariant we maintain is higher-indexed bags are heavier (closer to the root of the tree decomposition), so they will be evaluated later and constrained
556/// by evaluation of earlier bags.
557fn decompose_into_bags(original_ctx: &PlanningContext) -> Vec<PlanningContext> {
558    let mut atoms = original_ctx.atoms.clone();
559    let mut vars = original_ctx.vars.clone();
560
561    // Prune variables with no occurrences
562    for (var, vinfo) in original_ctx.vars.iter() {
563        if vinfo.occurrences.is_empty() {
564            vars.take(var).unwrap();
565        }
566    }
567
568    let mut bags = Vec::new();
569
570    // Variable elimination loop
571    while let Some(subquery_vars) = next_var_to_eliminate(&vars, &atoms, &original_ctx.fun_deps) {
572        // Create a fake covering atom to bridge back to the main query
573        // Remove hyperedges that only contain subquery variables.
574        update_hypergraph(&subquery_vars, &mut vars, &mut atoms);
575
576        // Collect atoms that only contain subquery variables.
577        let subquery_atoms: DenseIdMap<AtomId, Atom> = original_ctx
578            .atoms
579            .iter()
580            .filter(|(_, atom)| atom.vars().any(|var| subquery_vars.contains(&var)))
581            .map(|(atom_id, atom)| (atom_id, atom.clone()))
582            .collect();
583
584        let subquery_var_map = DenseIdMap::from_iter(subquery_vars.iter().map(|var| {
585            let mut var_info = original_ctx.vars[*var].clone();
586            // NB: used_in_rhs is handled in [`plan_single_bag`]
587            var_info
588                .occurrences
589                .retain(|occ| subquery_atoms.contains_key(occ.atom));
590            (*var, var_info)
591        }));
592
593        bags.push(PlanningContext {
594            vars: subquery_var_map,
595            atoms: subquery_atoms,
596            fun_deps: original_ctx.fun_deps.clone(),
597        });
598    }
599
600    assert!(
601        !atoms.iter().any(|(_, atom_info)| {
602            !atom_info.table.is_dummy() && !atom_info.var_columns.is_empty()
603        }),
604        "All atoms should be put into bags"
605    );
606
607    // Iteratively prune the query
608    let mut changed = true;
609    while changed {
610        changed = false;
611        // Pruning 1: Remove bags that are subsumed by others. A bag is subsumed by another bag if all of its variables are contained in the other bag,
612        // so the output of this bag must be a subset of the bigger bag.
613        let mut pruned_bags: Vec<PlanningContext> = Vec::with_capacity(bags.len());
614        for mut bag1 in bags.into_iter() {
615            pruned_bags.retain_mut(|bag2| {
616                let leq = bag1.is_subsumed_by(bag2);
617                let geq = bag2.is_subsumed_by(&bag1);
618                if leq || geq {
619                    bag1.merge_bag(bag2);
620                    changed = true;
621                    false
622                } else {
623                    true
624                }
625            });
626            pruned_bags.push(bag1);
627        }
628
629        // Pruning 2: Find "ears" and merge them with other bags. A bag is an ear if one of its atoms covers all of its variables, i.e., it only has one useful
630        // relation. We can safely remove an ear if it shares variables with only one bag - in this case, that bag is necessarily the parent in the tree decomposition.
631        //
632        // Why removing ears? Let's say an ear has the form R(x, y, z) with message variable {x}. The evaluation of its parent will already intersect on `x` with `R(x, y, z)`,
633        // so if `y` and `z` are expanded at the innermost loop of the evaluation, this does not incur any overhead. Versus if we keep this ear as a separate bag,
634        // we would need to first build a map x -> (y, z) only to enumerate each x to get the corresponding (y, z) values.
635        bags = pruned_bags;
636        let is_ear = |bag: &PlanningContext| {
637            bag.atoms.iter().any(|(_atom_id, atom)| {
638                let all_vars = original_ctx.fun_deps.closure(atom.vars());
639                bag.is_subsumed_by_vars(&all_vars)
640            })
641            // HACK: this weird condition says if there's exactly one atom whose variables are all wanted, then we can also treat it as an ear,
642            // because other atoms in the bag are likely added only to constrain the bag. This is approximately what a bag is, but not really.
643            // However, removing this condition makes some benchmark much worse...
644            || bag
645                .atoms
646                .iter()
647                .filter(|(_atom_id, atom)| bag.has_vars(atom.vars()))
648                .count()
649                == 1
650        };
651
652        let mut i = 0;
653        while i < bags.len() {
654            if !is_ear(&bags[i]) {
655                i += 1;
656                continue;
657            }
658
659            // Find the bag that shares the most variables with this ear bag, and merge the ear bag into it.
660            let parent = bags
661                .iter()
662                .enumerate()
663                .rev()
664                .filter(|(j, _)| *j != i)
665                .map(|(j, b)| (j, b.common_vars_with(&bags[i]).count()))
666                .collect::<Vec<_>>();
667
668            let j = parent.into_iter().max_by_key(|(_, count)| *count);
669            if j.is_none() || j.unwrap().1 == 0 {
670                i += 1;
671                continue;
672            }
673            let j = j.unwrap().0;
674
675            // Invariant: bigger-numbered bags are heavier and should stay at the root of the tree
676            if i < j {
677                let bag = mem::take(&mut bags[i]);
678                bags[j].merge_bag(&bag);
679                bags.remove(i);
680            } else {
681                let bag = mem::take(&mut bags[j]);
682                bags[i].merge_bag(&bag);
683                bags.remove(j);
684                i += 1;
685            }
686            changed = true;
687        }
688    }
689    bags
690}
691
692/// Topologically sorts bags based on variable dependencies, and merges bags so
693/// that the final result is a *chain*. This means `plan_single_bag` only ever
694/// needs a single prologue per bag and never an epilogue. This is because the
695/// epilogues do not participate in joins and are checked only after the main
696/// join loop, so they can easily lead to cartesian products.
697///
698/// At every DFS node we pick one child as the chain continuation. Every other reachable bag —
699/// siblings *and* their entire sub-trees — gets absorbed into the current chain node. The
700/// continuation is picked in a way that minimizes the maximum number of atoms in a bag, i.e.,
701/// the pathwidth.
702///
703/// The pathwidth of a path decomposition is the maximum bag size (minus one) over all bags,
704/// and the size of a bag is measured as the number of atoms in the bag.
705fn topologically_sort_bags(bags: Vec<PlanningContext>) -> Vec<PlanningContext> {
706    let mut all_children_list: Vec<Vec<usize>> = vec![vec![]; bags.len()];
707    // best_pathwidth[i] = the best pathwidth of the chain if we pick bag i
708    // to be the chain child.
709    let mut best_pathwidth = vec![usize::MAX; bags.len()];
710    let mut full = vec![HashSet::default(); bags.len()];
711    let mut choice = vec![usize::MAX; bags.len()];
712    for i in 0..bags.len() {
713        let mut full_i: HashSet<AtomId> =
714            bags[i].atoms.iter().map(|(atom_id, _)| atom_id).collect();
715        for child in all_children_list[i].iter() {
716            full_i.extend(full[*child].iter().copied());
717        }
718        full[i] = full_i;
719        best_pathwidth[i] = full[i].len();
720        for chain_child in all_children_list[i].iter() {
721            let mut chain_score: HashSet<_> =
722                bags[i].atoms.iter().map(|(atom_id, _)| atom_id).collect();
723            chain_score.extend(
724                all_children_list[*chain_child]
725                    .iter()
726                    .filter(|child| *child != chain_child)
727                    .flat_map(|child| full[*child].iter().copied()),
728            );
729            let s = chain_score.len().max(best_pathwidth[*chain_child]);
730            if s <= best_pathwidth[i] {
731                best_pathwidth[i] = s;
732                choice[i] = *chain_child;
733            }
734        }
735
736        // Find the parent of this bag, which must be the lowerest-numbered bag
737        // that shares the most variables with it.
738        let parent = bags
739            .iter()
740            .enumerate()
741            .skip(i + 1)
742            .map(|(j, b)| (j, b.common_vars_with(&bags[i]).count()))
743            .filter(|(_, count)| *count > 0)
744            .max_by_key(|(j, count)| (*count, -(*j as isize)));
745        if let Some((j, _count)) = parent {
746            all_children_list[j].push(i);
747        }
748    }
749
750    let mut bags_opt = bags.into_iter().map(Some).collect::<Vec<_>>();
751    let mut bags_topo = Vec::<PlanningContext>::with_capacity(bags_opt.len());
752    let mut visited = vec![false; bags_opt.len()];
753    // Stack entries: (bag_id, parent). `parent` is None for chain nodes (the bag is
754    // pushed to `bags_topo` as a new standalone entry) and Some(idx) for nodes being
755    // absorbed into `bags_topo[idx]`.
756    let mut stack: Vec<(usize, Option<usize>)> = Vec::new();
757
758    // Starting from the last, since early bags are more likely to be leaves and we don't
759    // want a leafy bag to be a root.
760    for i in (0..bags_opt.len()).rev() {
761        if visited[i] {
762            continue;
763        }
764        stack.push((i, None));
765        visited[i] = true;
766
767        while let Some((bag_id, parent)) = stack.pop() {
768            let bag = mem::take(&mut bags_opt[bag_id]).unwrap();
769
770            let this;
771            if let Some(parent) = parent {
772                bags_topo[parent].merge_bag(&bag);
773                this = parent;
774            } else {
775                this = bags_topo.len();
776            }
777
778            let all_children = &mut all_children_list[bag_id];
779
780            if parent.is_some() {
781                // This bag is being absorbed into `bags_topo[this]`. To keep the
782                // result a chain, every descendant of this bag is also absorbed —
783                // none of them get to spawn a new chain node.
784                for &i in all_children.iter() {
785                    visited[i] = true;
786                    stack.push((i, Some(this)));
787                }
788            } else {
789                // This bag is a chain node. The child that minimizes pathwidth continues the
790                // chain; the rest (and all their descendants, via the branch above)
791                // are absorbed into this chain node.
792                if !all_children.is_empty() {
793                    for &i in all_children[1..].iter() {
794                        if i == choice[bag_id] {
795                            continue;
796                        }
797                        visited[i] = true;
798                        stack.push((i, Some(this)));
799                    }
800                    visited[choice[bag_id]] = true;
801                    stack.push((choice[bag_id], None));
802                }
803            }
804
805            if parent.is_none() {
806                bags_topo.push(bag);
807            }
808        }
809    }
810
811    bags_topo.reverse();
812    bags_topo
813}
814
815/// Counts how many bags each variable appears in.
816///
817/// This is used to determine whether a variable should be passed as a message
818/// variable (if used in later bags) or a value variable (if only used in the current bag).
819fn count_variable_usage_per_bag(bags: &[PlanningContext]) -> DenseIdMap<Variable, usize> {
820    let mut n_used_in_bag = DenseIdMap::new();
821    for bag in bags {
822        for (var, _vinfo) in bag.vars.iter() {
823            if !n_used_in_bag.contains_key(var) {
824                n_used_in_bag.insert(var, 0);
825            }
826            n_used_in_bag[var] += 1;
827        }
828    }
829    n_used_in_bag
830}
831
832/// Plans the execution stages for a single bag.
833///
834/// This involves:
835/// - Dividing variables into message variables (passed to later stages) and value variables
836/// - Planning join stages within the bag
837/// - Adding prologue and epilogue instructions so that the bag is constrained by previous materializations.
838///
839/// This function also sets the `used_in_rhs` field for variables. A variable is not used in RHS during the planning
840/// of a bag if it's not used in later bags.
841fn plan_single_bag(
842    bag: &mut PlanningContext,
843    blocks: &[(JoinStages, MatSpec)],
844    // If this bag has been used to prune its parent
845    has_block_contributed: &mut [bool],
846    n_used_in_bag: &mut DenseIdMap<Variable, usize>,
847    strat: PlanStrategy,
848) -> (Vec<JoinHeader>, JoinStages, MatSpec) {
849    let mut msg_vars = smallvec![];
850    let mut val_vars = smallvec![];
851
852    // Classify variables as message or value variables
853    for (var, vinfo) in bag.vars.iter_mut() {
854        n_used_in_bag[var] -= 1;
855        if n_used_in_bag[var] > 0 {
856            // If this is a public variable, then we need to pass it on anyway
857            vinfo.used_in_rhs = true;
858            msg_vars.push(var);
859        } else {
860            // If this variable is not used in later and previous bag,
861            // and it is not used in the right hand side,
862            // this variable doesn't need to be expanded.
863            if !vinfo.used_in_rhs
864                && blocks.iter().all(|(_, spec)| !spec.msg_vars.contains(&var))
865                && n_used_in_bag[var] == 0
866            {
867                continue;
868            }
869            val_vars.push(var);
870            vinfo.used_in_rhs = true;
871        }
872    }
873
874    let mut stripped_bag = bag.clone();
875
876    // Add prologue and epilogue instructions to look up previous materialized bags
877    // These are constraints from children blocks. If there's only one such block, it can be the header.
878    // Otherwise, they have to be epilogue instructions doing filtering at the end, which is less efficient.
879    let mut prologue = None;
880    let mut epilogue = Vec::new();
881    for (i, prev_block) in blocks.iter().enumerate().rev() {
882        if prev_block.1.msg_vars.is_empty() {
883            continue;
884        }
885        if !has_block_contributed[i]
886            && prev_block
887                .1
888                .msg_vars
889                .iter()
890                .all(|var| bag.vars.contains_key(*var))
891        {
892            has_block_contributed[i] = true;
893            if prologue.is_none() {
894                let bind = prev_block
895                    .1
896                    .msg_vars
897                    .iter()
898                    .enumerate()
899                    .map(|(j, var)| (ColumnId::from_usize(j), *var))
900                    .collect();
901                let mut to_intersect: Vec<(ScanSpec, SmallVec<[ColumnId; 2]>)> = vec![];
902                for (col, var) in prev_block.1.msg_vars.iter().enumerate() {
903                    let vinfo = &bag.vars[*var];
904                    for occ in vinfo.occurrences.iter() {
905                        let isect = match to_intersect
906                            .iter_mut()
907                            .find(|(spec, _)| spec.to_index.atom == occ.atom)
908                        {
909                            Some(isect) => isect,
910                            None => {
911                                to_intersect.push((
912                                    ScanSpec {
913                                        to_index: SubAtom {
914                                            atom: occ.atom,
915                                            vars: smallvec![],
916                                        },
917                                        constraints: vec![],
918                                    },
919                                    smallvec![],
920                                ));
921                                to_intersect.last_mut().unwrap()
922                            }
923                        };
924                        isect.0.to_index.vars.extend(occ.vars.iter().copied());
925                        isect
926                            .1
927                            .extend(occ.vars.iter().map(|_| ColumnId::from_usize(col)));
928                    }
929                }
930
931                prologue = Some(JoinStage::FusedIntersectMat {
932                    cover: MatId::from_usize(i),
933                    mode: MatScanMode::KeyOnly,
934                    bind,
935                    to_intersect,
936                });
937
938                stripped_bag
939                    .vars
940                    .retain(|var, _vinfo| !prev_block.1.msg_vars.contains(&var));
941            } else {
942                epilogue.push(JoinStage::FusedIntersectMat {
943                    cover: MatId::from_usize(i),
944                    mode: MatScanMode::Lookup(prev_block.1.msg_vars.clone()),
945                    bind: smallvec![],
946                    to_intersect: vec![],
947                });
948            }
949        }
950    }
951
952    let (header, mut instrs) = plan_stages(&stripped_bag, strat);
953    instrs.splice(0..0, prologue);
954    instrs.extend(epilogue);
955    // `plan_gj` decides `is_leaf_scan` based only on the stages it produces, so it cannot
956    // see the prologue/epilogue stages we just spliced in. A leaf scan factorizes its
957    // bound variables into `binding_info.binding_sets` rather than `binding_info.bindings`,
958    // which breaks any later `FusedIntersectMat::Value`/`Lookup` that reads those variables
959    // directly from `bindings`. Downgrade bad leaf scans now that the full instruction
960    // sequence is known.
961    revert_bad_leaf_scans(&mut instrs);
962
963    let stages = JoinStages {
964        instrs: Arc::new(instrs),
965    };
966
967    (header, stages, MatSpec { msg_vars, val_vars })
968}
969
970/// Builds the final result block that collects results from all materialized bags.
971///
972/// This performs a bottom-up pass through the materialized bags, binding value
973/// variables and gathering results. Each block is scanned at most once.
974fn build_result_block(blocks: &[(JoinStages, MatSpec)]) -> JoinStages {
975    let mut result_block = Vec::new();
976    let mut pinned_vars = DenseIdMap::<Variable, ()>::new();
977
978    for (i, (_stages, mat_spec)) in blocks.iter().enumerate().rev() {
979        let to_bind: SmallVec<[(ColumnId, Variable); 2]> = mat_spec
980            .val_vars
981            .iter()
982            .copied()
983            .enumerate()
984            .filter(|(_, var)| !pinned_vars.contains_key(*var))
985            .map(|(i, var)| (ColumnId::from_usize(i), var))
986            .collect();
987
988        if to_bind.is_empty() {
989            continue;
990        }
991
992        for (_, var) in to_bind.iter() {
993            pinned_vars.insert(*var, ());
994        }
995
996        result_block.push(JoinStage::FusedIntersectMat {
997            cover: MatId::from_usize(i),
998            mode: if i == blocks.len() - 1 {
999                MatScanMode::Full
1000            } else {
1001                MatScanMode::Value(mat_spec.msg_vars.clone())
1002            },
1003            bind: to_bind,
1004            to_intersect: vec![],
1005        });
1006    }
1007
1008    JoinStages {
1009        instrs: Arc::new(result_block),
1010    }
1011}
1012
1013/// The last stage and the result block have the following structure:
1014///
1015/// for ...
1016///    yield [] -> x1, x2, ... as Mn
1017///
1018/// For x1, x2, ... in Mn:
1019///   ...
1020///
1021/// This can be fused into one loop
1022///
1023/// This is currently not used because somehow iterating the materialized RowBuffer is much faster than iterating the table
1024#[allow(unused)]
1025fn fuse_last_stage(
1026    mut blocks: Vec<(JoinStages, MatSpec)>,
1027    result_block: JoinStages,
1028) -> (Vec<(JoinStages, MatSpec)>, JoinStages) {
1029    if blocks.is_empty() {
1030        return (blocks, result_block);
1031    }
1032
1033    let last_block = blocks.pop().unwrap();
1034    assert!(last_block.1.msg_vars.is_empty());
1035    if !matches!(
1036        result_block.instrs[0],
1037        JoinStage::FusedIntersectMat {
1038            cover,
1039            mode: MatScanMode::Full,
1040            ..
1041        } if cover == MatId::from_usize(blocks.len()
1042    )) {
1043        // If the first stage of the result block does not scan the last materialization
1044        return (blocks, result_block);
1045    }
1046
1047    // Fuse the instructions
1048    let mut last_block = last_block.0;
1049    let mut instrs = Arc::unwrap_or_clone(last_block.instrs);
1050    instrs.extend(result_block.instrs[1..].iter().cloned());
1051    last_block.instrs = Arc::new(instrs);
1052
1053    (blocks, last_block)
1054}
1055
1056/// Downgrade any `FusedIntersect { is_leaf_scan: true, .. }` whose factorized bindings
1057/// would be observed as missing by a later stage.
1058///
1059/// `is_leaf_scan` instructs the executor to factorize the stage's bindings into
1060/// `BindingInfo::binding_sets` instead of materializing them into `BindingInfo::bindings`.
1061/// That is OK only if no subsequent stage in the same `JoinStages` reads any of those
1062/// variables as a scalar value from `bindings`. The only stages that do so are
1063/// `FusedIntersectMat::Value` / `Lookup`, whose `index_vars` are looked up directly.
1064///
1065/// `plan_gj` performs its own (cover-atom-based) check, but that runs before prologue /
1066/// epilogue stages are spliced into the instruction list, so it cannot detect
1067/// `FusedIntersectMat` lookups that come from the epilogue. This pass closes that gap.
1068///
1069/// # Example
1070///
1071/// Suppose `plan_gj` produces the following two stages for a bag, where stage 0 is
1072/// the last `FusedIntersect` and binds variable `y` from atom `A` via a leaf scan
1073/// (no atom later in `plan_gj`'s output reuses `A`):
1074///
1075/// ```text
1076///   stage 0: FusedIntersect { cover: A, bind: [(col, y)], is_leaf_scan: true, ... }
1077/// ```
1078///
1079/// `plan_single_bag` then splices an epilogue that looks up an earlier block's
1080/// materialization keyed on `y`:
1081///
1082/// ```text
1083///   stage 0: FusedIntersect { cover: A, bind: [(col, y)], is_leaf_scan: true, ... }
1084///   stage 1: FusedIntersectMat { mode: Lookup([y]), ... }
1085/// ```
1086///
1087/// With `is_leaf_scan: true`, stage 0 puts `y` into `binding_sets` instead of
1088/// `bindings`. Stage 1 then tries `binding_info.bindings[y]` and panics on a missing
1089/// entry. This pass detects the overlap (`y ∈ bound_vars ∩ Lookup`'s index_vars) and
1090/// flips stage 0's `is_leaf_scan` back to `false`.
1091fn revert_bad_leaf_scans(stages: &mut [JoinStage]) {
1092    for i in 0..stages.len() {
1093        let bound_vars: SmallVec<[Variable; 4]> = match &stages[i] {
1094            JoinStage::FusedIntersect {
1095                bind,
1096                is_leaf_scan: true,
1097                ..
1098            } => bind.iter().map(|(_, v)| *v).collect(),
1099            _ => continue,
1100        };
1101
1102        let needs_scalar = ((i + 1)..stages.len()).any(|j| match &stages[j] {
1103            JoinStage::FusedIntersectMat {
1104                mode: MatScanMode::Value(vars) | MatScanMode::Lookup(vars),
1105                ..
1106            } => vars.iter().any(|v| bound_vars.contains(v)),
1107            _ => false,
1108        });
1109
1110        if needs_scalar && let JoinStage::FusedIntersect { is_leaf_scan, .. } = &mut stages[i] {
1111            *is_leaf_scan = false;
1112        }
1113    }
1114}
1115
1116/// Eagerly lift materialization lookups up
1117///
1118/// For example, in the following, looking up of `r` can be lifted up before `z`
1119///
1120/// for x in R isec S:
1121///  R = R[x]; S = S[x]
1122///  for z in R:
1123///   if r in Mat[x]:
1124///     yield
1125fn loop_lifting(stages: JoinStages) -> JoinStages {
1126    let mut instrs = Arc::unwrap_or_clone(stages.instrs);
1127    for i in 1..instrs.len() {
1128        if let JoinStage::FusedIntersectMat {
1129            cover: _,
1130            mode: MatScanMode::Lookup(vars),
1131            bind,
1132            to_intersect,
1133        } = &instrs[i]
1134        {
1135            assert!(bind.is_empty() && to_intersect.is_empty());
1136            let vars = vars.clone();
1137            let mut j = i;
1138            while j > 0 {
1139                if matches!(
1140                    &instrs[j - 1], JoinStage::FusedIntersect { bind, .. } | JoinStage::FusedIntersectMat { bind, ..}
1141                        if bind.iter().all(|(_, var)| !vars.contains(var))
1142                ) || matches!(&instrs[j - 1], JoinStage::Intersect { var, .. } if !vars.contains(var))
1143                {
1144                    instrs.swap(j - 1, j);
1145                    j -= 1;
1146                } else {
1147                    break;
1148                }
1149            }
1150        }
1151    }
1152    JoinStages {
1153        instrs: Arc::new(instrs),
1154    }
1155}
1156
1157/// This is the main entry point for query optimization using tree decomposition.
1158pub(crate) fn tree_decompose_and_plan(
1159    ctx: PlanningContext,
1160    strat: PlanStrategy,
1161    actions: ActionId,
1162) -> Plan {
1163    macro_rules! fast_path {
1164        () => {{
1165            let (header, instrs) = plan_stages(&ctx, strat);
1166            let stages = JoinStages {
1167                instrs: Arc::new(instrs),
1168            };
1169
1170            Plan::SinglePlan(SinglePlan {
1171                atoms: Arc::new(ctx.atoms),
1172                header,
1173                stages,
1174                actions,
1175            })
1176        }};
1177    }
1178    if ctx.atoms.len() <= 2 {
1179        return fast_path!();
1180    }
1181
1182    // Step 1: Decompose the query into tree-structured bags
1183    let bags = decompose_into_bags(&ctx);
1184    if bags.len() <= 1 {
1185        // Don't do Yannakakis if it's just one bag
1186        return fast_path!();
1187    }
1188
1189    // Step 2: Sort bags topologically and merge leafy bags with their parents
1190    let mut bags = topologically_sort_bags(bags);
1191
1192    if bags.len() <= 1 {
1193        return fast_path!();
1194    }
1195
1196    // Step 3: Count variable usage across bags. Used for deciding if a variable is public (i.e., message variables) or private.
1197    let mut n_used_in_bag = count_variable_usage_per_bag(&bags);
1198    let mut has_block_contributed = vec![false; bags.len()];
1199
1200    // Step 4: Plan each bag and create materialization blocks
1201    let mut blocks = Vec::new();
1202    let mut header = vec![];
1203    for bag in bags.iter_mut() {
1204        let (bag_header, stages, mat_spec) = plan_single_bag(
1205            bag,
1206            &blocks,
1207            &mut has_block_contributed,
1208            &mut n_used_in_bag,
1209            strat,
1210        );
1211        blocks.push((stages, mat_spec));
1212        header.extend(bag_header);
1213    }
1214
1215    // Step 5: Build the final result block
1216    let result_block = build_result_block(&blocks);
1217
1218    // Optimization the avoids the last materialization
1219    // let (blocks, result_block) = fuse_last_stage(blocks, result_block);
1220
1221    // Lifting variables
1222    let blocks = blocks
1223        .into_iter()
1224        .map(|(stages, mat_spec)| (loop_lifting(stages), mat_spec))
1225        .collect::<Vec<_>>();
1226    let result_block = loop_lifting(result_block);
1227
1228    Plan::DecomposedPlan(DecomposedPlan {
1229        atoms: Arc::new(ctx.atoms),
1230        header,
1231        stages: JoinStageBlocks { blocks },
1232        result_block,
1233        actions,
1234    })
1235}
1236
1237pub(crate) fn plan_query(query: Query) -> Plan {
1238    let atoms = query.atoms;
1239    let ctx = PlanningContext {
1240        vars: query.var_info,
1241        atoms,
1242        fun_deps: Arc::new(query.fun_deps),
1243    };
1244    tree_decompose_and_plan(ctx, query.plan_strategy, query.action)
1245}
1246
1247/// StageInfo is an intermediate stage used to describe the ordering of
1248/// operations. One of these contains enough information to "expand" it to a
1249/// JoinStage, but it still contains variable information.
1250///
1251/// This separation makes it easier for us to iterate with different planning
1252/// algorithms while sharing the same "backend" that generates a concrete plan.
1253#[derive(Debug)]
1254struct StageInfo {
1255    cover: SubAtom,
1256    vars: SmallVec<[Variable; 1]>,
1257    filters: Vec<(
1258        SubAtom,                 /* the subatom to index */
1259        SmallVec<[ColumnId; 2]>, /* how to build a key for that index from the cover atom */
1260    )>,
1261}
1262
1263/// Immutable context for query planning containing references to query metadata.
1264#[derive(Debug, Clone, Default)]
1265pub(crate) struct PlanningContext {
1266    vars: DenseIdMap<Variable, VarInfo>,
1267    atoms: DenseIdMap<AtomId, Atom>,
1268    fun_deps: Arc<FunDeps>,
1269}
1270
1271impl PlanningContext {
1272    fn is_subsumed_by(&self, bag2: &PlanningContext) -> bool {
1273        self.is_subsumed_by_vars(&bag2.vars)
1274    }
1275
1276    fn is_subsumed_by_vars<I>(&self, bag2: &DenseIdMap<Variable, I>) -> bool {
1277        self.vars.iter().all(|(var, _)| bag2.contains_key(var))
1278    }
1279
1280    fn merge_bag(&mut self, bag2: &PlanningContext) {
1281        for (var, vinfo) in bag2.vars.iter() {
1282            if self.vars.contains_key(var) {
1283                for new_occ in vinfo.occurrences.iter().cloned() {
1284                    if !self.vars[var]
1285                        .occurrences
1286                        .iter()
1287                        .any(|occ| occ.atom == new_occ.atom)
1288                    {
1289                        self.vars[var].occurrences.push(new_occ);
1290                    }
1291                }
1292            } else {
1293                self.vars.insert(var, vinfo.clone());
1294            }
1295        }
1296        for (atom_id, atom) in bag2.atoms.iter() {
1297            // atoms don't need to be merged
1298            if !self.atoms.contains_key(atom_id) {
1299                self.atoms.insert(atom_id, atom.clone());
1300            }
1301        }
1302    }
1303
1304    fn common_vars_with<'a>(
1305        &'a self,
1306        other: &'a PlanningContext,
1307    ) -> impl Iterator<Item = Variable> + 'a {
1308        self.vars
1309            .iter()
1310            .filter(|(var, _)| other.vars.contains_key(*var))
1311            .map(|(var, _)| var)
1312    }
1313
1314    fn has_vars(&self, mut vars: impl Iterator<Item = Variable>) -> bool {
1315        vars.all(|var| self.vars.contains_key(var))
1316    }
1317}
1318
1319type VarSet = FixedBitSet;
1320type AtomSet = FixedBitSet;
1321
1322/// Mutable state tracked during query planning.
1323#[derive(Clone)]
1324pub(crate) struct PlanningState {
1325    used_vars: VarSet,
1326    constrained_atoms: AtomSet,
1327}
1328
1329impl PlanningState {
1330    fn new(n_vars: usize, n_atoms: usize) -> Self {
1331        Self {
1332            used_vars: VarSet::with_capacity(n_vars),
1333            constrained_atoms: AtomSet::with_capacity(n_atoms),
1334        }
1335    }
1336
1337    fn mark_var_used(&mut self, var: Variable) {
1338        self.used_vars.insert(var.index());
1339    }
1340
1341    fn is_var_used(&self, var: Variable) -> bool {
1342        self.used_vars.contains(var.index())
1343    }
1344
1345    fn mark_atom_constrained(&mut self, atom: AtomId) {
1346        self.constrained_atoms.insert(atom.index());
1347    }
1348
1349    fn is_atom_constrained(&self, atom: AtomId) -> bool {
1350        self.constrained_atoms.contains(atom.index())
1351    }
1352}
1353
1354/// Datastructure used to greedily solve the set cover problem for a given free
1355/// join plan.
1356struct BucketQueue<'a> {
1357    var_info: &'a DenseIdMap<Variable, VarInfo>,
1358    cover: VarSet,
1359    atom_info: DenseIdMap<AtomId, VarSet>,
1360    sizes: BTreeMap<usize, IndexSet<AtomId>>,
1361}
1362
1363impl<'a> BucketQueue<'a> {
1364    fn new(var_info: &'a DenseIdMap<Variable, VarInfo>, atoms: &DenseIdMap<AtomId, Atom>) -> Self {
1365        let cover = VarSet::with_capacity(var_info.n_ids());
1366        let mut atom_info = DenseIdMap::with_capacity(atoms.n_ids());
1367        let mut sizes = BTreeMap::<usize, IndexSet<AtomId>>::new();
1368        for (id, atom) in atoms.iter() {
1369            let mut bitset = VarSet::with_capacity(var_info.n_ids());
1370            for var in atom.vars() {
1371                bitset.insert(var.index());
1372            }
1373            sizes.entry(bitset.count_ones(..)).or_default().insert(id);
1374            atom_info.insert(id, bitset);
1375        }
1376        BucketQueue {
1377            var_info,
1378            cover,
1379            atom_info,
1380            sizes,
1381        }
1382    }
1383
1384    /// Return the atom with the largest number of uncovered variables. A
1385    /// variable is "covered" if a previous call to `pop_min` returned an atom
1386    /// referencing that variable.
1387    fn pop_min(&mut self) -> Option<AtomId> {
1388        // Pick an arbitrary atom from the smallest bucket.
1389        let (_, atoms) = self.sizes.iter_mut().next_back()?;
1390        let res = atoms.pop().unwrap();
1391        let vars = self.atom_info[res].clone();
1392        // For each variable that we added to the cover, remove it from the
1393        // entries in atom_info referencing it and update `sizes` to reflect the
1394        // new ordering.
1395        for new_var in vars.difference(&self.cover).map(Variable::from_usize) {
1396            for subatom in &self.var_info[new_var].occurrences {
1397                let cur_set = &mut self.atom_info[subatom.atom];
1398                let old_size = cur_set.count_ones(..);
1399                cur_set.difference_with(&vars);
1400                let new_size = cur_set.count_ones(..);
1401                if old_size == new_size {
1402                    continue;
1403                }
1404                if let Some(old_size_set) = self.sizes.get_mut(&old_size) {
1405                    old_size_set.swap_remove(&subatom.atom);
1406                    if old_size_set.is_empty() {
1407                        self.sizes.remove(&old_size);
1408                    }
1409                }
1410                if new_size > 0 {
1411                    self.sizes.entry(new_size).or_default().insert(subatom.atom);
1412                }
1413            }
1414        }
1415        self.cover.union_with(&vars);
1416        Some(res)
1417    }
1418}
1419
1420/// Build join headers from fast constraints and compute remaining constraints for planning.
1421/// Returns (headers, remaining_constraints) tuple.
1422fn plan_headers(
1423    ctx: &PlanningContext,
1424) -> (
1425    Vec<JoinHeader>,
1426    DenseIdMap<
1427        AtomId,
1428        (
1429            usize, /* The approx size of the subset matching the constraints. */
1430            &Pooled<Vec<Constraint>>,
1431        ),
1432    >,
1433) {
1434    let mut header = Vec::new();
1435    let mut remaining_constraints: DenseIdMap<AtomId, (usize, &Pooled<Vec<Constraint>>)> =
1436        Default::default();
1437
1438    for (atom, atom_info) in ctx.atoms.iter() {
1439        remaining_constraints.insert(
1440            atom,
1441            (
1442                atom_info.constraints.approx_size(),
1443                &atom_info.constraints.slow,
1444            ),
1445        );
1446        if !atom_info.constraints.fast.is_empty() {
1447            header.push(JoinHeader {
1448                atom,
1449                constraints: Pooled::cloned(&atom_info.constraints.fast),
1450                subset: atom_info.constraints.subset.clone(),
1451            });
1452        }
1453    }
1454
1455    (header, remaining_constraints)
1456}
1457
1458/// Plan query execution stages using the specified strategy.
1459/// Returns (header, instructions) tuple that can be assembled into a Plan by the caller.
1460/// It does not directly return the plan because the caller may want to further modify the stages.
1461fn plan_stages(ctx: &PlanningContext, strat: PlanStrategy) -> (Vec<JoinHeader>, Vec<JoinStage>) {
1462    let (header, remaining_constraints) = plan_headers(ctx);
1463    let mut instrs = Vec::new();
1464    let mut state = PlanningState::new(ctx.vars.n_ids(), ctx.atoms.n_ids());
1465
1466    match strat {
1467        PlanStrategy::PureSize | PlanStrategy::MinCover => {
1468            plan_free_join(ctx, &mut state, strat, &remaining_constraints, &mut instrs)
1469        }
1470        PlanStrategy::Gj => plan_gj(ctx, &mut state, &remaining_constraints, &mut instrs),
1471    };
1472
1473    (header, instrs)
1474}
1475
1476/// Plan free join queries using pure size or minimal cover strategy.
1477fn plan_free_join(
1478    ctx: &PlanningContext,
1479    state: &mut PlanningState,
1480    strat: PlanStrategy,
1481    remaining_constraints: &DenseIdMap<AtomId, (usize, &Pooled<Vec<Constraint>>)>,
1482    stages: &mut Vec<JoinStage>,
1483) {
1484    let mut size_info = Vec::<(AtomId, usize)>::new();
1485
1486    match strat {
1487        PlanStrategy::PureSize => {
1488            for (atom, (size, _)) in remaining_constraints.iter() {
1489                size_info.push((atom, *size));
1490            }
1491        }
1492        PlanStrategy::MinCover => {
1493            let mut eligible_covers = HashSet::default();
1494            let mut queue = BucketQueue::new(&ctx.vars, &ctx.atoms);
1495            while let Some(atom) = queue.pop_min() {
1496                eligible_covers.insert(atom);
1497            }
1498            for (atom, (size, _)) in remaining_constraints
1499                .iter()
1500                .filter(|(atom, _)| eligible_covers.contains(atom))
1501            {
1502                size_info.push((atom, *size));
1503            }
1504        }
1505        PlanStrategy::Gj => unreachable!(),
1506    };
1507
1508    size_info.sort_by_key(|(_, size)| *size);
1509    let mut atoms = size_info.iter().map(|(atom, _)| *atom);
1510
1511    while let Some(info) = get_next_freejoin_stage(ctx, state, &mut atoms) {
1512        let stage = compile_stage(ctx, state, info);
1513        stages.push(stage);
1514    }
1515}
1516
1517/// Generate the next free join stage by picking an atom from the ordering.
1518/// Returns the stage info and updated state, or None if all atoms are covered.
1519fn get_next_freejoin_stage(
1520    ctx: &PlanningContext,
1521    state: &mut PlanningState,
1522    ordering: &mut impl Iterator<Item = AtomId>,
1523) -> Option<StageInfo> {
1524    let mut scratch_subatom: HashMap<AtomId, SmallVec<[ColumnId; 2]>> = Default::default();
1525
1526    loop {
1527        let mut covered = false;
1528        let atom = ordering.next()?;
1529        let atom_info = &ctx.atoms[atom];
1530        let mut cover = SubAtom::new(atom);
1531        let mut vars = SmallVec::<[Variable; 1]>::new();
1532
1533        for (ix, var) in atom_info.var_columns.iter() {
1534            if state.is_var_used(var) {
1535                continue;
1536            }
1537            // This atom is not completely covered by previous stages.
1538            covered = true;
1539            state.mark_var_used(var);
1540            vars.push(var);
1541            cover.vars.push(ix);
1542
1543            for subatom in ctx.vars[var].occurrences.iter() {
1544                if subatom.atom == atom {
1545                    continue;
1546                }
1547                scratch_subatom
1548                    .entry(subatom.atom)
1549                    .or_default()
1550                    .extend(subatom.vars.iter().copied());
1551            }
1552        }
1553
1554        if !covered {
1555            // Search the next atom.
1556            continue;
1557        }
1558
1559        let mut filters = Vec::new();
1560        for (atom, cols) in scratch_subatom.drain() {
1561            let mut form_key = SmallVec::<[ColumnId; 2]>::new();
1562            for var_ix in &cols {
1563                let var = ctx.atoms[atom].get_var(*var_ix).unwrap();
1564                // form_key is an index _into the subatom forming the cover_.
1565                let cover_col = vars.iter().position(|v| *v == var).unwrap();
1566                form_key.push(ColumnId::from_usize(cover_col));
1567            }
1568            filters.push((SubAtom { atom, vars: cols }, form_key));
1569        }
1570
1571        return Some(StageInfo {
1572            cover,
1573            vars,
1574            filters,
1575        });
1576    }
1577}
1578
1579/// Plan generic join queries (one variable per stage).
1580fn plan_gj(
1581    ctx: &PlanningContext,
1582    state: &mut PlanningState,
1583    remaining_constraints: &DenseIdMap<AtomId, (usize, &Pooled<Vec<Constraint>>)>,
1584    stages: &mut Vec<JoinStage>,
1585) {
1586    // First, map all variables to the size of the smallest atom in which they appear:
1587    let mut min_sizes = Vec::with_capacity(ctx.vars.n_ids());
1588    let mut atoms_hit = AtomSet::with_capacity(ctx.atoms.n_ids());
1589    for (var, var_info) in ctx.vars.iter() {
1590        let n_occs = var_info.occurrences.len();
1591        if n_occs == 1 && !var_info.used_in_rhs {
1592            // Do not plan this one. Unless (see below).
1593            continue;
1594        }
1595        if let Some(min_size) = var_info
1596            .occurrences
1597            .iter()
1598            .map(|subatom| {
1599                atoms_hit.set(subatom.atom.index(), true);
1600                remaining_constraints[subatom.atom].0
1601            })
1602            .min()
1603        {
1604            min_sizes.push((var, min_size, n_occs));
1605        }
1606        // If the variable has no ocurrences, it may be bound on the RHS of a
1607        // rule (or it may just be unused). Either way, we will ignore it when
1608        // planning the query.
1609    }
1610    for (var, var_info) in ctx.vars.iter() {
1611        if var_info.occurrences.len() == 1 && !var_info.used_in_rhs {
1612            // We skipped this variable the first time around because it
1613            // looks "unused". If it belongs to an atom that otherwise has
1614            // gone unmentioned, though, we need to plan it anyway.
1615            let atom = var_info.occurrences[0].atom;
1616            if !atoms_hit.contains(atom.index()) {
1617                min_sizes.push((var, remaining_constraints[atom].0, 1));
1618            }
1619        }
1620    }
1621    // Sort ascending by size, then descending by number of occurrences.
1622    min_sizes.sort_by_key(|(_, size, occs)| (*size, -(*occs as i64)));
1623    for (var, _, _) in min_sizes {
1624        let occ = ctx.vars[var].occurrences[0].clone();
1625        let mut info = StageInfo {
1626            cover: occ,
1627            vars: smallvec![var],
1628            filters: Default::default(),
1629        };
1630        for occ in &ctx.vars[var].occurrences[1..] {
1631            info.filters
1632                .push((occ.clone(), smallvec![ColumnId::new(0); occ.vars.len()]));
1633        }
1634
1635        let next_stage = compile_stage(ctx, state, info);
1636        if let Some(prev) = stages.last_mut()
1637            && prev.fuse(&next_stage)
1638        {
1639            continue;
1640        }
1641        stages.push(next_stage);
1642    }
1643    for i in 0..stages.len() {
1644        if let JoinStage::FusedIntersect {
1645            cover,
1646            to_intersect,
1647            ..
1648        } = &stages[i]
1649            && to_intersect.is_empty()
1650        {
1651            let cover_atom = cover.to_index.atom;
1652            let mut used_later = false;
1653            for later_stage in &stages[i + 1..] {
1654                used_later = used_later
1655                    || match later_stage {
1656                        JoinStage::Intersect { scans, .. } => {
1657                            scans.iter().any(|scan| scan.atom == cover_atom)
1658                        }
1659                        JoinStage::FusedIntersect { cover, .. } => {
1660                            cover.to_index.atom == cover_atom
1661                        }
1662                        JoinStage::FusedIntersectMat { .. } => unreachable!(),
1663                    };
1664                if used_later {
1665                    break;
1666                }
1667            }
1668            if !used_later {
1669                let JoinStage::FusedIntersect { is_leaf_scan, .. } = &mut stages[i] else {
1670                    unreachable!();
1671                };
1672                *is_leaf_scan = true;
1673            }
1674        }
1675    }
1676}
1677
1678/// Compile a stage info into a concrete join stage, updating constraint state.
1679fn compile_stage(
1680    ctx: &PlanningContext,
1681    state: &mut PlanningState,
1682    StageInfo {
1683        cover,
1684        vars,
1685        filters,
1686    }: StageInfo,
1687) -> JoinStage {
1688    fn take_atom_constraints_if_new(
1689        ctx: &PlanningContext,
1690        state: &mut PlanningState,
1691        atom: AtomId,
1692    ) -> Vec<Constraint> {
1693        if state.is_atom_constrained(atom) {
1694            Default::default()
1695        } else {
1696            state.mark_atom_constrained(atom);
1697            ctx.atoms[atom].constraints.slow.clone()
1698        }
1699    }
1700
1701    // Only do this if it's a join of more than one relations
1702    if vars.len() == 1 && !filters.is_empty() {
1703        let scans = SmallVec::<[SingleScanSpec; 3]>::from_iter(
1704            iter::once(&cover)
1705                .chain(filters.iter().map(|(x, _)| x))
1706                .map(|subatom| {
1707                    let atom = subatom.atom;
1708                    SingleScanSpec {
1709                        atom,
1710                        column: subatom.vars[0],
1711                        cs: take_atom_constraints_if_new(ctx, state, atom),
1712                    }
1713                }),
1714        );
1715
1716        return JoinStage::Intersect {
1717            var: vars[0],
1718            scans,
1719        };
1720    }
1721
1722    // FusedIntersect case
1723    let atom = cover.atom;
1724
1725    let cover_spec = ScanSpec {
1726        to_index: cover,
1727        constraints: take_atom_constraints_if_new(ctx, state, atom),
1728    };
1729
1730    let mut bind = SmallVec::new();
1731    for var in vars {
1732        bind.push((ctx.atoms[atom].get_col(var).unwrap(), var));
1733    }
1734
1735    let mut to_intersect = Vec::with_capacity(filters.len());
1736    for (subatom, key_spec) in filters {
1737        let atom = subatom.atom;
1738        let scan = ScanSpec {
1739            to_index: subatom,
1740            constraints: take_atom_constraints_if_new(ctx, state, atom),
1741        };
1742        to_intersect.push((scan, key_spec));
1743    }
1744
1745    JoinStage::FusedIntersect {
1746        cover: cover_spec,
1747        bind,
1748        to_intersect,
1749        is_leaf_scan: false,
1750    }
1751}