egglog/
extract.rs

1use crate::ast::FunctionSubtype;
2use crate::termdag::{TermDag, TermId};
3use crate::util::{HashMap, HashSet};
4use crate::*;
5use std::collections::VecDeque;
6
7/// An interface for custom cost model.
8///
9/// To use it with the default extractor, the cost type must also satisfy `Ord + Eq + Clone + Debug`.
10/// Additionally, the cost model should guarantee that a term has a no-smaller cost
11/// than its subterms to avoid cycles in the extracted terms for common case usages.
12/// For more niche usages, a term can have a cost less than its subterms.
13/// As long as there is no negative cost cycle,
14/// the default extractor is guaranteed to terminate in computing the costs.
15/// However, the user needs to be careful to guarantee acyclicity in the extracted terms.
16pub trait CostModel<C: Cost> {
17    /// The total cost of a term given the cost of the root e-node and its immediate children's total costs.
18    fn fold(&self, head: &str, children_cost: &[C], head_cost: C) -> C;
19
20    /// The cost of an enode (without the cost of children)
21    fn enode_cost(&self, egraph: &EGraph, func: &Function, row: &egglog_bridge::FunctionRow) -> C;
22
23    /// The cost of a container value given the costs of its elements.
24    ///
25    /// The default cost for containers is just the sum of all the elements inside
26    fn container_cost(
27        &self,
28        egraph: &EGraph,
29        sort: &ArcSort,
30        value: Value,
31        element_costs: &[C],
32    ) -> C {
33        let _egraph = egraph;
34        let _sort = sort;
35        let _value = value;
36        element_costs
37            .iter()
38            .fold(C::identity(), |s, c| s.combine(c))
39    }
40
41    /// Compute the cost of a (non-container) primitive value.
42    ///
43    /// The default cost for base values is the constant one
44    fn base_value_cost(&self, egraph: &EGraph, sort: &ArcSort, value: Value) -> C {
45        let _egraph = egraph;
46        let _sort = sort;
47        let _value = value;
48        C::unit()
49    }
50}
51
52/// Requirements for a type to be usable as a cost by a [`CostModel`].
53pub trait Cost {
54    /// An identity element, usually zero.
55    fn identity() -> Self;
56
57    /// The default cost for a node with no children, usually one.
58    fn unit() -> Self;
59
60    /// A binary operation to combine costs, usually addition.
61    /// This operation must NOT overflow or panic when given large values!
62    fn combine(self, other: &Self) -> Self;
63}
64
65macro_rules! cost_impl_int {
66    ($($cost:ty),*) => {$(
67        impl Cost for $cost {
68            fn identity() -> Self { 0 }
69            fn unit()     -> Self { 1 }
70            fn combine(self, other: &Self) -> Self {
71                self.saturating_add(*other)
72            }
73        }
74    )*};
75}
76cost_impl_int!(u8, u16, u32, u64, u128, usize);
77cost_impl_int!(i8, i16, i32, i64, i128, isize);
78
79macro_rules! cost_impl_num {
80    ($($cost:ty),*) => {$(
81        impl Cost for $cost {
82            fn identity() -> Self {
83                use num::Zero;
84                Self::zero()
85            }
86            fn unit() -> Self {
87                use num::One;
88                Self::one()
89            }
90            fn combine(self, other: &Self) -> Self {
91                self + other
92            }
93        }
94    )*};
95}
96cost_impl_num!(num::BigInt, num::BigRational);
97use ordered_float::OrderedFloat;
98cost_impl_num!(f32, f64, OrderedFloat<f32>, OrderedFloat<f64>);
99
100pub type DefaultCost = u64;
101
102/// A cost model that computes the cost by summing the cost of each node.
103#[derive(Default, Clone)]
104pub struct TreeAdditiveCostModel {}
105
106impl CostModel<DefaultCost> for TreeAdditiveCostModel {
107    fn fold(
108        &self,
109        _head: &str,
110        children_cost: &[DefaultCost],
111        head_cost: DefaultCost,
112    ) -> DefaultCost {
113        children_cost.iter().fold(head_cost, |s, c| s.combine(c))
114    }
115
116    fn enode_cost(
117        &self,
118        _egraph: &EGraph,
119        func: &Function,
120        _row: &egglog_bridge::FunctionRow,
121    ) -> DefaultCost {
122        func.decl.cost.unwrap_or(DefaultCost::unit())
123    }
124}
125
126/// The default, Bellman-Ford like extractor. This extractor is optimal for [`CostModel`].
127///
128/// Note that this assumes optimal substructure in the cost model, that is, a lower-cost
129/// subterm should always lead to a non-worse superterm, to guarantee the extracted term
130/// being optimal under the given cost model.
131/// If this is not followed, the extractor may panic on reconstruction
132pub struct Extractor<C: Cost + Ord + Eq + Clone + Debug> {
133    rootsorts: Vec<ArcSort>,
134    funcs: Vec<String>,
135    cost_model: Box<dyn CostModel<C>>,
136    costs: HashMap<String, HashMap<Value, C>>,
137    topo_rnk_cnt: usize,
138    topo_rnk: HashMap<String, HashMap<Value, usize>>,
139    parent_edge: HashMap<String, HashMap<Value, (String, Vec<Value>)>>,
140}
141
142/// Options for configuring extraction behavior.
143struct ExtractionOptions<C: Cost> {
144    /// The cost model to use for extraction.
145    cost_model: Box<dyn CostModel<C>>,
146    /// Root sorts to extract from. If None, all extractable root sorts are used.
147    rootsorts: Option<Vec<ArcSort>>,
148    /// Whether to respect the unextractable flag on constructors.
149    /// When true, constructors marked as unextractable will not be used during extraction.
150    respect_unextractable: bool,
151    /// Whether to skip view tables (those with term_constructor annotations).
152    /// When true, view tables are skipped, which is useful for proof extraction
153    /// where we need to extract from the original term tables with their original names.
154    skip_view_tables: bool,
155    /// Whether to respect the hidden flag on constructors.
156    /// When true, constructors marked as hidden will not be used during extraction.
157    respect_hidden: bool,
158}
159
160impl<C: Cost + Ord + Eq + Clone + Debug> Extractor<C> {
161    /// Bulk of the computation happens at initialization time.
162    /// The later extractions only reuses saved results.
163    /// This means a new extractor must be created if the egraph changes.
164    /// Holding a reference to the egraph would enforce this but prevents the extractor being reused.
165    ///
166    /// For convenience, if the rootsorts is `None`, it defaults to extract all extractable rootsorts.
167    pub fn compute_costs_from_rootsorts(
168        rootsorts: Option<Vec<ArcSort>>,
169        egraph: &EGraph,
170        cost_model: impl CostModel<C> + 'static,
171    ) -> Self {
172        // For user extraction: respect unextractable and hidden, but use view tables (they have better names)
173        Self::compute_costs_from_rootsorts_internal(
174            egraph,
175            ExtractionOptions {
176                cost_model: Box::new(cost_model),
177                rootsorts,
178                respect_unextractable: true,
179                skip_view_tables: false,
180                respect_hidden: true,
181            },
182        )
183    }
184
185    /// Like `compute_costs_from_rootsorts`, but ignores the unextractable and hidden flags.
186    /// This is used for proof extraction where we need to extract proofs even
187    /// from terms that are marked unextractable (like global let bindings).
188    /// Also skips view tables (those with term_constructor) since proofs need
189    /// to extract from the original term tables with their original names.
190    pub(crate) fn compute_costs_from_rootsorts_allow_unextractable(
191        rootsorts: Option<Vec<ArcSort>>,
192        egraph: &EGraph,
193        cost_model: impl CostModel<C> + 'static,
194    ) -> Self {
195        Self::compute_costs_from_rootsorts_internal(
196            egraph,
197            ExtractionOptions {
198                cost_model: Box::new(cost_model),
199                rootsorts,
200                respect_unextractable: false,
201                skip_view_tables: true,
202                respect_hidden: false,
203            },
204        )
205    }
206
207    fn compute_costs_from_rootsorts_internal(
208        egraph: &EGraph,
209        options: ExtractionOptions<C>,
210    ) -> Self {
211        // We filter out tables unreachable from the root sorts
212        let extract_all_sorts = options.rootsorts.is_none();
213
214        let mut rootsorts = options.rootsorts.unwrap_or_default();
215
216        // Built a reverse index from output sort to function head symbols
217        // Only include constructors (not regular functions) and respect unextractable flag
218        let mut rev_index: HashMap<String, Vec<String>> = Default::default();
219        for func in egraph.functions.iter() {
220            let unextractable = func.1.decl.unextractable && options.respect_unextractable;
221            let should_skip_view =
222                options.skip_view_tables && func.1.decl.term_constructor.is_some();
223            let hidden = func.1.decl.internal_hidden && options.respect_hidden;
224
225            // only extract constructors, skip view tables when requested for proof extraction, and respect unextractable/hidden flag
226            if !unextractable
227                && !should_skip_view
228                && !hidden
229                && func.1.decl.subtype == FunctionSubtype::Constructor
230            {
231                let func_name = func.0.clone();
232                // For view tables (with term_constructor in proof mode), the e-class is the last input column
233                let output_sort_name = func.1.extraction_output_sort().name();
234                if let Some(v) = rev_index.get_mut(output_sort_name) {
235                    v.push(func_name);
236                } else {
237                    rev_index.insert(output_sort_name.to_owned(), vec![func_name]);
238                    if extract_all_sorts {
239                        rootsorts.push(func.1.extraction_output_sort().clone());
240                    }
241                }
242            }
243        }
244
245        // Do a BFS to find reachable tables
246        let mut q: VecDeque<ArcSort> = VecDeque::new();
247        let mut seen: HashSet<String> = Default::default();
248        for rootsort in rootsorts.iter() {
249            q.push_back(rootsort.clone());
250            seen.insert(rootsort.name().to_owned());
251        }
252
253        let mut funcs_set: HashSet<String> = Default::default();
254        let mut funcs: Vec<String> = Vec::new();
255        while !q.is_empty() {
256            let sort = q.pop_front().unwrap();
257            if sort.is_container_sort() {
258                let inner_sorts = sort.inner_sorts();
259                for s in inner_sorts {
260                    if !seen.contains(s.name()) {
261                        q.push_back(s.clone());
262                        seen.insert(s.name().to_owned());
263                    }
264                }
265            } else if sort.is_eq_sort() {
266                if let Some(head_symbols) = rev_index.get(sort.name()) {
267                    for h in head_symbols {
268                        if !funcs_set.contains(h) {
269                            let func = egraph.functions.get(h).unwrap();
270                            // For view tables, children are all but the last input (which is the e-class)
271                            let num_children = func.extraction_num_children();
272                            for ch in func.schema.input.iter().take(num_children) {
273                                let ch_name = ch.name();
274                                if !seen.contains(ch_name) {
275                                    q.push_back(ch.clone());
276                                    seen.insert(ch_name.to_owned());
277                                }
278                            }
279                            funcs_set.insert(h.clone());
280                            funcs.push(h.clone());
281                        }
282                    }
283                }
284            }
285        }
286
287        // Initialize the tables to have the reachable entries
288        let mut costs: HashMap<String, HashMap<Value, C>> = Default::default();
289        let mut topo_rnk: HashMap<String, HashMap<Value, usize>> = Default::default();
290        let mut parent_edge: HashMap<String, HashMap<Value, (String, Vec<Value>)>> =
291            Default::default();
292
293        for func_name in funcs.iter() {
294            let func = egraph.functions.get(func_name).unwrap();
295            let output_sort_name = func.extraction_output_sort().name();
296            if !costs.contains_key(output_sort_name) {
297                costs.insert(output_sort_name.to_owned(), Default::default());
298                topo_rnk.insert(output_sort_name.to_owned(), Default::default());
299                parent_edge.insert(output_sort_name.to_owned(), Default::default());
300            }
301        }
302
303        let mut extractor = Extractor {
304            rootsorts,
305            funcs,
306            cost_model: options.cost_model,
307            costs,
308            topo_rnk_cnt: 0,
309            topo_rnk,
310            parent_edge,
311        };
312
313        extractor.bellman_ford(egraph);
314
315        extractor
316    }
317
318    /// Compute the cost of a single enode
319    /// Recurse if container
320    /// Returns None if contains an undefined eqsort term (potentially after unfolding)
321    fn compute_cost_node(&self, egraph: &EGraph, value: Value, sort: &ArcSort) -> Option<C> {
322        if sort.is_container_sort() {
323            let elements = sort.inner_values(egraph.backend.container_values(), value);
324            let mut ch_costs: Vec<C> = Vec::new();
325            for ch in elements.iter() {
326                ch_costs.push(self.compute_cost_node(egraph, ch.1, &ch.0)?);
327            }
328            Some(
329                self.cost_model
330                    .container_cost(egraph, sort, value, &ch_costs),
331            )
332        } else if sort.is_eq_sort() {
333            self.costs.get(sort.name())?.get(&value).cloned()
334        } else {
335            // Primitive
336            Some(self.cost_model.base_value_cost(egraph, sort, value))
337        }
338    }
339
340    /// A row in a constructor table is a hyperedge from the set of input terms to the constructed output term.
341    fn compute_cost_hyperedge(
342        &self,
343        egraph: &EGraph,
344        row: &egglog_bridge::FunctionRow,
345        func: &Function,
346    ) -> Option<C> {
347        let mut ch_costs: Vec<C> = Vec::new();
348        let sorts = &func.schema.input;
349        let num_children = func.extraction_num_children();
350        for (value, sort) in row.vals.iter().take(num_children).zip(sorts.iter()) {
351            ch_costs.push(self.compute_cost_node(egraph, *value, sort)?);
352        }
353        let head_name = func.extraction_term_name();
354        Some(self.cost_model.fold(
355            head_name,
356            &ch_costs,
357            self.cost_model.enode_cost(egraph, func, row),
358        ))
359    }
360
361    fn compute_topo_rnk_node(&self, egraph: &EGraph, value: Value, sort: &ArcSort) -> usize {
362        if sort.is_container_sort() {
363            sort.inner_values(egraph.backend.container_values(), value)
364                .iter()
365                .fold(0, |ret, (sort, value)| {
366                    usize::max(ret, self.compute_topo_rnk_node(egraph, *value, sort))
367                })
368        } else if sort.is_eq_sort() {
369            if let Some(t) = self.topo_rnk.get(sort.name()) {
370                *t.get(&value).unwrap_or(&usize::MAX)
371            } else {
372                usize::MAX
373            }
374        } else {
375            0
376        }
377    }
378
379    fn compute_topo_rnk_hyperedge(
380        &self,
381        egraph: &EGraph,
382        row: &egglog_bridge::FunctionRow,
383        func: &Function,
384    ) -> usize {
385        let sorts = &func.schema.input;
386        let num_children = func.extraction_num_children();
387        row.vals
388            .iter()
389            .take(num_children)
390            .zip(sorts.iter())
391            .fold(0, |ret, (value, sort)| {
392                usize::max(ret, self.compute_topo_rnk_node(egraph, *value, sort))
393            })
394    }
395
396    /// We use Bellman-Ford to compute the costs of the relevant eq sorts' terms
397    /// [Bellman-Ford](https://en.wikipedia.org/wiki/Bellman%E2%80%93Ford_algorithm) is a shortest path algorithm.
398    /// The version implemented here computes the shortest path from any node in a set of sources to all the reachable nodes.
399    /// Computing the minimum cost for terms is treated as a shortest path problem on a hypergraph here.
400    /// In this hypergraph, the nodes corresponde to eclasses, the distances are the costs to extract a term of those eclasses,
401    /// and each enode is a hyperedge that goes from the set of children eclasses to the enode's eclass.
402    /// The sources are the eclasses with known costs from the cost model.
403    /// Additionally, to avoid cycles in the extraction even when the cost model can assign an equal cost to a term and its subterm.
404    /// It computes a topological rank for each eclass
405    /// and only allows each eclass to have children of classes of strictly smaller ranks in the extraction.
406    fn bellman_ford(&mut self, egraph: &EGraph) {
407        let mut ensure_fixpoint = false;
408
409        let funcs = self.funcs.clone();
410
411        while !ensure_fixpoint {
412            ensure_fixpoint = true;
413
414            for func_name in funcs.iter() {
415                let func = egraph.functions.get(func_name).unwrap();
416                let target_sort = func.extraction_output_sort();
417
418                let output_idx = func.extraction_output_index();
419                let relax_hyperedge = |row: egglog_bridge::FunctionRow| {
420                    if !row.subsumed {
421                        let target = &row.vals[output_idx];
422                        let mut updated = false;
423                        if let Some(new_cost) = self.compute_cost_hyperedge(egraph, &row, func) {
424                            match self
425                                .costs
426                                .get_mut(target_sort.name())
427                                .unwrap()
428                                .entry(*target)
429                            {
430                                HEntry::Vacant(e) => {
431                                    updated = true;
432                                    e.insert(new_cost);
433                                }
434                                HEntry::Occupied(mut e) => {
435                                    if new_cost < *(e.get()) {
436                                        updated = true;
437                                        e.insert(new_cost);
438                                    }
439                                }
440                            }
441                        }
442                        // record the chronological order of the updates
443                        // which serves as a topological order that avoids cycles
444                        // even when a term has a cost equal to its subterms
445                        if updated {
446                            ensure_fixpoint = false;
447                            self.topo_rnk_cnt += 1;
448                            self.topo_rnk
449                                .get_mut(target_sort.name())
450                                .unwrap()
451                                .insert(*target, self.topo_rnk_cnt);
452                        }
453                    }
454                };
455
456                egraph.backend.for_each(func.backend_id, relax_hyperedge);
457            }
458        }
459
460        // Save the edges for reconstruction
461        for func_name in funcs.iter() {
462            let func = egraph.functions.get(func_name).unwrap();
463            let target_sort = func.extraction_output_sort();
464            let output_idx = func.extraction_output_index();
465
466            let save_best_parent_edge = |row: egglog_bridge::FunctionRow| {
467                if !row.subsumed {
468                    let target = &row.vals[output_idx];
469                    if let Some(best_cost) = self.costs.get(target_sort.name()).unwrap().get(target)
470                    {
471                        if Some(best_cost.clone())
472                            == self.compute_cost_hyperedge(egraph, &row, func)
473                        {
474                            // one of the possible best parent edges
475                            let target_topo_rnk = *self
476                                .topo_rnk
477                                .get(target_sort.name())
478                                .unwrap()
479                                .get(target)
480                                .unwrap();
481                            if target_topo_rnk > self.compute_topo_rnk_hyperedge(egraph, &row, func)
482                            {
483                                // one of the parent edges that avoids cycles
484                                if let HEntry::Vacant(e) = self
485                                    .parent_edge
486                                    .get_mut(target_sort.name())
487                                    .unwrap()
488                                    .entry(*target)
489                                {
490                                    e.insert((func.decl.name.clone(), row.vals.to_vec()));
491                                }
492                            }
493                        }
494                    }
495                }
496            };
497
498            egraph
499                .backend
500                .for_each(func.backend_id, save_best_parent_edge);
501        }
502    }
503
504    /// This recursively reconstruct the termdag that gives the minimum cost for eclass value.
505    fn reconstruct_termdag_node(
506        &self,
507        egraph: &EGraph,
508        termdag: &mut TermDag,
509        value: Value,
510        sort: &ArcSort,
511    ) -> TermId {
512        self.reconstruct_termdag_node_helper(egraph, termdag, value, sort, &mut Default::default())
513    }
514
515    fn reconstruct_termdag_node_helper(
516        &self,
517        egraph: &EGraph,
518        termdag: &mut TermDag,
519        value: Value,
520        sort: &ArcSort,
521        cache: &mut HashMap<(Value, String), TermId>,
522    ) -> TermId {
523        let key = (value, sort.name().to_owned());
524        if let Some(term) = cache.get(&key) {
525            return *term;
526        }
527
528        let term = if sort.is_container_sort() {
529            let elements = sort.inner_values(egraph.backend.container_values(), value);
530            let mut ch_terms: Vec<TermId> = Vec::new();
531            for ch in elements.iter() {
532                ch_terms.push(
533                    self.reconstruct_termdag_node_helper(egraph, termdag, ch.1, &ch.0, cache),
534                );
535            }
536            sort.reconstruct_termdag_container(
537                egraph.backend.container_values(),
538                value,
539                termdag,
540                ch_terms,
541            )
542        } else if sort.is_eq_sort() {
543            let (func_name, hyperedge) = self
544                .parent_edge
545                .get(sort.name())
546                .unwrap()
547                .get(&value)
548                .unwrap();
549            let func = egraph.functions.get(func_name).unwrap();
550            let ch_sorts = &func.schema.input;
551
552            let num_children = func.extraction_num_children();
553            let output_name = func.extraction_term_name();
554
555            let mut ch_terms: Vec<TermId> = Vec::new();
556            for (value, sort) in hyperedge.iter().take(num_children).zip(ch_sorts.iter()) {
557                ch_terms.push(
558                    self.reconstruct_termdag_node_helper(egraph, termdag, *value, sort, cache),
559                );
560            }
561            termdag.app(output_name.to_string(), ch_terms)
562        } else {
563            // Base value case
564            sort.reconstruct_termdag_base(egraph.backend.base_values(), value, termdag)
565        };
566
567        cache.insert(key, term);
568        term
569    }
570
571    /// Extract the best term of a value from a given sort.
572    ///
573    /// This function expects the sort to be already computed,
574    /// which can be one of the rootsorts, or reachable from rootsorts, or primitives, or containers of computed sorts.
575    pub fn extract_best_with_sort(
576        &self,
577        egraph: &EGraph,
578        termdag: &mut TermDag,
579        value: Value,
580        sort: ArcSort,
581    ) -> Option<(C, TermId)> {
582        // Canonicalize the value using the union-find if available (for term-encoding mode)
583        let canonical_value = self.find_canonical(egraph, value, &sort);
584
585        match self.compute_cost_node(egraph, canonical_value, &sort) {
586            Some(best_cost) => {
587                log::debug!("Best cost for the extract root: {best_cost:?}");
588
589                let term = self.reconstruct_termdag_node(egraph, termdag, canonical_value, &sort);
590
591                Some((best_cost, term))
592            }
593            None => {
594                log::error!("Unextractable root {value:?} with sort {sort:?}",);
595                None
596            }
597        }
598    }
599
600    /// A convenience method for extraction.
601    ///
602    /// This expects the value to be of the unique sort the extractor has been initialized with
603    pub fn extract_best(
604        &self,
605        egraph: &EGraph,
606        termdag: &mut TermDag,
607        value: Value,
608    ) -> Option<(C, TermId)> {
609        assert!(
610            self.rootsorts.len() == 1,
611            "extract_best requires a single rootsort"
612        );
613        self.extract_best_with_sort(
614            egraph,
615            termdag,
616            value,
617            self.rootsorts.first().unwrap().clone(),
618        )
619    }
620
621    /// Find the canonical representative of a value using the union-find table.
622    /// If no UF is registered for this sort, returns the original value.
623    /// The UF table stores (value, canonical) pairs - one hop lookup.
624    fn find_canonical(&self, egraph: &EGraph, value: Value, sort: &ArcSort) -> Value {
625        // Check if there's a UF registered for this sort
626        let Some(uf_name) = egraph.proof_state.uf_parent.get(sort.name()) else {
627            return value;
628        };
629
630        // Get the UF function
631        let Some(uf_func) = egraph.functions.get(uf_name) else {
632            return value;
633        };
634
635        // Single lookup in UF table - it's guaranteed to be one hop to canonical
636        let mut canonical = value;
637        egraph
638            .backend
639            .for_each(uf_func.backend_id, |row: egglog_bridge::FunctionRow| {
640                // UF table has (child, parent) as inputs
641                if row.vals[0] == value {
642                    canonical = row.vals[1];
643                }
644            });
645
646        canonical
647    }
648
649    /// Extract variants of an e-class.
650    ///
651    /// The variants are selected by first picking `nvairants` e-nodes with the lowest cost from the e-class
652    /// and then extracting a term from each e-node.
653    pub fn extract_variants_with_sort(
654        &self,
655        egraph: &EGraph,
656        termdag: &mut TermDag,
657        value: Value,
658        nvariants: usize,
659        sort: ArcSort,
660    ) -> Vec<(C, TermId)> {
661        debug_assert!(self.rootsorts.iter().any(|s| { s.name() == sort.name() }));
662
663        if sort.is_eq_sort() {
664            // Canonicalize the value using the union-find if available
665            let canonical_value = self.find_canonical(egraph, value, &sort);
666
667            let mut root_variants: Vec<(C, String, Vec<Value>)> = Vec::new();
668
669            let mut root_funcs: Vec<String> = Vec::new();
670
671            for func_name in self.funcs.iter() {
672                // Need an eq on sorts - use extraction_output_sort for view table support
673                if sort.name()
674                    == egraph
675                        .functions
676                        .get(func_name)
677                        .unwrap()
678                        .extraction_output_sort()
679                        .name()
680                {
681                    root_funcs.push(func_name.clone());
682                }
683            }
684
685            for func_name in root_funcs.iter() {
686                let func = egraph.functions.get(func_name).unwrap();
687                let output_idx = func.extraction_output_index();
688
689                let find_root_variants = |row: egglog_bridge::FunctionRow| {
690                    if !row.subsumed {
691                        let target = &row.vals[output_idx];
692                        if *target == canonical_value {
693                            let cost = self.compute_cost_hyperedge(egraph, &row, func).unwrap();
694                            root_variants.push((cost, func_name.clone(), row.vals.to_vec()));
695                        }
696                    }
697                };
698
699                egraph.backend.for_each(func.backend_id, find_root_variants);
700            }
701
702            let mut res: Vec<(C, TermId)> = Vec::new();
703            let mut cache: HashMap<(Value, String), TermId> = Default::default();
704            root_variants.sort();
705            root_variants.truncate(nvariants);
706            for (cost, func_name, hyperedge) in root_variants {
707                let mut ch_terms: Vec<TermId> = Vec::new();
708                let func = egraph.functions.get(&func_name).unwrap();
709                let ch_sorts = &func.schema.input;
710                let num_children = func.extraction_num_children();
711                // For view tables, children are all but the last input (which is the e-class)
712                for (value, sort) in hyperedge.iter().zip(ch_sorts.iter()).take(num_children) {
713                    ch_terms.push(self.reconstruct_termdag_node_helper(
714                        egraph, termdag, *value, sort, &mut cache,
715                    ));
716                }
717                // Use extraction_term_name for view tables (maps to the original constructor)
718                res.push((
719                    cost,
720                    termdag.app(func.extraction_term_name().to_string(), ch_terms),
721                ));
722            }
723
724            res
725        } else {
726            log::warn!(
727                "extracting multiple variants for containers or primitives is not implemented, returning a single variant."
728            );
729            if let Some(res) = self.extract_best_with_sort(egraph, termdag, value, sort) {
730                vec![res]
731            } else {
732                vec![]
733            }
734        }
735    }
736
737    /// A convenience method for extracting variants of a value.
738    ///
739    /// This expects the value to be of the unique sort the extractor has been initialized with.
740    pub fn extract_variants(
741        &self,
742        egraph: &EGraph,
743        termdag: &mut TermDag,
744        value: Value,
745        nvariants: usize,
746    ) -> Vec<(C, TermId)> {
747        assert!(
748            self.rootsorts.len() == 1,
749            "extract_variants requires a single rootsort"
750        );
751        self.extract_variants_with_sort(
752            egraph,
753            termdag,
754            value,
755            nvariants,
756            self.rootsorts.first().unwrap().clone(),
757        )
758    }
759}
760
761impl Function {
762    /// For view tables (with term_constructor), the effective output sort is the last input column.
763    /// For regular tables, it's the output sort.
764    /// This is used by extraction to determine which sort a table produces values for.
765    pub(crate) fn extraction_output_sort(&self) -> &ArcSort {
766        if self.decl.term_constructor.is_some() {
767            self.schema.input.last().unwrap()
768        } else {
769            &self.schema.output
770        }
771    }
772
773    /// Returns the number of children for extraction purposes.
774    /// For view tables, this excludes the last column (the e-class).
775    pub(crate) fn extraction_num_children(&self) -> usize {
776        if self.decl.term_constructor.is_some() {
777            self.schema.input.len() - 1
778        } else {
779            self.schema.input.len()
780        }
781    }
782
783    /// Returns the name to use when building terms during extraction.
784    /// For view tables, this is the term_constructor name.
785    pub(crate) fn extraction_term_name(&self) -> &str {
786        self.decl
787            .term_constructor
788            .as_ref()
789            .unwrap_or(&self.decl.name)
790    }
791
792    /// Returns the index of the output value in a row for extraction purposes.
793    /// For view tables, the e-class is the last input column (second-to-last in the row).
794    /// For regular tables, it's the last column (the actual output).
795    pub(crate) fn extraction_output_index(&self) -> usize {
796        if self.decl.term_constructor.is_some() {
797            // For view tables: input is [children..., eclass], output is view_sort
798            // Row is [children..., eclass, view_sort]
799            // We want eclass which is at index input.len() - 1
800            self.schema.input.len() - 1
801        } else {
802            // For regular tables: row is [inputs..., output]
803            self.schema.input.len()
804        }
805    }
806}
807
808impl EGraph {
809    /// Extract a value to a [`TermDag`] and [`TermId`] in the [`TermDag`] using the default cost model.
810    /// See also [`EGraph::extract_value_with_cost_model`] for more control.
811    pub fn extract_value(
812        &self,
813        sort: &ArcSort,
814        value: Value,
815    ) -> Result<(TermDag, TermId, DefaultCost), Error> {
816        self.extract_value_with_cost_model(sort, value, TreeAdditiveCostModel::default())
817    }
818
819    /// Extract a value to a [`TermDag`] and [`TermId`] in the [`TermDag`].
820    /// Note that the `TermDag` may contain a superset of the nodes referenced by the returned `TermId`.
821    /// See also [`EGraph::extract_value_to_string`] for convenience.
822    pub fn extract_value_with_cost_model<CM: CostModel<DefaultCost> + 'static>(
823        &self,
824        sort: &ArcSort,
825        value: Value,
826        cost_model: CM,
827    ) -> Result<(TermDag, TermId, DefaultCost), Error> {
828        let extractor =
829            Extractor::compute_costs_from_rootsorts(Some(vec![sort.clone()]), self, cost_model);
830        let mut termdag = TermDag::default();
831        let (cost, term) = extractor.extract_best(self, &mut termdag, value).unwrap();
832        Ok((termdag, term, cost))
833    }
834
835    /// Extract a value to a string for printing.
836    /// See also [`EGraph::extract_value`] for more control.
837    pub fn extract_value_to_string(
838        &self,
839        sort: &ArcSort,
840        value: Value,
841    ) -> Result<(String, DefaultCost), Error> {
842        let (termdag, term, cost) = self.extract_value(sort, value)?;
843        Ok((termdag.to_string(term), cost))
844    }
845
846    /// For constructors and relations, the output column can be ignored
847    pub fn function_to_dag(
848        &self,
849        sym: &str,
850        n: usize,
851        include_output: bool,
852    ) -> Result<(Vec<TermId>, Option<Vec<TermId>>, TermDag), Error> {
853        let func = self
854            .functions
855            .get(sym)
856            .ok_or(TypeError::UnboundFunction(sym.to_owned(), span!()))?;
857        let mut rootsorts = func.schema.input.clone();
858        if include_output {
859            rootsorts.push(func.schema.output.clone());
860        }
861        let extractor = Extractor::compute_costs_from_rootsorts(
862            Some(rootsorts),
863            self,
864            TreeAdditiveCostModel::default(),
865        );
866
867        let mut termdag = TermDag::default();
868        let mut inputs: Vec<TermId> = Vec::new();
869        let mut output: Option<Vec<TermId>> = if include_output {
870            Some(Vec::new())
871        } else {
872            None
873        };
874
875        let extract_row = |row: egglog_bridge::FunctionRow| {
876            if inputs.len() < n {
877                // include subsumed rows
878                let mut children: Vec<TermId> = Vec::new();
879                for (value, sort) in row.vals.iter().zip(&func.schema.input) {
880                    let (_, term_id) = extractor
881                        .extract_best_with_sort(self, &mut termdag, *value, sort.clone())
882                        .unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
883                    children.push(term_id);
884                }
885                inputs.push(termdag.app(sym.to_owned(), children));
886                if include_output {
887                    let value = row.vals[func.schema.input.len()];
888                    let sort = &func.schema.output;
889                    let (_, term) = extractor
890                        .extract_best_with_sort(self, &mut termdag, value, sort.clone())
891                        .unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
892                    output.as_mut().unwrap().push(term);
893                }
894                true
895            } else {
896                false
897            }
898        };
899
900        self.backend.for_each_while(func.backend_id, extract_row);
901
902        Ok((inputs, output, termdag))
903    }
904}