egglog/
serialize.rs

1use crate::{util::HashMap, *};
2use core_relations::BaseValuePrinter;
3use ordered_float::NotNan;
4use std::collections::VecDeque;
5
6pub struct SerializeConfig {
7    // Maximumum number of functions to include in the serialized graph, any after this will be discarded
8    pub max_functions: Option<usize>,
9    // Maximum number of calls to include per function, any after this will be discarded
10    pub max_calls_per_function: Option<usize>,
11    // Whether to include temporary functions in the serialized graph
12    pub include_temporary_functions: bool,
13    // Root eclasses to include in the output
14    pub root_eclasses: Vec<(ArcSort, Value)>,
15}
16
17/// Output of serializing an e-graph, including values that were omitted if any.
18pub struct SerializeOutput {
19    /// The serialized e-graph.
20    pub egraph: egraph_serialize::EGraph,
21    /// Functions with more calls than max_calls_per_function, so that not all values are included.
22    pub truncated_functions: Vec<String>,
23    /// Functions that were discarded from the output, because more functions were present than max_functions
24    pub discarded_functions: Vec<String>,
25}
26
27impl SerializeOutput {
28    /// Returns true if the serialization is complete and no functions were truncated or discarded.
29    pub fn is_complete(&self) -> bool {
30        self.truncated_functions.is_empty() && self.discarded_functions.is_empty()
31    }
32    /// Description of what was omitted from the e-graph
33    pub fn omitted_description(&self) -> String {
34        let mut msg = String::new();
35        if !self.discarded_functions.is_empty() {
36            msg.push_str(&format!(
37                "Omitted: {}\n",
38                self.discarded_functions.join(", ")
39            ));
40        }
41        if !self.truncated_functions.is_empty() {
42            msg.push_str(&format!(
43                "Truncated: {}\n",
44                self.truncated_functions.join(", ")
45            ));
46        }
47        msg
48    }
49}
50
51#[allow(dead_code)]
52struct Serializer {
53    node_ids: NodeIDs,
54    result: egraph_serialize::EGraph,
55    let_bindings: HashMap<egraph_serialize::ClassId, Vec<String>>,
56}
57
58/// Default is used for exporting JSON and will output all nodes.
59impl Default for SerializeConfig {
60    fn default() -> Self {
61        SerializeConfig {
62            max_functions: None,
63            max_calls_per_function: None,
64            include_temporary_functions: false,
65            root_eclasses: vec![],
66        }
67    }
68}
69
70/// A node in the serialized egraph.
71#[derive(PartialEq, Debug, Clone)]
72pub enum SerializedNode {
73    /// A user defined function call.
74    Function {
75        /// The name of the function.
76        name: String,
77        /// The offset of the index in the table.
78        /// This can be resolved to the output and input values with table.get_index(offset, true).
79        offset: usize,
80    },
81    /// A primitive value.
82    Primitive(Value),
83    /// A dummy node used to represent omitted nodes.
84    Dummy(Value),
85    /// A node that was split into multiple e-classes.
86    Split(Box<SerializedNode>),
87}
88
89impl SerializedNode {
90    /// Returns true if the node is a primitive value.
91    pub fn is_primitive(&self) -> bool {
92        match self {
93            SerializedNode::Primitive(_) => true,
94            SerializedNode::Split(node) => node.is_primitive(),
95            _ => false,
96        }
97    }
98}
99
100impl EGraph {
101    /// Serialize the egraph into a format that can be read by the egraph-serialize crate.
102    ///
103    /// There are multiple different semantically valid ways to do this. This is how this implementation does it:
104    ///
105    /// For node costs:
106    /// - Primitives: 1.0
107    /// - Function without costs: 1.0
108    /// - Function with costs: the cost
109    /// - Omitted nodes: infinite
110    ///
111    /// For node IDs:
112    /// - Functions: Function name + hash of input values
113    /// - Args which are eq sorts: Choose one ID from the e-class, distribute roughly evenly.
114    /// - Args and outputs values which are primitives: Sort name + hash of value
115    ///
116    /// For e-classes IDs:
117    /// - tag and value of canonicalized value
118    ///
119    /// This is to achieve the following properties:
120    /// - Equivalent primitive values will show up once in the e-graph.
121    /// - Functions which return primitive values will be added to the e-class of that value.
122    /// - Nodes will have consistant IDs throughout execution of e-graph (used for animating changes in the visualization)
123    /// - Edges in the visualization will be well distributed (used for animating changes in the visualization)
124    ///   (Note that this will be changed in `<https://github.com/egraphs-good/egglog/pull/158>` so that edges point to exact nodes instead of looking up the e-class)
125    pub fn serialize(&self, config: SerializeConfig) -> SerializeOutput {
126        let mut truncated_functions = Vec::new();
127        let mut discarded_functions = Vec::new();
128        let max_calls_per_function = config.max_calls_per_function.unwrap_or(usize::MAX);
129        let max_functions = config.max_functions.unwrap_or(usize::MAX);
130        let mut all_calls: Vec<(
131            &Function,
132            Vec<Value>, // inputs
133            Value,      // output
134            bool,       // is subsumed
135            egraph_serialize::ClassId,
136            egraph_serialize::NodeId,
137        )> = Vec::new();
138        let mut functions_kept = 0usize;
139        let mut let_bindings = HashMap::default();
140        for (name, function) in self.functions.iter() {
141            if functions_kept >= max_functions {
142                discarded_functions.push(name.clone());
143                continue;
144            }
145            let mut rows = 0;
146            self.backend.for_each_while(function.backend_id, |row| {
147                if rows >= max_calls_per_function {
148                    truncated_functions.push(name.clone());
149                    return false;
150                }
151                let (out, inps) = row.vals.split_last().unwrap();
152                let class_id = self.value_to_class_id(&function.schema.output, *out);
153                if function.decl.let_binding {
154                    let_bindings
155                        .entry(class_id.clone())
156                        .or_insert_with(Vec::new)
157                        .push(name.clone());
158                } else {
159                    all_calls.push((
160                        function,
161                        inps.to_vec(),
162                        *out,
163                        row.subsumed,
164                        class_id,
165                        self.to_node_id(
166                            None,
167                            SerializedNode::Function {
168                                name: name.clone(),
169                                offset: rows,
170                            },
171                        ),
172                    ));
173                    rows += 1;
174                }
175                true
176            });
177            if rows != 0 {
178                functions_kept += 1;
179            }
180        }
181
182        // Then create a mapping from each canonical e-class ID to the set of node IDs in that e-class
183        // Note that this is only for e-classes, primitives have e-classes equal to their node ID
184        // This is for when we need to find what node ID to use for an edge to an e-class, we can rotate them evenly
185        // amoung all possible options.
186        let node_ids: NodeIDs = all_calls.iter().fold(
187            HashMap::default(),
188            |mut acc, (func, _input, _output, _subsumed, class_id, node_id)| {
189                if func.schema.output.is_eq_sort() {
190                    acc.entry(class_id.clone())
191                        .or_default()
192                        .push_back(node_id.clone());
193                }
194                acc
195            },
196        );
197
198        let mut serializer = Serializer {
199            node_ids,
200            result: egraph_serialize::EGraph::default(),
201            let_bindings,
202        };
203
204        for (func, input, output, subsumed, class_id, node_id) in all_calls {
205            self.serialize_value(&mut serializer, &func.schema.output, output, &class_id);
206
207            assert_eq!(input.len(), func.schema.input.len());
208            let children: Vec<_> = input
209                .iter()
210                .zip(&func.schema.input)
211                .map(|(&v, sort)| {
212                    self.serialize_value(&mut serializer, sort, v, &self.value_to_class_id(sort, v))
213                })
214                .collect();
215            serializer.result.nodes.insert(
216                node_id,
217                egraph_serialize::Node {
218                    op: func.decl.name.to_string(),
219                    eclass: class_id.clone(),
220                    cost: NotNan::new(func.decl.cost.unwrap_or(1) as f64).unwrap(),
221                    children,
222                    subsumed,
223                },
224            );
225        }
226
227        serializer.result.root_eclasses = config
228            .root_eclasses
229            .iter()
230            .map(|(sort, v)| self.value_to_class_id(sort, *v))
231            .collect();
232        SerializeOutput {
233            egraph: serializer.result,
234            truncated_functions,
235            discarded_functions,
236        }
237    }
238
239    /// Gets the serialized class ID for a value.
240    pub fn value_to_class_id(&self, sort: &ArcSort, value: Value) -> egraph_serialize::ClassId {
241        // Canonicalize the value first so that we always use the canonical e-class ID
242        let value = self
243            .backend
244            .get_canon_repr(value, sort.column_ty(&self.backend));
245        assert!(
246            !sort.name().to_string().contains('-'),
247            "Tag cannot contain '-' when serializing"
248        );
249        use numeric_id::NumericId;
250        format!("{}-{}", sort.name(), value.rep()).into()
251    }
252
253    /// Gets the value for a serialized class ID.
254    pub fn class_id_to_value(&self, eclass_id: &egraph_serialize::ClassId) -> Value {
255        let s = eclass_id.to_string();
256        let (_tag, bits) = s.split_once('-').unwrap();
257        Value::new_const(bits.parse().unwrap())
258    }
259
260    /// Gets the serialized node ID for the primitive, omitted, or function value.
261    pub fn to_node_id(
262        &self,
263        sort: Option<&ArcSort>,
264        node: SerializedNode,
265    ) -> egraph_serialize::NodeId {
266        match node {
267            SerializedNode::Function { name, offset } => {
268                assert!(sort.is_none());
269                format!("function-{}-{}", offset, name).into()
270            }
271            SerializedNode::Primitive(value) => {
272                format!("primitive-{}", self.value_to_class_id(sort.unwrap(), value)).into()
273            }
274            SerializedNode::Dummy(value) => {
275                format!("dummy-{}", self.value_to_class_id(sort.unwrap(), value)).into()
276            }
277            SerializedNode::Split(node) => format!("split-{}", self.to_node_id(sort, *node)).into(),
278        }
279    }
280
281    /// Gets the serialized node for the node ID.
282    pub fn from_node_id(&self, node_id: &egraph_serialize::NodeId) -> SerializedNode {
283        let node_id = node_id.to_string();
284        let (tag, rest) = node_id.split_once('-').unwrap();
285        match tag {
286            "function" => {
287                let (offset, name) = rest.split_once('-').unwrap();
288                SerializedNode::Function {
289                    name: name.into(),
290                    offset: offset.parse().unwrap(),
291                }
292            }
293            "primitive" => {
294                let class_id: egraph_serialize::ClassId = rest.into();
295                SerializedNode::Primitive(self.class_id_to_value(&class_id))
296            }
297            "dummy" => {
298                let class_id: egraph_serialize::ClassId = rest.into();
299                SerializedNode::Dummy(self.class_id_to_value(&class_id))
300            }
301            "split" => {
302                let (_offset, rest) = rest.split_once('-').unwrap();
303                let node_id: egraph_serialize::NodeId = rest.into();
304                SerializedNode::Split(Box::new(self.from_node_id(&node_id)))
305            }
306            _ => std::panic::panic_any(format!("Unknown node ID: {}-{}", tag, rest)),
307        }
308    }
309
310    /// Serialize the value and return the node ID
311    /// If this is a primitive value, we will add the node to the data, but if it is an eclass, we will not
312    /// When this is called on the output of a node, we only use the e-class to know which e-class its a part of
313    /// When this is called on an input of a node, we only use the node ID to know which node to point to.
314    fn serialize_value(
315        &self,
316        serializer: &mut Serializer,
317        sort: &ArcSort,
318        value: Value,
319        class_id: &egraph_serialize::ClassId,
320    ) -> egraph_serialize::NodeId {
321        let node_id = if sort.is_eq_sort() {
322            let node_ids = serializer
323                .node_ids
324                .entry(class_id.clone())
325                .or_insert_with(|| {
326                    // If we don't find node IDs for this class, it means that all nodes for it were omitted due to size constraints
327                    // In this case, add a dummy node in this class to represent the missing nodes
328                    let node_id = self.to_node_id(Some(sort), SerializedNode::Dummy(value));
329                    serializer.result.nodes.insert(
330                        node_id.clone(),
331                        egraph_serialize::Node {
332                            op: "[...]".to_string(),
333                            eclass: class_id.clone(),
334                            cost: NotNan::new(f64::INFINITY).unwrap(),
335                            children: vec![],
336                            subsumed: false,
337                        },
338                    );
339                    VecDeque::from(vec![node_id])
340                });
341            node_ids.rotate_left(1);
342            node_ids.front().unwrap().clone()
343        } else {
344            let node_id = self.to_node_id(Some(sort), SerializedNode::Primitive(value));
345            // Add node for value
346            {
347                let container_values = self.backend.container_values();
348                // Children will be empty unless this is a container sort
349                let children: Vec<egraph_serialize::NodeId> = sort
350                    .inner_values(container_values, value)
351                    .into_iter()
352                    .map(|(s, v)| {
353                        self.serialize_value(serializer, &s, v, &self.value_to_class_id(&s, v))
354                    })
355                    .collect();
356                // If this is a container sort, use the name, otherwise use the value
357                let op = if sort.is_container_sort() {
358                    sort.serialized_name(container_values, value)
359                } else {
360                    let primitive_id = self
361                        .backend
362                        .base_values()
363                        .get_ty_by_id(sort.value_type().unwrap());
364                    let formatted_val = BaseValuePrinter {
365                        base: self.backend.base_values(),
366                        ty: primitive_id,
367                        val: value,
368                    };
369                    format!("{:?}", formatted_val)
370                };
371                serializer.result.nodes.insert(
372                    node_id.clone(),
373                    egraph_serialize::Node {
374                        op,
375                        eclass: class_id.clone(),
376                        cost: NotNan::new(1.0).unwrap(),
377                        children,
378                        subsumed: false,
379                    },
380                );
381            };
382            node_id
383        };
384        #[allow(clippy::disallowed_types)]
385        let mut extra = std::collections::HashMap::default();
386        if let Some(let_bindings) = serializer.let_bindings.get(class_id) {
387            if !let_bindings.is_empty() {
388                extra.insert("let".to_string(), let_bindings.join(", "));
389            }
390        }
391        serializer.result.class_data.insert(
392            class_id.clone(),
393            egraph_serialize::ClassData {
394                typ: Some(sort.name().to_string()),
395                extra,
396            },
397        );
398        node_id
399    }
400}
401
402type NodeIDs = HashMap<egraph_serialize::ClassId, VecDeque<egraph_serialize::NodeId>>;