1pub mod ast;
15#[cfg(feature = "bin")]
16mod cli;
17pub mod constraint;
18mod core;
19pub mod extract;
20pub mod prelude;
21pub mod scheduler;
22mod serialize;
23pub mod sort;
24mod termdag;
25mod typechecking;
26pub mod util;
27
28extern crate self as egglog;
31use ast::*;
32#[cfg(feature = "bin")]
33pub use cli::*;
34use constraint::{Constraint, Problem, SimpleTypeConstraint, TypeConstraint};
35use core::{AtomTerm, ResolvedAtomTerm, ResolvedCall};
36pub use core_relations::{BaseValue, ContainerValue, ExecutionState, Value};
37use core_relations::{ExternalFunctionId, make_external_func};
38use csv::Writer;
39pub use egglog_add_primitive::add_primitive;
40use egglog_ast::generic_ast::{Change, GenericExpr, Literal};
41use egglog_ast::span::Span;
42use egglog_ast::util::ListDisplay;
43pub use egglog_bridge::FunctionRow;
44use egglog_bridge::{ColumnTy, QueryEntry};
45use egglog_core_relations as core_relations;
46use egglog_numeric_id as numeric_id;
47use egglog_reports::{ReportLevel, RunReport};
48use extract::{CostModel, DefaultCost, Extractor, TreeAdditiveCostModel};
49use indexmap::map::Entry;
50use log::{Level, log_enabled};
51use numeric_id::DenseIdMap;
52use prelude::*;
53use scheduler::{SchedulerId, SchedulerRecord};
54pub use serialize::{SerializeConfig, SerializeOutput, SerializedNode};
55use sort::*;
56use std::fmt::{Debug, Display, Formatter};
57use std::fs::File;
58use std::hash::Hash;
59use std::io::{Read, Write as _};
60use std::iter::once;
61use std::ops::Deref;
62use std::path::PathBuf;
63use std::str::FromStr;
64use std::sync::Arc;
65pub use termdag::{Term, TermDag, TermId};
66use thiserror::Error;
67pub use typechecking::TypeError;
68use typechecking::TypeInfo;
69use util::*;
70
71use crate::core::{GenericActionsExt, ResolvedRuleExt};
72
73pub type ArcSort = Arc<dyn Sort>;
74
75pub trait Primitive {
79 fn name(&self) -> &str;
81
82 fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint>;
85
86 fn apply(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value>;
90}
91
92pub trait UserDefinedCommandOutput: Debug + std::fmt::Display + Send + Sync {}
94impl<T> UserDefinedCommandOutput for T where T: Debug + std::fmt::Display + Send + Sync {}
95
96#[derive(Clone, Debug)]
98pub enum CommandOutput {
99 PrintFunctionSize(usize),
101 PrintAllFunctionsSize(Vec<(String, usize)>),
103 ExtractBest(TermDag, DefaultCost, Term),
105 ExtractVariants(TermDag, Vec<Term>),
107 OverallStatistics(RunReport),
109 PrintFunction(Function, TermDag, Vec<(Term, Term)>, PrintFunctionMode),
111 RunSchedule(RunReport),
113 UserDefined(Arc<dyn UserDefinedCommandOutput>),
115}
116
117impl std::fmt::Display for CommandOutput {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 match self {
121 CommandOutput::PrintFunctionSize(size) => writeln!(f, "{}", size),
122 CommandOutput::PrintAllFunctionsSize(names_and_sizes) => {
123 for name in names_and_sizes {
124 writeln!(f, "{}: {}", name.0, name.1)?;
125 }
126 Ok(())
127 }
128 CommandOutput::ExtractBest(termdag, _cost, term) => {
129 writeln!(f, "{}", termdag.to_string(term))
130 }
131 CommandOutput::ExtractVariants(termdag, terms) => {
132 writeln!(f, "(")?;
133 for expr in terms {
134 writeln!(f, " {}", termdag.to_string(expr))?;
135 }
136 writeln!(f, ")")
137 }
138 CommandOutput::OverallStatistics(run_report) => {
139 write!(f, "Overall statistics:\n{}", run_report)
140 }
141 CommandOutput::PrintFunction(function, termdag, terms_and_outputs, mode) => {
142 let out_is_unit = function.schema.output.name() == UnitSort.name();
143 if *mode == PrintFunctionMode::CSV {
144 let mut wtr = Writer::from_writer(vec![]);
145 for (term, output) in terms_and_outputs {
146 match term {
147 Term::App(name, children) => {
148 let mut values = vec![name.clone()];
149 for child_id in children {
150 values.push(termdag.to_string(termdag.get(*child_id)));
151 }
152
153 if !out_is_unit {
154 values.push(termdag.to_string(output));
155 }
156 wtr.write_record(&values).map_err(|_| std::fmt::Error)?;
157 }
158 _ => panic!("Expect function_to_dag to return a list of apps."),
159 }
160 }
161 let csv_bytes = wtr.into_inner().map_err(|_| std::fmt::Error)?;
162 f.write_str(&String::from_utf8(csv_bytes).map_err(|_| std::fmt::Error)?)
163 } else {
164 writeln!(f, "(")?;
165 for (term, output) in terms_and_outputs.iter() {
166 write!(f, " {}", termdag.to_string(term))?;
167 if !out_is_unit {
168 write!(f, " -> {}", termdag.to_string(output))?;
169 }
170 writeln!(f)?;
171 }
172 writeln!(f, ")")
173 }
174 }
175 CommandOutput::RunSchedule(_report) => Ok(()),
176 CommandOutput::UserDefined(output) => {
177 write!(f, "{}", *output)
178 }
179 }
180 }
181}
182
183#[derive(Clone)]
197pub struct EGraph {
198 backend: egglog_bridge::EGraph,
199 pub parser: Parser,
200 names: check_shadowing::Names,
201 pushed_egraph: Option<Box<Self>>,
204 functions: IndexMap<String, Function>,
205 rulesets: IndexMap<String, Ruleset>,
206 pub fact_directory: Option<PathBuf>,
207 pub seminaive: bool,
208 type_info: TypeInfo,
209 overall_run_report: RunReport,
211 schedulers: DenseIdMap<SchedulerId, SchedulerRecord>,
212 commands: IndexMap<String, Arc<dyn UserDefinedCommand>>,
213}
214
215pub trait UserDefinedCommand: Send + Sync {
221 fn update(&self, egraph: &mut EGraph, args: &[Expr]) -> Result<Option<CommandOutput>, Error>;
223}
224
225#[derive(Clone)]
230pub struct Function {
231 decl: ResolvedFunctionDecl,
232 schema: ResolvedSchema,
233 can_subsume: bool,
234 backend_id: egglog_bridge::FunctionId,
235}
236
237impl Function {
238 pub fn name(&self) -> &str {
240 &self.decl.name
241 }
242
243 pub fn schema(&self) -> &ResolvedSchema {
245 &self.schema
246 }
247
248 pub fn can_subsume(&self) -> bool {
250 self.can_subsume
251 }
252}
253
254#[derive(Clone, Debug)]
255pub struct ResolvedSchema {
256 pub input: Vec<ArcSort>,
257 pub output: ArcSort,
258}
259
260impl ResolvedSchema {
261 pub fn get_by_pos(&self, index: usize) -> Option<&ArcSort> {
263 if self.input.len() == index {
264 Some(&self.output)
265 } else {
266 self.input.get(index)
267 }
268 }
269}
270
271impl Debug for Function {
272 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
273 f.debug_struct("Function")
274 .field("decl", &self.decl)
275 .field("schema", &self.schema)
276 .finish()
277 }
278}
279
280impl Default for EGraph {
281 fn default() -> Self {
282 let mut eg = Self {
283 backend: Default::default(),
284 parser: Default::default(),
285 names: Default::default(),
286 pushed_egraph: Default::default(),
287 functions: Default::default(),
288 rulesets: Default::default(),
289 fact_directory: None,
290 seminaive: true,
291 overall_run_report: Default::default(),
292 type_info: Default::default(),
293 schedulers: Default::default(),
294 commands: Default::default(),
295 };
296
297 add_base_sort(&mut eg, UnitSort, span!()).unwrap();
298 add_base_sort(&mut eg, StringSort, span!()).unwrap();
299 add_base_sort(&mut eg, BoolSort, span!()).unwrap();
300 add_base_sort(&mut eg, I64Sort, span!()).unwrap();
301 add_base_sort(&mut eg, F64Sort, span!()).unwrap();
302 add_base_sort(&mut eg, BigIntSort, span!()).unwrap();
303 add_base_sort(&mut eg, BigRatSort, span!()).unwrap();
304 eg.type_info.add_presort::<MapSort>(span!()).unwrap();
305 eg.type_info.add_presort::<SetSort>(span!()).unwrap();
306 eg.type_info.add_presort::<VecSort>(span!()).unwrap();
307 eg.type_info.add_presort::<FunctionSort>(span!()).unwrap();
308 eg.type_info.add_presort::<MultiSetSort>(span!()).unwrap();
309
310 add_primitive!(&mut eg, "!=" = |a: #, b: #| -?> () {
311 (a != b).then_some(())
312 });
313 add_primitive!(&mut eg, "value-eq" = |a: #, b: #| -?> () {
314 (a == b).then_some(())
315 });
316 add_primitive!(&mut eg, "ordering-min" = |a: #, b: #| -> # {
317 if a < b { a } else { b }
318 });
319 add_primitive!(&mut eg, "ordering-max" = |a: #, b: #| -> # {
320 if a > b { a } else { b }
321 });
322
323 eg.rulesets
324 .insert("".into(), Ruleset::Rules(Default::default()));
325
326 eg
327 }
328}
329
330#[derive(Debug, Error)]
331#[error("Not found: {0}")]
332pub struct NotFoundError(String);
333
334impl EGraph {
335 pub fn add_command(
337 &mut self,
338 name: String,
339 command: Arc<dyn UserDefinedCommand>,
340 ) -> Result<(), Error> {
341 if self.commands.contains_key(&name)
342 || self.functions.contains_key(&name)
343 || self.type_info.get_prims(&name).is_some()
344 {
345 return Err(Error::CommandAlreadyExists(name, span!()));
346 }
347 self.commands.insert(name.clone(), command);
348 self.parser.add_user_defined(name)?;
349 Ok(())
350 }
351
352 pub fn push(&mut self) {
356 let prev_prev: Option<Box<Self>> = self.pushed_egraph.take();
357 let mut prev = self.clone();
358 prev.pushed_egraph = prev_prev;
359 self.pushed_egraph = Some(Box::new(prev));
360 }
361
362 pub fn pop(&mut self) -> Result<(), Error> {
367 match self.pushed_egraph.take() {
368 Some(e) => {
369 let overall_run_report = self.overall_run_report.clone();
371 *self = *e;
372 self.overall_run_report = overall_run_report;
373 Ok(())
374 }
375 None => Err(Error::Pop(span!())),
376 }
377 }
378
379 fn translate_expr_to_mergefn(
380 &self,
381 expr: &ResolvedExpr,
382 ) -> Result<egglog_bridge::MergeFn, Error> {
383 match expr {
384 GenericExpr::Lit(_, literal) => {
385 let val = literal_to_value(&self.backend, literal);
386 Ok(egglog_bridge::MergeFn::Const(val))
387 }
388 GenericExpr::Var(span, resolved_var) => match resolved_var.name.as_str() {
389 "old" => Ok(egglog_bridge::MergeFn::Old),
390 "new" => Ok(egglog_bridge::MergeFn::New),
391 _ => Err(TypeError::Unbound(resolved_var.name.clone(), span.clone()).into()),
393 },
394 GenericExpr::Call(_, ResolvedCall::Func(f), args) => {
395 let translated_args = args
396 .iter()
397 .map(|arg| self.translate_expr_to_mergefn(arg))
398 .collect::<Result<Vec<_>, _>>()?;
399 Ok(egglog_bridge::MergeFn::Function(
400 self.functions[&f.name].backend_id,
401 translated_args,
402 ))
403 }
404 GenericExpr::Call(_, ResolvedCall::Primitive(p), args) => {
405 let translated_args = args
406 .iter()
407 .map(|arg| self.translate_expr_to_mergefn(arg))
408 .collect::<Result<Vec<_>, _>>()?;
409 Ok(egglog_bridge::MergeFn::Primitive(
410 p.primitive.1,
411 translated_args,
412 ))
413 }
414 }
415 }
416
417 fn declare_function(&mut self, decl: &ResolvedFunctionDecl) -> Result<(), Error> {
418 let get_sort = |name: &String| match self.type_info.get_sort_by_name(name) {
419 Some(sort) => Ok(sort.clone()),
420 None => Err(Error::TypeError(TypeError::UndefinedSort(
421 name.to_owned(),
422 decl.span.clone(),
423 ))),
424 };
425
426 let input = decl
427 .schema
428 .input
429 .iter()
430 .map(get_sort)
431 .collect::<Result<Vec<_>, _>>()?;
432 let output = get_sort(&decl.schema.output)?;
433
434 let can_subsume = match decl.subtype {
435 FunctionSubtype::Constructor => true,
436 FunctionSubtype::Relation => true,
437 FunctionSubtype::Custom => false,
438 };
439
440 use egglog_bridge::{DefaultVal, MergeFn};
441 let backend_id = self.backend.add_table(egglog_bridge::FunctionConfig {
442 schema: input
443 .iter()
444 .chain([&output])
445 .map(|sort| sort.column_ty(&self.backend))
446 .collect(),
447 default: match decl.subtype {
448 FunctionSubtype::Constructor => DefaultVal::FreshId,
449 FunctionSubtype::Custom => DefaultVal::Fail,
450 FunctionSubtype::Relation => DefaultVal::Const(self.backend.base_values().get(())),
451 },
452 merge: match decl.subtype {
453 FunctionSubtype::Constructor => MergeFn::UnionId,
454 FunctionSubtype::Relation => MergeFn::AssertEq,
455 FunctionSubtype::Custom => match &decl.merge {
456 None => MergeFn::AssertEq,
457 Some(expr) => self.translate_expr_to_mergefn(expr)?,
458 },
459 },
460 name: decl.name.to_string(),
461 can_subsume,
462 });
463
464 let function = Function {
465 decl: decl.clone(),
466 schema: ResolvedSchema { input, output },
467 can_subsume,
468 backend_id,
469 };
470
471 let old = self.functions.insert(decl.name.clone(), function);
472 if old.is_some() {
473 panic!(
474 "Typechecking should have caught function already bound: {}",
475 decl.name
476 );
477 }
478
479 Ok(())
480 }
481
482 pub fn function_to_dag(
487 &self,
488 sym: &str,
489 n: usize,
490 include_output: bool,
491 ) -> Result<(Vec<Term>, Option<Vec<Term>>, TermDag), Error> {
492 let func = self
493 .functions
494 .get(sym)
495 .ok_or(TypeError::UnboundFunction(sym.to_owned(), span!()))?;
496 let mut rootsorts = func.schema.input.clone();
497 if include_output {
498 rootsorts.push(func.schema.output.clone());
499 }
500 let extractor = Extractor::compute_costs_from_rootsorts(
501 Some(rootsorts),
502 self,
503 TreeAdditiveCostModel::default(),
504 );
505
506 let mut termdag = TermDag::default();
507 let mut inputs: Vec<Term> = Vec::new();
508 let mut output: Option<Vec<Term>> = if include_output {
509 Some(Vec::new())
510 } else {
511 None
512 };
513
514 let extract_row = |row: egglog_bridge::FunctionRow| {
515 if inputs.len() < n {
516 let mut children: Vec<Term> = Vec::new();
518 for (value, sort) in row.vals.iter().zip(&func.schema.input) {
519 let (_, term) = extractor
520 .extract_best_with_sort(self, &mut termdag, *value, sort.clone())
521 .unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
522 children.push(term);
523 }
524 inputs.push(termdag.app(sym.to_owned(), children));
525 if include_output {
526 let value = row.vals[func.schema.input.len()];
527 let sort = &func.schema.output;
528 let (_, term) = extractor
529 .extract_best_with_sort(self, &mut termdag, value, sort.clone())
530 .unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
531 output.as_mut().unwrap().push(term);
532 }
533 true
534 } else {
535 false
536 }
537 };
538
539 self.backend.for_each_while(func.backend_id, extract_row);
540
541 Ok((inputs, output, termdag))
542 }
543
544 pub fn print_function(
547 &mut self,
548 sym: &str,
549 n: Option<usize>,
550 file: Option<File>,
551 mode: PrintFunctionMode,
552 ) -> Result<Option<CommandOutput>, Error> {
553 let n = match n {
554 Some(n) => {
555 log::info!("Printing up to {n} tuples of function {sym} as {mode}");
556 n
557 }
558 None => {
559 log::info!("Printing all tuples of function {sym} as {mode}");
560 usize::MAX
561 }
562 };
563
564 let (terms, outputs, termdag) = self.function_to_dag(sym, n, true)?;
565 let f = self
566 .functions
567 .get(sym)
568 .unwrap();
570 let terms_and_outputs: Vec<_> = terms.into_iter().zip(outputs.unwrap()).collect();
571 let output = CommandOutput::PrintFunction(f.clone(), termdag, terms_and_outputs, mode);
572 match file {
573 Some(mut file) => {
574 log::info!("Writing output to file");
575 file.write_all(output.to_string().as_bytes())
576 .expect("Error writing to file");
577 Ok(None)
578 }
579 None => Ok(Some(output)),
580 }
581 }
582
583 pub fn print_size(&mut self, sym: Option<&str>) -> Result<CommandOutput, Error> {
586 if let Some(sym) = sym {
587 let f = self
588 .functions
589 .get(sym)
590 .ok_or(TypeError::UnboundFunction(sym.to_owned(), span!()))?;
591 let size = self.backend.table_size(f.backend_id);
592 log::info!("Function {} has size {}", sym, size);
593 Ok(CommandOutput::PrintFunctionSize(size))
594 } else {
595 let mut lens = self
597 .functions
598 .iter()
599 .map(|(sym, f)| (sym.clone(), self.backend.table_size(f.backend_id)))
600 .collect::<Vec<_>>();
601
602 lens.sort_by_key(|(name, _)| name.clone());
604 if log_enabled!(Level::Info) {
605 for (sym, len) in &lens {
606 log::info!("Function {} has size {}", sym, len);
607 }
608 }
609 Ok(CommandOutput::PrintAllFunctionsSize(lens))
610 }
611 }
612
613 fn run_schedule(&mut self, sched: &ResolvedSchedule) -> Result<RunReport, Error> {
615 match sched {
616 ResolvedSchedule::Run(span, config) => self.run_rules(span, config),
617 ResolvedSchedule::Repeat(_span, limit, sched) => {
618 let mut report = RunReport::default();
619 for _i in 0..*limit {
620 let rec = self.run_schedule(sched)?;
621 let updated = rec.updated;
622 report.union(rec);
623 if !updated {
624 break;
625 }
626 }
627 Ok(report)
628 }
629 ResolvedSchedule::Saturate(_span, sched) => {
630 let mut report = RunReport::default();
631 loop {
632 let rec = self.run_schedule(sched)?;
633 let updated = rec.updated;
634 report.union(rec);
635 if !updated {
636 break;
637 }
638 }
639 Ok(report)
640 }
641 ResolvedSchedule::Sequence(_span, scheds) => {
642 let mut report = RunReport::default();
643 for sched in scheds {
644 report.union(self.run_schedule(sched)?);
645 }
646 Ok(report)
647 }
648 }
649 }
650
651 pub fn extract_value(
654 &self,
655 sort: &ArcSort,
656 value: Value,
657 ) -> Result<(TermDag, Term, DefaultCost), Error> {
658 self.extract_value_with_cost_model(sort, value, TreeAdditiveCostModel::default())
659 }
660
661 pub fn extract_value_with_cost_model<CM: CostModel<DefaultCost> + 'static>(
665 &self,
666 sort: &ArcSort,
667 value: Value,
668 cost_model: CM,
669 ) -> Result<(TermDag, Term, DefaultCost), Error> {
670 let extractor =
671 Extractor::compute_costs_from_rootsorts(Some(vec![sort.clone()]), self, cost_model);
672 let mut termdag = TermDag::default();
673 let (cost, term) = extractor.extract_best(self, &mut termdag, value).unwrap();
674 Ok((termdag, term, cost))
675 }
676
677 pub fn extract_value_to_string(
680 &self,
681 sort: &ArcSort,
682 value: Value,
683 ) -> Result<(String, DefaultCost), Error> {
684 let (termdag, term, cost) = self.extract_value(sort, value)?;
685 Ok((termdag.to_string(&term), cost))
686 }
687
688 fn run_rules(&mut self, span: &Span, config: &ResolvedRunConfig) -> Result<RunReport, Error> {
689 let mut report: RunReport = Default::default();
690
691 let GenericRunConfig { ruleset, until } = config;
692
693 if let Some(facts) = until {
694 if self.check_facts(span, facts).is_ok() {
695 log::info!(
696 "Breaking early because of facts:\n {}!",
697 ListDisplay(facts, "\n")
698 );
699 return Ok(report);
700 }
701 }
702
703 let subreport = self.step_rules(ruleset)?;
704 report.union(subreport);
705
706 if log_enabled!(Level::Debug) {
707 log::debug!("database size: {}", self.num_tuples());
708 }
709
710 Ok(report)
711 }
712
713 pub fn step_rules(&mut self, ruleset: &str) -> Result<RunReport, Error> {
720 fn collect_rule_ids(
721 ruleset: &str,
722 rulesets: &IndexMap<String, Ruleset>,
723 ids: &mut Vec<egglog_bridge::RuleId>,
724 ) {
725 match &rulesets[ruleset] {
726 Ruleset::Rules(rules) => {
727 for (_, id) in rules.values() {
728 ids.push(*id);
729 }
730 }
731 Ruleset::Combined(sub_rulesets) => {
732 for sub_ruleset in sub_rulesets {
733 collect_rule_ids(sub_ruleset, rulesets, ids);
734 }
735 }
736 }
737 }
738
739 let mut rule_ids = Vec::new();
740 collect_rule_ids(ruleset, &self.rulesets, &mut rule_ids);
741
742 let iteration_report = self
743 .backend
744 .run_rules(&rule_ids)
745 .map_err(|e| Error::BackendError(e.to_string()))?;
746
747 Ok(RunReport::singleton(ruleset, iteration_report))
748 }
749
750 fn add_rule(&mut self, rule: ast::ResolvedRule) -> Result<String, Error> {
751 let core_rule =
752 rule.to_canonicalized_core_rule(&self.type_info, &mut self.parser.symbol_gen)?;
753 let (query, actions) = (&core_rule.body, &core_rule.head);
754
755 let rule_id = {
756 let mut translator = BackendRule::new(
757 self.backend.new_rule(&rule.name, self.seminaive),
758 &self.functions,
759 &self.type_info,
760 );
761 translator.query(query, false);
762 translator.actions(actions)?;
763 translator.build()
764 };
765
766 if let Some(rules) = self.rulesets.get_mut(&rule.ruleset) {
767 match rules {
768 Ruleset::Rules(rules) => {
769 match rules.entry(rule.name.clone()) {
770 indexmap::map::Entry::Occupied(_) => {
771 let name = rule.name;
772 panic!("Rule '{name}' was already present")
773 }
774 indexmap::map::Entry::Vacant(e) => e.insert((core_rule, rule_id)),
775 };
776 Ok(rule.name)
777 }
778 Ruleset::Combined(_) => Err(Error::CombinedRulesetError(rule.ruleset, rule.span)),
779 }
780 } else {
781 Err(Error::NoSuchRuleset(rule.ruleset, rule.span))
782 }
783 }
784
785 fn eval_actions(&mut self, actions: &ResolvedActions) -> Result<(), Error> {
786 let (actions, _) = actions.to_core_actions(
787 &self.type_info,
788 &mut Default::default(),
789 &mut self.parser.symbol_gen,
790 )?;
791
792 let mut translator = BackendRule::new(
793 self.backend.new_rule("eval_actions", false),
794 &self.functions,
795 &self.type_info,
796 );
797 translator.actions(&actions)?;
798 let id = translator.build();
799 let result = self.backend.run_rules(&[id]);
800 self.backend.free_rule(id);
801
802 match result {
803 Ok(_) => Ok(()),
804 Err(e) => Err(Error::BackendError(e.to_string())),
805 }
806 }
807
808 pub fn eval_expr(&mut self, expr: &Expr) -> Result<(ArcSort, Value), Error> {
810 let span = expr.span();
811 let command = Command::Action(Action::Expr(span.clone(), expr.clone()));
812 let resolved_commands = self.process_command(command)?;
813 assert_eq!(resolved_commands.len(), 1);
814 let resolved_command = resolved_commands.into_iter().next().unwrap();
815 let resolved_expr = match resolved_command {
816 ResolvedNCommand::CoreAction(ResolvedAction::Expr(_, resolved_expr)) => resolved_expr,
817 _ => unreachable!(),
818 };
819 let sort = resolved_expr.output_type();
820 let value = self.eval_resolved_expr(span, &resolved_expr)?;
821 Ok((sort, value))
822 }
823
824 fn eval_resolved_expr(&mut self, span: Span, expr: &ResolvedExpr) -> Result<Value, Error> {
825 let unit_id = self.backend.base_values().get_ty::<()>();
826 let unit_val = self.backend.base_values().get(());
827
828 let result: egglog_bridge::SideChannel<Value> = Default::default();
829 let result_ref = result.clone();
830 let ext_id = self
831 .backend
832 .register_external_func(make_external_func(move |_es, vals| {
833 debug_assert!(vals.len() == 1);
834 *result_ref.lock().unwrap() = Some(vals[0]);
835 Some(unit_val)
836 }));
837
838 let mut translator = BackendRule::new(
839 self.backend.new_rule("eval_resolved_expr", false),
840 &self.functions,
841 &self.type_info,
842 );
843
844 let result_var = ResolvedVar {
845 name: self.parser.symbol_gen.fresh("eval_resolved_expr"),
846 sort: expr.output_type(),
847 is_global_ref: false,
848 };
849 let actions = ResolvedActions::singleton(ResolvedAction::Let(
850 span.clone(),
851 result_var.clone(),
852 expr.clone(),
853 ));
854 let actions = actions
855 .to_core_actions(
856 &self.type_info,
857 &mut Default::default(),
858 &mut self.parser.symbol_gen,
859 )?
860 .0;
861 translator.actions(&actions)?;
862
863 let arg = translator.entry(&ResolvedAtomTerm::Var(span.clone(), result_var));
864 translator.rb.call_external_func(
865 ext_id,
866 &[arg],
867 egglog_bridge::ColumnTy::Base(unit_id),
868 || "this function will never panic".to_string(),
869 );
870
871 let id = translator.build();
872 let rule_result = self.backend.run_rules(&[id]);
873 self.backend.free_rule(id);
874 self.backend.free_external_func(ext_id);
875 let _ = rule_result.map_err(|e| {
876 Error::BackendError(format!("Failed to evaluate expression '{}': {}", expr, e))
877 })?;
878
879 let result = result.lock().unwrap().unwrap();
880 Ok(result)
881 }
882
883 fn add_combined_ruleset(&mut self, name: String, rulesets: Vec<String>) {
884 match self.rulesets.entry(name.clone()) {
885 Entry::Occupied(_) => panic!("Ruleset '{name}' was already present"),
886 Entry::Vacant(e) => e.insert(Ruleset::Combined(rulesets)),
887 };
888 }
889
890 fn add_ruleset(&mut self, name: String) {
891 match self.rulesets.entry(name.clone()) {
892 Entry::Occupied(_) => panic!("Ruleset '{name}' was already present"),
893 Entry::Vacant(e) => e.insert(Ruleset::Rules(Default::default())),
894 };
895 }
896
897 fn check_facts(&mut self, span: &Span, facts: &[ResolvedFact]) -> Result<(), Error> {
898 let fresh_name = self.parser.symbol_gen.fresh("check_facts");
899 let fresh_ruleset = self.parser.symbol_gen.fresh("check_facts_ruleset");
900 let rule = ast::ResolvedRule {
901 span: span.clone(),
902 head: ResolvedActions::default(),
903 body: facts.to_vec(),
904 name: fresh_name.clone(),
905 ruleset: fresh_ruleset.clone(),
906 };
907 let core_rule =
908 rule.to_canonicalized_core_rule(&self.type_info, &mut self.parser.symbol_gen)?;
909 let query = core_rule.body;
910
911 let ext_sc = egglog_bridge::SideChannel::default();
912 let ext_sc_ref = ext_sc.clone();
913 let ext_id = self
914 .backend
915 .register_external_func(make_external_func(move |_, _| {
916 *ext_sc_ref.lock().unwrap() = Some(());
917 Some(Value::new_const(0))
918 }));
919
920 let mut translator = BackendRule::new(
921 self.backend.new_rule("check_facts", false),
922 &self.functions,
923 &self.type_info,
924 );
925 translator.query(&query, true);
926 translator
927 .rb
928 .call_external_func(ext_id, &[], egglog_bridge::ColumnTy::Id, || {
929 "this function will never panic".to_string()
930 });
931 let id = translator.build();
932 let _ = self.backend.run_rules(&[id]).unwrap();
933 self.backend.free_rule(id);
934
935 self.backend.free_external_func(ext_id);
936
937 let ext_sc_val = ext_sc.lock().unwrap().take();
938 let matched = matches!(ext_sc_val, Some(()));
939
940 if !matched {
941 Err(Error::CheckError(
942 facts.iter().map(|f| f.clone().make_unresolved()).collect(),
943 span.clone(),
944 ))
945 } else {
946 Ok(())
947 }
948 }
949
950 fn run_command(&mut self, command: ResolvedNCommand) -> Result<Option<CommandOutput>, Error> {
951 match command {
952 ResolvedNCommand::Sort(_span, name, _presort_and_args) => {
954 log::info!("Declared sort {}.", name)
955 }
956 ResolvedNCommand::Function(fdecl) => {
957 self.declare_function(&fdecl)?;
958 log::info!("Declared {} {}.", fdecl.subtype, fdecl.name)
959 }
960 ResolvedNCommand::AddRuleset(_span, name) => {
961 self.add_ruleset(name.clone());
962 log::info!("Declared ruleset {name}.");
963 }
964 ResolvedNCommand::UnstableCombinedRuleset(_span, name, others) => {
965 self.add_combined_ruleset(name.clone(), others);
966 log::info!("Declared ruleset {name}.");
967 }
968 ResolvedNCommand::NormRule { rule } => {
969 let name = rule.name.clone();
970 self.add_rule(rule)?;
971 log::info!("Declared rule {name}.")
972 }
973 ResolvedNCommand::RunSchedule(sched) => {
974 let report = self.run_schedule(&sched)?;
975 log::info!("Ran schedule {}.", sched);
976 log::info!("Report: {}", report);
977 self.overall_run_report.union(report.clone());
978 return Ok(Some(CommandOutput::RunSchedule(report)));
979 }
980 ResolvedNCommand::PrintOverallStatistics(span, file) => match file {
981 None => {
982 log::info!("Printed overall statistics");
983 return Ok(Some(CommandOutput::OverallStatistics(
984 self.overall_run_report.clone(),
985 )));
986 }
987 Some(path) => {
988 let mut file = std::fs::File::create(&path)
989 .map_err(|e| Error::IoError(path.clone().into(), e, span.clone()))?;
990 log::info!("Printed overall statistics to json file {}", path);
991
992 serde_json::to_writer(&mut file, &self.overall_run_report)
993 .expect("error serializing to json");
994 }
995 },
996 ResolvedNCommand::Check(span, facts) => {
997 self.check_facts(&span, &facts)?;
998 log::info!("Checked fact {:?}.", facts);
999 }
1000 ResolvedNCommand::CoreAction(action) => match &action {
1001 ResolvedAction::Let(_, name, contents) => {
1002 panic!("Globals should have been desugared away: {name} = {contents}")
1003 }
1004 _ => {
1005 self.eval_actions(&ResolvedActions::new(vec![action.clone()]))?;
1006 }
1007 },
1008 ResolvedNCommand::Extract(span, expr, variants) => {
1009 let sort = expr.output_type();
1010
1011 let x = self.eval_resolved_expr(span.clone(), &expr)?;
1012 let n = self.eval_resolved_expr(span, &variants)?;
1013 let n: i64 = self.backend.base_values().unwrap(n);
1014
1015 let mut termdag = TermDag::default();
1016
1017 let extractor = Extractor::compute_costs_from_rootsorts(
1018 Some(vec![sort]),
1019 self,
1020 TreeAdditiveCostModel::default(),
1021 );
1022 return if n == 0 {
1023 if let Some((cost, term)) = extractor.extract_best(self, &mut termdag, x) {
1024 if log_enabled!(Level::Info) {
1026 log::info!("extracted with cost {cost}: {}", termdag.to_string(&term));
1027 }
1028 Ok(Some(CommandOutput::ExtractBest(termdag, cost, term)))
1029 } else {
1030 Err(Error::ExtractError(
1031 "Unable to find any valid extraction (likely due to subsume or delete)"
1032 .to_string(),
1033 ))
1034 }
1035 } else {
1036 if n < 0 {
1037 panic!("Cannot extract negative number of variants");
1038 }
1039 let terms: Vec<Term> = extractor
1040 .extract_variants(self, &mut termdag, x, n as usize)
1041 .iter()
1042 .map(|e| e.1.clone())
1043 .collect();
1044 if log_enabled!(Level::Info) {
1045 let expr_str = expr.to_string();
1046 log::info!("extracted {} variants for {expr_str}", terms.len());
1047 }
1048 Ok(Some(CommandOutput::ExtractVariants(termdag, terms)))
1049 };
1050 }
1051 ResolvedNCommand::Push(n) => {
1052 (0..n).for_each(|_| self.push());
1053 log::info!("Pushed {n} levels.")
1054 }
1055 ResolvedNCommand::Pop(span, n) => {
1056 for _ in 0..n {
1057 self.pop().map_err(|err| {
1058 if let Error::Pop(_) = err {
1059 Error::Pop(span.clone())
1060 } else {
1061 err
1062 }
1063 })?;
1064 }
1065 log::info!("Popped {n} levels.")
1066 }
1067 ResolvedNCommand::PrintFunction(span, f, n, file, mode) => {
1068 let file = file
1069 .map(|file| {
1070 std::fs::File::create(&file)
1071 .map_err(|e| Error::IoError(file.into(), e, span.clone()))
1072 })
1073 .transpose()?;
1074 return self.print_function(&f, n, file, mode).map_err(|e| match e {
1075 Error::TypeError(TypeError::UnboundFunction(f, _)) => {
1076 Error::TypeError(TypeError::UnboundFunction(f, span.clone()))
1077 }
1078 _ => e,
1080 });
1081 }
1082 ResolvedNCommand::PrintSize(span, f) => {
1083 let res = self.print_size(f.as_deref()).map_err(|e| match e {
1084 Error::TypeError(TypeError::UnboundFunction(f, _)) => {
1085 Error::TypeError(TypeError::UnboundFunction(f, span.clone()))
1086 }
1087 _ => e,
1089 })?;
1090 return Ok(Some(res));
1091 }
1092 ResolvedNCommand::Fail(span, c) => {
1093 let result = self.run_command(*c);
1094 if let Err(e) = result {
1095 log::info!("Command failed as expected: {e}");
1096 } else {
1097 return Err(Error::ExpectFail(span));
1098 }
1099 }
1100 ResolvedNCommand::Input {
1101 span: _,
1102 name,
1103 file,
1104 } => {
1105 self.input_file(&name, file)?;
1106 }
1107 ResolvedNCommand::Output { span, file, exprs } => {
1108 let mut filename = self.fact_directory.clone().unwrap_or_default();
1109 filename.push(file.as_str());
1110 let mut f = File::options()
1112 .append(true)
1113 .create(true)
1114 .open(&filename)
1115 .map_err(|e| Error::IoError(filename.clone(), e, span.clone()))?;
1116
1117 let extractor = Extractor::compute_costs_from_rootsorts(
1118 None,
1119 self,
1120 TreeAdditiveCostModel::default(),
1121 );
1122 let mut termdag: TermDag = Default::default();
1123
1124 use std::io::Write;
1125 for expr in exprs {
1126 let value = self.eval_resolved_expr(span.clone(), &expr)?;
1127 let expr_type = expr.output_type();
1128
1129 let term = extractor
1130 .extract_best_with_sort(self, &mut termdag, value, expr_type)
1131 .unwrap()
1132 .1;
1133 writeln!(f, "{}", termdag.to_string(&term))
1134 .map_err(|e| Error::IoError(filename.clone(), e, span.clone()))?;
1135 }
1136
1137 log::info!("Output to '{filename:?}'.")
1138 }
1139 ResolvedNCommand::UserDefined(_span, name, exprs) => {
1140 let command = self.commands.swap_remove(&name).unwrap_or_else(|| {
1141 panic!("Unrecognized user-defined command: {}", name);
1142 });
1143 let res = command.update(self, &exprs);
1144 self.commands.insert(name, command);
1145 return res;
1146 }
1147 };
1148
1149 Ok(None)
1150 }
1151
1152 fn input_file(&mut self, func_name: &str, file: String) -> Result<(), Error> {
1153 let function_type = self
1154 .type_info
1155 .get_func_type(func_name)
1156 .unwrap_or_else(|| panic!("Unrecognized function name {}", func_name));
1157 let func = self.functions.get_mut(func_name).unwrap();
1158
1159 let mut filename = self.fact_directory.clone().unwrap_or_default();
1160 filename.push(file.as_str());
1161
1162 for t in &func.schema.input {
1165 match t.name() {
1166 "i64" | "f64" | "String" => {}
1167 s => panic!("Unsupported type {} for input", s),
1168 }
1169 }
1170
1171 if function_type.subtype != FunctionSubtype::Constructor {
1172 match func.schema.output.name() {
1173 "i64" | "String" | "Unit" => {}
1174 s => panic!("Unsupported type {} for input", s),
1175 }
1176 }
1177
1178 log::info!("Opening file '{:?}'...", filename);
1179 let mut f = File::open(filename).unwrap();
1180 let mut contents = String::new();
1181 f.read_to_string(&mut contents).unwrap();
1182
1183 let mut parsed_contents: Vec<Vec<Value>> = Vec::with_capacity(contents.lines().count());
1185
1186 let mut row_schema = func.schema.input.clone();
1187 if function_type.subtype == FunctionSubtype::Custom {
1188 row_schema.push(func.schema.output.clone());
1189 }
1190
1191 log::debug!("{:?}", row_schema);
1192
1193 let unit_val = self.backend.base_values().get(());
1194
1195 for line in contents.lines() {
1196 let mut it = line.split('\t').map(|s| s.trim());
1197
1198 let mut row: Vec<Value> = Vec::with_capacity(row_schema.len());
1199
1200 for sort in row_schema.iter() {
1201 if let Some(raw) = it.next() {
1202 let val = match sort.name() {
1203 "i64" => {
1204 if let Ok(i) = raw.parse::<i64>() {
1205 self.backend.base_values().get(i)
1206 } else {
1207 return Err(Error::InputFileFormatError(file));
1208 }
1209 }
1210 "f64" => {
1211 if let Ok(f) = raw.parse::<f64>() {
1212 self.backend
1213 .base_values()
1214 .get::<F>(core_relations::Boxed::new(f.into()))
1215 } else {
1216 return Err(Error::InputFileFormatError(file));
1217 }
1218 }
1219 "String" => self.backend.base_values().get::<S>(raw.to_string().into()),
1220 "Unit" => unit_val,
1221 _ => panic!("Unreachable"),
1222 };
1223 row.push(val);
1224 } else {
1225 break;
1226 }
1227 }
1228
1229 if row.is_empty() {
1230 continue;
1231 }
1232
1233 if row.len() != row_schema.len() || it.next().is_some() {
1234 return Err(Error::InputFileFormatError(file));
1235 }
1236
1237 parsed_contents.push(row);
1238 }
1239
1240 log::debug!("Successfully loaded file.");
1241
1242 let num_facts = parsed_contents.len();
1243
1244 let mut table_action = egglog_bridge::TableAction::new(&self.backend, func.backend_id);
1245
1246 if function_type.subtype != FunctionSubtype::Constructor {
1247 self.backend.with_execution_state(|es| {
1248 for row in parsed_contents.iter() {
1249 table_action.insert(es, row.iter().copied());
1250 }
1251 Some(unit_val)
1252 });
1253 } else {
1254 self.backend.with_execution_state(|es| {
1255 for row in parsed_contents.iter() {
1256 table_action.lookup(es, row);
1257 }
1258 Some(unit_val)
1259 });
1260 }
1261
1262 self.backend.flush_updates();
1263
1264 log::info!("Read {num_facts} facts into {func_name} from '{file}'.");
1265 Ok(())
1266 }
1267
1268 fn process_command(&mut self, command: Command) -> Result<Vec<ResolvedNCommand>, Error> {
1269 let mut program = self.resolve_command(command)?;
1270
1271 program = remove_globals::remove_globals(program, &mut self.parser.symbol_gen);
1272 for command in &program {
1273 self.names.check_shadowing(command)?;
1274 }
1275
1276 Ok(program)
1277 }
1278
1279 fn resolve_command(&mut self, command: Command) -> Result<Vec<ResolvedNCommand>, Error> {
1280 let program = desugar::desugar_program(vec![command], &mut self.parser, self.seminaive)?;
1281 Ok(self.typecheck_program(&program)?)
1282 }
1283
1284 pub fn run_program(&mut self, program: Vec<Command>) -> Result<Vec<CommandOutput>, Error> {
1287 let mut outputs = Vec::new();
1288 for command in program {
1289 for processed in self.process_command(command)? {
1292 let result = self.run_command(processed)?;
1293 if let Some(output) = result {
1294 outputs.push(output);
1295 }
1296 }
1297 }
1298
1299 Ok(outputs)
1300 }
1301
1302 pub fn resugar_program(
1303 &mut self,
1304 filename: Option<String>,
1305 input: &str,
1306 ) -> Result<Vec<String>, Error> {
1307 let parsed = self.parser.get_program_from_string(filename, input)?;
1308 let mut outputs = Vec::new();
1309 for command in parsed {
1310 for processed in self.resolve_command(command)? {
1311 if let GenericNCommand::Push(..) | GenericNCommand::Pop(..) = &processed {
1314 self.run_command(processed.clone())?;
1315 }
1316 outputs.push(processed.to_command().to_string());
1317 }
1318 }
1319 Ok(outputs)
1320 }
1321
1322 pub fn parse_and_run_program(
1328 &mut self,
1329 filename: Option<String>,
1330 input: &str,
1331 ) -> Result<Vec<CommandOutput>, Error> {
1332 let parsed = self.parser.get_program_from_string(filename, input)?;
1333 self.run_program(parsed)
1334 }
1335
1336 pub fn num_tuples(&self) -> usize {
1339 self.functions
1340 .values()
1341 .map(|f| self.backend.table_size(f.backend_id))
1342 .sum()
1343 }
1344
1345 pub fn get_sort<S: Sort>(&self) -> Arc<S> {
1347 self.type_info.get_sort()
1348 }
1349
1350 pub fn get_sort_by<S: Sort>(&self, f: impl Fn(&Arc<S>) -> bool) -> Arc<S> {
1352 self.type_info.get_sort_by(f)
1353 }
1354
1355 pub fn get_sorts<S: Sort>(&self) -> Vec<Arc<S>> {
1357 self.type_info.get_sorts()
1358 }
1359
1360 pub fn get_sorts_by<S: Sort>(&self, f: impl Fn(&Arc<S>) -> bool) -> Vec<Arc<S>> {
1362 self.type_info.get_sorts_by(f)
1363 }
1364
1365 pub fn get_arcsort_by(&self, f: impl Fn(&ArcSort) -> bool) -> ArcSort {
1367 self.type_info.get_arcsort_by(f)
1368 }
1369
1370 pub fn get_arcsorts_by(&self, f: impl Fn(&ArcSort) -> bool) -> Vec<ArcSort> {
1372 self.type_info.get_arcsorts_by(f)
1373 }
1374
1375 pub fn get_sort_by_name(&self, sym: &str) -> Option<&ArcSort> {
1377 self.type_info.get_sort_by_name(sym)
1378 }
1379
1380 pub fn get_overall_run_report(&self) -> &RunReport {
1382 &self.overall_run_report
1383 }
1384
1385 pub fn value_to_base<T: BaseValue>(&self, x: Value) -> T {
1387 self.backend.base_values().unwrap::<T>(x)
1388 }
1389
1390 pub fn base_to_value<T: BaseValue>(&self, x: T) -> Value {
1392 self.backend.base_values().get::<T>(x)
1393 }
1394
1395 pub fn value_to_container<T: ContainerValue>(
1402 &self,
1403 x: Value,
1404 ) -> Option<impl Deref<Target = T>> {
1405 self.backend.container_values().get_val::<T>(x)
1406 }
1407
1408 pub fn container_to_value<T: ContainerValue>(&mut self, x: T) -> Value {
1410 self.backend.with_execution_state(|state| {
1411 self.backend.container_values().register_val::<T>(x, state)
1412 })
1413 }
1414
1415 pub fn get_size(&self, func: &str) -> usize {
1419 let function_id = self.functions.get(func).unwrap().backend_id;
1420 self.backend.table_size(function_id)
1421 }
1422
1423 pub fn lookup_function(&self, name: &str, key: &[Value]) -> Option<Value> {
1428 let func = self.functions.get(name).unwrap().backend_id;
1429 self.backend.lookup_id(func, key)
1430 }
1431
1432 pub fn get_function(&self, name: &str) -> Option<&Function> {
1436 self.functions.get(name)
1437 }
1438
1439 pub fn set_report_level(&mut self, level: ReportLevel) {
1440 self.backend.set_report_level(level);
1441 }
1442
1443 pub fn dump_debug_info(&self) {
1447 self.backend.dump_debug_info();
1448 }
1449
1450 pub fn get_canonical_value(&self, val: Value, sort: &ArcSort) -> Value {
1452 self.backend
1453 .get_canon_repr(val, sort.column_ty(&self.backend))
1454 }
1455}
1456
1457struct BackendRule<'a> {
1458 rb: egglog_bridge::RuleBuilder<'a>,
1459 entries: HashMap<core::ResolvedAtomTerm, QueryEntry>,
1460 functions: &'a IndexMap<String, Function>,
1461 type_info: &'a TypeInfo,
1462}
1463
1464impl<'a> BackendRule<'a> {
1465 fn new(
1466 rb: egglog_bridge::RuleBuilder<'a>,
1467 functions: &'a IndexMap<String, Function>,
1468 type_info: &'a TypeInfo,
1469 ) -> BackendRule<'a> {
1470 BackendRule {
1471 rb,
1472 functions,
1473 type_info,
1474 entries: Default::default(),
1475 }
1476 }
1477
1478 fn entry(&mut self, x: &core::ResolvedAtomTerm) -> QueryEntry {
1479 self.entries
1480 .entry(x.clone())
1481 .or_insert_with(|| match x {
1482 core::GenericAtomTerm::Var(_, v) => self
1483 .rb
1484 .new_var_named(v.sort.column_ty(self.rb.egraph()), &v.name),
1485 core::GenericAtomTerm::Literal(_, l) => literal_to_entry(self.rb.egraph(), l),
1486 core::GenericAtomTerm::Global(..) => {
1487 panic!("Globals should have been desugared")
1488 }
1489 })
1490 .clone()
1491 }
1492
1493 fn func(&self, f: &typechecking::FuncType) -> egglog_bridge::FunctionId {
1494 self.functions[&f.name].backend_id
1495 }
1496
1497 fn prim(
1498 &mut self,
1499 prim: &core::SpecializedPrimitive,
1500 args: &[core::ResolvedAtomTerm],
1501 ) -> (ExternalFunctionId, Vec<QueryEntry>, ColumnTy) {
1502 let mut qe_args = self.args(args);
1503
1504 if prim.primitive.0.name() == "unstable-fn" {
1505 let core::ResolvedAtomTerm::Literal(_, Literal::String(ref name)) = args[0] else {
1506 panic!("expected string literal after `unstable-fn`")
1507 };
1508 let id = if let Some(f) = self.type_info.get_func_type(name) {
1509 ResolvedFunctionId::Lookup(egglog_bridge::TableAction::new(
1510 self.rb.egraph(),
1511 self.func(f),
1512 ))
1513 } else if let Some(possible) = self.type_info.get_prims(name) {
1514 let mut ps: Vec<_> = possible.iter().collect();
1515 ps.retain(|p| {
1516 self.type_info
1517 .get_sorts::<FunctionSort>()
1518 .into_iter()
1519 .any(|f| {
1520 let types: Vec<_> = prim
1521 .input
1522 .iter()
1523 .skip(1)
1524 .chain(f.inputs())
1525 .chain([&f.output()])
1526 .cloned()
1527 .collect();
1528 p.accept(&types, self.type_info)
1529 })
1530 });
1531 assert!(ps.len() == 1, "options for {name}: {ps:?}");
1532 ResolvedFunctionId::Prim(ps.into_iter().next().unwrap().1)
1533 } else {
1534 panic!("no callable for {name}");
1535 };
1536 let do_rebuild = prim
1537 .input
1538 .iter()
1539 .skip(1)
1540 .map(|s| s.is_eq_sort() || s.is_eq_container_sort())
1541 .collect();
1542
1543 qe_args[0] = self.rb.egraph().base_value_constant(ResolvedFunction {
1544 id,
1545 do_rebuild,
1546 name: name.clone(),
1547 });
1548 }
1549
1550 (
1551 prim.primitive.1,
1552 qe_args,
1553 prim.output.column_ty(self.rb.egraph()),
1554 )
1555 }
1556
1557 fn args<'b>(
1558 &mut self,
1559 args: impl IntoIterator<Item = &'b core::ResolvedAtomTerm>,
1560 ) -> Vec<QueryEntry> {
1561 args.into_iter().map(|x| self.entry(x)).collect()
1562 }
1563
1564 fn query(&mut self, query: &core::Query<ResolvedCall, ResolvedVar>, include_subsumed: bool) {
1565 for atom in &query.atoms {
1566 match &atom.head {
1567 ResolvedCall::Func(f) => {
1568 let f = self.func(f);
1569 let args = self.args(&atom.args);
1570 let is_subsumed = match include_subsumed {
1571 true => None,
1572 false => Some(false),
1573 };
1574 self.rb.query_table(f, &args, is_subsumed).unwrap();
1575 }
1576 ResolvedCall::Primitive(p) => {
1577 let (p, args, ty) = self.prim(p, &atom.args);
1578 self.rb.query_prim(p, &args, ty).unwrap()
1579 }
1580 }
1581 }
1582 }
1583
1584 fn actions(&mut self, actions: &core::ResolvedCoreActions) -> Result<(), Error> {
1585 for action in &actions.0 {
1586 match action {
1587 core::GenericCoreAction::Let(span, v, f, args) => {
1588 let v = core::GenericAtomTerm::Var(span.clone(), v.clone());
1589 let y = match f {
1590 ResolvedCall::Func(f) => {
1591 let name = f.name.clone();
1592 let f = self.func(f);
1593 let args = self.args(args);
1594 let span = span.clone();
1595 self.rb.lookup(f, &args, move || {
1596 format!("{span}: lookup of function {name} failed")
1597 })
1598 }
1599 ResolvedCall::Primitive(p) => {
1600 let name = p.primitive.0.name().to_owned();
1601 let (p, args, ty) = self.prim(p, args);
1602 let span = span.clone();
1603 self.rb.call_external_func(p, &args, ty, move || {
1604 format!("{span}: call of primitive {name} failed")
1605 })
1606 }
1607 };
1608 self.entries.insert(v, y.into());
1609 }
1610 core::GenericCoreAction::LetAtomTerm(span, v, x) => {
1611 let v = core::GenericAtomTerm::Var(span.clone(), v.clone());
1612 let x = self.entry(x);
1613 self.entries.insert(v, x);
1614 }
1615 core::GenericCoreAction::Set(_, f, xs, y) => match f {
1616 ResolvedCall::Primitive(..) => panic!("runtime primitive set!"),
1617 ResolvedCall::Func(f) => {
1618 let f = self.func(f);
1619 let args = self.args(xs.iter().chain([y]));
1620 self.rb.set(f, &args)
1621 }
1622 },
1623 core::GenericCoreAction::Change(span, change, f, args) => match f {
1624 ResolvedCall::Primitive(..) => panic!("runtime primitive change!"),
1625 ResolvedCall::Func(f) => {
1626 let name = f.name.clone();
1627 let can_subsume = self.functions[&f.name].can_subsume;
1628 let f = self.func(f);
1629 let args = self.args(args);
1630 match change {
1631 Change::Delete => self.rb.remove(f, &args),
1632 Change::Subsume if can_subsume => self.rb.subsume(f, &args),
1633 Change::Subsume => {
1634 return Err(Error::SubsumeMergeError(name, span.clone()));
1635 }
1636 }
1637 }
1638 },
1639 core::GenericCoreAction::Union(_, x, y) => {
1640 let x = self.entry(x);
1641 let y = self.entry(y);
1642 self.rb.union(x, y)
1643 }
1644 core::GenericCoreAction::Panic(_, message) => self.rb.panic(message.clone()),
1645 }
1646 }
1647 Ok(())
1648 }
1649
1650 fn build(self) -> egglog_bridge::RuleId {
1651 self.rb.build()
1652 }
1653}
1654
1655fn literal_to_entry(egraph: &egglog_bridge::EGraph, l: &Literal) -> QueryEntry {
1656 match l {
1657 Literal::Int(x) => egraph.base_value_constant::<i64>(*x),
1658 Literal::Float(x) => egraph.base_value_constant::<sort::F>(x.into()),
1659 Literal::String(x) => egraph.base_value_constant::<sort::S>(sort::S::new(x.clone())),
1660 Literal::Bool(x) => egraph.base_value_constant::<bool>(*x),
1661 Literal::Unit => egraph.base_value_constant::<()>(()),
1662 }
1663}
1664
1665fn literal_to_value(egraph: &egglog_bridge::EGraph, l: &Literal) -> Value {
1666 match l {
1667 Literal::Int(x) => egraph.base_values().get::<i64>(*x),
1668 Literal::Float(x) => egraph.base_values().get::<sort::F>(x.into()),
1669 Literal::String(x) => egraph.base_values().get::<sort::S>(sort::S::new(x.clone())),
1670 Literal::Bool(x) => egraph.base_values().get::<bool>(*x),
1671 Literal::Unit => egraph.base_values().get::<()>(()),
1672 }
1673}
1674
1675#[derive(Debug, Error)]
1676pub enum Error {
1677 #[error(transparent)]
1678 ParseError(#[from] ParseError),
1679 #[error(transparent)]
1680 NotFoundError(#[from] NotFoundError),
1681 #[error(transparent)]
1682 TypeError(#[from] TypeError),
1683 #[error("Errors:\n{}", ListDisplay(.0, "\n"))]
1684 TypeErrors(Vec<TypeError>),
1685 #[error("{}\nCheck failed: \n{}", .1, ListDisplay(.0, "\n"))]
1686 CheckError(Vec<Fact>, Span),
1687 #[error("{1}\nNo such ruleset: {0}")]
1688 NoSuchRuleset(String, Span),
1689 #[error(
1690 "{1}\nAttempted to add a rule to combined ruleset {0}. Combined rulesets may only depend on other rulesets."
1691 )]
1692 CombinedRulesetError(String, Span),
1693 #[error("{0}")]
1694 BackendError(String),
1695 #[error("{0}\nTried to pop too much")]
1696 Pop(Span),
1697 #[error("{0}\nCommand should have failed.")]
1698 ExpectFail(Span),
1699 #[error("{2}\nIO error: {0}: {1}")]
1700 IoError(PathBuf, std::io::Error, Span),
1701 #[error("{1}\nCannot subsume function with merge: {0}")]
1702 SubsumeMergeError(String, Span),
1703 #[error("extraction failure: {:?}", .0)]
1704 ExtractError(String),
1705 #[error("{1}\n{2}\nShadowing is not allowed, but found {0}")]
1706 Shadowing(String, Span, Span),
1707 #[error("{1}\nCommand already exists: {0}")]
1708 CommandAlreadyExists(String, Span),
1709 #[error("Incorrect format in file '{0}'.")]
1710 InputFileFormatError(String),
1711}
1712
1713#[cfg(test)]
1714mod tests {
1715 use crate::constraint::SimpleTypeConstraint;
1716 use crate::sort::*;
1717 use crate::*;
1718
1719 #[derive(Clone)]
1720 struct InnerProduct {
1721 vec: ArcSort,
1722 }
1723
1724 impl Primitive for InnerProduct {
1725 fn name(&self) -> &str {
1726 "inner-product"
1727 }
1728
1729 fn get_type_constraints(&self, span: &Span) -> Box<dyn crate::constraint::TypeConstraint> {
1730 SimpleTypeConstraint::new(
1731 self.name(),
1732 vec![self.vec.clone(), self.vec.clone(), I64Sort.to_arcsort()],
1733 span.clone(),
1734 )
1735 .into_box()
1736 }
1737
1738 fn apply(&self, exec_state: &mut ExecutionState<'_>, args: &[Value]) -> Option<Value> {
1739 let mut sum = 0;
1740 let vec1 = exec_state
1741 .container_values()
1742 .get_val::<VecContainer>(args[0])
1743 .unwrap();
1744 let vec2 = exec_state
1745 .container_values()
1746 .get_val::<VecContainer>(args[1])
1747 .unwrap();
1748 assert_eq!(vec1.data.len(), vec2.data.len());
1749 for (a, b) in vec1.data.iter().zip(vec2.data.iter()) {
1750 let a = exec_state.base_values().unwrap::<i64>(*a);
1751 let b = exec_state.base_values().unwrap::<i64>(*b);
1752 sum += a * b;
1753 }
1754 Some(exec_state.base_values().get::<i64>(sum))
1755 }
1756 }
1757
1758 #[test]
1759 fn test_user_defined_primitive() {
1760 let mut egraph = EGraph::default();
1761 egraph
1762 .parse_and_run_program(None, "(sort IntVec (Vec i64))")
1763 .unwrap();
1764
1765 let int_vec_sort = egraph.get_arcsort_by(|s| {
1766 s.value_type() == Some(std::any::TypeId::of::<VecContainer>())
1767 && s.inner_sorts()[0].name() == I64Sort.name()
1768 });
1769
1770 egraph.add_primitive(InnerProduct { vec: int_vec_sort });
1771
1772 egraph
1773 .parse_and_run_program(
1774 None,
1775 "
1776 (let a (vec-of 1 2 3 4 5 6))
1777 (let b (vec-of 6 5 4 3 2 1))
1778 (check (= (inner-product a b) 56))
1779 ",
1780 )
1781 .unwrap();
1782 }
1783
1784 #[test]
1786 fn test_egraph_send_sync() {
1787 fn is_send<T: Send>(_t: &T) -> bool {
1788 true
1789 }
1790 fn is_sync<T: Sync>(_t: &T) -> bool {
1791 true
1792 }
1793 let egraph = EGraph::default();
1794 assert!(is_send(&egraph) && is_sync(&egraph));
1795 }
1796
1797 fn get_function(egraph: &EGraph, name: &str) -> Function {
1798 egraph.functions.get(name).unwrap().clone()
1799 }
1800
1801 fn get_value(egraph: &EGraph, name: &str) -> Value {
1802 let mut out = None;
1803 let id = get_function(egraph, name).backend_id;
1804 egraph.backend.for_each(id, |row| out = Some(row.vals[0]));
1805 out.unwrap()
1806 }
1807
1808 #[test]
1809 fn test_subsumed_unextractable_rebuild_arg() {
1810 let mut egraph = EGraph::default();
1812
1813 egraph
1814 .parse_and_run_program(
1815 None,
1816 r#"
1817 (datatype Math)
1818 (constructor container (Math) Math)
1819 (constructor exp () Math :cost 100)
1820 (constructor cheap () Math)
1821 (constructor cheap-1 () Math)
1822 ; we make the container cheap so that it will be extracted if possible, but then we mark it as subsumed
1823 ; so the (exp) expr should be extracted instead
1824 (let res (container (cheap)))
1825 (union res (exp))
1826 (cheap)
1827 (cheap-1)
1828 (subsume (container (cheap)))
1829 "#,
1830 ).unwrap();
1831 let orig_cheap_value = get_value(&egraph, "cheap");
1833 let orig_cheap_1_value = get_value(&egraph, "cheap-1");
1834 assert_ne!(orig_cheap_value, orig_cheap_1_value);
1835 egraph
1837 .parse_and_run_program(
1838 None,
1839 r#"
1840 (union (cheap-1) (cheap))
1841 "#,
1842 )
1843 .unwrap();
1844 let new_cheap_value = get_value(&egraph, "cheap");
1846 let new_cheap_1_value = get_value(&egraph, "cheap-1");
1847 assert_eq!(new_cheap_value, new_cheap_1_value);
1848 assert!(new_cheap_value != orig_cheap_value || new_cheap_1_value != orig_cheap_1_value);
1849 let outputs = egraph
1851 .parse_and_run_program(
1852 None,
1853 r#"
1854 (extract res)
1855 "#,
1856 )
1857 .unwrap();
1858 assert_eq!(outputs[0].to_string(), "(exp)\n");
1859 }
1860
1861 #[test]
1862 fn test_subsumed_unextractable_rebuild_self() {
1863 let mut egraph = EGraph::default();
1865
1866 egraph
1867 .parse_and_run_program(
1868 None,
1869 r#"
1870 (datatype Math)
1871 (constructor container (Math) Math)
1872 (constructor exp () Math :cost 100)
1873 (constructor cheap () Math)
1874 (exp)
1875 (let x (cheap))
1876 (subsume (cheap))
1877 "#,
1878 )
1879 .unwrap();
1880
1881 let orig_cheap_value = get_value(&egraph, "cheap");
1882 egraph
1884 .parse_and_run_program(
1885 None,
1886 r#"
1887 (union (exp) x)
1888 "#,
1889 )
1890 .unwrap();
1891 let new_cheap_value = get_value(&egraph, "cheap");
1893 assert_ne!(new_cheap_value, orig_cheap_value);
1894
1895 let res = egraph
1897 .parse_and_run_program(
1898 None,
1899 r#"
1900 (extract x)
1901 "#,
1902 )
1903 .unwrap();
1904 assert_eq!(res[0].to_string(), "(exp)\n");
1905 }
1906}