egglog/
extract.rs

1use crate::termdag::{Term, TermDag};
2use crate::util::{HashMap, HashSet};
3use crate::*;
4use std::collections::VecDeque;
5
6/// An interface for custom cost model.
7///
8/// To use it with the default extractor, the cost type must also satisfy `Ord + Eq + Clone + Debug`.
9/// Additionally, the cost model should guarantee that a term has a no-smaller cost
10/// than its subterms to avoid cycles in the extracted terms for common case usages.
11/// For more niche usages, a term can have a cost less than its subterms.
12/// As long as there is no negative cost cycle,
13/// the default extractor is guaranteed to terminate in computing the costs.
14/// However, the user needs to be careful to guarantee acyclicity in the extracted terms.
15pub trait CostModel<C: Cost> {
16    /// The total cost of a term given the cost of the root e-node and its immediate children's total costs.
17    fn fold(&self, head: &str, children_cost: &[C], head_cost: C) -> C;
18
19    /// The cost of an enode (without the cost of children)
20    fn enode_cost(&self, egraph: &EGraph, func: &Function, row: &egglog_bridge::FunctionRow) -> C;
21
22    /// The cost of a container value given the costs of its elements.
23    ///
24    /// The default cost for containers is just the sum of all the elements inside
25    fn container_cost(
26        &self,
27        egraph: &EGraph,
28        sort: &ArcSort,
29        value: Value,
30        element_costs: &[C],
31    ) -> C {
32        let _egraph = egraph;
33        let _sort = sort;
34        let _value = value;
35        element_costs
36            .iter()
37            .fold(C::identity(), |s, c| s.combine(c))
38    }
39
40    /// Compute the cost of a (non-container) primitive value.
41    ///
42    /// The default cost for base values is the constant one
43    fn base_value_cost(&self, egraph: &EGraph, sort: &ArcSort, value: Value) -> C {
44        let _egraph = egraph;
45        let _sort = sort;
46        let _value = value;
47        C::unit()
48    }
49}
50
51/// Requirements for a type to be usable as a cost by a [`CostModel`].
52pub trait Cost {
53    /// An identity element, usually zero.
54    fn identity() -> Self;
55
56    /// The default cost for a node with no children, usually one.
57    fn unit() -> Self;
58
59    /// A binary operation to combine costs, usually addition.
60    /// This operation must NOT overflow or panic when given large values!
61    fn combine(self, other: &Self) -> Self;
62}
63
64macro_rules! cost_impl_int {
65    ($($cost:ty),*) => {$(
66        impl Cost for $cost {
67            fn identity() -> Self { 0 }
68            fn unit()     -> Self { 1 }
69            fn combine(self, other: &Self) -> Self {
70                self.saturating_add(*other)
71            }
72        }
73    )*};
74}
75cost_impl_int!(u8, u16, u32, u64, u128, usize);
76cost_impl_int!(i8, i16, i32, i64, i128, isize);
77
78macro_rules! cost_impl_num {
79    ($($cost:ty),*) => {$(
80        impl Cost for $cost {
81            fn identity() -> Self {
82                use num::Zero;
83                Self::zero()
84            }
85            fn unit() -> Self {
86                use num::One;
87                Self::one()
88            }
89            fn combine(self, other: &Self) -> Self {
90                self + other
91            }
92        }
93    )*};
94}
95cost_impl_num!(num::BigInt, num::BigRational);
96use ordered_float::OrderedFloat;
97cost_impl_num!(f32, f64, OrderedFloat<f32>, OrderedFloat<f64>);
98
99pub type DefaultCost = u64;
100
101/// A cost model that computes the cost by summing the cost of each node.
102#[derive(Default, Clone)]
103pub struct TreeAdditiveCostModel {}
104
105impl CostModel<DefaultCost> for TreeAdditiveCostModel {
106    fn fold(
107        &self,
108        _head: &str,
109        children_cost: &[DefaultCost],
110        head_cost: DefaultCost,
111    ) -> DefaultCost {
112        children_cost.iter().fold(head_cost, |s, c| s.combine(c))
113    }
114
115    fn enode_cost(
116        &self,
117        _egraph: &EGraph,
118        func: &Function,
119        _row: &egglog_bridge::FunctionRow,
120    ) -> DefaultCost {
121        func.decl.cost.unwrap_or(DefaultCost::unit())
122    }
123}
124
125/// The default, Bellman-Ford like extractor. This extractor is optimal for [`CostModel`].
126pub struct Extractor<C: Cost + Ord + Eq + Clone + Debug> {
127    rootsorts: Vec<ArcSort>,
128    funcs: Vec<String>,
129    cost_model: Box<dyn CostModel<C>>,
130    costs: HashMap<String, HashMap<Value, C>>,
131    topo_rnk_cnt: usize,
132    topo_rnk: HashMap<String, HashMap<Value, usize>>,
133    parent_edge: HashMap<String, HashMap<Value, (String, Vec<Value>)>>,
134}
135
136impl<C: Cost + Ord + Eq + Clone + Debug> Extractor<C> {
137    /// Bulk of the computation happens at initialization time.
138    /// The later extractions only reuses saved results.
139    /// This means a new extractor must be created if the egraph changes.
140    /// Holding a reference to the egraph would enforce this but prevents the extractor being reused.
141    ///
142    /// For convenience, if the rootsorts is `None`, it defaults to extract all extractable rootsorts.
143    pub fn compute_costs_from_rootsorts(
144        rootsorts: Option<Vec<ArcSort>>,
145        egraph: &EGraph,
146        cost_model: impl CostModel<C> + 'static,
147    ) -> Self {
148        // We filter out tables unreachable from the root sorts
149        let extract_all_sorts = rootsorts.is_none();
150
151        let mut rootsorts = rootsorts.unwrap_or_default();
152
153        // Built a reverse index from output sort to function head symbols
154        let mut rev_index: HashMap<String, Vec<String>> = Default::default();
155        for func in egraph.functions.iter() {
156            if !func.1.decl.unextractable {
157                let func_name = func.0.clone();
158                let output_sort_name = func.1.schema.output.name();
159                if let Some(v) = rev_index.get_mut(output_sort_name) {
160                    v.push(func_name);
161                } else {
162                    rev_index.insert(output_sort_name.to_owned(), vec![func_name]);
163                    if extract_all_sorts {
164                        rootsorts.push(func.1.schema.output.clone());
165                    }
166                }
167            }
168        }
169
170        // Do a BFS to find reachable tables
171        let mut q: VecDeque<ArcSort> = VecDeque::new();
172        let mut seen: HashSet<String> = Default::default();
173        for rootsort in rootsorts.iter() {
174            q.push_back(rootsort.clone());
175            seen.insert(rootsort.name().to_owned());
176        }
177
178        let mut funcs_set: HashSet<String> = Default::default();
179        let mut funcs: Vec<String> = Vec::new();
180        while !q.is_empty() {
181            let sort = q.pop_front().unwrap();
182            if sort.is_container_sort() {
183                let inner_sorts = sort.inner_sorts();
184                for s in inner_sorts {
185                    if !seen.contains(s.name()) {
186                        q.push_back(s.clone());
187                        seen.insert(s.name().to_owned());
188                    }
189                }
190            } else if sort.is_eq_sort() {
191                if let Some(head_symbols) = rev_index.get(sort.name()) {
192                    for h in head_symbols {
193                        if !funcs_set.contains(h) {
194                            let func = egraph.functions.get(h).unwrap();
195                            for ch in &func.schema.input {
196                                let ch_name = ch.name();
197                                if !seen.contains(ch_name) {
198                                    q.push_back(ch.clone());
199                                    seen.insert(ch_name.to_owned());
200                                }
201                            }
202                            funcs_set.insert(h.clone());
203                            funcs.push(h.clone());
204                        }
205                    }
206                }
207            }
208        }
209
210        // Initialize the tables to have the reachable entries
211        let mut costs: HashMap<String, HashMap<Value, C>> = Default::default();
212        let mut topo_rnk: HashMap<String, HashMap<Value, usize>> = Default::default();
213        let mut parent_edge: HashMap<String, HashMap<Value, (String, Vec<Value>)>> =
214            Default::default();
215
216        for func_name in funcs.iter() {
217            let func = egraph.functions.get(func_name).unwrap();
218            if !costs.contains_key(func.schema.output.name()) {
219                debug_assert!(func.schema.output.is_eq_sort());
220                costs.insert(func.schema.output.name().to_owned(), Default::default());
221                topo_rnk.insert(func.schema.output.name().to_owned(), Default::default());
222                parent_edge.insert(func.schema.output.name().to_owned(), Default::default());
223            }
224        }
225
226        let mut extractor = Extractor {
227            rootsorts,
228            funcs,
229            cost_model: Box::new(cost_model),
230            costs,
231            topo_rnk_cnt: 0,
232            topo_rnk,
233            parent_edge,
234        };
235
236        extractor.bellman_ford(egraph);
237
238        extractor
239    }
240
241    /// Compute the cost of a single enode
242    /// Recurse if container
243    /// Returns None if contains an undefined eqsort term (potentially after unfolding)
244    fn compute_cost_node(&self, egraph: &EGraph, value: Value, sort: &ArcSort) -> Option<C> {
245        if sort.is_container_sort() {
246            let elements = sort.inner_values(egraph.backend.container_values(), value);
247            let mut ch_costs: Vec<C> = Vec::new();
248            for ch in elements.iter() {
249                if let Some(c) = self.compute_cost_node(egraph, ch.1, &ch.0) {
250                    ch_costs.push(c);
251                } else {
252                    return None;
253                }
254            }
255            Some(
256                self.cost_model
257                    .container_cost(egraph, sort, value, &ch_costs),
258            )
259        } else if sort.is_eq_sort() {
260            if self
261                .costs
262                .get(sort.name())
263                .is_some_and(|t| t.get(&value).is_some())
264            {
265                Some(
266                    self.costs
267                        .get(sort.name())
268                        .unwrap()
269                        .get(&value)
270                        .unwrap()
271                        .clone(),
272                )
273            } else {
274                None
275            }
276        } else {
277            // Primitive
278            Some(self.cost_model.base_value_cost(egraph, sort, value))
279        }
280    }
281
282    /// A row in a constructor table is a hyperedge from the set of input terms to the constructed output term.
283    fn compute_cost_hyperedge(
284        &self,
285        egraph: &EGraph,
286        row: &egglog_bridge::FunctionRow,
287        func: &Function,
288    ) -> Option<C> {
289        let mut ch_costs: Vec<C> = Vec::new();
290        let sorts = &func.schema.input;
291        //log::debug!("compute_cost_hyperedge head {} sorts {:?}", head, sorts);
292        // Relying on .zip to truncate the values
293        for (value, sort) in row.vals.iter().zip(sorts.iter()) {
294            if let Some(c) = self.compute_cost_node(egraph, *value, sort) {
295                ch_costs.push(c);
296            } else {
297                return None;
298            }
299        }
300        Some(self.cost_model.fold(
301            &func.decl.name,
302            &ch_costs,
303            self.cost_model.enode_cost(egraph, func, row),
304        ))
305    }
306
307    fn compute_topo_rnk_node(&self, egraph: &EGraph, value: Value, sort: &ArcSort) -> usize {
308        if sort.is_container_sort() {
309            sort.inner_values(egraph.backend.container_values(), value)
310                .iter()
311                .fold(0, |ret, (sort, value)| {
312                    usize::max(ret, self.compute_topo_rnk_node(egraph, *value, sort))
313                })
314        } else if sort.is_eq_sort() {
315            if let Some(t) = self.topo_rnk.get(sort.name()) {
316                *t.get(&value).unwrap_or(&usize::MAX)
317            } else {
318                usize::MAX
319            }
320        } else {
321            0
322        }
323    }
324
325    fn compute_topo_rnk_hyperedge(
326        &self,
327        egraph: &EGraph,
328        row: &egglog_bridge::FunctionRow,
329        func: &Function,
330    ) -> usize {
331        let sorts = &func.schema.input;
332        row.vals
333            .iter()
334            .zip(sorts.iter())
335            .fold(0, |ret, (value, sort)| {
336                usize::max(ret, self.compute_topo_rnk_node(egraph, *value, sort))
337            })
338    }
339
340    /// We use Bellman-Ford to compute the costs of the relevant eq sorts' terms
341    /// [Bellman-Ford](https://en.wikipedia.org/wiki/Bellman%E2%80%93Ford_algorithm) is a shortest path algorithm.
342    /// The version implemented here computes the shortest path from any node in a set of sources to all the reachable nodes.
343    /// Computing the minimum cost for terms is treated as a shortest path problem on a hypergraph here.
344    /// In this hypergraph, the nodes corresponde to eclasses, the distances are the costs to extract a term of those eclasses,
345    /// and each enode is a hyperedge that goes from the set of children eclasses to the enode's eclass.
346    /// The sources are the eclasses with known costs from the cost model.
347    /// Additionally, to avoid cycles in the extraction even when the cost model can assign an equal cost to a term and its subterm.
348    /// It computes a topological rank for each eclass
349    /// and only allows each eclass to have children of classes of strictly smaller ranks in the extraction.
350    fn bellman_ford(&mut self, egraph: &EGraph) {
351        let mut ensure_fixpoint = false;
352
353        let funcs = self.funcs.clone();
354
355        while !ensure_fixpoint {
356            ensure_fixpoint = true;
357
358            for func_name in funcs.iter() {
359                let func = egraph.functions.get(func_name).unwrap();
360                let target_sort = func.schema.output.clone();
361
362                let relax_hyperedge = |row: egglog_bridge::FunctionRow| {
363                    log::debug!("Relaxing a new hyperedge: {:?}", row);
364                    if !row.subsumed {
365                        let target = row.vals.last().unwrap();
366                        let mut updated = false;
367                        if let Some(new_cost) = self.compute_cost_hyperedge(egraph, &row, func) {
368                            match self
369                                .costs
370                                .get_mut(target_sort.name())
371                                .unwrap()
372                                .entry(*target)
373                            {
374                                HEntry::Vacant(e) => {
375                                    updated = true;
376                                    e.insert(new_cost);
377                                }
378                                HEntry::Occupied(mut e) => {
379                                    if new_cost < *(e.get()) {
380                                        updated = true;
381                                        e.insert(new_cost);
382                                    }
383                                }
384                            }
385                        }
386                        // record the chronological order of the updates
387                        // which serves as a topological order that avoids cycles
388                        // even when a term has a cost equal to its subterms
389                        if updated {
390                            ensure_fixpoint = false;
391                            self.topo_rnk_cnt += 1;
392                            self.topo_rnk
393                                .get_mut(target_sort.name())
394                                .unwrap()
395                                .insert(*target, self.topo_rnk_cnt);
396                        }
397                    }
398                };
399
400                egraph.backend.for_each(func.backend_id, relax_hyperedge);
401            }
402        }
403
404        // Save the edges for reconstruction
405        for func_name in funcs.iter() {
406            let func = egraph.functions.get(func_name).unwrap();
407            let target_sort = func.schema.output.clone();
408
409            let save_best_parent_edge = |row: egglog_bridge::FunctionRow| {
410                if !row.subsumed {
411                    let target = row.vals.last().unwrap();
412                    if let Some(best_cost) = self.costs.get(target_sort.name()).unwrap().get(target)
413                    {
414                        if Some(best_cost.clone())
415                            == self.compute_cost_hyperedge(egraph, &row, func)
416                        {
417                            // one of the possible best parent edges
418                            let target_topo_rnk = *self
419                                .topo_rnk
420                                .get(target_sort.name())
421                                .unwrap()
422                                .get(target)
423                                .unwrap();
424                            if target_topo_rnk > self.compute_topo_rnk_hyperedge(egraph, &row, func)
425                            {
426                                // one of the parent edges that avoids cycles
427                                if let HEntry::Vacant(e) = self
428                                    .parent_edge
429                                    .get_mut(target_sort.name())
430                                    .unwrap()
431                                    .entry(*target)
432                                {
433                                    e.insert((func.decl.name.clone(), row.vals.to_vec()));
434                                }
435                            }
436                        }
437                    }
438                }
439            };
440
441            egraph
442                .backend
443                .for_each(func.backend_id, save_best_parent_edge);
444        }
445    }
446
447    /// This recursively reconstruct the termdag that gives the minimum cost for eclass value.
448    fn reconstruct_termdag_node(
449        &self,
450        egraph: &EGraph,
451        termdag: &mut TermDag,
452        value: Value,
453        sort: &ArcSort,
454    ) -> Term {
455        self.reconstruct_termdag_node_helper(egraph, termdag, value, sort, &mut Default::default())
456    }
457
458    fn reconstruct_termdag_node_helper(
459        &self,
460        egraph: &EGraph,
461        termdag: &mut TermDag,
462        value: Value,
463        sort: &ArcSort,
464        cache: &mut HashMap<(Value, String), Term>,
465    ) -> Term {
466        let key = (value, sort.name().to_owned());
467        if let Some(term) = cache.get(&key) {
468            return term.clone();
469        }
470
471        let term = if sort.is_container_sort() {
472            let elements = sort.inner_values(egraph.backend.container_values(), value);
473            let mut ch_terms: Vec<Term> = Vec::new();
474            for ch in elements.iter() {
475                ch_terms.push(
476                    self.reconstruct_termdag_node_helper(egraph, termdag, ch.1, &ch.0, cache),
477                );
478            }
479            sort.reconstruct_termdag_container(
480                egraph.backend.container_values(),
481                value,
482                termdag,
483                ch_terms,
484            )
485        } else if sort.is_eq_sort() {
486            let (func_name, hyperedge) = self
487                .parent_edge
488                .get(sort.name())
489                .unwrap()
490                .get(&value)
491                .unwrap();
492            let mut ch_terms: Vec<Term> = Vec::new();
493            let ch_sorts = &egraph.functions.get(func_name).unwrap().schema.input;
494            for (value, sort) in hyperedge.iter().zip(ch_sorts.iter()) {
495                ch_terms.push(
496                    self.reconstruct_termdag_node_helper(egraph, termdag, *value, sort, cache),
497                );
498            }
499            termdag.app(func_name.clone(), ch_terms)
500        } else {
501            // Base value case
502            sort.reconstruct_termdag_base(egraph.backend.base_values(), value, termdag)
503        };
504
505        cache.insert(key, term.clone());
506        term
507    }
508
509    /// Extract the best term of a value from a given sort.
510    ///
511    /// This function expects the sort to be already computed,
512    /// which can be one of the rootsorts, or reachable from rootsorts, or primitives, or containers of computed sorts.
513    pub fn extract_best_with_sort(
514        &self,
515        egraph: &EGraph,
516        termdag: &mut TermDag,
517        value: Value,
518        sort: ArcSort,
519    ) -> Option<(C, Term)> {
520        match self.compute_cost_node(egraph, value, &sort) {
521            Some(best_cost) => {
522                log::debug!("Best cost for the extract root: {:?}", best_cost);
523
524                let term = self.reconstruct_termdag_node(egraph, termdag, value, &sort);
525
526                Some((best_cost, term))
527            }
528            None => {
529                log::error!("Unextractable root {:?} with sort {:?}", value, sort,);
530                None
531            }
532        }
533    }
534
535    /// A convenience method for extraction.
536    ///
537    /// This expects the value to be of the unique sort the extractor has been initialized with
538    pub fn extract_best(
539        &self,
540        egraph: &EGraph,
541        termdag: &mut TermDag,
542        value: Value,
543    ) -> Option<(C, Term)> {
544        assert!(
545            self.rootsorts.len() == 1,
546            "extract_best requires a single rootsort"
547        );
548        self.extract_best_with_sort(
549            egraph,
550            termdag,
551            value,
552            self.rootsorts.first().unwrap().clone(),
553        )
554    }
555
556    /// Extract variants of an e-class.
557    ///
558    /// The variants are selected by first picking `nvairants` e-nodes with the lowest cost from the e-class
559    /// and then extracting a term from each e-node.
560    pub fn extract_variants_with_sort(
561        &self,
562        egraph: &EGraph,
563        termdag: &mut TermDag,
564        value: Value,
565        nvariants: usize,
566        sort: ArcSort,
567    ) -> Vec<(C, Term)> {
568        debug_assert!(self.rootsorts.iter().any(|s| { s.name() == sort.name() }));
569
570        if sort.is_eq_sort() {
571            let mut root_variants: Vec<(C, String, Vec<Value>)> = Vec::new();
572
573            let mut root_funcs: Vec<String> = Vec::new();
574
575            for func_name in self.funcs.iter() {
576                // Need an eq on sorts
577                if sort.name()
578                    == egraph
579                        .functions
580                        .get(func_name)
581                        .unwrap()
582                        .schema
583                        .output
584                        .name()
585                {
586                    root_funcs.push(func_name.clone());
587                }
588            }
589
590            for func_name in root_funcs.iter() {
591                let func = egraph.functions.get(func_name).unwrap();
592
593                let find_root_variants = |row: egglog_bridge::FunctionRow| {
594                    if !row.subsumed {
595                        let target = row.vals.last().unwrap();
596                        if *target == value {
597                            let cost = self.compute_cost_hyperedge(egraph, &row, func).unwrap();
598                            root_variants.push((cost, func_name.clone(), row.vals.to_vec()));
599                        }
600                    }
601                };
602
603                egraph.backend.for_each(func.backend_id, find_root_variants);
604            }
605
606            let mut res: Vec<(C, Term)> = Vec::new();
607            root_variants.sort();
608            root_variants.truncate(nvariants);
609            for (cost, func_name, hyperedge) in root_variants {
610                let mut ch_terms: Vec<Term> = Vec::new();
611                let ch_sorts = &egraph.functions.get(&func_name).unwrap().schema.input;
612                // zip truncates the row
613                for (value, sort) in hyperedge.iter().zip(ch_sorts.iter()) {
614                    ch_terms.push(self.reconstruct_termdag_node(egraph, termdag, *value, sort));
615                }
616                res.push((cost, termdag.app(func_name, ch_terms)));
617            }
618
619            res
620        } else {
621            log::warn!(
622                "extracting multiple variants for containers or primitives is not implemented, returning a single variant."
623            );
624            if let Some(res) = self.extract_best_with_sort(egraph, termdag, value, sort) {
625                vec![res]
626            } else {
627                vec![]
628            }
629        }
630    }
631
632    /// A convenience method for extracting variants of a value.
633    ///
634    /// This expects the value to be of the unique sort the extractor has been initialized with.
635    pub fn extract_variants(
636        &self,
637        egraph: &EGraph,
638        termdag: &mut TermDag,
639        value: Value,
640        nvariants: usize,
641    ) -> Vec<(C, Term)> {
642        assert!(
643            self.rootsorts.len() == 1,
644            "extract_variants requires a single rootsort"
645        );
646        self.extract_variants_with_sort(
647            egraph,
648            termdag,
649            value,
650            nvariants,
651            self.rootsorts.first().unwrap().clone(),
652        )
653    }
654}