1use crate::{util::HashMap, *};
2use core_relations::BaseValuePrinter;
3use ordered_float::NotNan;
4use std::collections::VecDeque;
5
6pub struct SerializeConfig {
7 pub max_functions: Option<usize>,
9 pub max_calls_per_function: Option<usize>,
11 pub include_temporary_functions: bool,
13 pub root_eclasses: Vec<(ArcSort, Value)>,
15}
16
17pub struct SerializeOutput {
19 pub egraph: egraph_serialize::EGraph,
21 pub truncated_functions: Vec<String>,
23 pub discarded_functions: Vec<String>,
25}
26
27impl SerializeOutput {
28 pub fn is_complete(&self) -> bool {
30 self.truncated_functions.is_empty() && self.discarded_functions.is_empty()
31 }
32 pub fn omitted_description(&self) -> String {
34 let mut msg = String::new();
35 if !self.discarded_functions.is_empty() {
36 msg.push_str(&format!(
37 "Omitted: {}\n",
38 self.discarded_functions.join(", ")
39 ));
40 }
41 if !self.truncated_functions.is_empty() {
42 msg.push_str(&format!(
43 "Truncated: {}\n",
44 self.truncated_functions.join(", ")
45 ));
46 }
47 msg
48 }
49}
50
51#[allow(dead_code)]
52struct Serializer {
53 node_ids: NodeIDs,
54 result: egraph_serialize::EGraph,
55 let_bindings: HashMap<egraph_serialize::ClassId, Vec<String>>,
56}
57
58impl Default for SerializeConfig {
60 fn default() -> Self {
61 SerializeConfig {
62 max_functions: None,
63 max_calls_per_function: None,
64 include_temporary_functions: false,
65 root_eclasses: vec![],
66 }
67 }
68}
69
70#[derive(PartialEq, Debug, Clone)]
72pub enum SerializedNode {
73 Function {
75 name: String,
77 offset: usize,
80 },
81 Primitive(Value),
83 Dummy(Value),
85 Split(Box<SerializedNode>),
87}
88
89impl SerializedNode {
90 pub fn is_primitive(&self) -> bool {
92 match self {
93 SerializedNode::Primitive(_) => true,
94 SerializedNode::Split(node) => node.is_primitive(),
95 _ => false,
96 }
97 }
98}
99
100impl EGraph {
101 pub fn serialize(&self, config: SerializeConfig) -> SerializeOutput {
126 let mut truncated_functions = Vec::new();
127 let mut discarded_functions = Vec::new();
128 let max_calls_per_function = config.max_calls_per_function.unwrap_or(usize::MAX);
129 let max_functions = config.max_functions.unwrap_or(usize::MAX);
130 let mut all_calls: Vec<(
131 &Function,
132 Vec<Value>, Value, bool, egraph_serialize::ClassId,
136 egraph_serialize::NodeId,
137 )> = Vec::new();
138 let mut functions_kept = 0usize;
139 let mut let_bindings = HashMap::default();
140 for (name, function) in self.functions.iter() {
141 if functions_kept >= max_functions {
142 discarded_functions.push(name.clone());
143 continue;
144 }
145 let mut rows = 0;
146 self.backend.for_each_while(function.backend_id, |row| {
147 if rows >= max_calls_per_function {
148 truncated_functions.push(name.clone());
149 return false;
150 }
151 let (out, inps) = row.vals.split_last().unwrap();
152 let class_id = self.value_to_class_id(&function.schema.output, *out);
153 if function.decl.let_binding {
154 let_bindings
155 .entry(class_id.clone())
156 .or_insert_with(Vec::new)
157 .push(name.clone());
158 } else {
159 all_calls.push((
160 function,
161 inps.to_vec(),
162 *out,
163 row.subsumed,
164 class_id,
165 self.to_node_id(
166 None,
167 SerializedNode::Function {
168 name: name.clone(),
169 offset: rows,
170 },
171 ),
172 ));
173 rows += 1;
174 }
175 true
176 });
177 if rows != 0 {
178 functions_kept += 1;
179 }
180 }
181
182 let node_ids: NodeIDs = all_calls.iter().fold(
187 HashMap::default(),
188 |mut acc, (func, _input, _output, _subsumed, class_id, node_id)| {
189 if func.schema.output.is_eq_sort() {
190 acc.entry(class_id.clone())
191 .or_default()
192 .push_back(node_id.clone());
193 }
194 acc
195 },
196 );
197
198 let mut serializer = Serializer {
199 node_ids,
200 result: egraph_serialize::EGraph::default(),
201 let_bindings,
202 };
203
204 for (func, input, output, subsumed, class_id, node_id) in all_calls {
205 self.serialize_value(&mut serializer, &func.schema.output, output, &class_id);
206
207 assert_eq!(input.len(), func.schema.input.len());
208 let children: Vec<_> = input
209 .iter()
210 .zip(&func.schema.input)
211 .map(|(&v, sort)| {
212 self.serialize_value(&mut serializer, sort, v, &self.value_to_class_id(sort, v))
213 })
214 .collect();
215 serializer.result.nodes.insert(
216 node_id,
217 egraph_serialize::Node {
218 op: func.decl.name.to_string(),
219 eclass: class_id.clone(),
220 cost: NotNan::new(func.decl.cost.unwrap_or(1) as f64).unwrap(),
221 children,
222 subsumed,
223 },
224 );
225 }
226
227 serializer.result.root_eclasses = config
228 .root_eclasses
229 .iter()
230 .map(|(sort, v)| self.value_to_class_id(sort, *v))
231 .collect();
232 SerializeOutput {
233 egraph: serializer.result,
234 truncated_functions,
235 discarded_functions,
236 }
237 }
238
239 pub fn value_to_class_id(&self, sort: &ArcSort, value: Value) -> egraph_serialize::ClassId {
241 let value = self
243 .backend
244 .get_canon_repr(value, sort.column_ty(&self.backend));
245 assert!(
246 !sort.name().to_string().contains('-'),
247 "Tag cannot contain '-' when serializing"
248 );
249 use numeric_id::NumericId;
250 format!("{}-{}", sort.name(), value.rep()).into()
251 }
252
253 pub fn class_id_to_value(&self, eclass_id: &egraph_serialize::ClassId) -> Value {
255 let s = eclass_id.to_string();
256 let (_tag, bits) = s.split_once('-').unwrap();
257 Value::new_const(bits.parse().unwrap())
258 }
259
260 pub fn to_node_id(
262 &self,
263 sort: Option<&ArcSort>,
264 node: SerializedNode,
265 ) -> egraph_serialize::NodeId {
266 match node {
267 SerializedNode::Function { name, offset } => {
268 assert!(sort.is_none());
269 format!("function-{}-{}", offset, name).into()
270 }
271 SerializedNode::Primitive(value) => {
272 format!("primitive-{}", self.value_to_class_id(sort.unwrap(), value)).into()
273 }
274 SerializedNode::Dummy(value) => {
275 format!("dummy-{}", self.value_to_class_id(sort.unwrap(), value)).into()
276 }
277 SerializedNode::Split(node) => format!("split-{}", self.to_node_id(sort, *node)).into(),
278 }
279 }
280
281 pub fn from_node_id(&self, node_id: &egraph_serialize::NodeId) -> SerializedNode {
283 let node_id = node_id.to_string();
284 let (tag, rest) = node_id.split_once('-').unwrap();
285 match tag {
286 "function" => {
287 let (offset, name) = rest.split_once('-').unwrap();
288 SerializedNode::Function {
289 name: name.into(),
290 offset: offset.parse().unwrap(),
291 }
292 }
293 "primitive" => {
294 let class_id: egraph_serialize::ClassId = rest.into();
295 SerializedNode::Primitive(self.class_id_to_value(&class_id))
296 }
297 "dummy" => {
298 let class_id: egraph_serialize::ClassId = rest.into();
299 SerializedNode::Dummy(self.class_id_to_value(&class_id))
300 }
301 "split" => {
302 let (_offset, rest) = rest.split_once('-').unwrap();
303 let node_id: egraph_serialize::NodeId = rest.into();
304 SerializedNode::Split(Box::new(self.from_node_id(&node_id)))
305 }
306 _ => std::panic::panic_any(format!("Unknown node ID: {}-{}", tag, rest)),
307 }
308 }
309
310 fn serialize_value(
315 &self,
316 serializer: &mut Serializer,
317 sort: &ArcSort,
318 value: Value,
319 class_id: &egraph_serialize::ClassId,
320 ) -> egraph_serialize::NodeId {
321 let node_id = if sort.is_eq_sort() {
322 let node_ids = serializer
323 .node_ids
324 .entry(class_id.clone())
325 .or_insert_with(|| {
326 let node_id = self.to_node_id(Some(sort), SerializedNode::Dummy(value));
329 serializer.result.nodes.insert(
330 node_id.clone(),
331 egraph_serialize::Node {
332 op: "[...]".to_string(),
333 eclass: class_id.clone(),
334 cost: NotNan::new(f64::INFINITY).unwrap(),
335 children: vec![],
336 subsumed: false,
337 },
338 );
339 VecDeque::from(vec![node_id])
340 });
341 node_ids.rotate_left(1);
342 node_ids.front().unwrap().clone()
343 } else {
344 let node_id = self.to_node_id(Some(sort), SerializedNode::Primitive(value));
345 {
347 let container_values = self.backend.container_values();
348 let children: Vec<egraph_serialize::NodeId> = sort
350 .inner_values(container_values, value)
351 .into_iter()
352 .map(|(s, v)| {
353 self.serialize_value(serializer, &s, v, &self.value_to_class_id(&s, v))
354 })
355 .collect();
356 let op = if sort.is_container_sort() {
358 sort.serialized_name(container_values, value)
359 } else {
360 let primitive_id = self
361 .backend
362 .base_values()
363 .get_ty_by_id(sort.value_type().unwrap());
364 let formatted_val = BaseValuePrinter {
365 base: self.backend.base_values(),
366 ty: primitive_id,
367 val: value,
368 };
369 format!("{:?}", formatted_val)
370 };
371 serializer.result.nodes.insert(
372 node_id.clone(),
373 egraph_serialize::Node {
374 op,
375 eclass: class_id.clone(),
376 cost: NotNan::new(1.0).unwrap(),
377 children,
378 subsumed: false,
379 },
380 );
381 };
382 node_id
383 };
384 #[allow(clippy::disallowed_types)]
385 let mut extra = std::collections::HashMap::default();
386 if let Some(let_bindings) = serializer.let_bindings.get(class_id) {
387 if !let_bindings.is_empty() {
388 extra.insert("let".to_string(), let_bindings.join(", "));
389 }
390 }
391 serializer.result.class_data.insert(
392 class_id.clone(),
393 egraph_serialize::ClassData {
394 typ: Some(sort.name().to_string()),
395 extra,
396 },
397 );
398 node_id
399 }
400}
401
402type NodeIDs = HashMap<egraph_serialize::ClassId, VecDeque<egraph_serialize::NodeId>>;