egglog/
extract.rs

1use crate::ast::FunctionSubtype;
2use crate::termdag::{TermDag, TermId};
3use crate::util::{HashMap, HashSet};
4use crate::*;
5use std::collections::VecDeque;
6
7/// An interface for custom cost model.
8///
9/// To use it with the default extractor, the cost type must also satisfy `Ord + Eq + Clone + Debug`.
10/// Additionally, the cost model should guarantee that a term has a no-smaller cost
11/// than its subterms to avoid cycles in the extracted terms for common case usages.
12/// For more niche usages, a term can have a cost less than its subterms.
13/// As long as there is no negative cost cycle,
14/// the default extractor is guaranteed to terminate in computing the costs.
15/// However, the user needs to be careful to guarantee acyclicity in the extracted terms.
16pub trait CostModel<C: Cost> {
17    /// The total cost of a term given the cost of the root e-node and its immediate children's total costs.
18    fn fold(&self, head: &str, children_cost: &[C], head_cost: C) -> C;
19
20    /// The cost of an enode (without the cost of children)
21    fn enode_cost(&self, egraph: &EGraph, func: &Function, row: &egglog_bridge::FunctionRow) -> C;
22
23    /// The cost of a container value given the costs of its elements.
24    ///
25    /// The default cost for containers is just the sum of all the elements inside
26    fn container_cost(
27        &self,
28        egraph: &EGraph,
29        sort: &ArcSort,
30        value: Value,
31        element_costs: &[C],
32    ) -> C {
33        let _egraph = egraph;
34        let _sort = sort;
35        let _value = value;
36        element_costs
37            .iter()
38            .fold(C::identity(), |s, c| s.combine(c))
39    }
40
41    /// Compute the cost of a (non-container) primitive value.
42    ///
43    /// The default cost for base values is the constant one
44    fn base_value_cost(&self, egraph: &EGraph, sort: &ArcSort, value: Value) -> C {
45        let _egraph = egraph;
46        let _sort = sort;
47        let _value = value;
48        C::unit()
49    }
50}
51
52/// Requirements for a type to be usable as a cost by a [`CostModel`].
53pub trait Cost {
54    /// An identity element, usually zero.
55    fn identity() -> Self;
56
57    /// The default cost for a node with no children, usually one.
58    fn unit() -> Self;
59
60    /// A binary operation to combine costs, usually addition.
61    /// This operation must NOT overflow or panic when given large values!
62    fn combine(self, other: &Self) -> Self;
63}
64
65macro_rules! cost_impl_int {
66    ($($cost:ty),*) => {$(
67        impl Cost for $cost {
68            fn identity() -> Self { 0 }
69            fn unit()     -> Self { 1 }
70            fn combine(self, other: &Self) -> Self {
71                self.saturating_add(*other)
72            }
73        }
74    )*};
75}
76cost_impl_int!(u8, u16, u32, u64, u128, usize);
77cost_impl_int!(i8, i16, i32, i64, i128, isize);
78
79macro_rules! cost_impl_num {
80    ($($cost:ty),*) => {$(
81        impl Cost for $cost {
82            fn identity() -> Self {
83                use num::Zero;
84                Self::zero()
85            }
86            fn unit() -> Self {
87                use num::One;
88                Self::one()
89            }
90            fn combine(self, other: &Self) -> Self {
91                self + other
92            }
93        }
94    )*};
95}
96cost_impl_num!(num::BigInt, num::BigRational);
97use ordered_float::OrderedFloat;
98cost_impl_num!(f32, f64, OrderedFloat<f32>, OrderedFloat<f64>);
99
100pub type DefaultCost = u64;
101
102/// A cost model that computes the cost by summing the cost of each node.
103#[derive(Default, Clone)]
104pub struct TreeAdditiveCostModel {}
105
106impl CostModel<DefaultCost> for TreeAdditiveCostModel {
107    fn fold(
108        &self,
109        _head: &str,
110        children_cost: &[DefaultCost],
111        head_cost: DefaultCost,
112    ) -> DefaultCost {
113        children_cost.iter().fold(head_cost, |s, c| s.combine(c))
114    }
115
116    fn enode_cost(
117        &self,
118        egraph: &EGraph,
119        func: &Function,
120        _row: &egglog_bridge::FunctionRow,
121    ) -> DefaultCost {
122        func.extraction_head_cost(egraph)
123    }
124}
125
126/// The default, Bellman-Ford like extractor. This extractor is optimal for [`CostModel`].
127///
128/// Note that this assumes optimal substructure in the cost model, that is, a lower-cost
129/// subterm should always lead to a non-worse superterm, to guarantee the extracted term
130/// being optimal under the given cost model.
131/// If this is not followed, the extractor may panic on reconstruction
132pub struct Extractor<C: Cost + Ord + Eq + Clone + Debug> {
133    rootsorts: Vec<ArcSort>,
134    funcs: Vec<String>,
135    cost_model: Box<dyn CostModel<C>>,
136    costs: HashMap<String, HashMap<Value, C>>,
137    topo_rnk_cnt: usize,
138    topo_rnk: HashMap<String, HashMap<Value, usize>>,
139    parent_edge: HashMap<String, HashMap<Value, (String, Vec<Value>)>>,
140}
141
142/// Options for configuring extraction behavior.
143struct ExtractionOptions<C: Cost> {
144    /// The cost model to use for extraction.
145    cost_model: Box<dyn CostModel<C>>,
146    /// Root sorts to extract from. If None, all extractable root sorts are used.
147    rootsorts: Option<Vec<ArcSort>>,
148    /// Whether to respect the unextractable flag on constructors.
149    /// When true, constructors marked as unextractable will not be used during extraction.
150    respect_unextractable: bool,
151    /// Whether to skip view tables (those with term_constructor annotations).
152    /// When true, view tables are skipped, which is useful for proof extraction
153    /// where we need to extract from the original term tables with their original names.
154    skip_view_tables: bool,
155    /// Whether to respect the hidden flag on constructors.
156    /// When true, constructors marked as hidden will not be used during extraction.
157    respect_hidden: bool,
158}
159
160impl<C: Cost + Ord + Eq + Clone + Debug> Extractor<C> {
161    /// Bulk of the computation happens at initialization time.
162    /// The later extractions only reuses saved results.
163    /// This means a new extractor must be created if the egraph changes.
164    /// Holding a reference to the egraph would enforce this but prevents the extractor being reused.
165    ///
166    /// For convenience, if the rootsorts is `None`, it defaults to extract all extractable rootsorts.
167    pub fn compute_costs_from_rootsorts(
168        rootsorts: Option<Vec<ArcSort>>,
169        egraph: &EGraph,
170        cost_model: impl CostModel<C> + 'static,
171    ) -> Self {
172        // For user extraction: respect unextractable and hidden, but use view tables (they have better names)
173        Self::compute_costs_from_rootsorts_internal(
174            egraph,
175            ExtractionOptions {
176                cost_model: Box::new(cost_model),
177                rootsorts,
178                respect_unextractable: true,
179                skip_view_tables: false,
180                respect_hidden: true,
181            },
182        )
183    }
184
185    /// Like `compute_costs_from_rootsorts`, but ignores the unextractable and hidden flags.
186    /// This is used for proof extraction where we need to extract proofs even
187    /// from terms that are marked unextractable (like global let bindings).
188    /// Also skips view tables (those with term_constructor) since proofs need
189    /// to extract from the original term tables with their original names.
190    pub(crate) fn compute_costs_from_rootsorts_allow_unextractable(
191        rootsorts: Option<Vec<ArcSort>>,
192        egraph: &EGraph,
193        cost_model: impl CostModel<C> + 'static,
194    ) -> Self {
195        Self::compute_costs_from_rootsorts_internal(
196            egraph,
197            ExtractionOptions {
198                cost_model: Box::new(cost_model),
199                rootsorts,
200                respect_unextractable: false,
201                skip_view_tables: true,
202                respect_hidden: false,
203            },
204        )
205    }
206
207    fn compute_costs_from_rootsorts_internal(
208        egraph: &EGraph,
209        options: ExtractionOptions<C>,
210    ) -> Self {
211        // We filter out tables unreachable from the root sorts
212        let extract_all_sorts = options.rootsorts.is_none();
213
214        let mut rootsorts = options.rootsorts.unwrap_or_default();
215
216        // Built a reverse index from output sort to function head symbols
217        // Only include constructors (not regular functions) and respect unextractable flag
218        let mut rev_index: HashMap<String, Vec<String>> = Default::default();
219        for func in egraph.functions.iter() {
220            let unextractable = func.1.decl.unextractable && options.respect_unextractable;
221            let should_skip_view =
222                options.skip_view_tables && func.1.decl.term_constructor.is_some();
223            let hidden = func.1.decl.internal_hidden && options.respect_hidden;
224
225            // only extract constructors (and functions with term_constructor), skip view tables when requested for proof extraction, and respect unextractable/hidden flag
226            if !unextractable
227                && !should_skip_view
228                && !hidden
229                && (func.1.decl.subtype == FunctionSubtype::Constructor
230                    || func.1.decl.term_constructor.is_some())
231            {
232                let func_name = func.0.clone();
233                // For view tables (with term_constructor in proof mode), the e-class is the last input column
234                let output_sort_name = func.1.extraction_output_sort().name();
235                if let Some(v) = rev_index.get_mut(output_sort_name) {
236                    v.push(func_name);
237                } else {
238                    rev_index.insert(output_sort_name.to_owned(), vec![func_name]);
239                    if extract_all_sorts {
240                        rootsorts.push(func.1.extraction_output_sort().clone());
241                    }
242                }
243            }
244        }
245
246        // Do a BFS to find reachable tables
247        let mut q: VecDeque<ArcSort> = VecDeque::new();
248        let mut seen: HashSet<String> = Default::default();
249        for rootsort in rootsorts.iter() {
250            q.push_back(rootsort.clone());
251            seen.insert(rootsort.name().to_owned());
252        }
253
254        let mut funcs_set: HashSet<String> = Default::default();
255        let mut funcs: Vec<String> = Vec::new();
256        while !q.is_empty() {
257            let sort = q.pop_front().unwrap();
258            if sort.is_container_sort() {
259                let inner_sorts = sort.inner_sorts();
260                for s in inner_sorts {
261                    if !seen.contains(s.name()) {
262                        q.push_back(s.clone());
263                        seen.insert(s.name().to_owned());
264                    }
265                }
266            } else if sort.is_eq_sort()
267                && let Some(head_symbols) = rev_index.get(sort.name())
268            {
269                for h in head_symbols {
270                    if !funcs_set.contains(h) {
271                        let func = egraph.functions.get(h).unwrap();
272                        // For view tables, children are all but the last input (which is the e-class)
273                        let num_children = func.extraction_num_children();
274                        for ch in func.schema.input.iter().take(num_children) {
275                            let ch_name = ch.name();
276                            if !seen.contains(ch_name) {
277                                q.push_back(ch.clone());
278                                seen.insert(ch_name.to_owned());
279                            }
280                        }
281                        funcs_set.insert(h.clone());
282                        funcs.push(h.clone());
283                    }
284                }
285            }
286        }
287
288        // Initialize the tables to have the reachable entries
289        let mut costs: HashMap<String, HashMap<Value, C>> = Default::default();
290        let mut topo_rnk: HashMap<String, HashMap<Value, usize>> = Default::default();
291        let mut parent_edge: HashMap<String, HashMap<Value, (String, Vec<Value>)>> =
292            Default::default();
293
294        for func_name in funcs.iter() {
295            let func = egraph.functions.get(func_name).unwrap();
296            let output_sort_name = func.extraction_output_sort().name();
297            if !costs.contains_key(output_sort_name) {
298                costs.insert(output_sort_name.to_owned(), Default::default());
299                topo_rnk.insert(output_sort_name.to_owned(), Default::default());
300                parent_edge.insert(output_sort_name.to_owned(), Default::default());
301            }
302        }
303
304        let mut extractor = Extractor {
305            rootsorts,
306            funcs,
307            cost_model: options.cost_model,
308            costs,
309            topo_rnk_cnt: 0,
310            topo_rnk,
311            parent_edge,
312        };
313
314        extractor.bellman_ford(egraph);
315
316        extractor
317    }
318
319    /// Compute the cost of a single enode
320    /// Recurse if container
321    /// Returns None if contains an undefined eqsort term (potentially after unfolding)
322    fn compute_cost_node(&self, egraph: &EGraph, value: Value, sort: &ArcSort) -> Option<C> {
323        if sort.is_container_sort() {
324            let elements = sort.inner_values(egraph.backend.container_values(), value);
325            let mut ch_costs: Vec<C> = Vec::new();
326            for ch in elements.iter() {
327                ch_costs.push(self.compute_cost_node(egraph, ch.1, &ch.0)?);
328            }
329            Some(
330                self.cost_model
331                    .container_cost(egraph, sort, value, &ch_costs),
332            )
333        } else if sort.is_eq_sort() {
334            self.costs.get(sort.name())?.get(&value).cloned()
335        } else {
336            // Primitive
337            Some(self.cost_model.base_value_cost(egraph, sort, value))
338        }
339    }
340
341    /// A row in a constructor table is a hyperedge from the set of input terms to the constructed output term.
342    fn compute_cost_hyperedge(
343        &self,
344        egraph: &EGraph,
345        row: &egglog_bridge::FunctionRow,
346        func: &Function,
347    ) -> Option<C> {
348        let mut ch_costs: Vec<C> = Vec::new();
349        let sorts = &func.schema.input;
350        let num_children = func.extraction_num_children();
351        for (value, sort) in row.vals.iter().take(num_children).zip(sorts.iter()) {
352            ch_costs.push(self.compute_cost_node(egraph, *value, sort)?);
353        }
354        let head_name = func.extraction_term_name();
355        Some(self.cost_model.fold(
356            head_name,
357            &ch_costs,
358            self.cost_model.enode_cost(egraph, func, row),
359        ))
360    }
361
362    fn compute_topo_rnk_node(&self, egraph: &EGraph, value: Value, sort: &ArcSort) -> usize {
363        if sort.is_container_sort() {
364            sort.inner_values(egraph.backend.container_values(), value)
365                .iter()
366                .fold(0, |ret, (sort, value)| {
367                    usize::max(ret, self.compute_topo_rnk_node(egraph, *value, sort))
368                })
369        } else if sort.is_eq_sort() {
370            if let Some(t) = self.topo_rnk.get(sort.name()) {
371                *t.get(&value).unwrap_or(&usize::MAX)
372            } else {
373                usize::MAX
374            }
375        } else {
376            0
377        }
378    }
379
380    fn compute_topo_rnk_hyperedge(
381        &self,
382        egraph: &EGraph,
383        row: &egglog_bridge::FunctionRow,
384        func: &Function,
385    ) -> usize {
386        let sorts = &func.schema.input;
387        let num_children = func.extraction_num_children();
388        row.vals
389            .iter()
390            .take(num_children)
391            .zip(sorts.iter())
392            .fold(0, |ret, (value, sort)| {
393                usize::max(ret, self.compute_topo_rnk_node(egraph, *value, sort))
394            })
395    }
396
397    /// We use Bellman-Ford to compute the costs of the relevant eq sorts' terms
398    /// [Bellman-Ford](https://en.wikipedia.org/wiki/Bellman%E2%80%93Ford_algorithm) is a shortest path algorithm.
399    /// The version implemented here computes the shortest path from any node in a set of sources to all the reachable nodes.
400    /// Computing the minimum cost for terms is treated as a shortest path problem on a hypergraph here.
401    /// In this hypergraph, the nodes corresponde to eclasses, the distances are the costs to extract a term of those eclasses,
402    /// and each enode is a hyperedge that goes from the set of children eclasses to the enode's eclass.
403    /// The sources are the eclasses with known costs from the cost model.
404    /// Additionally, to avoid cycles in the extraction even when the cost model can assign an equal cost to a term and its subterm.
405    /// It computes a topological rank for each eclass
406    /// and only allows each eclass to have children of classes of strictly smaller ranks in the extraction.
407    fn bellman_ford(&mut self, egraph: &EGraph) {
408        let mut ensure_fixpoint = false;
409
410        let funcs = self.funcs.clone();
411
412        while !ensure_fixpoint {
413            ensure_fixpoint = true;
414
415            for func_name in funcs.iter() {
416                let func = egraph.functions.get(func_name).unwrap();
417                let target_sort = func.extraction_output_sort();
418
419                let output_idx = func.extraction_output_index();
420                let relax_hyperedge = |row: egglog_bridge::FunctionRow| {
421                    if !row.subsumed {
422                        let target = &row.vals[output_idx];
423                        let mut updated = false;
424                        if let Some(new_cost) = self.compute_cost_hyperedge(egraph, &row, func) {
425                            match self
426                                .costs
427                                .get_mut(target_sort.name())
428                                .unwrap()
429                                .entry(*target)
430                            {
431                                HEntry::Vacant(e) => {
432                                    updated = true;
433                                    e.insert(new_cost);
434                                }
435                                HEntry::Occupied(mut e) => {
436                                    if new_cost < *(e.get()) {
437                                        updated = true;
438                                        e.insert(new_cost);
439                                    }
440                                }
441                            }
442                        }
443                        // record the chronological order of the updates
444                        // which serves as a topological order that avoids cycles
445                        // even when a term has a cost equal to its subterms
446                        if updated {
447                            ensure_fixpoint = false;
448                            self.topo_rnk_cnt += 1;
449                            self.topo_rnk
450                                .get_mut(target_sort.name())
451                                .unwrap()
452                                .insert(*target, self.topo_rnk_cnt);
453                        }
454                    }
455                };
456
457                egraph.backend.for_each(func.backend_id, relax_hyperedge);
458            }
459        }
460
461        // Save the edges for reconstruction
462        for func_name in funcs.iter() {
463            let func = egraph.functions.get(func_name).unwrap();
464            let target_sort = func.extraction_output_sort();
465            let output_idx = func.extraction_output_index();
466
467            let save_best_parent_edge = |row: egglog_bridge::FunctionRow| {
468                if !row.subsumed {
469                    let target = &row.vals[output_idx];
470                    if let Some(best_cost) = self.costs.get(target_sort.name()).unwrap().get(target)
471                        && Some(best_cost.clone())
472                            == self.compute_cost_hyperedge(egraph, &row, func)
473                    {
474                        // one of the possible best parent edges
475                        let target_topo_rnk = *self
476                            .topo_rnk
477                            .get(target_sort.name())
478                            .unwrap()
479                            .get(target)
480                            .unwrap();
481                        if target_topo_rnk > self.compute_topo_rnk_hyperedge(egraph, &row, func) {
482                            // one of the parent edges that avoids cycles
483                            if let HEntry::Vacant(e) = self
484                                .parent_edge
485                                .get_mut(target_sort.name())
486                                .unwrap()
487                                .entry(*target)
488                            {
489                                e.insert((func.decl.name.clone(), row.vals.to_vec()));
490                            }
491                        }
492                    }
493                }
494            };
495
496            egraph
497                .backend
498                .for_each(func.backend_id, save_best_parent_edge);
499        }
500    }
501
502    /// This recursively reconstruct the termdag that gives the minimum cost for eclass value.
503    fn reconstruct_termdag_node(
504        &self,
505        egraph: &EGraph,
506        termdag: &mut TermDag,
507        value: Value,
508        sort: &ArcSort,
509    ) -> TermId {
510        self.reconstruct_termdag_node_helper(egraph, termdag, value, sort, &mut Default::default())
511    }
512
513    fn reconstruct_termdag_node_helper(
514        &self,
515        egraph: &EGraph,
516        termdag: &mut TermDag,
517        value: Value,
518        sort: &ArcSort,
519        cache: &mut HashMap<(Value, String), TermId>,
520    ) -> TermId {
521        let key = (value, sort.name().to_owned());
522        if let Some(term) = cache.get(&key) {
523            return *term;
524        }
525
526        let term = if sort.is_container_sort() {
527            let elements = sort.inner_values(egraph.backend.container_values(), value);
528            let mut ch_terms: Vec<TermId> = Vec::new();
529            for ch in elements.iter() {
530                ch_terms.push(
531                    self.reconstruct_termdag_node_helper(egraph, termdag, ch.1, &ch.0, cache),
532                );
533            }
534            sort.reconstruct_termdag_container(
535                egraph.backend.container_values(),
536                value,
537                termdag,
538                ch_terms,
539            )
540        } else if sort.is_eq_sort() {
541            let (func_name, hyperedge) = self
542                .parent_edge
543                .get(sort.name())
544                .unwrap()
545                .get(&value)
546                .unwrap();
547            let func = egraph.functions.get(func_name).unwrap();
548            let ch_sorts = &func.schema.input;
549
550            let num_children = func.extraction_num_children();
551            let output_name = func.extraction_term_name();
552
553            let mut ch_terms: Vec<TermId> = Vec::new();
554            for (value, sort) in hyperedge.iter().take(num_children).zip(ch_sorts.iter()) {
555                ch_terms.push(
556                    self.reconstruct_termdag_node_helper(egraph, termdag, *value, sort, cache),
557                );
558            }
559            termdag.app(output_name.to_string(), ch_terms)
560        } else {
561            // Base value case
562            sort.reconstruct_termdag_base(egraph.backend.base_values(), value, termdag)
563        };
564
565        cache.insert(key, term);
566        term
567    }
568
569    /// Extract the best term of a value from a given sort.
570    ///
571    /// This function expects the sort to be already computed,
572    /// which can be one of the rootsorts, or reachable from rootsorts, or primitives, or containers of computed sorts.
573    pub fn extract_best_with_sort(
574        &self,
575        egraph: &EGraph,
576        termdag: &mut TermDag,
577        value: Value,
578        sort: ArcSort,
579    ) -> Option<(C, TermId)> {
580        // Canonicalize the value using the union-find if available (for term-encoding mode)
581        let canonical_value = self.find_canonical(egraph, value, &sort);
582
583        match self.compute_cost_node(egraph, canonical_value, &sort) {
584            Some(best_cost) => {
585                log::debug!("Best cost for the extract root: {best_cost:?}");
586
587                let term = self.reconstruct_termdag_node(egraph, termdag, canonical_value, &sort);
588
589                Some((best_cost, term))
590            }
591            None => {
592                log::error!("Unextractable root {value:?} with sort {sort:?}",);
593                None
594            }
595        }
596    }
597
598    /// A convenience method for extraction.
599    ///
600    /// This expects the value to be of the unique sort the extractor has been initialized with
601    pub fn extract_best(
602        &self,
603        egraph: &EGraph,
604        termdag: &mut TermDag,
605        value: Value,
606    ) -> Option<(C, TermId)> {
607        assert!(
608            self.rootsorts.len() == 1,
609            "extract_best requires a single rootsort"
610        );
611        self.extract_best_with_sort(
612            egraph,
613            termdag,
614            value,
615            self.rootsorts.first().unwrap().clone(),
616        )
617    }
618
619    /// Find the canonical representative of a value using the union-find table.
620    /// If no UF is registered for this sort, returns the original value.
621    /// The UF table stores (value, canonical) pairs - one hop lookup.
622    fn find_canonical(&self, egraph: &EGraph, value: Value, sort: &ArcSort) -> Value {
623        // Check if there's a UF registered for this sort
624        let Some(uf_name) = egraph.proof_state.uf_parent.get(sort.name()) else {
625            return value;
626        };
627
628        // Get the UF function
629        let Some(uf_func) = egraph.functions.get(uf_name) else {
630            return value;
631        };
632
633        // Single lookup in UF table - it's guaranteed to be one hop to canonical
634        let mut canonical = value;
635        egraph
636            .backend
637            .for_each(uf_func.backend_id, |row: egglog_bridge::FunctionRow| {
638                // UF table has (child, parent) as inputs
639                if row.vals[0] == value {
640                    canonical = row.vals[1];
641                }
642            });
643
644        canonical
645    }
646
647    /// Extract variants of an e-class.
648    ///
649    /// The variants are selected by first picking `nvairants` e-nodes with the lowest cost from the e-class
650    /// and then extracting a term from each e-node.
651    pub fn extract_variants_with_sort(
652        &self,
653        egraph: &EGraph,
654        termdag: &mut TermDag,
655        value: Value,
656        nvariants: usize,
657        sort: ArcSort,
658    ) -> Vec<(C, TermId)> {
659        debug_assert!(self.rootsorts.iter().any(|s| { s.name() == sort.name() }));
660
661        if sort.is_eq_sort() {
662            // Canonicalize the value using the union-find if available
663            let canonical_value = self.find_canonical(egraph, value, &sort);
664
665            let mut root_variants: Vec<(C, String, Vec<Value>)> = Vec::new();
666
667            let mut root_funcs: Vec<String> = Vec::new();
668
669            for func_name in self.funcs.iter() {
670                // Need an eq on sorts - use extraction_output_sort for view table support
671                if sort.name()
672                    == egraph
673                        .functions
674                        .get(func_name)
675                        .unwrap()
676                        .extraction_output_sort()
677                        .name()
678                {
679                    root_funcs.push(func_name.clone());
680                }
681            }
682
683            for func_name in root_funcs.iter() {
684                let func = egraph.functions.get(func_name).unwrap();
685                let output_idx = func.extraction_output_index();
686
687                let find_root_variants = |row: egglog_bridge::FunctionRow| {
688                    if !row.subsumed {
689                        let target = &row.vals[output_idx];
690                        if *target == canonical_value {
691                            let cost = self.compute_cost_hyperedge(egraph, &row, func).unwrap();
692                            root_variants.push((cost, func_name.clone(), row.vals.to_vec()));
693                        }
694                    }
695                };
696
697                egraph.backend.for_each(func.backend_id, find_root_variants);
698            }
699
700            let mut res: Vec<(C, TermId)> = Vec::new();
701            let mut cache: HashMap<(Value, String), TermId> = Default::default();
702            root_variants.sort();
703            root_variants.truncate(nvariants);
704            for (cost, func_name, hyperedge) in root_variants {
705                let mut ch_terms: Vec<TermId> = Vec::new();
706                let func = egraph.functions.get(&func_name).unwrap();
707                let ch_sorts = &func.schema.input;
708                let num_children = func.extraction_num_children();
709                // For view tables, children are all but the last input (which is the e-class)
710                for (value, sort) in hyperedge.iter().zip(ch_sorts.iter()).take(num_children) {
711                    ch_terms.push(self.reconstruct_termdag_node_helper(
712                        egraph, termdag, *value, sort, &mut cache,
713                    ));
714                }
715                // Use extraction_term_name for view tables (maps to the original constructor)
716                res.push((
717                    cost,
718                    termdag.app(func.extraction_term_name().to_string(), ch_terms),
719                ));
720            }
721
722            res
723        } else {
724            log::warn!(
725                "extracting multiple variants for containers or primitives is not implemented, returning a single variant."
726            );
727            if let Some(res) = self.extract_best_with_sort(egraph, termdag, value, sort) {
728                vec![res]
729            } else {
730                vec![]
731            }
732        }
733    }
734
735    /// A convenience method for extracting variants of a value.
736    ///
737    /// This expects the value to be of the unique sort the extractor has been initialized with.
738    pub fn extract_variants(
739        &self,
740        egraph: &EGraph,
741        termdag: &mut TermDag,
742        value: Value,
743        nvariants: usize,
744    ) -> Vec<(C, TermId)> {
745        assert!(
746            self.rootsorts.len() == 1,
747            "extract_variants requires a single rootsort"
748        );
749        self.extract_variants_with_sort(
750            egraph,
751            termdag,
752            value,
753            nvariants,
754            self.rootsorts.first().unwrap().clone(),
755        )
756    }
757}
758
759impl Function {
760    /// Returns the extraction head cost for this table.
761    /// View tables inherit the cost of their referenced hidden term constructor.
762    pub(crate) fn extraction_head_cost(&self, egraph: &EGraph) -> DefaultCost {
763        if let Some(term_constructor) = &self.decl.term_constructor {
764            egraph
765                .functions
766                .get(term_constructor)
767                .and_then(|func| func.decl.cost)
768                .unwrap_or(DefaultCost::unit())
769        } else {
770            self.decl.cost.unwrap_or(DefaultCost::unit())
771        }
772    }
773
774    /// For view tables (with term_constructor), the effective output sort is the last input column.
775    /// For regular tables, it's the output sort.
776    /// This is used by extraction to determine which sort a table produces values for.
777    pub(crate) fn extraction_output_sort(&self) -> &ArcSort {
778        if self.decl.term_constructor.is_some() {
779            self.schema.input.last().unwrap()
780        } else {
781            &self.schema.output
782        }
783    }
784
785    /// Returns the number of children for extraction purposes.
786    /// For view tables, this excludes the last column (the e-class).
787    pub(crate) fn extraction_num_children(&self) -> usize {
788        if self.decl.term_constructor.is_some() {
789            self.schema.input.len() - 1
790        } else {
791            self.schema.input.len()
792        }
793    }
794
795    /// Returns the name to use when building terms during extraction.
796    /// For view tables, this is the term_constructor name.
797    pub(crate) fn extraction_term_name(&self) -> &str {
798        self.decl
799            .term_constructor
800            .as_ref()
801            .unwrap_or(&self.decl.name)
802    }
803
804    /// Returns the index of the output value in a row for extraction purposes.
805    /// For view tables, the e-class is the last input column (second-to-last in the row).
806    /// For regular tables, it's the last column (the actual output).
807    pub(crate) fn extraction_output_index(&self) -> usize {
808        if self.decl.term_constructor.is_some() {
809            // For view tables: input is [children..., eclass], output is view_sort
810            // Row is [children..., eclass, view_sort]
811            // We want eclass which is at index input.len() - 1
812            self.schema.input.len() - 1
813        } else {
814            // For regular tables: row is [inputs..., output]
815            self.schema.input.len()
816        }
817    }
818}
819
820impl EGraph {
821    /// Extract a value to a [`TermDag`] and [`TermId`] in the [`TermDag`] using the default cost model.
822    /// See also [`EGraph::extract_value_with_cost_model`] for more control.
823    pub fn extract_value(
824        &self,
825        sort: &ArcSort,
826        value: Value,
827    ) -> Result<(TermDag, TermId, DefaultCost), Error> {
828        self.extract_value_with_cost_model(sort, value, TreeAdditiveCostModel::default())
829    }
830
831    /// Extract a value to a [`TermDag`] and [`TermId`] in the [`TermDag`].
832    /// Note that the `TermDag` may contain a superset of the nodes referenced by the returned `TermId`.
833    /// See also [`EGraph::extract_value_to_string`] for convenience.
834    pub fn extract_value_with_cost_model<CM: CostModel<DefaultCost> + 'static>(
835        &self,
836        sort: &ArcSort,
837        value: Value,
838        cost_model: CM,
839    ) -> Result<(TermDag, TermId, DefaultCost), Error> {
840        let extractor =
841            Extractor::compute_costs_from_rootsorts(Some(vec![sort.clone()]), self, cost_model);
842        let mut termdag = TermDag::default();
843        let (cost, term) = extractor.extract_best(self, &mut termdag, value).unwrap();
844        Ok((termdag, term, cost))
845    }
846
847    /// Extract a value to a string for printing.
848    /// See also [`EGraph::extract_value`] for more control.
849    pub fn extract_value_to_string(
850        &self,
851        sort: &ArcSort,
852        value: Value,
853    ) -> Result<(String, DefaultCost), Error> {
854        let (termdag, term, cost) = self.extract_value(sort, value)?;
855        Ok((termdag.to_string(term), cost))
856    }
857
858    /// For constructors and relations, the output column can be ignored
859    pub fn function_to_dag(
860        &self,
861        sym: &str,
862        n: usize,
863        include_output: bool,
864    ) -> Result<(Vec<TermId>, Option<Vec<TermId>>, TermDag), Error> {
865        let func = self
866            .functions
867            .get(sym)
868            .ok_or(TypeError::UnboundFunction(sym.to_owned(), span!()))?;
869        let mut rootsorts = func.schema.input.clone();
870        if include_output {
871            rootsorts.push(func.schema.output.clone());
872        }
873        let extractor = Extractor::compute_costs_from_rootsorts(
874            Some(rootsorts),
875            self,
876            TreeAdditiveCostModel::default(),
877        );
878
879        let mut termdag = TermDag::default();
880        let mut inputs: Vec<TermId> = Vec::new();
881        let mut output: Option<Vec<TermId>> = if include_output {
882            Some(Vec::new())
883        } else {
884            None
885        };
886
887        let extract_row = |row: egglog_bridge::FunctionRow| {
888            if inputs.len() < n {
889                // include subsumed rows
890                let mut children: Vec<TermId> = Vec::new();
891                for (value, sort) in row.vals.iter().zip(&func.schema.input) {
892                    let (_, term_id) = extractor
893                        .extract_best_with_sort(self, &mut termdag, *value, sort.clone())
894                        .unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
895                    children.push(term_id);
896                }
897                inputs.push(termdag.app(sym.to_owned(), children));
898                if include_output {
899                    let value = row.vals[func.schema.input.len()];
900                    let sort = &func.schema.output;
901                    let (_, term) = extractor
902                        .extract_best_with_sort(self, &mut termdag, value, sort.clone())
903                        .unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
904                    output.as_mut().unwrap().push(term);
905                }
906                true
907            } else {
908                false
909            }
910        };
911
912        self.backend.for_each_while(func.backend_id, extract_row);
913
914        Ok((inputs, output, termdag))
915    }
916}