1use std::sync::Arc;
2use std::sync::Mutex;
3
4use core_relations::{ExecutionState, ExternalFunction, Value};
5use egglog_bridge::{
6 ColumnTy, DefaultVal, FunctionConfig, FunctionId, MergeFn, RuleId, TableAction,
7};
8use egglog_reports::RunReport;
9use numeric_id::define_id;
10
11use crate::{ast::ResolvedVar, core::GenericAtomTerm, core::ResolvedCoreRule, util::IndexMap, *};
12
13pub trait Scheduler: dyn_clone::DynClone + Send + Sync {
18 fn can_stop(&mut self, rules: &[&str], ruleset: &str) -> bool {
24 let _ = (rules, ruleset);
25 true
26 }
27
28 fn filter_matches(&mut self, rule: &str, ruleset: &str, matches: &mut Matches) -> bool;
33}
34
35dyn_clone::clone_trait_object!(Scheduler);
36
37pub struct Matches {
40 matches: Vec<Value>,
41 chosen: Vec<usize>,
42 vars: Vec<ResolvedVar>,
43 all_chosen: bool,
44}
45
46pub struct Match<'a> {
49 values: &'a [Value],
50 vars: &'a [ResolvedVar],
51}
52
53impl Match<'_> {
54 pub fn get_value(&self, var: &str) -> Value {
56 let idx = self.vars.iter().position(|v| v.name == var).unwrap();
57 self.values[idx]
58 }
59}
60
61impl Matches {
62 fn new(matches: Vec<Value>, vars: Vec<ResolvedVar>) -> Self {
63 let total_len = matches.len();
64 let tuple_len = vars.len();
65 assert!(total_len % tuple_len == 0);
66 Self {
67 matches,
68 vars,
69 chosen: Vec::new(),
70 all_chosen: false,
71 }
72 }
73
74 pub fn match_size(&self) -> usize {
76 self.matches.len() / self.vars.len()
77 }
78
79 pub fn tuple_len(&self) -> usize {
81 self.vars.len()
82 }
83
84 pub fn get_match(&self, idx: usize) -> Match<'_> {
86 Match {
87 values: &self.matches[idx * self.tuple_len()..(idx + 1) * self.tuple_len()],
88 vars: &self.vars,
89 }
90 }
91
92 pub fn choose(&mut self, idx: usize) {
94 self.chosen.push(idx);
95 }
96
97 pub fn choose_all(&mut self) {
101 self.all_chosen = true;
102 }
103
104 fn instantiate(
106 mut self,
107 state: &mut ExecutionState<'_>,
108 mut table_action: TableAction,
109 ) -> Vec<Value> {
110 let tuple_len = self.tuple_len();
111 let unit = state.base_values().get(());
112
113 if self.all_chosen {
114 for row in self.matches.chunks(tuple_len) {
115 table_action.insert(state, row.iter().cloned().chain(std::iter::once(unit)));
116 }
117 vec![]
118 } else {
119 for idx in self.chosen.iter() {
120 let row = &self.matches[idx * tuple_len..(idx + 1) * tuple_len];
121 table_action.insert(state, row.iter().cloned().chain(std::iter::once(unit)));
122 }
123
124 self.chosen.sort_unstable();
126 self.chosen.dedup();
127 let mut p = self.match_size();
128 for c in self.chosen.into_iter().rev() {
129 p -= 1;
132 if c != p {
133 let idx_c = c * tuple_len;
134 let idx_p = p * tuple_len;
135 for i in 0..tuple_len {
136 self.matches.swap(idx_c + i, idx_p + i);
137 }
138 }
139 }
140 self.matches.truncate(p * tuple_len);
141
142 self.matches
143 }
144 }
145}
146
147define_id!(
148 pub SchedulerId, u32,
149 "A unique identifier for a scheduler in the EGraph."
150);
151
152impl EGraph {
153 pub fn add_scheduler(&mut self, scheduler: Box<dyn Scheduler>) -> SchedulerId {
155 self.schedulers.push(SchedulerRecord {
156 scheduler,
157 rule_info: Default::default(),
158 })
159 }
160
161 pub fn remove_scheduler(&mut self, scheduler_id: SchedulerId) -> Option<Box<dyn Scheduler>> {
163 self.schedulers.take(scheduler_id).map(|r| r.scheduler)
164 }
165
166 pub fn step_rules_with_scheduler(
168 &mut self,
169 scheduler_id: SchedulerId,
170 ruleset: &str,
171 ) -> Result<RunReport, Error> {
172 fn collect_rules<'a>(
173 ruleset: &str,
174 rulesets: &'a IndexMap<String, Ruleset>,
175 ids: &mut Vec<(String, &'a ResolvedCoreRule)>,
176 ) {
177 match &rulesets[ruleset] {
178 Ruleset::Rules(rules) => {
179 for (rule_name, (core_rule, _)) in rules.iter() {
180 ids.push((rule_name.clone(), core_rule));
181 }
182 }
183 Ruleset::Combined(sub_rulesets) => {
184 for sub_ruleset in sub_rulesets {
185 collect_rules(sub_ruleset, rulesets, ids);
186 }
187 }
188 }
189 }
190
191 let mut rules = Vec::new();
192 let rulesets = std::mem::take(&mut self.rulesets);
193 collect_rules(ruleset, &rulesets, &mut rules);
194 let mut schedulers = std::mem::take(&mut self.schedulers);
195
196 let record = &mut schedulers[scheduler_id];
198 rules.iter().for_each(|(id, rule)| {
199 record
200 .rule_info
201 .entry((*id).to_owned())
202 .or_insert_with(|| SchedulerRuleInfo::new(self, rule, id));
203 });
204
205 let query_rules = rules
207 .iter()
208 .filter_map(|(rule_id, _rule)| {
209 let rule_info = record.rule_info.get(rule_id).unwrap();
210
211 if rule_info.should_seek {
212 Some(rule_info.query_rule)
213 } else {
214 None
215 }
216 })
217 .collect::<Vec<_>>();
218
219 let query_iter_report = self
220 .backend
221 .run_rules(&query_rules)
222 .map_err(|e| Error::BackendError(e.to_string()))?;
223
224 self.backend.with_execution_state(|state| {
226 for (rule_id, _rule) in rules.iter() {
227 let rule_info = record.rule_info.get_mut(rule_id).unwrap();
228
229 let matches: Vec<Value> =
230 std::mem::take(rule_info.matches.lock().unwrap().as_mut());
231 let mut matches = Matches::new(matches, rule_info.free_vars.clone());
232 rule_info.should_seek =
233 record
234 .scheduler
235 .filter_matches(rule_id, ruleset, &mut matches);
236 let table_action = TableAction::new(&self.backend, rule_info.decided);
237 *rule_info.matches.lock().unwrap() = matches.instantiate(state, table_action);
238 }
239 });
240 self.backend.flush_updates();
241
242 let action_rules = rules
244 .iter()
245 .map(|(rule_id, _rule)| {
246 let rule_info = record.rule_info.get(rule_id).unwrap();
247 rule_info.action_rule
248 })
249 .collect::<Vec<_>>();
250 let action_iter_report = self
251 .backend
252 .run_rules(&action_rules)
253 .map_err(|e| Error::BackendError(e.to_string()))?;
254
255 let mut query_report = RunReport::singleton(ruleset, query_iter_report);
257 let mut action_report = RunReport::singleton(ruleset, action_iter_report);
258
259 query_report.updated = false;
261 query_report.num_matches_per_rule.clear();
262 action_report.updated = action_report.updated || {
264 let rule_ids = rules.iter().map(|(id, _)| id.as_str()).collect::<Vec<_>>();
265 !record.scheduler.can_stop(&rule_ids, ruleset)
266 };
267
268 query_report.union(action_report);
269
270 self.rulesets = rulesets;
271 self.schedulers = schedulers;
272
273 Ok(query_report)
274 }
275}
276
277#[derive(Clone)]
278pub(crate) struct SchedulerRecord {
279 scheduler: Box<dyn Scheduler>,
280 rule_info: HashMap<String, SchedulerRuleInfo>,
281}
282
283#[derive(Clone)]
288struct SchedulerRuleInfo {
289 matches: Arc<Mutex<Vec<Value>>>,
290 should_seek: bool,
291 decided: FunctionId,
292 query_rule: RuleId,
293 action_rule: RuleId,
294 free_vars: Vec<ResolvedVar>,
295}
296
297struct CollectMatches {
298 matches: Arc<Mutex<Vec<Value>>>,
299}
300
301impl Clone for CollectMatches {
302 fn clone(&self) -> Self {
303 Self {
304 matches: Arc::new(Mutex::new(self.matches.lock().unwrap().clone())),
305 }
306 }
307}
308
309impl CollectMatches {
310 fn new(matches: Arc<Mutex<Vec<Value>>>) -> Self {
311 Self { matches }
312 }
313}
314
315impl ExternalFunction for CollectMatches {
316 fn invoke(&self, state: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
317 self.matches.lock().unwrap().extend(args.iter().copied());
318 Some(state.base_values().get(()))
319 }
320}
321
322impl SchedulerRuleInfo {
323 fn new(egraph: &mut EGraph, rule: &ResolvedCoreRule, name: &str) -> SchedulerRuleInfo {
324 let free_vars = rule.head.get_free_vars().into_iter().collect::<Vec<_>>();
325 let unit_type = egraph.backend.base_values().get_ty::<()>();
326 let unit = egraph.backend.base_values().get(());
327 let unit_entry = egraph.backend.base_value_constant(());
328
329 let matches = Arc::new(Mutex::new(Vec::new()));
330 let collect_matches = egraph
331 .backend
332 .register_external_func(CollectMatches::new(matches.clone()));
333 let schema = free_vars
334 .iter()
335 .map(|v| v.sort.column_ty(&egraph.backend))
336 .chain(std::iter::once(ColumnTy::Base(unit_type)))
337 .collect();
338 let decided = egraph.backend.add_table(FunctionConfig {
339 schema,
340 default: DefaultVal::Const(unit),
341 merge: MergeFn::AssertEq,
342 name: "backend".to_string(),
343 can_subsume: false,
344 });
345
346 let mut qrule_builder = BackendRule::new(
348 egraph.backend.new_rule(name, true),
349 &egraph.functions,
350 &egraph.type_info,
351 );
352 qrule_builder.query(&rule.body, true);
353 let entries = free_vars
354 .iter()
355 .map(|fv| qrule_builder.entry(&GenericAtomTerm::Var(span!(), fv.clone())))
356 .collect::<Vec<_>>();
357 let _var = qrule_builder.rb.call_external_func(
358 collect_matches,
359 &entries,
360 ColumnTy::Base(unit_type),
361 || "collect_matches".to_string(),
362 );
363 let qrule_id = qrule_builder.build();
364
365 let mut arule_builder = BackendRule::new(
367 egraph.backend.new_rule(name, false),
368 &egraph.functions,
369 &egraph.type_info,
370 );
371 let mut entries = free_vars
372 .iter()
373 .map(|fv| arule_builder.entry(&GenericAtomTerm::Var(span!(), fv.clone())))
374 .collect::<Vec<_>>();
375 entries.push(unit_entry);
376 arule_builder
377 .rb
378 .query_table(decided, &entries, None)
379 .unwrap();
380 arule_builder.actions(&rule.head).unwrap();
381 entries.pop();
383 arule_builder.rb.remove(decided, &entries);
384 let arule_id = arule_builder.build();
385
386 SchedulerRuleInfo {
387 free_vars,
388 query_rule: qrule_id,
389 action_rule: arule_id,
390 matches,
391 decided,
392 should_seek: true,
393 }
394 }
395}
396
397#[cfg(test)]
398mod test {
399 use super::*;
400
401 #[derive(Clone)]
402 struct FirstNScheduler {
403 n: usize,
404 }
405
406 impl Scheduler for FirstNScheduler {
407 fn filter_matches(&mut self, _rule: &str, _ruleset: &str, matches: &mut Matches) -> bool {
408 if matches.match_size() <= self.n {
409 matches.choose_all();
410 } else {
411 for i in 0..self.n {
412 matches.choose(i);
413 }
414 }
415 matches.match_size() < self.n * 2
416 }
417 }
418
419 #[test]
420 fn test_first_n_scheduler() {
421 let mut egraph = EGraph::default();
422 let scheduler = FirstNScheduler { n: 10 };
423 let scheduler_id = egraph.add_scheduler(Box::new(scheduler));
424 let input = r#"
425 (relation R (i64))
426 (R 0)
427 (rule ((R x) (< x 100)) ((R (+ x 1))))
428 (run-schedule (saturate (run)))
429
430 (ruleset test)
431 (relation S (i64))
432 (rule ((R x)) ((S x)) :ruleset test :name "test-rule")
433 "#;
434 egraph.parse_and_run_program(None, input).unwrap();
435 assert_eq!(egraph.get_size("R"), 101);
436 let mut iter = 0;
437 loop {
438 let report = egraph
439 .step_rules_with_scheduler(scheduler_id, "test")
440 .unwrap();
441 let table_size = egraph.get_size("S");
442 iter += 1;
443 assert_eq!(table_size, std::cmp::min(iter * 10, 101));
444
445 let expected_matches = if iter <= 10 { 10 } else { 12 - iter };
446 assert_eq!(
447 report.num_matches_per_rule.iter().collect::<Vec<_>>(),
448 [(&"test-rule".into(), &expected_matches)]
449 );
450
451 assert!(
453 report
454 .search_and_apply_time_per_rule
455 .keys()
456 .all(|k| k.starts_with("test-rule"))
457 );
458 assert_eq!(
459 report.merge_time_per_ruleset.keys().collect::<Vec<_>>(),
460 [&"test".into()]
461 );
462 assert_eq!(
463 report
464 .search_and_apply_time_per_ruleset
465 .keys()
466 .collect::<Vec<_>>(),
467 [&"test".into()]
468 );
469
470 if !report.updated {
471 break;
472 }
473 }
474
475 assert_eq!(iter, 12);
476 }
477}