1use std::{collections::BTreeMap, iter, mem, sync::Arc};
2
3use crate::{
4 numeric_id::{DenseIdMap, NumericId},
5 query::SymbolMap,
6};
7use fixedbitset::FixedBitSet;
8use smallvec::{SmallVec, smallvec};
9
10use crate::{
11 common::{HashMap, HashSet, IndexSet},
12 offsets::Subset,
13 pool::Pooled,
14 query::{Atom, Query},
15 table_spec::Constraint,
16};
17
18use super::{ActionId, AtomId, ColumnId, SubAtom, VarInfo, Variable};
19
20#[derive(Clone, Debug, PartialEq, Eq)]
21pub(crate) struct ScanSpec {
22 pub to_index: SubAtom,
23 pub constraints: Vec<Constraint>,
25}
26
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub(crate) struct SingleScanSpec {
29 pub atom: AtomId,
30 pub column: ColumnId,
31 pub cs: Vec<Constraint>,
32}
33
34pub(crate) struct JoinHeader {
37 pub atom: AtomId,
38 #[allow(unused)]
41 pub constraints: Pooled<Vec<Constraint>>,
42 pub subset: Subset,
49}
50
51impl std::fmt::Debug for JoinHeader {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("JoinHeader")
54 .field("atom", &self.atom)
55 .field("constraints", &self.constraints)
56 .field(
57 "subset",
58 &format_args!("Subset(size={})", self.subset.size()),
59 )
60 .finish()
61 }
62}
63
64impl Clone for JoinHeader {
65 fn clone(&self) -> Self {
66 JoinHeader {
67 atom: self.atom,
68 constraints: Pooled::cloned(&self.constraints),
69 subset: self.subset.clone(),
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
75pub(crate) enum JoinStage {
76 Intersect {
80 var: Variable,
81 scans: SmallVec<[SingleScanSpec; 3]>,
82 },
83 FusedIntersect {
87 cover: ScanSpec,
88 bind: SmallVec<[(ColumnId, Variable); 2]>,
89 to_intersect: Vec<(ScanSpec, SmallVec<[ColumnId; 2]>)>,
91 },
92}
93
94impl JoinStage {
95 fn fuse(&mut self, other: &JoinStage) -> bool {
100 use JoinStage::*;
101 match (self, other) {
102 (
103 FusedIntersect {
104 cover,
105 bind,
106 to_intersect,
107 },
108 Intersect { var, scans },
109 ) if to_intersect.is_empty()
110 && scans.len() == 1
111 && cover.to_index.atom == scans[0].atom
112 && scans[0].cs.is_empty() =>
113 {
114 let col = scans[0].column;
115 bind.push((col, *var));
116 cover.to_index.vars.push(col);
117 true
118 }
119 (
120 x,
121 Intersect {
122 var: var2,
123 scans: scans2,
124 },
125 ) => {
126 let (var1, mut scans1) = if let Intersect {
129 var: var1,
130 scans: scans1,
131 } = x
132 {
133 if !(scans1.len() == 1
134 && scans2.len() == 1
135 && scans1[0].atom == scans2[0].atom
136 && scans2[0].cs.is_empty())
137 {
138 return false;
139 }
140 (*var1, mem::take(scans1))
141 } else {
142 return false;
143 };
144 let atom = scans1[0].atom;
145 let col1 = scans1[0].column;
146 let col2 = scans2[0].column;
147 *x = FusedIntersect {
148 cover: ScanSpec {
149 to_index: SubAtom {
150 atom,
151 vars: smallvec![col1, col2],
152 },
153 constraints: mem::take(&mut scans1[0].cs),
154 },
155 bind: smallvec![(col1, var1), (col2, *var2)],
156 to_intersect: Default::default(),
157 };
158 true
159 }
160 _ => false,
161 }
162 }
163}
164
165#[derive(Debug, Clone)]
166pub(crate) struct Plan {
167 pub atoms: Arc<DenseIdMap<AtomId, Atom>>,
168 pub stages: JoinStages,
169}
170impl Plan {
171 pub(crate) fn to_report(&self, symbol_map: &SymbolMap) -> egglog_reports::Plan {
172 use egglog_reports::{
173 Plan as ReportPlan, Scan as ReportScan, SingleScan as ReportSingleScan,
174 Stage as ReportStage,
175 };
176 const INTERNAL_PREFIX: &str = "@";
177 let get_var = |var: Variable| {
178 symbol_map
179 .vars
180 .get(&var)
181 .map(|s| s.to_string())
182 .unwrap_or_else(|| format!("{INTERNAL_PREFIX}x{var:?}"))
183 };
184 let get_atom = |atom: AtomId| {
185 symbol_map
186 .atoms
187 .get(&atom)
188 .map(|s| s.to_string())
189 .unwrap_or_else(|| format!("{INTERNAL_PREFIX}R{atom:?}"))
190 };
191 let mut stages = Vec::new();
192 for (i, stage) in self.stages.instrs.iter().enumerate() {
193 let report_stage = match stage {
194 JoinStage::Intersect { var, scans } => {
195 let var_name = get_var(*var);
196 let report_scans = scans
197 .iter()
198 .map(|scan| {
199 let atom_name = get_atom(scan.atom);
200 ReportSingleScan(
201 atom_name,
202 (var_name.clone(), scan.column.index() as i64),
203 )
204 })
205 .collect();
206 ReportStage::Intersect {
207 scans: report_scans,
208 }
209 }
210 JoinStage::FusedIntersect {
211 cover,
212 bind: _,
213 to_intersect,
214 } => {
215 let cover_atom_name = get_atom(cover.to_index.atom);
216 let cover_cols: Vec<(String, i64)> = cover
217 .to_index
218 .vars
219 .iter()
220 .map(|col| {
221 let var_name =
222 get_var(self.atoms[cover.to_index.atom].column_to_var[*col]);
223 (var_name, col.index() as i64)
224 })
225 .collect();
226 let report_cover = ReportScan(cover_atom_name, cover_cols);
227 let report_to_intersect = to_intersect
228 .iter()
229 .map(|(scan, key_spec)| {
230 let atom_name = get_atom(scan.to_index.atom);
231 let cols: Vec<(String, i64)> = key_spec
232 .iter()
233 .map(|col| {
234 let var_name =
235 get_var(self.atoms[scan.to_index.atom].column_to_var[*col]);
236 (var_name, col.index() as i64)
237 })
238 .collect();
239 ReportScan(atom_name, cols)
240 })
241 .collect();
242 ReportStage::FusedIntersect {
243 cover: report_cover,
244 to_intersect: report_to_intersect,
245 }
246 }
247 };
248 let next = if i == self.stages.instrs.len() - 1 {
249 vec![]
250 } else {
251 vec![i + 1]
252 };
253 stages.push((report_stage, None, next));
254 }
255 ReportPlan { stages }
256 }
257}
258
259#[derive(Debug, Clone)]
260pub(crate) struct JoinStages {
261 pub header: Vec<JoinHeader>,
262 pub instrs: Arc<Vec<JoinStage>>,
263 pub actions: ActionId,
264}
265
266type VarSet = FixedBitSet;
267type AtomSet = FixedBitSet;
268
269#[derive(Default, Copy, Clone)]
271pub enum PlanStrategy {
272 PureSize,
275
276 MinCover,
284
285 #[default]
288 Gj,
289}
290
291pub(crate) fn plan_query(query: Query) -> Plan {
292 let atoms = query.atoms;
293 let ctx = PlanningContext {
294 vars: query.var_info,
295 atoms,
296 };
297 let (header, instrs) = plan_stages(&ctx, query.plan_strategy);
298
299 Plan {
300 atoms: Arc::new(ctx.atoms),
301 stages: JoinStages {
302 header,
303 instrs: Arc::new(instrs),
304 actions: query.action,
305 },
306 }
307}
308
309#[derive(Debug)]
316struct StageInfo {
317 cover: SubAtom,
318 vars: SmallVec<[Variable; 1]>,
319 filters: Vec<(
320 SubAtom, SmallVec<[ColumnId; 2]>, )>,
323}
324
325struct PlanningContext {
327 vars: DenseIdMap<Variable, VarInfo>,
328 atoms: DenseIdMap<AtomId, Atom>,
329}
330
331#[derive(Clone)]
333struct PlanningState {
334 used_vars: VarSet,
335 constrained_atoms: AtomSet,
336}
337
338impl PlanningState {
339 fn new(n_vars: usize, n_atoms: usize) -> Self {
340 Self {
341 used_vars: VarSet::with_capacity(n_vars),
342 constrained_atoms: AtomSet::with_capacity(n_atoms),
343 }
344 }
345
346 fn mark_var_used(&mut self, var: Variable) {
347 self.used_vars.insert(var.index());
348 }
349
350 fn is_var_used(&self, var: Variable) -> bool {
351 self.used_vars.contains(var.index())
352 }
353
354 fn mark_atom_constrained(&mut self, atom: AtomId) {
355 self.constrained_atoms.insert(atom.index());
356 }
357
358 fn is_atom_constrained(&self, atom: AtomId) -> bool {
359 self.constrained_atoms.contains(atom.index())
360 }
361}
362
363struct BucketQueue<'a> {
366 var_info: &'a DenseIdMap<Variable, VarInfo>,
367 cover: VarSet,
368 atom_info: DenseIdMap<AtomId, VarSet>,
369 sizes: BTreeMap<usize, IndexSet<AtomId>>,
370}
371
372impl<'a> BucketQueue<'a> {
373 fn new(var_info: &'a DenseIdMap<Variable, VarInfo>, atoms: &DenseIdMap<AtomId, Atom>) -> Self {
374 let cover = VarSet::with_capacity(var_info.n_ids());
375 let mut atom_info = DenseIdMap::with_capacity(atoms.n_ids());
376 let mut sizes = BTreeMap::<usize, IndexSet<AtomId>>::new();
377 for (id, atom) in atoms.iter() {
378 let mut bitset = VarSet::with_capacity(var_info.n_ids());
379 for (_, var) in atom.column_to_var.iter() {
380 bitset.insert(var.index());
381 }
382 sizes.entry(bitset.count_ones(..)).or_default().insert(id);
383 atom_info.insert(id, bitset);
384 }
385 BucketQueue {
386 var_info,
387 cover,
388 atom_info,
389 sizes,
390 }
391 }
392
393 fn pop_min(&mut self) -> Option<AtomId> {
397 let (_, atoms) = self.sizes.iter_mut().next_back()?;
399 let res = atoms.pop().unwrap();
400 let vars = self.atom_info[res].clone();
401 for new_var in vars.difference(&self.cover).map(Variable::from_usize) {
405 for subatom in &self.var_info[new_var].occurrences {
406 let cur_set = &mut self.atom_info[subatom.atom];
407 let old_size = cur_set.count_ones(..);
408 cur_set.difference_with(&vars);
409 let new_size = cur_set.count_ones(..);
410 if old_size == new_size {
411 continue;
412 }
413 if let Some(old_size_set) = self.sizes.get_mut(&old_size) {
414 old_size_set.swap_remove(&subatom.atom);
415 if old_size_set.is_empty() {
416 self.sizes.remove(&old_size);
417 }
418 }
419 if new_size > 0 {
420 self.sizes.entry(new_size).or_default().insert(subatom.atom);
421 }
422 }
423 }
424 self.cover.union_with(&vars);
425 Some(res)
426 }
427}
428
429fn plan_headers(
432 ctx: &PlanningContext,
433) -> (
434 Vec<JoinHeader>,
435 DenseIdMap<
436 AtomId,
437 (
438 usize, &Pooled<Vec<Constraint>>,
440 ),
441 >,
442) {
443 let mut header = Vec::new();
444 let mut remaining_constraints: DenseIdMap<AtomId, (usize, &Pooled<Vec<Constraint>>)> =
445 Default::default();
446
447 for (atom, atom_info) in ctx.atoms.iter() {
448 remaining_constraints.insert(
449 atom,
450 (
451 atom_info.constraints.approx_size(),
452 &atom_info.constraints.slow,
453 ),
454 );
455 if !atom_info.constraints.fast.is_empty() {
456 header.push(JoinHeader {
457 atom,
458 constraints: Pooled::cloned(&atom_info.constraints.fast),
459 subset: atom_info.constraints.subset.clone(),
460 });
461 }
462 }
463
464 (header, remaining_constraints)
465}
466
467fn plan_stages(ctx: &PlanningContext, strat: PlanStrategy) -> (Vec<JoinHeader>, Vec<JoinStage>) {
470 let (header, remaining_constraints) = plan_headers(ctx);
471 let mut instrs = Vec::new();
472 let mut state = PlanningState::new(ctx.vars.n_ids(), ctx.atoms.n_ids());
473
474 match strat {
475 PlanStrategy::PureSize | PlanStrategy::MinCover => {
476 plan_free_join(ctx, &mut state, strat, &remaining_constraints, &mut instrs)
477 }
478 PlanStrategy::Gj => plan_gj(ctx, &mut state, &remaining_constraints, &mut instrs),
479 };
480
481 (header, instrs)
482}
483
484fn plan_free_join(
486 ctx: &PlanningContext,
487 state: &mut PlanningState,
488 strat: PlanStrategy,
489 remaining_constraints: &DenseIdMap<AtomId, (usize, &Pooled<Vec<Constraint>>)>,
490 stages: &mut Vec<JoinStage>,
491) {
492 let mut size_info = Vec::<(AtomId, usize)>::new();
493
494 match strat {
495 PlanStrategy::PureSize => {
496 for (atom, (size, _)) in remaining_constraints.iter() {
497 size_info.push((atom, *size));
498 }
499 }
500 PlanStrategy::MinCover => {
501 let mut eligible_covers = HashSet::default();
502 let mut queue = BucketQueue::new(&ctx.vars, &ctx.atoms);
503 while let Some(atom) = queue.pop_min() {
504 eligible_covers.insert(atom);
505 }
506 for (atom, (size, _)) in remaining_constraints
507 .iter()
508 .filter(|(atom, _)| eligible_covers.contains(atom))
509 {
510 size_info.push((atom, *size));
511 }
512 }
513 PlanStrategy::Gj => unreachable!(),
514 };
515
516 size_info.sort_by_key(|(_, size)| *size);
517 let mut atoms = size_info.iter().map(|(atom, _)| *atom);
518
519 while let Some(info) = get_next_freejoin_stage(ctx, state, &mut atoms) {
520 let stage = compile_stage(ctx, state, info);
521 stages.push(stage);
522 }
523}
524
525fn get_next_freejoin_stage(
528 ctx: &PlanningContext,
529 state: &mut PlanningState,
530 ordering: &mut impl Iterator<Item = AtomId>,
531) -> Option<StageInfo> {
532 let mut scratch_subatom: HashMap<AtomId, SmallVec<[ColumnId; 2]>> = Default::default();
533
534 loop {
535 let mut covered = false;
536 let atom = ordering.next()?;
537 let atom_info = &ctx.atoms[atom];
538 let mut cover = SubAtom::new(atom);
539 let mut vars = SmallVec::<[Variable; 1]>::new();
540
541 for (ix, var) in atom_info.column_to_var.iter() {
542 if state.is_var_used(*var) {
543 continue;
544 }
545 covered = true;
547 state.mark_var_used(*var);
548 vars.push(*var);
549 cover.vars.push(ix);
550
551 for subatom in ctx.vars[*var].occurrences.iter() {
552 if subatom.atom == atom {
553 continue;
554 }
555 scratch_subatom
556 .entry(subatom.atom)
557 .or_default()
558 .extend(subatom.vars.iter().copied());
559 }
560 }
561
562 if !covered {
563 continue;
565 }
566
567 let mut filters = Vec::new();
568 for (atom, cols) in scratch_subatom.drain() {
569 let mut form_key = SmallVec::<[ColumnId; 2]>::new();
570 for var_ix in &cols {
571 let var = ctx.atoms[atom].column_to_var[*var_ix];
572 let cover_col = vars.iter().position(|v| *v == var).unwrap();
574 form_key.push(ColumnId::from_usize(cover_col));
575 }
576 filters.push((SubAtom { atom, vars: cols }, form_key));
577 }
578
579 return Some(StageInfo {
580 cover,
581 vars,
582 filters,
583 });
584 }
585}
586
587fn plan_gj(
589 ctx: &PlanningContext,
590 state: &mut PlanningState,
591 remaining_constraints: &DenseIdMap<AtomId, (usize, &Pooled<Vec<Constraint>>)>,
592 stages: &mut Vec<JoinStage>,
593) {
594 let mut min_sizes = Vec::with_capacity(ctx.vars.n_ids());
596 let mut atoms_hit = AtomSet::with_capacity(ctx.atoms.n_ids());
597 for (var, var_info) in ctx.vars.iter() {
598 let n_occs = var_info.occurrences.len();
599 if n_occs == 1 && !var_info.used_in_rhs {
600 continue;
602 }
603 if let Some(min_size) = var_info
604 .occurrences
605 .iter()
606 .map(|subatom| {
607 atoms_hit.set(subatom.atom.index(), true);
608 remaining_constraints[subatom.atom].0
609 })
610 .min()
611 {
612 min_sizes.push((var, min_size, n_occs));
613 }
614 }
618 for (var, var_info) in ctx.vars.iter() {
619 if var_info.occurrences.len() == 1 && !var_info.used_in_rhs {
620 let atom = var_info.occurrences[0].atom;
624 if !atoms_hit.contains(atom.index()) {
625 min_sizes.push((var, remaining_constraints[atom].0, 1));
626 }
627 }
628 }
629 min_sizes.sort_by_key(|(_, size, occs)| (*size, -(*occs as i64)));
631 for (var, _, _) in min_sizes {
632 let occ = ctx.vars[var].occurrences[0].clone();
633 let mut info = StageInfo {
634 cover: occ,
635 vars: smallvec![var],
636 filters: Default::default(),
637 };
638 for occ in &ctx.vars[var].occurrences[1..] {
639 info.filters
640 .push((occ.clone(), smallvec![ColumnId::new(0); occ.vars.len()]));
641 }
642
643 let next_stage = compile_stage(ctx, state, info);
644 if let Some(prev) = stages.last_mut() {
645 if prev.fuse(&next_stage) {
646 continue;
647 }
648 }
649 stages.push(next_stage);
650 }
651}
652
653fn compile_stage(
655 ctx: &PlanningContext,
656 state: &mut PlanningState,
657 StageInfo {
658 cover,
659 vars,
660 filters,
661 }: StageInfo,
662) -> JoinStage {
663 fn take_atom_constraints_if_new(
664 ctx: &PlanningContext,
665 state: &mut PlanningState,
666 atom: AtomId,
667 ) -> Vec<Constraint> {
668 if state.is_atom_constrained(atom) {
669 Default::default()
670 } else {
671 state.mark_atom_constrained(atom);
672 ctx.atoms[atom].constraints.slow.clone()
673 }
674 }
675
676 if vars.len() == 1 {
677 let scans = SmallVec::<[SingleScanSpec; 3]>::from_iter(
678 iter::once(&cover)
679 .chain(filters.iter().map(|(x, _)| x))
680 .map(|subatom| {
681 let atom = subatom.atom;
682 SingleScanSpec {
683 atom,
684 column: subatom.vars[0],
685 cs: take_atom_constraints_if_new(ctx, state, atom),
686 }
687 }),
688 );
689
690 return JoinStage::Intersect {
691 var: vars[0],
692 scans,
693 };
694 }
695
696 let atom = cover.atom;
698
699 let cover_spec = ScanSpec {
700 to_index: cover,
701 constraints: take_atom_constraints_if_new(ctx, state, atom),
702 };
703
704 let mut bind = SmallVec::new();
705 let var_set = &ctx.atoms[atom].var_to_column;
706 for var in vars {
707 bind.push((var_set[&var], var));
708 }
709
710 let mut to_intersect = Vec::with_capacity(filters.len());
711 for (subatom, key_spec) in filters {
712 let atom = subatom.atom;
713 let scan = ScanSpec {
714 to_index: subatom,
715 constraints: take_atom_constraints_if_new(ctx, state, atom),
716 };
717 to_intersect.push((scan, key_spec));
718 }
719
720 JoinStage::FusedIntersect {
721 cover: cover_spec,
722 bind,
723 to_intersect,
724 }
725}