1use std::{iter::once, sync::Arc};
4
5use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
6use smallvec::SmallVec;
7use thiserror::Error;
8
9use crate::{
10 BaseValueId, CounterId, ExternalFunctionId, PoolSet,
11 action::{Instr, QueryEntry, WriteVal},
12 common::HashMap,
13 free_join::{
14 ActionId, AtomId, Database, ProcessedConstraints, SubAtom, TableId, TableInfo, VarInfo,
15 Variable,
16 plan::{JoinHeader, JoinStages, Plan, PlanStrategy},
17 },
18 pool::{Pooled, with_pool_set},
19 table_spec::{ColumnId, Constraint},
20};
21
22define_id!(pub RuleId, u32, "An identifier for a rule in a rule set");
23
24#[allow(dead_code)]
26#[derive(Debug, Clone)]
27pub struct SymbolMap {
28 pub atoms: HashMap<AtomId, Arc<str>>,
29 pub vars: HashMap<Variable, Arc<str>>,
30}
31
32pub struct CachedPlan {
34 plan: Plan,
35 desc: Arc<str>,
36 symbol_map: SymbolMap,
37 actions: ActionInfo,
38}
39
40#[derive(Debug, Clone)]
41pub(crate) struct ActionInfo {
42 pub(crate) used_vars: SmallVec<[Variable; 4]>,
43 pub(crate) instrs: Arc<Pooled<Vec<Instr>>>,
44}
45
46#[derive(Default)]
50pub struct RuleSet {
51 pub(crate) plans: IdVec<RuleId, (Plan, Arc<str> , SymbolMap, ActionId)>,
59 pub(crate) actions: DenseIdMap<ActionId, ActionInfo>,
60}
61
62impl RuleSet {
63 pub fn build_cached_plan(&self, rule_id: RuleId) -> CachedPlan {
64 let (plan, desc, symbol_map, action_id) = self.plans.get(rule_id).expect("rule must exist");
65 let actions = self
66 .actions
67 .get(*action_id)
68 .expect("action must exist")
69 .clone();
70 CachedPlan {
71 plan: plan.clone(),
72 desc: desc.clone(),
73 symbol_map: symbol_map.clone(),
74 actions,
75 }
76 }
77}
78
79pub struct RuleSetBuilder<'outer> {
92 rule_set: RuleSet,
93 db: &'outer mut Database,
94}
95
96impl<'outer> RuleSetBuilder<'outer> {
97 pub fn new(db: &'outer mut Database) -> Self {
98 Self {
99 rule_set: Default::default(),
100 db,
101 }
102 }
103
104 pub fn estimate_size(&self, table: TableId, c: Option<Constraint>) -> usize {
109 self.db.estimate_size(table, c)
110 }
111
112 pub fn new_rule<'a>(&'a mut self) -> QueryBuilder<'outer, 'a> {
114 let instrs = with_pool_set(PoolSet::get);
115 QueryBuilder {
116 rsb: self,
117 instrs,
118 query: Query {
119 var_info: Default::default(),
120 atoms: Default::default(),
121 action: ActionId::new(u32::MAX),
123 plan_strategy: Default::default(),
124 },
125 }
126 }
127
128 pub fn add_rule_from_cached_plan(
129 &mut self,
130 cached: &CachedPlan,
131 extra_constraints: &[(AtomId, Constraint)],
132 ) -> RuleId {
133 let action_id = self.rule_set.actions.push(cached.actions.clone());
135 let mut plan = Plan {
136 atoms: cached.plan.atoms.clone(),
137 stages: JoinStages {
138 header: Default::default(),
139 instrs: cached.plan.stages.instrs.clone(),
140 actions: action_id,
141 },
142 };
143
144 for (atom_id, constraint) in extra_constraints {
146 let atom_info = plan.atoms.get(*atom_id).expect("atom must exist in plan");
147 let table = atom_info.table;
148 let processed = self
149 .db
150 .process_constraints(table, std::slice::from_ref(constraint));
151 if !processed.slow.is_empty() {
152 panic!(
153 "Cached plans only support constraints with a fast pushdown. Got: {constraint:?} for table {table:?}",
154 );
155 }
156 plan.stages.header.push(JoinHeader {
157 atom: *atom_id,
158 constraints: processed.fast,
159 subset: processed.subset,
160 });
161 }
162
163 for JoinHeader {
166 atom, constraints, ..
167 } in &cached.plan.stages.header
168 {
169 let atom_info = plan.atoms.get(*atom).expect("atom must exist in plan");
170 let table = atom_info.table;
171 let processed = self.db.process_constraints(table, constraints);
172 if !processed.slow.is_empty() {
173 panic!(
174 "Cached plans only support constraints with a fast pushdown. Got: {constraints:?} for table {table:?}",
175 );
176 }
177 plan.stages.header.push(JoinHeader {
178 atom: *atom,
179 constraints: processed.fast,
180 subset: processed.subset,
181 });
182 }
183
184 self.rule_set.plans.push((
185 plan,
186 cached.desc.clone(),
187 cached.symbol_map.clone(),
188 action_id,
189 ))
190 }
191
192 pub fn build(self) -> RuleSet {
194 self.rule_set
195 }
196}
197
198pub struct QueryBuilder<'outer, 'a> {
203 rsb: &'a mut RuleSetBuilder<'outer>,
204 query: Query,
205 instrs: Pooled<Vec<Instr>>,
206}
207
208impl<'outer, 'a> QueryBuilder<'outer, 'a> {
209 pub fn build(self) -> RuleBuilder<'outer, 'a> {
211 RuleBuilder { qb: self }
212 }
213
214 pub fn set_plan_strategy(&mut self, strategy: PlanStrategy) {
216 self.query.plan_strategy = strategy;
217 }
218
219 pub fn new_var(&mut self) -> Variable {
221 self.query.var_info.push(VarInfo {
222 occurrences: Default::default(),
223 used_in_rhs: false,
224 defined_in_rhs: false,
225 name: None,
226 })
227 }
228
229 pub fn new_var_named(&mut self, name: &str) -> Variable {
230 self.query.var_info.push(VarInfo {
231 occurrences: Default::default(),
232 used_in_rhs: false,
233 defined_in_rhs: false,
234 name: Some(name.into()),
235 })
236 }
237
238 fn mark_used<'b>(&mut self, entries: impl IntoIterator<Item = &'b QueryEntry>) {
239 for entry in entries {
240 if let QueryEntry::Var(v) = entry {
241 self.query.var_info[*v].used_in_rhs = true;
242 }
243 }
244 }
245
246 fn mark_defined(&mut self, entry: &QueryEntry) {
247 if let QueryEntry::Var(v) = entry {
249 self.query.var_info[*v].defined_in_rhs = true;
250 }
251 }
252
253 pub fn add_atom<'b>(
266 &mut self,
267 table_id: TableId,
268 vars: &[QueryEntry],
269 cs: impl IntoIterator<Item = &'b Constraint>,
270 ) -> Result<AtomId, QueryError> {
271 let info = &self.rsb.db.tables[table_id];
272 let arity = info.spec.arity();
273 let check_constraint = |c: &Constraint| {
274 let process_col = |col: &ColumnId| -> Result<(), QueryError> {
275 if col.index() >= arity {
276 Err(QueryError::InvalidConstraint {
277 constraint: c.clone(),
278 column: col.index(),
279 table: table_id,
280 arity,
281 })
282 } else {
283 Ok(())
284 }
285 };
286 match c {
287 Constraint::Eq { l_col, r_col } => {
288 process_col(l_col)?;
289 process_col(r_col)
290 }
291 Constraint::EqConst { col, .. }
292 | Constraint::LtConst { col, .. }
293 | Constraint::GtConst { col, .. }
294 | Constraint::LeConst { col, .. }
295 | Constraint::GeConst { col, .. } => process_col(col),
296 }
297 };
298 if arity != vars.len() {
299 return Err(QueryError::BadArity {
300 table: table_id,
301 expected: arity,
302 got: vars.len(),
303 });
304 }
305 let cs = Vec::from_iter(
306 cs.into_iter()
307 .cloned()
308 .chain(vars.iter().enumerate().filter_map(|(i, qe)| match qe {
309 QueryEntry::Var(_) => None,
310 QueryEntry::Const(c) => Some(Constraint::EqConst {
311 col: ColumnId::from_usize(i),
312 val: *c,
313 }),
314 })),
315 );
316 cs.iter().try_fold((), |_, c| check_constraint(c))?;
317 let processed = self.rsb.db.process_constraints(table_id, &cs);
318 let mut atom = Atom {
319 table: table_id,
320 var_to_column: Default::default(),
321 column_to_var: Default::default(),
322 constraints: processed,
323 };
324 let next_atom = AtomId::from_usize(self.query.atoms.n_ids());
325 let mut subatoms = HashMap::<Variable, SubAtom>::default();
326 for (i, qe) in vars.iter().enumerate() {
327 let var = match qe {
328 QueryEntry::Var(var) => *var,
329 QueryEntry::Const(_) => {
330 continue;
331 }
332 };
333 if var == Variable::placeholder() {
334 continue;
335 }
336 let col = ColumnId::from_usize(i);
337 if let Some(prev) = atom.var_to_column.insert(var, col) {
338 atom.constraints.slow.push(Constraint::Eq {
339 l_col: col,
340 r_col: prev,
341 })
342 };
343 atom.column_to_var.insert(col, var);
344 subatoms
345 .entry(var)
346 .or_insert_with(|| SubAtom::new(next_atom))
347 .vars
348 .push(col);
349 }
350 for (var, subatom) in subatoms {
351 self.query
352 .var_info
353 .get_mut(var)
354 .expect("all variables must be bound in current query")
355 .occurrences
356 .push(subatom);
357 }
358 Ok(self.query.atoms.push(atom))
359 }
360}
361
362#[derive(Debug, Error)]
363pub enum QueryError {
364 #[error("table {table:?} has {expected:?} keys but got {got:?}")]
365 KeyArityMismatch {
366 table: TableId,
367 expected: usize,
368 got: usize,
369 },
370 #[error("table {table:?} has {expected:?} columns but got {got:?}")]
371 TableArityMismatch {
372 table: TableId,
373 expected: usize,
374 got: usize,
375 },
376
377 #[error(
378 "counter used in column {column_id:?} of table {table:?}, which is declared as a base value"
379 )]
380 CounterUsedInBaseColumn {
381 table: TableId,
382 column_id: ColumnId,
383 base: BaseValueId,
384 },
385
386 #[error("attempt to compare two groups of values, one of length {l}, another of length {r}")]
387 MultiComparisonMismatch { l: usize, r: usize },
388
389 #[error("table {table:?} expected {expected:?} columns but got {got:?}")]
390 BadArity {
391 table: TableId,
392 expected: usize,
393 got: usize,
394 },
395
396 #[error("expected {expected:?} columns in schema but got {got:?}")]
397 InvalidSchema { expected: usize, got: usize },
398
399 #[error(
400 "constraint {constraint:?} on table {table:?} references column {column:?}, but the table has arity {arity:?}"
401 )]
402 InvalidConstraint {
403 constraint: Constraint,
404 column: usize,
405 table: TableId,
406 arity: usize,
407 },
408}
409
410pub struct RuleBuilder<'outer, 'a> {
414 qb: QueryBuilder<'outer, 'a>,
415}
416
417impl RuleBuilder<'_, '_> {
418 fn table_info(&self, table: TableId) -> &TableInfo {
419 self.qb.rsb.db.get_table_info(table)
420 }
421
422 pub fn build(self) -> RuleId {
424 self.build_with_description("")
425 }
426
427 fn build_symbol_map(&self) -> SymbolMap {
428 let var_info = &self.qb.query.var_info;
429 SymbolMap {
430 atoms: self
431 .qb
432 .query
433 .atoms
434 .iter()
435 .filter_map(|(id, atom)| {
436 let name = self.table_info(atom.table).name.clone();
437 name.map(|name| (id, name))
438 })
439 .collect(),
440 vars: var_info
441 .iter()
442 .filter_map(|(id, info)| info.name.as_ref().map(|name| (id, name.clone())))
443 .collect(),
444 }
445 }
446
447 pub fn build_with_description(mut self, desc: impl Into<String>) -> RuleId {
448 let var_info = &self.qb.query.var_info;
449 let symbol_map = self.build_symbol_map();
450 let used_vars = SmallVec::from_iter(var_info.iter().filter_map(|(v, info)| {
452 if info.used_in_rhs && !info.defined_in_rhs {
453 Some(v)
454 } else {
455 None
456 }
457 }));
458 let action_id = self.qb.rsb.rule_set.actions.push(ActionInfo {
459 instrs: Arc::new(self.qb.instrs),
460 used_vars,
461 });
462 self.qb.query.action = action_id;
463 let plan = self.qb.rsb.db.plan_query(self.qb.query);
465 let desc: String = desc.into();
466 self.qb
468 .rsb
469 .rule_set
470 .plans
471 .push((plan, desc.into(), symbol_map, action_id))
472 }
473
474 pub fn read_counter(&mut self, counter: CounterId) -> Variable {
476 let dst = self.qb.new_var();
477 self.qb.instrs.push(Instr::ReadCounter { counter, dst });
478 self.qb.mark_defined(&dst.into());
479 dst
480 }
481
482 pub fn lookup_or_insert(
489 &mut self,
490 table: TableId,
491 args: &[QueryEntry],
492 default_vals: &[WriteVal],
493 dst_col: ColumnId,
494 ) -> Result<Variable, QueryError> {
495 let table_info = self.table_info(table);
496 self.validate_keys(table, table_info, args)?;
497 self.validate_vals(table, table_info, default_vals.iter())?;
498 let res = self.qb.new_var();
499 self.qb.instrs.push(Instr::LookupOrInsertDefault {
500 table,
501 args: args.to_vec(),
502 default: default_vals.to_vec(),
503 dst_col,
504 dst_var: res,
505 });
506 self.qb.mark_used(args);
507 self.qb
508 .mark_used(default_vals.iter().filter_map(|x| match x {
509 WriteVal::QueryEntry(qe) => Some(qe),
510 WriteVal::IncCounter(_) | WriteVal::CurrentVal(_) => None,
511 }));
512 self.qb.mark_defined(&res.into());
513 Ok(res)
514 }
515
516 pub fn lookup_with_default(
523 &mut self,
524 table: TableId,
525 args: &[QueryEntry],
526 default: QueryEntry,
527 dst_col: ColumnId,
528 ) -> Result<Variable, QueryError> {
529 let table_info = self.table_info(table);
530 self.validate_keys(table, table_info, args)?;
531 let res = self.qb.new_var();
532 self.qb.instrs.push(Instr::LookupWithDefault {
533 table,
534 args: args.to_vec(),
535 dst_col,
536 dst_var: res,
537 default,
538 });
539 self.qb.mark_used(args);
540 self.qb.mark_used(&[default]);
541 self.qb.mark_defined(&res.into());
542 Ok(res)
543 }
544
545 pub fn lookup(
552 &mut self,
553 table: TableId,
554 args: &[QueryEntry],
555 dst_col: ColumnId,
556 ) -> Result<Variable, QueryError> {
557 let table_info = self.table_info(table);
558 self.validate_keys(table, table_info, args)?;
559 let res = self.qb.new_var();
560 self.qb.instrs.push(Instr::Lookup {
561 table,
562 args: args.to_vec(),
563 dst_col,
564 dst_var: res,
565 });
566 self.qb.mark_used(args);
567 self.qb.mark_defined(&res.into());
568 Ok(res)
569 }
570
571 pub fn insert(&mut self, table: TableId, vals: &[QueryEntry]) -> Result<(), QueryError> {
573 let table_info = self.table_info(table);
574 self.validate_row(table, table_info, vals)?;
575 self.qb.instrs.push(Instr::Insert {
576 table,
577 vals: vals.to_vec(),
578 });
579 self.qb.mark_used(vals);
580 Ok(())
581 }
582
583 pub fn insert_if_eq(
585 &mut self,
586 table: TableId,
587 l: QueryEntry,
588 r: QueryEntry,
589 vals: &[QueryEntry],
590 ) -> Result<(), QueryError> {
591 let table_info = self.table_info(table);
592 self.validate_row(table, table_info, vals)?;
593 self.qb.instrs.push(Instr::InsertIfEq {
594 table,
595 l,
596 r,
597 vals: vals.to_vec(),
598 });
599 self.qb
600 .mark_used(vals.iter().chain(once(&l)).chain(once(&r)));
601 Ok(())
602 }
603
604 pub fn remove(&mut self, table: TableId, args: &[QueryEntry]) -> Result<(), QueryError> {
606 let table_info = self.table_info(table);
607 self.validate_keys(table, table_info, args)?;
608 self.qb.instrs.push(Instr::Remove {
609 table,
610 args: args.to_vec(),
611 });
612 self.qb.mark_used(args);
613 Ok(())
614 }
615
616 pub fn call_external(
618 &mut self,
619 func: ExternalFunctionId,
620 args: &[QueryEntry],
621 ) -> Result<Variable, QueryError> {
622 let res = self.qb.new_var();
623 self.qb.instrs.push(Instr::External {
624 func,
625 args: args.to_vec(),
626 dst: res,
627 });
628 self.qb.mark_used(args);
629 self.qb.mark_defined(&res.into());
630 Ok(res)
631 }
632
633 pub fn lookup_with_fallback(
637 &mut self,
638 table: TableId,
639 key: &[QueryEntry],
640 dst_col: ColumnId,
641 func: ExternalFunctionId,
642 func_args: &[QueryEntry],
643 ) -> Result<Variable, QueryError> {
644 let table_info = self.table_info(table);
645 self.validate_keys(table, table_info, key)?;
646 let res = self.qb.new_var();
647 self.qb.instrs.push(Instr::LookupWithFallback {
648 table,
649 table_key: key.to_vec(),
650 func,
651 func_args: func_args.to_vec(),
652 dst_var: res,
653 dst_col,
654 });
655 self.qb.mark_used(key);
656 self.qb.mark_used(func_args);
657 self.qb.mark_defined(&res.into());
658 Ok(res)
659 }
660
661 pub fn call_external_with_fallback(
662 &mut self,
663 f1: ExternalFunctionId,
664 args1: &[QueryEntry],
665 f2: ExternalFunctionId,
666 args2: &[QueryEntry],
667 ) -> Result<Variable, QueryError> {
668 let res = self.qb.new_var();
669 self.qb.instrs.push(Instr::ExternalWithFallback {
670 f1,
671 args1: args1.to_vec(),
672 f2,
673 args2: args2.to_vec(),
674 dst: res,
675 });
676 self.qb.mark_used(args1);
677 self.qb.mark_used(args2);
678 self.qb.mark_defined(&res.into());
679 Ok(res)
680 }
681
682 pub fn assert_eq(&mut self, l: QueryEntry, r: QueryEntry) {
684 self.qb.instrs.push(Instr::AssertEq(l, r));
685 self.qb.mark_used(&[l, r]);
686 }
687
688 pub fn assert_ne(&mut self, l: QueryEntry, r: QueryEntry) -> Result<(), QueryError> {
690 self.qb.instrs.push(Instr::AssertNe(l, r));
691 self.qb.mark_used(&[l, r]);
692 Ok(())
693 }
694
695 pub fn assert_any_ne(&mut self, l: &[QueryEntry], r: &[QueryEntry]) -> Result<(), QueryError> {
699 if l.len() != r.len() {
700 return Err(QueryError::MultiComparisonMismatch {
701 l: l.len(),
702 r: r.len(),
703 });
704 }
705
706 let mut ops = Vec::with_capacity(l.len() + r.len());
707 ops.extend_from_slice(l);
708 ops.extend_from_slice(r);
709 self.qb.instrs.push(Instr::AssertAnyNe {
710 ops,
711 divider: l.len(),
712 });
713 self.qb.mark_used(l);
714 self.qb.mark_used(r);
715 Ok(())
716 }
717
718 fn validate_row(
719 &self,
720 table: TableId,
721 info: &TableInfo,
722 vals: &[QueryEntry],
723 ) -> Result<(), QueryError> {
724 if vals.len() != info.spec.arity() {
725 Err(QueryError::TableArityMismatch {
726 table,
727 expected: info.spec.arity(),
728 got: vals.len(),
729 })
730 } else {
731 Ok(())
732 }
733 }
734
735 fn validate_keys(
736 &self,
737 table: TableId,
738 info: &TableInfo,
739 keys: &[QueryEntry],
740 ) -> Result<(), QueryError> {
741 if keys.len() != info.spec.n_keys {
742 Err(QueryError::KeyArityMismatch {
743 table,
744 expected: info.spec.n_keys,
745 got: keys.len(),
746 })
747 } else {
748 Ok(())
749 }
750 }
751
752 fn validate_vals<'b>(
753 &self,
754 table: TableId,
755 info: &TableInfo,
756 vals: impl Iterator<Item = &'b WriteVal>,
757 ) -> Result<(), QueryError> {
758 for (i, _) in vals.enumerate() {
759 let col = i + info.spec.n_keys;
760 if col >= info.spec.arity() {
761 return Err(QueryError::TableArityMismatch {
762 table,
763 expected: info.spec.arity(),
764 got: col,
765 });
766 }
767 }
768 Ok(())
769 }
770}
771
772#[derive(Debug, Clone)]
773pub(crate) struct Atom {
774 pub(crate) table: TableId,
775 pub(crate) var_to_column: HashMap<Variable, ColumnId>,
776 pub(crate) column_to_var: DenseIdMap<ColumnId, Variable>,
777 pub(crate) constraints: ProcessedConstraints,
783}
784
785pub(crate) struct Query {
786 pub(crate) var_info: DenseIdMap<Variable, VarInfo>,
787 pub(crate) atoms: DenseIdMap<AtomId, Atom>,
788 pub(crate) action: ActionId,
789 pub(crate) plan_strategy: PlanStrategy,
790}