1use crate::termdag::{Term, TermDag};
2use crate::util::{HashMap, HashSet};
3use crate::*;
4use std::collections::VecDeque;
5
6pub trait CostModel<C: Cost> {
16 fn fold(&self, head: &str, children_cost: &[C], head_cost: C) -> C;
18
19 fn enode_cost(&self, egraph: &EGraph, func: &Function, row: &egglog_bridge::FunctionRow) -> C;
21
22 fn container_cost(
26 &self,
27 egraph: &EGraph,
28 sort: &ArcSort,
29 value: Value,
30 element_costs: &[C],
31 ) -> C {
32 let _egraph = egraph;
33 let _sort = sort;
34 let _value = value;
35 element_costs
36 .iter()
37 .fold(C::identity(), |s, c| s.combine(c))
38 }
39
40 fn base_value_cost(&self, egraph: &EGraph, sort: &ArcSort, value: Value) -> C {
44 let _egraph = egraph;
45 let _sort = sort;
46 let _value = value;
47 C::unit()
48 }
49}
50
51pub trait Cost {
53 fn identity() -> Self;
55
56 fn unit() -> Self;
58
59 fn combine(self, other: &Self) -> Self;
62}
63
64macro_rules! cost_impl_int {
65 ($($cost:ty),*) => {$(
66 impl Cost for $cost {
67 fn identity() -> Self { 0 }
68 fn unit() -> Self { 1 }
69 fn combine(self, other: &Self) -> Self {
70 self.saturating_add(*other)
71 }
72 }
73 )*};
74}
75cost_impl_int!(u8, u16, u32, u64, u128, usize);
76cost_impl_int!(i8, i16, i32, i64, i128, isize);
77
78macro_rules! cost_impl_num {
79 ($($cost:ty),*) => {$(
80 impl Cost for $cost {
81 fn identity() -> Self {
82 use num::Zero;
83 Self::zero()
84 }
85 fn unit() -> Self {
86 use num::One;
87 Self::one()
88 }
89 fn combine(self, other: &Self) -> Self {
90 self + other
91 }
92 }
93 )*};
94}
95cost_impl_num!(num::BigInt, num::BigRational);
96use ordered_float::OrderedFloat;
97cost_impl_num!(f32, f64, OrderedFloat<f32>, OrderedFloat<f64>);
98
99pub type DefaultCost = u64;
100
101#[derive(Default, Clone)]
103pub struct TreeAdditiveCostModel {}
104
105impl CostModel<DefaultCost> for TreeAdditiveCostModel {
106 fn fold(
107 &self,
108 _head: &str,
109 children_cost: &[DefaultCost],
110 head_cost: DefaultCost,
111 ) -> DefaultCost {
112 children_cost.iter().fold(head_cost, |s, c| s.combine(c))
113 }
114
115 fn enode_cost(
116 &self,
117 _egraph: &EGraph,
118 func: &Function,
119 _row: &egglog_bridge::FunctionRow,
120 ) -> DefaultCost {
121 func.decl.cost.unwrap_or(DefaultCost::unit())
122 }
123}
124
125pub struct Extractor<C: Cost + Ord + Eq + Clone + Debug> {
127 rootsorts: Vec<ArcSort>,
128 funcs: Vec<String>,
129 cost_model: Box<dyn CostModel<C>>,
130 costs: HashMap<String, HashMap<Value, C>>,
131 topo_rnk_cnt: usize,
132 topo_rnk: HashMap<String, HashMap<Value, usize>>,
133 parent_edge: HashMap<String, HashMap<Value, (String, Vec<Value>)>>,
134}
135
136impl<C: Cost + Ord + Eq + Clone + Debug> Extractor<C> {
137 pub fn compute_costs_from_rootsorts(
144 rootsorts: Option<Vec<ArcSort>>,
145 egraph: &EGraph,
146 cost_model: impl CostModel<C> + 'static,
147 ) -> Self {
148 let extract_all_sorts = rootsorts.is_none();
150
151 let mut rootsorts = rootsorts.unwrap_or_default();
152
153 let mut rev_index: HashMap<String, Vec<String>> = Default::default();
155 for func in egraph.functions.iter() {
156 if !func.1.decl.unextractable {
157 let func_name = func.0.clone();
158 let output_sort_name = func.1.schema.output.name();
159 if let Some(v) = rev_index.get_mut(output_sort_name) {
160 v.push(func_name);
161 } else {
162 rev_index.insert(output_sort_name.to_owned(), vec![func_name]);
163 if extract_all_sorts {
164 rootsorts.push(func.1.schema.output.clone());
165 }
166 }
167 }
168 }
169
170 let mut q: VecDeque<ArcSort> = VecDeque::new();
172 let mut seen: HashSet<String> = Default::default();
173 for rootsort in rootsorts.iter() {
174 q.push_back(rootsort.clone());
175 seen.insert(rootsort.name().to_owned());
176 }
177
178 let mut funcs_set: HashSet<String> = Default::default();
179 let mut funcs: Vec<String> = Vec::new();
180 while !q.is_empty() {
181 let sort = q.pop_front().unwrap();
182 if sort.is_container_sort() {
183 let inner_sorts = sort.inner_sorts();
184 for s in inner_sorts {
185 if !seen.contains(s.name()) {
186 q.push_back(s.clone());
187 seen.insert(s.name().to_owned());
188 }
189 }
190 } else if sort.is_eq_sort() {
191 if let Some(head_symbols) = rev_index.get(sort.name()) {
192 for h in head_symbols {
193 if !funcs_set.contains(h) {
194 let func = egraph.functions.get(h).unwrap();
195 for ch in &func.schema.input {
196 let ch_name = ch.name();
197 if !seen.contains(ch_name) {
198 q.push_back(ch.clone());
199 seen.insert(ch_name.to_owned());
200 }
201 }
202 funcs_set.insert(h.clone());
203 funcs.push(h.clone());
204 }
205 }
206 }
207 }
208 }
209
210 let mut costs: HashMap<String, HashMap<Value, C>> = Default::default();
212 let mut topo_rnk: HashMap<String, HashMap<Value, usize>> = Default::default();
213 let mut parent_edge: HashMap<String, HashMap<Value, (String, Vec<Value>)>> =
214 Default::default();
215
216 for func_name in funcs.iter() {
217 let func = egraph.functions.get(func_name).unwrap();
218 if !costs.contains_key(func.schema.output.name()) {
219 debug_assert!(func.schema.output.is_eq_sort());
220 costs.insert(func.schema.output.name().to_owned(), Default::default());
221 topo_rnk.insert(func.schema.output.name().to_owned(), Default::default());
222 parent_edge.insert(func.schema.output.name().to_owned(), Default::default());
223 }
224 }
225
226 let mut extractor = Extractor {
227 rootsorts,
228 funcs,
229 cost_model: Box::new(cost_model),
230 costs,
231 topo_rnk_cnt: 0,
232 topo_rnk,
233 parent_edge,
234 };
235
236 extractor.bellman_ford(egraph);
237
238 extractor
239 }
240
241 fn compute_cost_node(&self, egraph: &EGraph, value: Value, sort: &ArcSort) -> Option<C> {
245 if sort.is_container_sort() {
246 let elements = sort.inner_values(egraph.backend.container_values(), value);
247 let mut ch_costs: Vec<C> = Vec::new();
248 for ch in elements.iter() {
249 if let Some(c) = self.compute_cost_node(egraph, ch.1, &ch.0) {
250 ch_costs.push(c);
251 } else {
252 return None;
253 }
254 }
255 Some(
256 self.cost_model
257 .container_cost(egraph, sort, value, &ch_costs),
258 )
259 } else if sort.is_eq_sort() {
260 if self
261 .costs
262 .get(sort.name())
263 .is_some_and(|t| t.get(&value).is_some())
264 {
265 Some(
266 self.costs
267 .get(sort.name())
268 .unwrap()
269 .get(&value)
270 .unwrap()
271 .clone(),
272 )
273 } else {
274 None
275 }
276 } else {
277 Some(self.cost_model.base_value_cost(egraph, sort, value))
279 }
280 }
281
282 fn compute_cost_hyperedge(
284 &self,
285 egraph: &EGraph,
286 row: &egglog_bridge::FunctionRow,
287 func: &Function,
288 ) -> Option<C> {
289 let mut ch_costs: Vec<C> = Vec::new();
290 let sorts = &func.schema.input;
291 for (value, sort) in row.vals.iter().zip(sorts.iter()) {
294 if let Some(c) = self.compute_cost_node(egraph, *value, sort) {
295 ch_costs.push(c);
296 } else {
297 return None;
298 }
299 }
300 Some(self.cost_model.fold(
301 &func.decl.name,
302 &ch_costs,
303 self.cost_model.enode_cost(egraph, func, row),
304 ))
305 }
306
307 fn compute_topo_rnk_node(&self, egraph: &EGraph, value: Value, sort: &ArcSort) -> usize {
308 if sort.is_container_sort() {
309 sort.inner_values(egraph.backend.container_values(), value)
310 .iter()
311 .fold(0, |ret, (sort, value)| {
312 usize::max(ret, self.compute_topo_rnk_node(egraph, *value, sort))
313 })
314 } else if sort.is_eq_sort() {
315 if let Some(t) = self.topo_rnk.get(sort.name()) {
316 *t.get(&value).unwrap_or(&usize::MAX)
317 } else {
318 usize::MAX
319 }
320 } else {
321 0
322 }
323 }
324
325 fn compute_topo_rnk_hyperedge(
326 &self,
327 egraph: &EGraph,
328 row: &egglog_bridge::FunctionRow,
329 func: &Function,
330 ) -> usize {
331 let sorts = &func.schema.input;
332 row.vals
333 .iter()
334 .zip(sorts.iter())
335 .fold(0, |ret, (value, sort)| {
336 usize::max(ret, self.compute_topo_rnk_node(egraph, *value, sort))
337 })
338 }
339
340 fn bellman_ford(&mut self, egraph: &EGraph) {
351 let mut ensure_fixpoint = false;
352
353 let funcs = self.funcs.clone();
354
355 while !ensure_fixpoint {
356 ensure_fixpoint = true;
357
358 for func_name in funcs.iter() {
359 let func = egraph.functions.get(func_name).unwrap();
360 let target_sort = func.schema.output.clone();
361
362 let relax_hyperedge = |row: egglog_bridge::FunctionRow| {
363 log::debug!("Relaxing a new hyperedge: {:?}", row);
364 if !row.subsumed {
365 let target = row.vals.last().unwrap();
366 let mut updated = false;
367 if let Some(new_cost) = self.compute_cost_hyperedge(egraph, &row, func) {
368 match self
369 .costs
370 .get_mut(target_sort.name())
371 .unwrap()
372 .entry(*target)
373 {
374 HEntry::Vacant(e) => {
375 updated = true;
376 e.insert(new_cost);
377 }
378 HEntry::Occupied(mut e) => {
379 if new_cost < *(e.get()) {
380 updated = true;
381 e.insert(new_cost);
382 }
383 }
384 }
385 }
386 if updated {
390 ensure_fixpoint = false;
391 self.topo_rnk_cnt += 1;
392 self.topo_rnk
393 .get_mut(target_sort.name())
394 .unwrap()
395 .insert(*target, self.topo_rnk_cnt);
396 }
397 }
398 };
399
400 egraph.backend.for_each(func.backend_id, relax_hyperedge);
401 }
402 }
403
404 for func_name in funcs.iter() {
406 let func = egraph.functions.get(func_name).unwrap();
407 let target_sort = func.schema.output.clone();
408
409 let save_best_parent_edge = |row: egglog_bridge::FunctionRow| {
410 if !row.subsumed {
411 let target = row.vals.last().unwrap();
412 if let Some(best_cost) = self.costs.get(target_sort.name()).unwrap().get(target)
413 {
414 if Some(best_cost.clone())
415 == self.compute_cost_hyperedge(egraph, &row, func)
416 {
417 let target_topo_rnk = *self
419 .topo_rnk
420 .get(target_sort.name())
421 .unwrap()
422 .get(target)
423 .unwrap();
424 if target_topo_rnk > self.compute_topo_rnk_hyperedge(egraph, &row, func)
425 {
426 if let HEntry::Vacant(e) = self
428 .parent_edge
429 .get_mut(target_sort.name())
430 .unwrap()
431 .entry(*target)
432 {
433 e.insert((func.decl.name.clone(), row.vals.to_vec()));
434 }
435 }
436 }
437 }
438 }
439 };
440
441 egraph
442 .backend
443 .for_each(func.backend_id, save_best_parent_edge);
444 }
445 }
446
447 fn reconstruct_termdag_node(
449 &self,
450 egraph: &EGraph,
451 termdag: &mut TermDag,
452 value: Value,
453 sort: &ArcSort,
454 ) -> Term {
455 self.reconstruct_termdag_node_helper(egraph, termdag, value, sort, &mut Default::default())
456 }
457
458 fn reconstruct_termdag_node_helper(
459 &self,
460 egraph: &EGraph,
461 termdag: &mut TermDag,
462 value: Value,
463 sort: &ArcSort,
464 cache: &mut HashMap<(Value, String), Term>,
465 ) -> Term {
466 let key = (value, sort.name().to_owned());
467 if let Some(term) = cache.get(&key) {
468 return term.clone();
469 }
470
471 let term = if sort.is_container_sort() {
472 let elements = sort.inner_values(egraph.backend.container_values(), value);
473 let mut ch_terms: Vec<Term> = Vec::new();
474 for ch in elements.iter() {
475 ch_terms.push(
476 self.reconstruct_termdag_node_helper(egraph, termdag, ch.1, &ch.0, cache),
477 );
478 }
479 sort.reconstruct_termdag_container(
480 egraph.backend.container_values(),
481 value,
482 termdag,
483 ch_terms,
484 )
485 } else if sort.is_eq_sort() {
486 let (func_name, hyperedge) = self
487 .parent_edge
488 .get(sort.name())
489 .unwrap()
490 .get(&value)
491 .unwrap();
492 let mut ch_terms: Vec<Term> = Vec::new();
493 let ch_sorts = &egraph.functions.get(func_name).unwrap().schema.input;
494 for (value, sort) in hyperedge.iter().zip(ch_sorts.iter()) {
495 ch_terms.push(
496 self.reconstruct_termdag_node_helper(egraph, termdag, *value, sort, cache),
497 );
498 }
499 termdag.app(func_name.clone(), ch_terms)
500 } else {
501 sort.reconstruct_termdag_base(egraph.backend.base_values(), value, termdag)
503 };
504
505 cache.insert(key, term.clone());
506 term
507 }
508
509 pub fn extract_best_with_sort(
514 &self,
515 egraph: &EGraph,
516 termdag: &mut TermDag,
517 value: Value,
518 sort: ArcSort,
519 ) -> Option<(C, Term)> {
520 match self.compute_cost_node(egraph, value, &sort) {
521 Some(best_cost) => {
522 log::debug!("Best cost for the extract root: {:?}", best_cost);
523
524 let term = self.reconstruct_termdag_node(egraph, termdag, value, &sort);
525
526 Some((best_cost, term))
527 }
528 None => {
529 log::error!("Unextractable root {:?} with sort {:?}", value, sort,);
530 None
531 }
532 }
533 }
534
535 pub fn extract_best(
539 &self,
540 egraph: &EGraph,
541 termdag: &mut TermDag,
542 value: Value,
543 ) -> Option<(C, Term)> {
544 assert!(
545 self.rootsorts.len() == 1,
546 "extract_best requires a single rootsort"
547 );
548 self.extract_best_with_sort(
549 egraph,
550 termdag,
551 value,
552 self.rootsorts.first().unwrap().clone(),
553 )
554 }
555
556 pub fn extract_variants_with_sort(
561 &self,
562 egraph: &EGraph,
563 termdag: &mut TermDag,
564 value: Value,
565 nvariants: usize,
566 sort: ArcSort,
567 ) -> Vec<(C, Term)> {
568 debug_assert!(self.rootsorts.iter().any(|s| { s.name() == sort.name() }));
569
570 if sort.is_eq_sort() {
571 let mut root_variants: Vec<(C, String, Vec<Value>)> = Vec::new();
572
573 let mut root_funcs: Vec<String> = Vec::new();
574
575 for func_name in self.funcs.iter() {
576 if sort.name()
578 == egraph
579 .functions
580 .get(func_name)
581 .unwrap()
582 .schema
583 .output
584 .name()
585 {
586 root_funcs.push(func_name.clone());
587 }
588 }
589
590 for func_name in root_funcs.iter() {
591 let func = egraph.functions.get(func_name).unwrap();
592
593 let find_root_variants = |row: egglog_bridge::FunctionRow| {
594 if !row.subsumed {
595 let target = row.vals.last().unwrap();
596 if *target == value {
597 let cost = self.compute_cost_hyperedge(egraph, &row, func).unwrap();
598 root_variants.push((cost, func_name.clone(), row.vals.to_vec()));
599 }
600 }
601 };
602
603 egraph.backend.for_each(func.backend_id, find_root_variants);
604 }
605
606 let mut res: Vec<(C, Term)> = Vec::new();
607 root_variants.sort();
608 root_variants.truncate(nvariants);
609 for (cost, func_name, hyperedge) in root_variants {
610 let mut ch_terms: Vec<Term> = Vec::new();
611 let ch_sorts = &egraph.functions.get(&func_name).unwrap().schema.input;
612 for (value, sort) in hyperedge.iter().zip(ch_sorts.iter()) {
614 ch_terms.push(self.reconstruct_termdag_node(egraph, termdag, *value, sort));
615 }
616 res.push((cost, termdag.app(func_name, ch_terms)));
617 }
618
619 res
620 } else {
621 log::warn!(
622 "extracting multiple variants for containers or primitives is not implemented, returning a single variant."
623 );
624 if let Some(res) = self.extract_best_with_sort(egraph, termdag, value, sort) {
625 vec![res]
626 } else {
627 vec![]
628 }
629 }
630 }
631
632 pub fn extract_variants(
636 &self,
637 egraph: &EGraph,
638 termdag: &mut TermDag,
639 value: Value,
640 nvariants: usize,
641 ) -> Vec<(C, Term)> {
642 assert!(
643 self.rootsorts.len() == 1,
644 "extract_variants requires a single rootsort"
645 );
646 self.extract_variants_with_sort(
647 egraph,
648 termdag,
649 value,
650 nvariants,
651 self.rootsorts.first().unwrap().clone(),
652 )
653 }
654}