egglog_core_relations/free_join/
plan.rs

1use std::{collections::BTreeMap, iter, mem, sync::Arc};
2
3use crate::{
4    numeric_id::{DenseIdMap, NumericId},
5    query::SymbolMap,
6};
7use fixedbitset::FixedBitSet;
8use smallvec::{SmallVec, smallvec};
9
10use crate::{
11    common::{HashMap, HashSet, IndexSet},
12    offsets::Subset,
13    pool::Pooled,
14    query::{Atom, Query},
15    table_spec::Constraint,
16};
17
18use super::{ActionId, AtomId, ColumnId, SubAtom, VarInfo, Variable};
19
20#[derive(Clone, Debug, PartialEq, Eq)]
21pub(crate) struct ScanSpec {
22    pub to_index: SubAtom,
23    // Only yield rows where the given constraints match.
24    pub constraints: Vec<Constraint>,
25}
26
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub(crate) struct SingleScanSpec {
29    pub atom: AtomId,
30    pub column: ColumnId,
31    pub cs: Vec<Constraint>,
32}
33
34/// Join headers evaluate constraints on a single atom; they prune the search space before the rest
35/// of the join plan is executed.
36pub(crate) struct JoinHeader {
37    pub atom: AtomId,
38    /// We currently aren't using these at all. The plan is to use this to
39    /// dedup plan stages later (it also helps for debugging).
40    #[allow(unused)]
41    pub constraints: Pooled<Vec<Constraint>>,
42    /// A pre-computed table subset that we can use to filter the table,
43    /// given these constaints.
44    ///
45    /// Why use the constraints at all? Because we want to use them to
46    /// discover common plan nodes from different queries (subsets can be
47    /// large).
48    pub subset: Subset,
49}
50
51impl std::fmt::Debug for JoinHeader {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("JoinHeader")
54            .field("atom", &self.atom)
55            .field("constraints", &self.constraints)
56            .field(
57                "subset",
58                &format_args!("Subset(size={})", self.subset.size()),
59            )
60            .finish()
61    }
62}
63
64impl Clone for JoinHeader {
65    fn clone(&self) -> Self {
66        JoinHeader {
67            atom: self.atom,
68            constraints: Pooled::cloned(&self.constraints),
69            subset: self.subset.clone(),
70        }
71    }
72}
73
74#[derive(Debug, Clone)]
75pub(crate) enum JoinStage {
76    /// `Intersect` takes a variable and intersects a set of atoms
77    /// on that variable.
78    /// This corresponds to the classic generic join algorithm.
79    Intersect {
80        var: Variable,
81        scans: SmallVec<[SingleScanSpec; 3]>,
82    },
83    /// `FusedIntersect` takes a "cover" (sub)atom and use it to probe other (sub)atoms.
84    /// This corresponds to the free join algorithm, or when to_intersect.len() == 1 and cover is
85    /// the entire atom, a hash join.
86    FusedIntersect {
87        cover: ScanSpec,
88        bind: SmallVec<[(ColumnId, Variable); 2]>,
89        // to_intersect.1 is the index into the cover atom.
90        to_intersect: Vec<(ScanSpec, SmallVec<[ColumnId; 2]>)>,
91    },
92}
93
94impl JoinStage {
95    /// Attempt to fuse two stages into one.
96    ///
97    /// This operation is very conservative right now, it only fuses multiple
98    /// scans that do no filtering whatsoever.
99    fn fuse(&mut self, other: &JoinStage) -> bool {
100        use JoinStage::*;
101        match (self, other) {
102            (
103                FusedIntersect {
104                    cover,
105                    bind,
106                    to_intersect,
107                },
108                Intersect { var, scans },
109            ) if to_intersect.is_empty()
110                && scans.len() == 1
111                && cover.to_index.atom == scans[0].atom
112                && scans[0].cs.is_empty() =>
113            {
114                let col = scans[0].column;
115                bind.push((col, *var));
116                cover.to_index.vars.push(col);
117                true
118            }
119            (
120                x,
121                Intersect {
122                    var: var2,
123                    scans: scans2,
124                },
125            ) => {
126                // This is all somewhat mangled because of the borrowing rules
127                // when we pass &mut self into a tuple.
128                let (var1, mut scans1) = if let Intersect {
129                    var: var1,
130                    scans: scans1,
131                } = x
132                {
133                    if !(scans1.len() == 1
134                        && scans2.len() == 1
135                        && scans1[0].atom == scans2[0].atom
136                        && scans2[0].cs.is_empty())
137                    {
138                        return false;
139                    }
140                    (*var1, mem::take(scans1))
141                } else {
142                    return false;
143                };
144                let atom = scans1[0].atom;
145                let col1 = scans1[0].column;
146                let col2 = scans2[0].column;
147                *x = FusedIntersect {
148                    cover: ScanSpec {
149                        to_index: SubAtom {
150                            atom,
151                            vars: smallvec![col1, col2],
152                        },
153                        constraints: mem::take(&mut scans1[0].cs),
154                    },
155                    bind: smallvec![(col1, var1), (col2, *var2)],
156                    to_intersect: Default::default(),
157                };
158                true
159            }
160            _ => false,
161        }
162    }
163}
164
165#[derive(Debug, Clone)]
166pub(crate) struct Plan {
167    pub atoms: Arc<DenseIdMap<AtomId, Atom>>,
168    pub stages: JoinStages,
169}
170impl Plan {
171    pub(crate) fn to_report(&self, symbol_map: &SymbolMap) -> egglog_reports::Plan {
172        use egglog_reports::{
173            Plan as ReportPlan, Scan as ReportScan, SingleScan as ReportSingleScan,
174            Stage as ReportStage,
175        };
176        const INTERNAL_PREFIX: &str = "@";
177        let get_var = |var: Variable| {
178            symbol_map
179                .vars
180                .get(&var)
181                .map(|s| s.to_string())
182                .unwrap_or_else(|| format!("{INTERNAL_PREFIX}x{var:?}"))
183        };
184        let get_atom = |atom: AtomId| {
185            symbol_map
186                .atoms
187                .get(&atom)
188                .map(|s| s.to_string())
189                .unwrap_or_else(|| format!("{INTERNAL_PREFIX}R{atom:?}"))
190        };
191        let mut stages = Vec::new();
192        for (i, stage) in self.stages.instrs.iter().enumerate() {
193            let report_stage = match stage {
194                JoinStage::Intersect { var, scans } => {
195                    let var_name = get_var(*var);
196                    let report_scans = scans
197                        .iter()
198                        .map(|scan| {
199                            let atom_name = get_atom(scan.atom);
200                            ReportSingleScan(
201                                atom_name,
202                                (var_name.clone(), scan.column.index() as i64),
203                            )
204                        })
205                        .collect();
206                    ReportStage::Intersect {
207                        scans: report_scans,
208                    }
209                }
210                JoinStage::FusedIntersect {
211                    cover,
212                    bind: _,
213                    to_intersect,
214                } => {
215                    let cover_atom_name = get_atom(cover.to_index.atom);
216                    let cover_cols: Vec<(String, i64)> = cover
217                        .to_index
218                        .vars
219                        .iter()
220                        .map(|col| {
221                            let var_name =
222                                get_var(self.atoms[cover.to_index.atom].column_to_var[*col]);
223                            (var_name, col.index() as i64)
224                        })
225                        .collect();
226                    let report_cover = ReportScan(cover_atom_name, cover_cols);
227                    let report_to_intersect = to_intersect
228                        .iter()
229                        .map(|(scan, key_spec)| {
230                            let atom_name = get_atom(scan.to_index.atom);
231                            let cols: Vec<(String, i64)> = key_spec
232                                .iter()
233                                .map(|col| {
234                                    let var_name =
235                                        get_var(self.atoms[scan.to_index.atom].column_to_var[*col]);
236                                    (var_name, col.index() as i64)
237                                })
238                                .collect();
239                            ReportScan(atom_name, cols)
240                        })
241                        .collect();
242                    ReportStage::FusedIntersect {
243                        cover: report_cover,
244                        to_intersect: report_to_intersect,
245                    }
246                }
247            };
248            let next = if i == self.stages.instrs.len() - 1 {
249                vec![]
250            } else {
251                vec![i + 1]
252            };
253            stages.push((report_stage, None, next));
254        }
255        ReportPlan { stages }
256    }
257}
258
259#[derive(Debug, Clone)]
260pub(crate) struct JoinStages {
261    pub header: Vec<JoinHeader>,
262    pub instrs: Arc<Vec<JoinStage>>,
263    pub actions: ActionId,
264}
265
266type VarSet = FixedBitSet;
267type AtomSet = FixedBitSet;
268
269/// The algorithm used to produce a join plan.
270#[derive(Default, Copy, Clone)]
271pub enum PlanStrategy {
272    /// Free Join: Iteratively pick the smallest atom as the cover for the next
273    /// stage, until all subatoms have been visited.
274    PureSize,
275
276    /// Free Join: Pick an approximate minimal set of covers, then order those
277    /// covers in increasing order of size.
278    ///
279    /// This is similar to PureSize but we first limit the potential atoms that
280    /// can act as covers so as to minimize the total number of stages in the
281    /// plan. This is only an approximate minimum: the problem of finding the
282    /// exact minimum ("set cover") is NP-hard.
283    MinCover,
284
285    /// Generate a plan for the classic Generic Join algorithm, constraining a
286    /// single variable per stage.
287    #[default]
288    Gj,
289}
290
291pub(crate) fn plan_query(query: Query) -> Plan {
292    let atoms = query.atoms;
293    let ctx = PlanningContext {
294        vars: query.var_info,
295        atoms,
296    };
297    let (header, instrs) = plan_stages(&ctx, query.plan_strategy);
298
299    Plan {
300        atoms: Arc::new(ctx.atoms),
301        stages: JoinStages {
302            header,
303            instrs: Arc::new(instrs),
304            actions: query.action,
305        },
306    }
307}
308
309/// StageInfo is an intermediate stage used to describe the ordering of
310/// operations. One of these contains enough information to "expand" it to a
311/// JoinStage, but it still contains variable information.
312///
313/// This separation makes it easier for us to iterate with different planning
314/// algorithms while sharing the same "backend" that generates a concrete plan.
315#[derive(Debug)]
316struct StageInfo {
317    cover: SubAtom,
318    vars: SmallVec<[Variable; 1]>,
319    filters: Vec<(
320        SubAtom,                 /* the subatom to index */
321        SmallVec<[ColumnId; 2]>, /* how to build a key for that index from the cover atom */
322    )>,
323}
324
325/// Immutable context for query planning containing references to query metadata.
326struct PlanningContext {
327    vars: DenseIdMap<Variable, VarInfo>,
328    atoms: DenseIdMap<AtomId, Atom>,
329}
330
331/// Mutable state tracked during query planning.
332#[derive(Clone)]
333struct PlanningState {
334    used_vars: VarSet,
335    constrained_atoms: AtomSet,
336}
337
338impl PlanningState {
339    fn new(n_vars: usize, n_atoms: usize) -> Self {
340        Self {
341            used_vars: VarSet::with_capacity(n_vars),
342            constrained_atoms: AtomSet::with_capacity(n_atoms),
343        }
344    }
345
346    fn mark_var_used(&mut self, var: Variable) {
347        self.used_vars.insert(var.index());
348    }
349
350    fn is_var_used(&self, var: Variable) -> bool {
351        self.used_vars.contains(var.index())
352    }
353
354    fn mark_atom_constrained(&mut self, atom: AtomId) {
355        self.constrained_atoms.insert(atom.index());
356    }
357
358    fn is_atom_constrained(&self, atom: AtomId) -> bool {
359        self.constrained_atoms.contains(atom.index())
360    }
361}
362
363/// Datastructure used to greedily solve the set cover problem for a given free
364/// join plan.
365struct BucketQueue<'a> {
366    var_info: &'a DenseIdMap<Variable, VarInfo>,
367    cover: VarSet,
368    atom_info: DenseIdMap<AtomId, VarSet>,
369    sizes: BTreeMap<usize, IndexSet<AtomId>>,
370}
371
372impl<'a> BucketQueue<'a> {
373    fn new(var_info: &'a DenseIdMap<Variable, VarInfo>, atoms: &DenseIdMap<AtomId, Atom>) -> Self {
374        let cover = VarSet::with_capacity(var_info.n_ids());
375        let mut atom_info = DenseIdMap::with_capacity(atoms.n_ids());
376        let mut sizes = BTreeMap::<usize, IndexSet<AtomId>>::new();
377        for (id, atom) in atoms.iter() {
378            let mut bitset = VarSet::with_capacity(var_info.n_ids());
379            for (_, var) in atom.column_to_var.iter() {
380                bitset.insert(var.index());
381            }
382            sizes.entry(bitset.count_ones(..)).or_default().insert(id);
383            atom_info.insert(id, bitset);
384        }
385        BucketQueue {
386            var_info,
387            cover,
388            atom_info,
389            sizes,
390        }
391    }
392
393    /// Return the atom with the largest number of uncovered variables. A
394    /// variable is "covered" if a previous call to `pop_min` returned an atom
395    /// referencing that variable.
396    fn pop_min(&mut self) -> Option<AtomId> {
397        // Pick an arbitrary atom from the smallest bucket.
398        let (_, atoms) = self.sizes.iter_mut().next_back()?;
399        let res = atoms.pop().unwrap();
400        let vars = self.atom_info[res].clone();
401        // For each variable that we added to the cover, remove it from the
402        // entries in atom_info referencing it and update `sizes` to reflect the
403        // new ordering.
404        for new_var in vars.difference(&self.cover).map(Variable::from_usize) {
405            for subatom in &self.var_info[new_var].occurrences {
406                let cur_set = &mut self.atom_info[subatom.atom];
407                let old_size = cur_set.count_ones(..);
408                cur_set.difference_with(&vars);
409                let new_size = cur_set.count_ones(..);
410                if old_size == new_size {
411                    continue;
412                }
413                if let Some(old_size_set) = self.sizes.get_mut(&old_size) {
414                    old_size_set.swap_remove(&subatom.atom);
415                    if old_size_set.is_empty() {
416                        self.sizes.remove(&old_size);
417                    }
418                }
419                if new_size > 0 {
420                    self.sizes.entry(new_size).or_default().insert(subatom.atom);
421                }
422            }
423        }
424        self.cover.union_with(&vars);
425        Some(res)
426    }
427}
428
429/// Build join headers from fast constraints and compute remaining constraints for planning.
430/// Returns (headers, remaining_constraints) tuple.
431fn plan_headers(
432    ctx: &PlanningContext,
433) -> (
434    Vec<JoinHeader>,
435    DenseIdMap<
436        AtomId,
437        (
438            usize, /* The approx size of the subset matching the constraints. */
439            &Pooled<Vec<Constraint>>,
440        ),
441    >,
442) {
443    let mut header = Vec::new();
444    let mut remaining_constraints: DenseIdMap<AtomId, (usize, &Pooled<Vec<Constraint>>)> =
445        Default::default();
446
447    for (atom, atom_info) in ctx.atoms.iter() {
448        remaining_constraints.insert(
449            atom,
450            (
451                atom_info.constraints.approx_size(),
452                &atom_info.constraints.slow,
453            ),
454        );
455        if !atom_info.constraints.fast.is_empty() {
456            header.push(JoinHeader {
457                atom,
458                constraints: Pooled::cloned(&atom_info.constraints.fast),
459                subset: atom_info.constraints.subset.clone(),
460            });
461        }
462    }
463
464    (header, remaining_constraints)
465}
466
467/// Plan query execution stages using the specified strategy.
468/// Returns (header, instructions) tuple that can be assembled into a Plan by the caller.
469fn plan_stages(ctx: &PlanningContext, strat: PlanStrategy) -> (Vec<JoinHeader>, Vec<JoinStage>) {
470    let (header, remaining_constraints) = plan_headers(ctx);
471    let mut instrs = Vec::new();
472    let mut state = PlanningState::new(ctx.vars.n_ids(), ctx.atoms.n_ids());
473
474    match strat {
475        PlanStrategy::PureSize | PlanStrategy::MinCover => {
476            plan_free_join(ctx, &mut state, strat, &remaining_constraints, &mut instrs)
477        }
478        PlanStrategy::Gj => plan_gj(ctx, &mut state, &remaining_constraints, &mut instrs),
479    };
480
481    (header, instrs)
482}
483
484/// Plan free join queries using pure size or minimal cover strategy.
485fn plan_free_join(
486    ctx: &PlanningContext,
487    state: &mut PlanningState,
488    strat: PlanStrategy,
489    remaining_constraints: &DenseIdMap<AtomId, (usize, &Pooled<Vec<Constraint>>)>,
490    stages: &mut Vec<JoinStage>,
491) {
492    let mut size_info = Vec::<(AtomId, usize)>::new();
493
494    match strat {
495        PlanStrategy::PureSize => {
496            for (atom, (size, _)) in remaining_constraints.iter() {
497                size_info.push((atom, *size));
498            }
499        }
500        PlanStrategy::MinCover => {
501            let mut eligible_covers = HashSet::default();
502            let mut queue = BucketQueue::new(&ctx.vars, &ctx.atoms);
503            while let Some(atom) = queue.pop_min() {
504                eligible_covers.insert(atom);
505            }
506            for (atom, (size, _)) in remaining_constraints
507                .iter()
508                .filter(|(atom, _)| eligible_covers.contains(atom))
509            {
510                size_info.push((atom, *size));
511            }
512        }
513        PlanStrategy::Gj => unreachable!(),
514    };
515
516    size_info.sort_by_key(|(_, size)| *size);
517    let mut atoms = size_info.iter().map(|(atom, _)| *atom);
518
519    while let Some(info) = get_next_freejoin_stage(ctx, state, &mut atoms) {
520        let stage = compile_stage(ctx, state, info);
521        stages.push(stage);
522    }
523}
524
525/// Generate the next free join stage by picking an atom from the ordering.
526/// Returns the stage info and updated state, or None if all atoms are covered.
527fn get_next_freejoin_stage(
528    ctx: &PlanningContext,
529    state: &mut PlanningState,
530    ordering: &mut impl Iterator<Item = AtomId>,
531) -> Option<StageInfo> {
532    let mut scratch_subatom: HashMap<AtomId, SmallVec<[ColumnId; 2]>> = Default::default();
533
534    loop {
535        let mut covered = false;
536        let atom = ordering.next()?;
537        let atom_info = &ctx.atoms[atom];
538        let mut cover = SubAtom::new(atom);
539        let mut vars = SmallVec::<[Variable; 1]>::new();
540
541        for (ix, var) in atom_info.column_to_var.iter() {
542            if state.is_var_used(*var) {
543                continue;
544            }
545            // This atom is not completely covered by previous stages.
546            covered = true;
547            state.mark_var_used(*var);
548            vars.push(*var);
549            cover.vars.push(ix);
550
551            for subatom in ctx.vars[*var].occurrences.iter() {
552                if subatom.atom == atom {
553                    continue;
554                }
555                scratch_subatom
556                    .entry(subatom.atom)
557                    .or_default()
558                    .extend(subatom.vars.iter().copied());
559            }
560        }
561
562        if !covered {
563            // Search the next atom.
564            continue;
565        }
566
567        let mut filters = Vec::new();
568        for (atom, cols) in scratch_subatom.drain() {
569            let mut form_key = SmallVec::<[ColumnId; 2]>::new();
570            for var_ix in &cols {
571                let var = ctx.atoms[atom].column_to_var[*var_ix];
572                // form_key is an index _into the subatom forming the cover_.
573                let cover_col = vars.iter().position(|v| *v == var).unwrap();
574                form_key.push(ColumnId::from_usize(cover_col));
575            }
576            filters.push((SubAtom { atom, vars: cols }, form_key));
577        }
578
579        return Some(StageInfo {
580            cover,
581            vars,
582            filters,
583        });
584    }
585}
586
587/// Plan generic join queries (one variable per stage).
588fn plan_gj(
589    ctx: &PlanningContext,
590    state: &mut PlanningState,
591    remaining_constraints: &DenseIdMap<AtomId, (usize, &Pooled<Vec<Constraint>>)>,
592    stages: &mut Vec<JoinStage>,
593) {
594    // First, map all variables to the size of the smallest atom in which they appear:
595    let mut min_sizes = Vec::with_capacity(ctx.vars.n_ids());
596    let mut atoms_hit = AtomSet::with_capacity(ctx.atoms.n_ids());
597    for (var, var_info) in ctx.vars.iter() {
598        let n_occs = var_info.occurrences.len();
599        if n_occs == 1 && !var_info.used_in_rhs {
600            // Do not plan this one. Unless (see below).
601            continue;
602        }
603        if let Some(min_size) = var_info
604            .occurrences
605            .iter()
606            .map(|subatom| {
607                atoms_hit.set(subatom.atom.index(), true);
608                remaining_constraints[subatom.atom].0
609            })
610            .min()
611        {
612            min_sizes.push((var, min_size, n_occs));
613        }
614        // If the variable has no ocurrences, it may be bound on the RHS of a
615        // rule (or it may just be unused). Either way, we will ignore it when
616        // planning the query.
617    }
618    for (var, var_info) in ctx.vars.iter() {
619        if var_info.occurrences.len() == 1 && !var_info.used_in_rhs {
620            // We skipped this variable the first time around because it
621            // looks "unused". If it belongs to an atom that otherwise has
622            // gone unmentioned, though, we need to plan it anyway.
623            let atom = var_info.occurrences[0].atom;
624            if !atoms_hit.contains(atom.index()) {
625                min_sizes.push((var, remaining_constraints[atom].0, 1));
626            }
627        }
628    }
629    // Sort ascending by size, then descending by number of occurrences.
630    min_sizes.sort_by_key(|(_, size, occs)| (*size, -(*occs as i64)));
631    for (var, _, _) in min_sizes {
632        let occ = ctx.vars[var].occurrences[0].clone();
633        let mut info = StageInfo {
634            cover: occ,
635            vars: smallvec![var],
636            filters: Default::default(),
637        };
638        for occ in &ctx.vars[var].occurrences[1..] {
639            info.filters
640                .push((occ.clone(), smallvec![ColumnId::new(0); occ.vars.len()]));
641        }
642
643        let next_stage = compile_stage(ctx, state, info);
644        if let Some(prev) = stages.last_mut() {
645            if prev.fuse(&next_stage) {
646                continue;
647            }
648        }
649        stages.push(next_stage);
650    }
651}
652
653/// Compile a stage info into a concrete join stage, updating constraint state.
654fn compile_stage(
655    ctx: &PlanningContext,
656    state: &mut PlanningState,
657    StageInfo {
658        cover,
659        vars,
660        filters,
661    }: StageInfo,
662) -> JoinStage {
663    fn take_atom_constraints_if_new(
664        ctx: &PlanningContext,
665        state: &mut PlanningState,
666        atom: AtomId,
667    ) -> Vec<Constraint> {
668        if state.is_atom_constrained(atom) {
669            Default::default()
670        } else {
671            state.mark_atom_constrained(atom);
672            ctx.atoms[atom].constraints.slow.clone()
673        }
674    }
675
676    if vars.len() == 1 {
677        let scans = SmallVec::<[SingleScanSpec; 3]>::from_iter(
678            iter::once(&cover)
679                .chain(filters.iter().map(|(x, _)| x))
680                .map(|subatom| {
681                    let atom = subatom.atom;
682                    SingleScanSpec {
683                        atom,
684                        column: subatom.vars[0],
685                        cs: take_atom_constraints_if_new(ctx, state, atom),
686                    }
687                }),
688        );
689
690        return JoinStage::Intersect {
691            var: vars[0],
692            scans,
693        };
694    }
695
696    // FusedIntersect case
697    let atom = cover.atom;
698
699    let cover_spec = ScanSpec {
700        to_index: cover,
701        constraints: take_atom_constraints_if_new(ctx, state, atom),
702    };
703
704    let mut bind = SmallVec::new();
705    let var_set = &ctx.atoms[atom].var_to_column;
706    for var in vars {
707        bind.push((var_set[&var], var));
708    }
709
710    let mut to_intersect = Vec::with_capacity(filters.len());
711    for (subatom, key_spec) in filters {
712        let atom = subatom.atom;
713        let scan = ScanSpec {
714            to_index: subatom,
715            constraints: take_atom_constraints_if_new(ctx, state, atom),
716        };
717        to_intersect.push((scan, key_spec));
718    }
719
720    JoinStage::FusedIntersect {
721        cover: cover_spec,
722        bind,
723        to_intersect,
724    }
725}