egglog/
scheduler.rs

1use std::sync::Arc;
2use std::sync::Mutex;
3
4use core_relations::{ExecutionState, ExternalFunction, Value};
5use egglog_bridge::{
6    ColumnTy, DefaultVal, FunctionConfig, FunctionId, MergeFn, RuleId, TableAction,
7};
8use egglog_reports::RunReport;
9use numeric_id::define_id;
10
11use crate::{ast::ResolvedVar, core::GenericAtomTerm, core::ResolvedCoreRule, util::IndexMap, *};
12
13/// A scheduler decides which matches to be applied for a rule.
14///
15/// The matches that are not chosen in this iteration will be delayed
16/// to the next iteration.
17pub trait Scheduler: dyn_clone::DynClone + Send + Sync {
18    /// Whether or not the rules can be considered as saturated (i.e.,
19    /// `run_report.updated == false`).
20    ///
21    /// This is only called when the runner is otherwise saturated.
22    /// Default implementation just returns `true`.
23    fn can_stop(&mut self, rules: &[&str], ruleset: &str) -> bool {
24        let _ = (rules, ruleset);
25        true
26    }
27
28    /// Filter the matches for a rule.
29    ///
30    /// Return `true` if the scheduler's next run of the rule should feed
31    /// `filter_matches` with a new iteration of matches.
32    fn filter_matches(&mut self, rule: &str, ruleset: &str, matches: &mut Matches) -> bool;
33}
34
35dyn_clone::clone_trait_object!(Scheduler);
36
37/// A collection of matches produced by a rule.
38/// The user can choose which matches to be fired.
39pub struct Matches {
40    matches: Vec<Value>,
41    chosen: Vec<usize>,
42    vars: Vec<ResolvedVar>,
43    all_chosen: bool,
44}
45
46/// A match is a tuple of values corresponding to the variables in a rule.
47/// It allows you to retrieve the value corresponding to a variable in the match.
48pub struct Match<'a> {
49    values: &'a [Value],
50    vars: &'a [ResolvedVar],
51}
52
53impl Match<'_> {
54    /// Get the value corresponding a variable in this match.
55    pub fn get_value(&self, var: &str) -> Value {
56        let idx = self.vars.iter().position(|v| v.name == var).unwrap();
57        self.values[idx]
58    }
59}
60
61impl Matches {
62    fn new(matches: Vec<Value>, vars: Vec<ResolvedVar>) -> Self {
63        let total_len = matches.len();
64        let tuple_len = vars.len();
65        assert!(total_len % tuple_len == 0);
66        Self {
67            matches,
68            vars,
69            chosen: Vec::new(),
70            all_chosen: false,
71        }
72    }
73
74    /// The number of matches in total.
75    pub fn match_size(&self) -> usize {
76        self.matches.len() / self.vars.len()
77    }
78
79    /// The length of a tuple.
80    pub fn tuple_len(&self) -> usize {
81        self.vars.len()
82    }
83
84    /// Get `idx`-th match.
85    pub fn get_match(&self, idx: usize) -> Match<'_> {
86        Match {
87            values: &self.matches[idx * self.tuple_len()..(idx + 1) * self.tuple_len()],
88            vars: &self.vars,
89        }
90    }
91
92    /// Pick the match at `idx` to be fired.
93    pub fn choose(&mut self, idx: usize) {
94        self.chosen.push(idx);
95    }
96
97    /// Pick all matches to be fired.
98    ///
99    /// This is more efficient than calling `choose` for each match.
100    pub fn choose_all(&mut self) {
101        self.all_chosen = true;
102    }
103
104    /// Apply the chosen matches and return the residual matches.
105    fn instantiate(
106        mut self,
107        state: &mut ExecutionState<'_>,
108        mut table_action: TableAction,
109    ) -> Vec<Value> {
110        let tuple_len = self.tuple_len();
111        let unit = state.base_values().get(());
112
113        if self.all_chosen {
114            for row in self.matches.chunks(tuple_len) {
115                table_action.insert(state, row.iter().cloned().chain(std::iter::once(unit)));
116            }
117            vec![]
118        } else {
119            for idx in self.chosen.iter() {
120                let row = &self.matches[idx * tuple_len..(idx + 1) * tuple_len];
121                table_action.insert(state, row.iter().cloned().chain(std::iter::once(unit)));
122            }
123
124            // swap remove the chosen matches
125            self.chosen.sort_unstable();
126            self.chosen.dedup();
127            let mut p = self.match_size();
128            for c in self.chosen.into_iter().rev() {
129                // It's important to decrement `p` first, because otherwise it might underflow when
130                // matches are exhausted.
131                p -= 1;
132                if c != p {
133                    let idx_c = c * tuple_len;
134                    let idx_p = p * tuple_len;
135                    for i in 0..tuple_len {
136                        self.matches.swap(idx_c + i, idx_p + i);
137                    }
138                }
139            }
140            self.matches.truncate(p * tuple_len);
141
142            self.matches
143        }
144    }
145}
146
147define_id!(
148    pub SchedulerId, u32,
149    "A unique identifier for a scheduler in the EGraph."
150);
151
152impl EGraph {
153    /// Register a new scheduler and return its id.
154    pub fn add_scheduler(&mut self, scheduler: Box<dyn Scheduler>) -> SchedulerId {
155        self.schedulers.push(SchedulerRecord {
156            scheduler,
157            rule_info: Default::default(),
158        })
159    }
160
161    /// Removes a scheduler
162    pub fn remove_scheduler(&mut self, scheduler_id: SchedulerId) -> Option<Box<dyn Scheduler>> {
163        self.schedulers.take(scheduler_id).map(|r| r.scheduler)
164    }
165
166    /// Runs a ruleset for one iteration using the given ruleset
167    pub fn step_rules_with_scheduler(
168        &mut self,
169        scheduler_id: SchedulerId,
170        ruleset: &str,
171    ) -> Result<RunReport, Error> {
172        fn collect_rules<'a>(
173            ruleset: &str,
174            rulesets: &'a IndexMap<String, Ruleset>,
175            ids: &mut Vec<(String, &'a ResolvedCoreRule)>,
176        ) {
177            match &rulesets[ruleset] {
178                Ruleset::Rules(rules) => {
179                    for (rule_name, (core_rule, _)) in rules.iter() {
180                        ids.push((rule_name.clone(), core_rule));
181                    }
182                }
183                Ruleset::Combined(sub_rulesets) => {
184                    for sub_ruleset in sub_rulesets {
185                        collect_rules(sub_ruleset, rulesets, ids);
186                    }
187                }
188            }
189        }
190
191        let mut rules = Vec::new();
192        let rulesets = std::mem::take(&mut self.rulesets);
193        collect_rules(ruleset, &rulesets, &mut rules);
194        let mut schedulers = std::mem::take(&mut self.schedulers);
195
196        // Step 1: build all the query/action rules and worklist if have not already
197        let record = &mut schedulers[scheduler_id];
198        rules.iter().for_each(|(id, rule)| {
199            record
200                .rule_info
201                .entry((*id).to_owned())
202                .or_insert_with(|| SchedulerRuleInfo::new(self, rule, id));
203        });
204
205        // Step 2: run all the queries for one iteration
206        let query_rules = rules
207            .iter()
208            .filter_map(|(rule_id, _rule)| {
209                let rule_info = record.rule_info.get(rule_id).unwrap();
210
211                if rule_info.should_seek {
212                    Some(rule_info.query_rule)
213                } else {
214                    None
215                }
216            })
217            .collect::<Vec<_>>();
218
219        let query_iter_report = self
220            .backend
221            .run_rules(&query_rules)
222            .map_err(|e| Error::BackendError(e.to_string()))?;
223
224        // Step 3: let the scheduler decide which matches need to be kept
225        self.backend.with_execution_state(|state| {
226            for (rule_id, _rule) in rules.iter() {
227                let rule_info = record.rule_info.get_mut(rule_id).unwrap();
228
229                let matches: Vec<Value> =
230                    std::mem::take(rule_info.matches.lock().unwrap().as_mut());
231                let mut matches = Matches::new(matches, rule_info.free_vars.clone());
232                rule_info.should_seek =
233                    record
234                        .scheduler
235                        .filter_matches(rule_id, ruleset, &mut matches);
236                let table_action = TableAction::new(&self.backend, rule_info.decided);
237                *rule_info.matches.lock().unwrap() = matches.instantiate(state, table_action);
238            }
239        });
240        self.backend.flush_updates();
241
242        // Step 4: run the action rules
243        let action_rules = rules
244            .iter()
245            .map(|(rule_id, _rule)| {
246                let rule_info = record.rule_info.get(rule_id).unwrap();
247                rule_info.action_rule
248            })
249            .collect::<Vec<_>>();
250        let action_iter_report = self
251            .backend
252            .run_rules(&action_rules)
253            .map_err(|e| Error::BackendError(e.to_string()))?;
254
255        // Step 5: combine the reports
256        let mut query_report = RunReport::singleton(ruleset, query_iter_report);
257        let mut action_report = RunReport::singleton(ruleset, action_iter_report);
258
259        // query matches don't count
260        query_report.updated = false;
261        query_report.num_matches_per_rule.clear();
262        // if the scheduler says it shouldn't stop, then it's considered updated (unsaturated)
263        action_report.updated = action_report.updated || {
264            let rule_ids = rules.iter().map(|(id, _)| id.as_str()).collect::<Vec<_>>();
265            !record.scheduler.can_stop(&rule_ids, ruleset)
266        };
267
268        query_report.union(action_report);
269
270        self.rulesets = rulesets;
271        self.schedulers = schedulers;
272
273        Ok(query_report)
274    }
275}
276
277#[derive(Clone)]
278pub(crate) struct SchedulerRecord {
279    scheduler: Box<dyn Scheduler>,
280    rule_info: HashMap<String, SchedulerRuleInfo>,
281}
282
283/// To enable scheduling without modifying the backend,
284/// we split a rule (rule query action) into a worklist relation
285/// two rules (rule query (worklist vars false)) and
286/// (rule (worklist vars false) (action ... (delete (worklist vars false))))
287#[derive(Clone)]
288struct SchedulerRuleInfo {
289    matches: Arc<Mutex<Vec<Value>>>,
290    should_seek: bool,
291    decided: FunctionId,
292    query_rule: RuleId,
293    action_rule: RuleId,
294    free_vars: Vec<ResolvedVar>,
295}
296
297struct CollectMatches {
298    matches: Arc<Mutex<Vec<Value>>>,
299}
300
301impl Clone for CollectMatches {
302    fn clone(&self) -> Self {
303        Self {
304            matches: Arc::new(Mutex::new(self.matches.lock().unwrap().clone())),
305        }
306    }
307}
308
309impl CollectMatches {
310    fn new(matches: Arc<Mutex<Vec<Value>>>) -> Self {
311        Self { matches }
312    }
313}
314
315impl ExternalFunction for CollectMatches {
316    fn invoke(&self, state: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
317        self.matches.lock().unwrap().extend(args.iter().copied());
318        Some(state.base_values().get(()))
319    }
320}
321
322impl SchedulerRuleInfo {
323    fn new(egraph: &mut EGraph, rule: &ResolvedCoreRule, name: &str) -> SchedulerRuleInfo {
324        let free_vars = rule.head.get_free_vars().into_iter().collect::<Vec<_>>();
325        let unit_type = egraph.backend.base_values().get_ty::<()>();
326        let unit = egraph.backend.base_values().get(());
327        let unit_entry = egraph.backend.base_value_constant(());
328
329        let matches = Arc::new(Mutex::new(Vec::new()));
330        let collect_matches = egraph
331            .backend
332            .register_external_func(CollectMatches::new(matches.clone()));
333        let schema = free_vars
334            .iter()
335            .map(|v| v.sort.column_ty(&egraph.backend))
336            .chain(std::iter::once(ColumnTy::Base(unit_type)))
337            .collect();
338        let decided = egraph.backend.add_table(FunctionConfig {
339            schema,
340            default: DefaultVal::Const(unit),
341            merge: MergeFn::AssertEq,
342            name: "backend".to_string(),
343            can_subsume: false,
344        });
345
346        // Step 1: build the query rule
347        let mut qrule_builder = BackendRule::new(
348            egraph.backend.new_rule(name, true),
349            &egraph.functions,
350            &egraph.type_info,
351        );
352        qrule_builder.query(&rule.body, true);
353        let entries = free_vars
354            .iter()
355            .map(|fv| qrule_builder.entry(&GenericAtomTerm::Var(span!(), fv.clone())))
356            .collect::<Vec<_>>();
357        let _var = qrule_builder.rb.call_external_func(
358            collect_matches,
359            &entries,
360            ColumnTy::Base(unit_type),
361            || "collect_matches".to_string(),
362        );
363        let qrule_id = qrule_builder.build();
364
365        // Step 2: build the action rule
366        let mut arule_builder = BackendRule::new(
367            egraph.backend.new_rule(name, false),
368            &egraph.functions,
369            &egraph.type_info,
370        );
371        let mut entries = free_vars
372            .iter()
373            .map(|fv| arule_builder.entry(&GenericAtomTerm::Var(span!(), fv.clone())))
374            .collect::<Vec<_>>();
375        entries.push(unit_entry);
376        arule_builder
377            .rb
378            .query_table(decided, &entries, None)
379            .unwrap();
380        arule_builder.actions(&rule.head).unwrap();
381        // Remove the entry as it's now done
382        entries.pop();
383        arule_builder.rb.remove(decided, &entries);
384        let arule_id = arule_builder.build();
385
386        SchedulerRuleInfo {
387            free_vars,
388            query_rule: qrule_id,
389            action_rule: arule_id,
390            matches,
391            decided,
392            should_seek: true,
393        }
394    }
395}
396
397#[cfg(test)]
398mod test {
399    use super::*;
400
401    #[derive(Clone)]
402    struct FirstNScheduler {
403        n: usize,
404    }
405
406    impl Scheduler for FirstNScheduler {
407        fn filter_matches(&mut self, _rule: &str, _ruleset: &str, matches: &mut Matches) -> bool {
408            if matches.match_size() <= self.n {
409                matches.choose_all();
410            } else {
411                for i in 0..self.n {
412                    matches.choose(i);
413                }
414            }
415            matches.match_size() < self.n * 2
416        }
417    }
418
419    #[test]
420    fn test_first_n_scheduler() {
421        let mut egraph = EGraph::default();
422        let scheduler = FirstNScheduler { n: 10 };
423        let scheduler_id = egraph.add_scheduler(Box::new(scheduler));
424        let input = r#"
425        (relation R (i64))
426        (R 0)
427        (rule ((R x) (< x 100)) ((R (+ x 1))))
428        (run-schedule (saturate (run)))
429
430        (ruleset test)
431        (relation S (i64))
432        (rule ((R x)) ((S x)) :ruleset test :name "test-rule")
433        "#;
434        egraph.parse_and_run_program(None, input).unwrap();
435        assert_eq!(egraph.get_size("R"), 101);
436        let mut iter = 0;
437        loop {
438            let report = egraph
439                .step_rules_with_scheduler(scheduler_id, "test")
440                .unwrap();
441            let table_size = egraph.get_size("S");
442            iter += 1;
443            assert_eq!(table_size, std::cmp::min(iter * 10, 101));
444
445            let expected_matches = if iter <= 10 { 10 } else { 12 - iter };
446            assert_eq!(
447                report.num_matches_per_rule.iter().collect::<Vec<_>>(),
448                [(&"test-rule".into(), &expected_matches)]
449            );
450
451            // Because of semi-naive, the exact rules that are run are more than just `test-rule`
452            assert!(
453                report
454                    .search_and_apply_time_per_rule
455                    .keys()
456                    .all(|k| k.starts_with("test-rule"))
457            );
458            assert_eq!(
459                report.merge_time_per_ruleset.keys().collect::<Vec<_>>(),
460                [&"test".into()]
461            );
462            assert_eq!(
463                report
464                    .search_and_apply_time_per_ruleset
465                    .keys()
466                    .collect::<Vec<_>>(),
467                [&"test".into()]
468            );
469
470            if !report.updated {
471                break;
472            }
473        }
474
475        assert_eq!(iter, 12);
476    }
477}