1pub mod ast;
16#[cfg(feature = "bin")]
17mod cli;
18mod command_macro;
19pub mod constraint;
20mod core;
21pub mod extract;
22pub mod prelude;
23mod proofs;
24
25pub mod scheduler;
26mod serialize;
27pub mod sort;
28mod termdag;
29mod typechecking;
30pub mod util;
31pub use command_macro::{CommandMacro, CommandMacroRegistry};
32
33extern crate self as egglog;
36pub use ast::{ResolvedExpr, ResolvedFact, ResolvedVar};
37#[cfg(feature = "bin")]
38pub use cli::*;
39use constraint::{Constraint, Problem, SimpleTypeConstraint, TypeConstraint};
40pub use core::{Atom, AtomTerm};
41use core::{CoreActionContext, ResolvedAtomTerm};
42pub use core::{ResolvedCall, SpecializedPrimitive};
43pub use core_relations::{BaseValue, ContainerValue, ExecutionState, Value};
44use core_relations::{ExternalFunctionId, make_external_func};
45use csv::Writer;
46pub use egglog_add_primitive::add_literal_prim;
47pub use egglog_add_primitive::add_primitive;
48pub use egglog_add_primitive::add_primitive_with_validator;
49use egglog_ast::generic_ast::{Change, GenericExpr, Literal};
50use egglog_ast::span::Span;
51use egglog_ast::util::ListDisplay;
52pub use egglog_bridge::FunctionRow;
53use egglog_bridge::{ColumnTy, QueryEntry};
54use egglog_core_relations as core_relations;
55use egglog_numeric_id as numeric_id;
56use egglog_reports::{ReportLevel, RunReport};
57use extract::{CostModel, DefaultCost, Extractor, TreeAdditiveCostModel};
58use indexmap::map::Entry;
59use log::{Level, log_enabled};
60use numeric_id::DenseIdMap;
61use prelude::*;
62pub use proofs::proof_encoding_helpers::{file_supports_proofs, program_supports_proofs};
63use scheduler::{SchedulerId, SchedulerRecord};
64pub use serialize::{SerializeConfig, SerializeOutput, SerializedNode};
65use sort::*;
66use std::fmt::{Debug, Display, Formatter};
67use std::fs::File;
68use std::hash::Hash;
69use std::io::{Read, Write as _};
70use std::iter::once;
71use std::ops::Deref;
72use std::path::PathBuf;
73use std::str::FromStr;
74use std::sync::Arc;
75pub use termdag::{Term, TermDag, TermId};
76use thiserror::Error;
77pub use typechecking::PrimitiveValidator;
78pub use typechecking::TypeError;
79pub use typechecking::TypeInfo;
80use util::*;
81
82use crate::ast::desugar::desugar_command;
83use crate::ast::*;
84use crate::core::{GenericActionsExt, ResolvedRuleExt};
85use crate::proofs::proof_encoding::{EncodingState, ProofInstrumentor};
86use crate::proofs::proof_encoding_helpers::command_supports_proof_encoding;
87use crate::proofs::proof_extraction::ProveExistsError;
88use crate::proofs::proof_format::{ProofId, ProofStore};
89use crate::proofs::proof_normal_form::proof_form;
90
91pub const GLOBAL_NAME_PREFIX: &str = "$";
92
93pub type ArcSort = Arc<dyn Sort>;
94
95pub trait Primitive {
99 fn name(&self) -> &str;
101
102 fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint>;
105
106 fn apply(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value>;
110}
111
112pub trait UserDefinedCommandOutput: Debug + std::fmt::Display + Send + Sync {}
114impl<T> UserDefinedCommandOutput for T where T: Debug + std::fmt::Display + Send + Sync {}
115
116#[derive(Clone, Debug)]
118#[allow(clippy::large_enum_variant)]
119pub enum CommandOutput {
120 PrintFunctionSize(usize),
122 PrintAllFunctionsSize(Vec<(String, usize)>),
124 ExtractBest(TermDag, DefaultCost, TermId),
126 ExtractVariants(TermDag, Vec<TermId>),
128 ProveExists {
130 proof_store: ProofStore,
131 proof_id: ProofId,
132 },
133 OverallStatistics(RunReport),
135 PrintFunction(Function, TermDag, Vec<(TermId, TermId)>, PrintFunctionMode),
137 RunSchedule(RunReport),
139 UserDefined(Arc<dyn UserDefinedCommandOutput>),
141}
142
143impl std::fmt::Display for CommandOutput {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 match self {
147 CommandOutput::PrintFunctionSize(size) => writeln!(f, "{}", size),
148 CommandOutput::PrintAllFunctionsSize(names_and_sizes) => {
149 for name in names_and_sizes {
150 writeln!(f, "{}: {}", name.0, name.1)?;
151 }
152 Ok(())
153 }
154 CommandOutput::ExtractBest(termdag, _cost, term) => {
155 writeln!(f, "{}", termdag.to_string(*term))
156 }
157 CommandOutput::ExtractVariants(termdag, terms) => {
158 writeln!(f, "(")?;
159 for expr in terms {
160 writeln!(f, " {}", termdag.to_string(*expr))?;
161 }
162 writeln!(f, ")")
163 }
164 CommandOutput::ProveExists {
165 proof_store,
166 proof_id,
167 } => write!(f, "{}", proof_store.proof_to_string(*proof_id)),
168 CommandOutput::OverallStatistics(run_report) => {
169 write!(f, "Overall statistics:\n{}", run_report)
170 }
171 CommandOutput::PrintFunction(function, termdag, terms_and_outputs, mode) => {
172 let out_is_unit = function.schema.output.name() == UnitSort.name();
173 if *mode == PrintFunctionMode::CSV {
174 let mut wtr = Writer::from_writer(vec![]);
175 for (term_id, output) in terms_and_outputs {
176 let term = termdag.get(*term_id);
177 match term {
178 Term::App(name, children) => {
179 let mut values = vec![name.clone()];
180 for child_id in children {
181 values.push(termdag.to_string(*child_id));
182 }
183
184 if !out_is_unit {
185 values.push(termdag.to_string(*output));
186 }
187 wtr.write_record(&values).map_err(|_| std::fmt::Error)?;
188 }
189 _ => panic!("Expect function_to_dag to return a list of apps."),
190 }
191 }
192 let csv_bytes = wtr.into_inner().map_err(|_| std::fmt::Error)?;
193 f.write_str(&String::from_utf8(csv_bytes).map_err(|_| std::fmt::Error)?)
194 } else {
195 writeln!(f, "(")?;
196 for (term, output) in terms_and_outputs.iter() {
197 write!(f, " {}", termdag.to_string(*term))?;
198 if !out_is_unit {
199 write!(f, " -> {}", termdag.to_string(*output))?;
200 }
201 writeln!(f)?;
202 }
203 writeln!(f, ")")
204 }
205 }
206 CommandOutput::RunSchedule(_report) => Ok(()),
207 CommandOutput::UserDefined(output) => {
208 write!(f, "{}", *output)
209 }
210 }
211 }
212}
213
214#[derive(Clone)]
228pub struct EGraph {
229 backend: egglog_bridge::EGraph,
230 pub parser: Parser,
231 names: check_shadowing::Names,
232 pushed_egraph: Option<Box<Self>>,
235 functions: IndexMap<String, Function>,
236 rulesets: IndexMap<String, Ruleset>,
237 pub fact_directory: Option<PathBuf>,
238 pub seminaive: bool,
239 type_info: TypeInfo,
240 overall_run_report: RunReport,
242 schedulers: DenseIdMap<SchedulerId, SchedulerRecord>,
243 commands: IndexMap<String, Arc<dyn UserDefinedCommand>>,
244 strict_mode: bool,
245 warned_about_global_prefix: bool,
246 command_macros: CommandMacroRegistry,
248 proof_state: EncodingState,
249 desugared_commands: Vec<ResolvedNCommand>,
251}
252
253pub trait UserDefinedCommand: Send + Sync {
259 fn update(&self, egraph: &mut EGraph, args: &[Expr]) -> Result<Option<CommandOutput>, Error>;
261}
262
263#[derive(Clone)]
268pub struct Function {
269 decl: ResolvedFunctionDecl,
270 schema: ResolvedSchema,
271 can_subsume: bool,
272 backend_id: egglog_bridge::FunctionId,
273}
274
275impl Function {
276 pub fn name(&self) -> &str {
278 &self.decl.name
279 }
280
281 pub fn schema(&self) -> &ResolvedSchema {
283 &self.schema
284 }
285
286 pub fn can_subsume(&self) -> bool {
288 self.can_subsume
289 }
290}
291
292#[derive(Clone, Debug)]
293pub struct ResolvedSchema {
294 pub input: Vec<ArcSort>,
295 pub output: ArcSort,
296}
297
298impl ResolvedSchema {
299 pub fn get_by_pos(&self, index: usize) -> Option<&ArcSort> {
301 if self.input.len() == index {
302 Some(&self.output)
303 } else {
304 self.input.get(index)
305 }
306 }
307}
308
309impl Debug for Function {
310 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
311 f.debug_struct("Function")
312 .field("decl", &self.decl)
313 .field("schema", &self.schema)
314 .finish()
315 }
316}
317
318impl Default for EGraph {
319 fn default() -> Self {
320 let mut parser = Parser::default();
321 let proof_state = EncodingState::new(&mut parser.symbol_gen);
322 let mut eg = Self {
323 backend: Default::default(),
324 parser,
325 names: Default::default(),
326 pushed_egraph: Default::default(),
327 functions: Default::default(),
328 rulesets: Default::default(),
329 fact_directory: None,
330 seminaive: true,
331 overall_run_report: Default::default(),
332 type_info: Default::default(),
333 schedulers: Default::default(),
334 commands: Default::default(),
335 strict_mode: false,
336 warned_about_global_prefix: false,
337 command_macros: Default::default(),
338 proof_state,
339 desugared_commands: vec![],
340 };
341
342 add_base_sort(&mut eg, UnitSort, span!()).unwrap();
343 add_base_sort(&mut eg, StringSort, span!()).unwrap();
344 add_base_sort(&mut eg, BoolSort, span!()).unwrap();
345 add_base_sort(&mut eg, I64Sort, span!()).unwrap();
346 add_base_sort(&mut eg, F64Sort, span!()).unwrap();
347 add_base_sort(&mut eg, BigIntSort, span!()).unwrap();
348 add_base_sort(&mut eg, BigRatSort, span!()).unwrap();
349 eg.type_info.add_presort::<MapSort>(span!()).unwrap();
350 eg.type_info.add_presort::<SetSort>(span!()).unwrap();
351 eg.type_info.add_presort::<VecSort>(span!()).unwrap();
352 eg.type_info.add_presort::<FunctionSort>(span!()).unwrap();
353 eg.type_info.add_presort::<MultiSetSort>(span!()).unwrap();
354
355 let neq_validator = |termdag: &mut TermDag, args: &[TermId]| -> Option<TermId> {
357 if args.len() == 2 && args[0] != args[1] {
358 Some(termdag.lit(Literal::Unit))
360 } else {
361 None
362 }
363 };
364 add_primitive_with_validator!(
365 &mut eg,
366 "!=" = |a: #, b: #| -?> () {
367 (a != b).then_some(())
368 },
369 neq_validator
370 );
371
372 add_primitive!(&mut eg, "value-eq" = |a: #, b: #| -?> () {
373 (a == b).then_some(())
374 });
375 add_primitive!(&mut eg, "ordering-min" = |a: #, b: #| -> # {
376 if a < b { a } else { b }
377 });
378 add_primitive!(&mut eg, "ordering-max" = |a: #, b: #| -> # {
379 if a > b { a } else { b }
380 });
381
382 eg.rulesets
383 .insert("".into(), Ruleset::Rules(Default::default()));
384
385 eg
386 }
387}
388
389#[derive(Debug, Error)]
390#[error("Not found: {0}")]
391pub struct NotFoundError(String);
392
393impl EGraph {
394 pub fn new_with_term_encoding() -> Self {
403 let mut egraph = EGraph::default();
404 egraph.proof_state.original_typechecking = Some(Box::new(egraph.clone()));
405 egraph
406 }
407
408 pub fn new_with_proofs() -> Self {
410 let mut egraph = EGraph::new_with_term_encoding();
411 egraph.proof_state.proofs_enabled = true;
412 egraph
413 }
414
415 pub(crate) fn with_term_encoding_enabled(mut self) -> Self {
419 self.proof_state.original_typechecking = Some(Box::new(self.clone()));
420 self
421 }
422
423 pub(crate) fn with_proofs_enabled(mut self) -> Self {
427 self = self.with_term_encoding_enabled();
428 self.proof_state.proofs_enabled = true;
429 self
430 }
431
432 pub fn with_proof_testing(mut self) -> Self {
434 self.proof_state.proof_testing = true;
435 self
436 }
437
438 pub fn type_info(&mut self) -> &mut TypeInfo {
441 &mut self.type_info
442 }
443
444 pub fn command_macros(&self) -> &CommandMacroRegistry {
446 &self.command_macros
447 }
448
449 pub fn command_macros_mut(&mut self) -> &mut CommandMacroRegistry {
451 &mut self.command_macros
452 }
453
454 pub fn add_command(
455 &mut self,
456 name: String,
457 command: Arc<dyn UserDefinedCommand>,
458 ) -> Result<(), Error> {
459 if self.commands.contains_key(&name)
460 || self.functions.contains_key(&name)
461 || self.type_info.get_prims(&name).is_some()
462 {
463 return Err(Error::CommandAlreadyExists(name, span!()));
464 }
465 self.commands.insert(name.clone(), command);
466 self.parser.add_user_defined(name)?;
467 Ok(())
468 }
469
470 pub fn set_strict_mode(&mut self, strict_mode: bool) {
472 self.strict_mode = strict_mode;
473 }
474
475 pub fn strict_mode(&self) -> bool {
477 self.strict_mode
478 }
479
480 fn ensure_global_name_prefix(&mut self, span: &Span, name: &str) -> Result<(), TypeError> {
481 if name.starts_with(GLOBAL_NAME_PREFIX) {
482 return Ok(());
483 }
484 if self.strict_mode {
485 Err(TypeError::GlobalMissingPrefix {
486 name: name.to_owned(),
487 span: span.clone(),
488 })
489 } else {
490 self.warn_missing_global_prefix(span, name)?;
491 Ok(())
492 }
493 }
494
495 fn warn_missing_global_prefix(
496 &mut self,
497 span: &Span,
498 canonical_name: &str,
499 ) -> Result<(), TypeError> {
500 if self.strict_mode {
501 return Err(TypeError::GlobalMissingPrefix {
502 name: format!("{}{}", GLOBAL_NAME_PREFIX, canonical_name),
503 span: span.clone(),
504 });
505 }
506 if self.warned_about_global_prefix {
507 return Ok(());
508 }
509 self.warned_about_global_prefix = true;
510 log::warn!(
511 "{}\nGlobal `{}` should start with `{}`. Enable `--strict-mode` to turn this warning into an error. Suppressing additional warnings of this type.",
512 span,
513 canonical_name,
514 GLOBAL_NAME_PREFIX
515 );
516 Ok(())
517 }
518
519 fn warn_prefixed_non_globals(
520 &mut self,
521 span: &Span,
522 canonical_name: &str,
523 ) -> Result<(), TypeError> {
524 if self.strict_mode {
525 return Err(TypeError::NonGlobalPrefixed {
526 name: canonical_name.to_string(),
527 span: span.clone(),
528 });
529 }
530 if self.warned_about_global_prefix {
531 return Ok(());
532 }
533 self.warned_about_global_prefix = true;
534 log::warn!(
535 "{}\nNon-global `{}` should not start with `{}`. Enable `--strict-mode` to turn this warning into an error. Suppressing additional warnings of this type.",
536 span,
537 canonical_name,
538 GLOBAL_NAME_PREFIX
539 );
540 Ok(())
541 }
542
543 pub fn push(&mut self) {
547 let prev_prev: Option<Box<Self>> = self.pushed_egraph.take();
548 let mut prev = self.clone();
549 prev.pushed_egraph = prev_prev;
550 self.pushed_egraph = Some(Box::new(prev));
551 }
552
553 pub fn pop(&mut self) -> Result<(), Error> {
558 match self.pushed_egraph.take() {
559 Some(e) => {
560 let overall_run_report = self.overall_run_report.clone();
562 *self = *e;
563 self.overall_run_report = overall_run_report;
564 Ok(())
565 }
566 None => Err(Error::Pop(span!())),
567 }
568 }
569
570 fn translate_expr_to_mergefn(
571 &self,
572 expr: &ResolvedExpr,
573 ) -> Result<egglog_bridge::MergeFn, Error> {
574 match expr {
575 GenericExpr::Lit(_, literal) => {
576 let val = literal_to_value(&self.backend, literal);
577 Ok(egglog_bridge::MergeFn::Const(val))
578 }
579 GenericExpr::Var(span, resolved_var) => match resolved_var.name.as_str() {
580 "old" => Ok(egglog_bridge::MergeFn::Old),
581 "new" => Ok(egglog_bridge::MergeFn::New),
582 _ => Err(TypeError::Unbound(resolved_var.name.clone(), span.clone()).into()),
584 },
585 GenericExpr::Call(_, ResolvedCall::Func(f), args) => {
586 let translated_args = args
587 .iter()
588 .map(|arg| self.translate_expr_to_mergefn(arg))
589 .collect::<Result<Vec<_>, _>>()?;
590 Ok(egglog_bridge::MergeFn::Function(
591 self.functions[&f.name].backend_id,
592 translated_args,
593 ))
594 }
595 GenericExpr::Call(_, ResolvedCall::Primitive(p), args) => {
596 let translated_args = args
597 .iter()
598 .map(|arg| self.translate_expr_to_mergefn(arg))
599 .collect::<Result<Vec<_>, _>>()?;
600 Ok(egglog_bridge::MergeFn::Primitive(
601 p.external_id(),
602 translated_args,
603 ))
604 }
605 }
606 }
607
608 fn declare_function(&mut self, decl: &ResolvedFunctionDecl) -> Result<(), Error> {
609 let get_sort = |name: &String| match self.type_info.get_sort_by_name(name) {
610 Some(sort) => Ok(sort.clone()),
611 None => Err(Error::TypeError(TypeError::UndefinedSort(
612 name.to_owned(),
613 decl.span.clone(),
614 ))),
615 };
616
617 let input = decl
618 .schema
619 .input
620 .iter()
621 .map(get_sort)
622 .collect::<Result<Vec<_>, _>>()?;
623 let output = get_sort(&decl.schema.output)?;
624
625 let can_subsume = match decl.subtype {
626 FunctionSubtype::Constructor => true,
627 FunctionSubtype::Custom => false,
628 };
629
630 use egglog_bridge::{DefaultVal, MergeFn};
631 let backend_id = self.backend.add_table(egglog_bridge::FunctionConfig {
632 schema: input
633 .iter()
634 .chain([&output])
635 .map(|sort| sort.column_ty(&self.backend))
636 .collect(),
637 default: match decl.subtype {
638 FunctionSubtype::Constructor => DefaultVal::FreshId,
639 FunctionSubtype::Custom => DefaultVal::Fail,
640 },
641 merge: match decl.subtype {
642 FunctionSubtype::Constructor => MergeFn::UnionId,
643 FunctionSubtype::Custom => match &decl.merge {
644 None => MergeFn::AssertEq,
645 Some(expr) => self.translate_expr_to_mergefn(expr)?,
646 },
647 },
648 name: decl.name.to_string(),
649 can_subsume,
650 });
651
652 let function = Function {
653 decl: decl.clone(),
654 schema: ResolvedSchema { input, output },
655 can_subsume,
656 backend_id,
657 };
658
659 let old = self.functions.insert(decl.name.clone(), function);
660 if old.is_some() {
661 panic!(
662 "Typechecking should have caught function already bound: {}",
663 decl.name
664 );
665 }
666
667 Ok(())
668 }
669
670 pub fn function_to_dag(
675 &self,
676 sym: &str,
677 n: usize,
678 include_output: bool,
679 ) -> Result<(Vec<TermId>, Option<Vec<TermId>>, TermDag), Error> {
680 let func = self
681 .functions
682 .get(sym)
683 .ok_or(TypeError::UnboundFunction(sym.to_owned(), span!()))?;
684 let mut rootsorts = func.schema.input.clone();
685 if include_output {
686 rootsorts.push(func.schema.output.clone());
687 }
688 let extractor = Extractor::compute_costs_from_rootsorts(
689 Some(rootsorts),
690 self,
691 TreeAdditiveCostModel::default(),
692 );
693
694 let mut termdag = TermDag::default();
695 let mut inputs: Vec<TermId> = Vec::new();
696 let mut output: Option<Vec<TermId>> = if include_output {
697 Some(Vec::new())
698 } else {
699 None
700 };
701
702 let extract_row = |row: egglog_bridge::FunctionRow| {
703 if inputs.len() < n {
704 let mut children: Vec<TermId> = Vec::new();
706 for (value, sort) in row.vals.iter().zip(&func.schema.input) {
707 let (_, term_id) = extractor
708 .extract_best_with_sort(self, &mut termdag, *value, sort.clone())
709 .unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
710 children.push(term_id);
711 }
712 inputs.push(termdag.app(sym.to_owned(), children));
713 if include_output {
714 let value = row.vals[func.schema.input.len()];
715 let sort = &func.schema.output;
716 let (_, term) = extractor
717 .extract_best_with_sort(self, &mut termdag, value, sort.clone())
718 .unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
719 output.as_mut().unwrap().push(term);
720 }
721 true
722 } else {
723 false
724 }
725 };
726
727 self.backend.for_each_while(func.backend_id, extract_row);
728
729 Ok((inputs, output, termdag))
730 }
731
732 pub fn print_function(
735 &mut self,
736 sym: &str,
737 n: Option<usize>,
738 file: Option<File>,
739 mode: PrintFunctionMode,
740 ) -> Result<Option<CommandOutput>, Error> {
741 let n = match n {
742 Some(n) => {
743 log::info!("Printing up to {n} tuples of function {sym} as {mode}");
744 n
745 }
746 None => {
747 log::info!("Printing all tuples of function {sym} as {mode}");
748 usize::MAX
749 }
750 };
751
752 let (terms, outputs, termdag) = self.function_to_dag(sym, n, true)?;
753 let f = self
754 .functions
755 .get(sym)
756 .unwrap();
758 let terms_and_outputs: Vec<_> = terms.into_iter().zip(outputs.unwrap()).collect();
759 let output = CommandOutput::PrintFunction(f.clone(), termdag, terms_and_outputs, mode);
760 match file {
761 Some(mut file) => {
762 log::info!("Writing output to file");
763 file.write_all(output.to_string().as_bytes())
764 .expect("Error writing to file");
765 Ok(None)
766 }
767 None => Ok(Some(output)),
768 }
769 }
770
771 pub fn print_size(&self, sym: Option<&str>) -> Result<CommandOutput, Error> {
774 if let Some(sym) = sym {
775 let f = self
776 .functions
777 .get(sym)
778 .ok_or(TypeError::UnboundFunction(sym.to_owned(), span!()))?;
779 let size = self.backend.table_size(f.backend_id);
780 log::info!("Function {} has size {}", sym, size);
781 Ok(CommandOutput::PrintFunctionSize(size))
782 } else {
783 let mut lens = self
785 .functions
786 .iter()
787 .map(|(sym, f)| (sym.clone(), self.backend.table_size(f.backend_id)))
788 .collect::<Vec<_>>();
789
790 lens.sort_by_key(|(name, _)| name.clone());
792 if log_enabled!(Level::Info) {
793 for (sym, len) in &lens {
794 log::info!("Function {} has size {}", sym, len);
795 }
796 }
797 Ok(CommandOutput::PrintAllFunctionsSize(lens))
798 }
799 }
800
801 fn run_schedule(&mut self, sched: &ResolvedSchedule) -> Result<RunReport, Error> {
803 match sched {
804 ResolvedSchedule::Run(span, config) => self.run_rules(span, config),
805 ResolvedSchedule::Repeat(_span, limit, sched) => {
806 let mut report = RunReport::default();
807 for _i in 0..*limit {
808 let rec = self.run_schedule(sched)?;
809 let updated = rec.updated;
810 report.union(rec);
811 if !updated {
812 break;
813 }
814 }
815 Ok(report)
816 }
817 ResolvedSchedule::Saturate(_span, sched) => {
818 let mut report = RunReport::default();
819 loop {
820 let rec = self.run_schedule(sched)?;
821 let updated = rec.updated;
822 report.union(rec);
823 if !updated {
824 break;
825 }
826 }
827 Ok(report)
828 }
829 ResolvedSchedule::Sequence(_span, scheds) => {
830 let mut report = RunReport::default();
831 for sched in scheds {
832 report.union(self.run_schedule(sched)?);
833 }
834 Ok(report)
835 }
836 }
837 }
838
839 pub fn extract_value(
842 &self,
843 sort: &ArcSort,
844 value: Value,
845 ) -> Result<(TermDag, TermId, DefaultCost), Error> {
846 self.extract_value_with_cost_model(sort, value, TreeAdditiveCostModel::default())
847 }
848
849 pub fn extract_value_with_cost_model<CM: CostModel<DefaultCost> + 'static>(
853 &self,
854 sort: &ArcSort,
855 value: Value,
856 cost_model: CM,
857 ) -> Result<(TermDag, TermId, DefaultCost), Error> {
858 let extractor =
859 Extractor::compute_costs_from_rootsorts(Some(vec![sort.clone()]), self, cost_model);
860 let mut termdag = TermDag::default();
861 let (cost, term) = extractor.extract_best(self, &mut termdag, value).unwrap();
862 Ok((termdag, term, cost))
863 }
864
865 pub fn extract_value_to_string(
868 &self,
869 sort: &ArcSort,
870 value: Value,
871 ) -> Result<(String, DefaultCost), Error> {
872 let (termdag, term, cost) = self.extract_value(sort, value)?;
873 Ok((termdag.to_string(term), cost))
874 }
875
876 fn run_rules(&mut self, span: &Span, config: &ResolvedRunConfig) -> Result<RunReport, Error> {
877 log::debug!("Running ruleset: {}", config.ruleset);
878 let mut report: RunReport = Default::default();
879
880 let GenericRunConfig { ruleset, until } = config;
881
882 if let Some(facts) = until {
883 if self.check_facts(span, facts).is_ok() {
884 log::info!(
885 "Breaking early because of facts:\n {}!",
886 ListDisplay(facts, "\n")
887 );
888 return Ok(report);
889 }
890 }
891
892 let subreport = self.step_rules(ruleset)?;
893 report.union(subreport);
894
895 if log_enabled!(Level::Debug) {
896 log::debug!("database size: {}", self.num_tuples());
897 }
898
899 Ok(report)
900 }
901
902 pub fn step_rules(&mut self, ruleset: &str) -> Result<RunReport, Error> {
909 fn collect_rule_ids(
910 ruleset: &str,
911 rulesets: &IndexMap<String, Ruleset>,
912 ids: &mut Vec<egglog_bridge::RuleId>,
913 ) {
914 match &rulesets[ruleset] {
915 Ruleset::Rules(rules) => {
916 for (_, id) in rules.values() {
917 ids.push(*id);
918 }
919 }
920 Ruleset::Combined(sub_rulesets) => {
921 for sub_ruleset in sub_rulesets {
922 collect_rule_ids(sub_ruleset, rulesets, ids);
923 }
924 }
925 }
926 }
927
928 let mut rule_ids = Vec::new();
929 collect_rule_ids(ruleset, &self.rulesets, &mut rule_ids);
930
931 let iteration_report = self
932 .backend
933 .run_rules(&rule_ids)
934 .map_err(|e| Error::BackendError(e.to_string()))?;
935
936 Ok(RunReport::singleton(ruleset, iteration_report))
937 }
938
939 fn add_rule(&mut self, rule: ast::ResolvedRule) -> Result<String, Error> {
940 let core_rule = rule.to_canonicalized_core_rule(
943 &self.type_info,
944 &mut self.parser.symbol_gen,
945 self.proof_state.original_typechecking.is_none(),
946 )?;
947 let (query, actions) = (&core_rule.body, &core_rule.head);
948
949 let rule_id = {
950 let mut translator = BackendRule::new(
951 self.backend.new_rule(&rule.name, self.seminaive),
952 &self.functions,
953 &self.type_info,
954 );
955 translator.query(query, false);
956 translator.actions(actions)?;
957 translator.build()
958 };
959
960 if let Some(rules) = self.rulesets.get_mut(&rule.ruleset) {
961 match rules {
962 Ruleset::Rules(rules) => {
963 match rules.entry(rule.name.clone()) {
964 indexmap::map::Entry::Occupied(_) => {
965 let name = rule.name;
966 panic!("Rule '{name}' was already present")
967 }
968 indexmap::map::Entry::Vacant(e) => e.insert((core_rule, rule_id)),
969 };
970 Ok(rule.name)
971 }
972 Ruleset::Combined(_) => Err(Error::CombinedRulesetError(rule.ruleset, rule.span)),
973 }
974 } else {
975 Err(Error::NoSuchRuleset(rule.ruleset, rule.span))
976 }
977 }
978
979 fn eval_actions(&mut self, actions: &ResolvedActions) -> Result<(), Error> {
980 let mut binding = IndexSet::default();
981 let mut ctx = CoreActionContext::new(
982 &self.type_info,
983 &mut binding,
984 &mut self.parser.symbol_gen,
985 self.proof_state.original_typechecking.is_none(),
986 );
987 let (actions, _) = actions.to_core_actions(&mut ctx)?;
988
989 let mut translator = BackendRule::new(
990 self.backend.new_rule("eval_actions", false),
991 &self.functions,
992 &self.type_info,
993 );
994 translator.actions(&actions)?;
995 let id = translator.build();
996 let result = self.backend.run_rules(&[id]);
997 self.backend.free_rule(id);
998
999 match result {
1000 Ok(_) => Ok(()),
1001 Err(e) => Err(Error::BackendError(e.to_string())),
1002 }
1003 }
1004
1005 pub fn eval_expr(&mut self, expr: &Expr) -> Result<(ArcSort, Value), Error> {
1007 let span = expr.span();
1008 let command = Command::Action(Action::Expr(span.clone(), expr.clone()));
1009 let resolved_commands = self.resolve_command(command)?;
1010 assert_eq!(resolved_commands.len(), 1);
1011 let resolved_command = resolved_commands.into_iter().next().unwrap();
1012 let resolved_expr = match resolved_command {
1013 ResolvedNCommand::CoreAction(ResolvedAction::Expr(_, resolved_expr)) => resolved_expr,
1014 _ => unreachable!(),
1015 };
1016 let sort = resolved_expr.output_type();
1017 let value = self.eval_resolved_expr(span, &resolved_expr)?;
1018 Ok((sort, value))
1019 }
1020
1021 fn eval_resolved_expr(&mut self, span: Span, expr: &ResolvedExpr) -> Result<Value, Error> {
1022 let unit_id = self.backend.base_values().get_ty::<()>();
1023 let unit_val = self.backend.base_values().get(());
1024
1025 let result: egglog_bridge::SideChannel<Value> = Default::default();
1026 let result_ref = result.clone();
1027 let ext_id = self
1028 .backend
1029 .register_external_func(Box::new(make_external_func(move |_es, vals| {
1030 debug_assert!(vals.len() == 1);
1031 *result_ref.lock().unwrap() = Some(vals[0]);
1032 Some(unit_val)
1033 })));
1034
1035 let mut translator = BackendRule::new(
1036 self.backend.new_rule("eval_resolved_expr", false),
1037 &self.functions,
1038 &self.type_info,
1039 );
1040
1041 let result_var = ResolvedVar {
1042 name: self.parser.symbol_gen.fresh("eval_resolved_expr"),
1043 sort: expr.output_type(),
1044 is_global_ref: false,
1045 };
1046 let actions = ResolvedActions::singleton(ResolvedAction::Let(
1047 span.clone(),
1048 result_var.clone(),
1049 expr.clone(),
1050 ));
1051 let mut binding = IndexSet::default();
1052 let mut ctx = CoreActionContext::new(
1053 &self.type_info,
1054 &mut binding,
1055 &mut self.parser.symbol_gen,
1056 self.proof_state.original_typechecking.is_none(),
1057 );
1058 let actions = actions.to_core_actions(&mut ctx)?.0;
1059 translator.actions(&actions)?;
1060
1061 let arg = translator.entry(&ResolvedAtomTerm::Var(span.clone(), result_var));
1062 translator.rb.call_external_func(
1063 ext_id,
1064 &[arg],
1065 egglog_bridge::ColumnTy::Base(unit_id),
1066 || "this function will never panic".to_string(),
1067 );
1068
1069 let id = translator.build();
1070 let rule_result = self.backend.run_rules(&[id]);
1071 self.backend.free_rule(id);
1072 self.backend.free_external_func(ext_id);
1073 let _ = rule_result.map_err(|e| {
1074 Error::BackendError(format!("Failed to evaluate expression '{}': {}", expr, e))
1075 })?;
1076
1077 let result = result.lock().unwrap().unwrap();
1078 Ok(result)
1079 }
1080
1081 fn add_combined_ruleset(&mut self, name: String, rulesets: Vec<String>) {
1082 match self.rulesets.entry(name.clone()) {
1083 Entry::Occupied(_) => panic!("Ruleset '{name}' was already present"),
1084 Entry::Vacant(e) => e.insert(Ruleset::Combined(rulesets)),
1085 };
1086 }
1087
1088 fn add_ruleset(&mut self, name: String) {
1089 match self.rulesets.entry(name.clone()) {
1090 Entry::Occupied(_) => panic!("Ruleset '{name}' was already present"),
1091 Entry::Vacant(e) => e.insert(Ruleset::Rules(Default::default())),
1092 };
1093 }
1094
1095 fn check_facts(&mut self, span: &Span, facts: &[ResolvedFact]) -> Result<(), Error> {
1096 let fresh_name = self.parser.symbol_gen.fresh("check_facts");
1097 let fresh_ruleset = self.parser.symbol_gen.fresh("check_facts_ruleset");
1098 let rule = ast::ResolvedRule {
1099 span: span.clone(),
1100 head: ResolvedActions::default(),
1101 body: facts.to_vec(),
1102 name: fresh_name.clone(),
1103 ruleset: fresh_ruleset.clone(),
1104 };
1105 let core_rule = rule.to_canonicalized_core_rule(
1106 &self.type_info,
1107 &mut self.parser.symbol_gen,
1108 self.proof_state.original_typechecking.is_none(),
1109 )?;
1110 let query = core_rule.body;
1111
1112 let ext_sc = egglog_bridge::SideChannel::default();
1113 let ext_sc_ref = ext_sc.clone();
1114 let ext_id = self
1115 .backend
1116 .register_external_func(Box::new(make_external_func(move |_, _| {
1117 *ext_sc_ref.lock().unwrap() = Some(());
1118 Some(Value::new_const(0))
1119 })));
1120
1121 let mut translator = BackendRule::new(
1122 self.backend.new_rule("check_facts", false),
1123 &self.functions,
1124 &self.type_info,
1125 );
1126 translator.query(&query, true);
1127 translator
1128 .rb
1129 .call_external_func(ext_id, &[], egglog_bridge::ColumnTy::Id, || {
1130 "this function will never panic".to_string()
1131 });
1132 let id = translator.build();
1133 let _ = self.backend.run_rules(&[id]).unwrap();
1134 self.backend.free_rule(id);
1135
1136 self.backend.free_external_func(ext_id);
1137
1138 let ext_sc_val = ext_sc.lock().unwrap().take();
1139 let matched = matches!(ext_sc_val, Some(()));
1140
1141 if !matched {
1142 Err(Error::CheckError(
1143 facts.iter().map(|f| f.clone().make_unresolved()).collect(),
1144 span.clone(),
1145 ))
1146 } else {
1147 Ok(())
1148 }
1149 }
1150
1151 fn run_command(&mut self, command: ResolvedNCommand) -> Result<Option<CommandOutput>, Error> {
1152 match command {
1153 ResolvedNCommand::Sort(_span, name, _presort_and_args) => {
1155 log::info!("Declared sort {}.", name)
1156 }
1157 ResolvedNCommand::Function(fdecl) => {
1158 self.declare_function(&fdecl)?;
1159 log::info!("Declared {} {}.", fdecl.subtype, fdecl.name)
1160 }
1161 ResolvedNCommand::AddRuleset(_span, name) => {
1162 self.add_ruleset(name.clone());
1163 log::info!("Declared ruleset {name}.");
1164 }
1165 ResolvedNCommand::UnstableCombinedRuleset(_span, name, others) => {
1166 self.add_combined_ruleset(name.clone(), others);
1167 log::info!("Declared ruleset {name}.");
1168 }
1169 ResolvedNCommand::NormRule { rule } => {
1170 let name = rule.name.clone();
1171 self.add_rule(rule)?;
1172 log::info!("Declared rule {name}.")
1173 }
1174 ResolvedNCommand::RunSchedule(sched) => {
1175 let report = self.run_schedule(&sched)?;
1176 log::info!("Ran schedule {}.", sched);
1177 log::info!("Report: {}", report);
1178 self.overall_run_report.union(report.clone());
1179 return Ok(Some(CommandOutput::RunSchedule(report)));
1180 }
1181 ResolvedNCommand::PrintOverallStatistics(span, file) => match file {
1182 None => {
1183 log::info!("Printed overall statistics");
1184 return Ok(Some(CommandOutput::OverallStatistics(
1185 self.overall_run_report.clone(),
1186 )));
1187 }
1188 Some(path) => {
1189 let mut file = std::fs::File::create(&path)
1190 .map_err(|e| Error::IoError(path.clone().into(), e, span.clone()))?;
1191 log::info!("Printed overall statistics to json file {}", path);
1192
1193 serde_json::to_writer(&mut file, &self.overall_run_report)
1194 .expect("error serializing to json");
1195 }
1196 },
1197 ResolvedNCommand::Check(span, facts) => {
1198 self.check_facts(&span, &facts)?;
1199 log::info!("Checked fact {:?}.", facts);
1200 }
1201 ResolvedNCommand::CoreAction(action) => match &action {
1202 ResolvedAction::Let(_, name, contents) => {
1203 panic!("Globals should have been desugared away: {name} = {contents}")
1204 }
1205 _ => {
1206 self.eval_actions(&ResolvedActions::new(vec![action.clone()]))?;
1207 }
1208 },
1209 ResolvedNCommand::Extract(span, expr, variants) => {
1210 let sort = expr.output_type();
1211
1212 let x = self.eval_resolved_expr(span.clone(), &expr)?;
1213 let n = self.eval_resolved_expr(span, &variants)?;
1214 let n: i64 = self.backend.base_values().unwrap(n);
1215
1216 let mut termdag = TermDag::default();
1217
1218 let extractor = Extractor::compute_costs_from_rootsorts(
1219 Some(vec![sort]),
1220 self,
1221 TreeAdditiveCostModel::default(),
1222 );
1223 return if n == 0 {
1224 if let Some((cost, term)) = extractor.extract_best(self, &mut termdag, x) {
1225 if log_enabled!(Level::Info) {
1227 log::info!("extracted with cost {cost}: {}", termdag.to_string(term));
1228 }
1229 Ok(Some(CommandOutput::ExtractBest(termdag, cost, term)))
1230 } else {
1231 Err(Error::ExtractError(
1232 "Unable to find any valid extraction (likely due to subsume or delete)"
1233 .to_string(),
1234 ))
1235 }
1236 } else {
1237 if n < 0 {
1238 panic!("Cannot extract negative number of variants");
1239 }
1240 let terms: Vec<TermId> = extractor
1241 .extract_variants(self, &mut termdag, x, n as usize)
1242 .iter()
1243 .map(|e| e.1)
1244 .collect();
1245 if log_enabled!(Level::Info) {
1246 let expr_str = expr.to_string();
1247 log::info!("extracted {} variants for {expr_str}", terms.len());
1248 }
1249 Ok(Some(CommandOutput::ExtractVariants(termdag, terms)))
1250 };
1251 }
1252 ResolvedNCommand::Push(n) => {
1253 (0..n).for_each(|_| self.push());
1254 log::info!("Pushed {n} levels.")
1255 }
1256 ResolvedNCommand::Pop(span, n) => {
1257 for _ in 0..n {
1258 self.pop().map_err(|err| {
1259 if let Error::Pop(_) = err {
1260 Error::Pop(span.clone())
1261 } else {
1262 err
1263 }
1264 })?;
1265 }
1266 log::info!("Popped {n} levels.")
1267 }
1268 ResolvedNCommand::PrintFunction(span, f, n, file, mode) => {
1269 let file = file
1270 .map(|file| {
1271 std::fs::File::create(&file)
1272 .map_err(|e| Error::IoError(file.into(), e, span.clone()))
1273 })
1274 .transpose()?;
1275 return self.print_function(&f, n, file, mode).map_err(|e| match e {
1276 Error::TypeError(TypeError::UnboundFunction(f, _)) => {
1277 Error::TypeError(TypeError::UnboundFunction(f, span.clone()))
1278 }
1279 _ => e,
1281 });
1282 }
1283 ResolvedNCommand::PrintSize(span, f) => {
1284 let res = self.print_size(f.as_deref()).map_err(|e| match e {
1285 Error::TypeError(TypeError::UnboundFunction(f, _)) => {
1286 Error::TypeError(TypeError::UnboundFunction(f, span.clone()))
1287 }
1288 _ => e,
1290 })?;
1291 return Ok(Some(res));
1292 }
1293 ResolvedNCommand::Fail(span, c) => {
1294 let result = self.run_command(*c);
1295 if let Err(e) = result {
1296 log::info!("Command failed as expected: {e}");
1297 } else {
1298 return Err(Error::ExpectFail(span));
1299 }
1300 }
1301 ResolvedNCommand::Input {
1302 span: _,
1303 name,
1304 file,
1305 } => {
1306 self.input_file(&name, file)?;
1307 }
1308 ResolvedNCommand::Output { span, file, exprs } => {
1309 let mut filename = self.fact_directory.clone().unwrap_or_default();
1310 filename.push(file.as_str());
1311 let mut f = File::options()
1313 .append(true)
1314 .create(true)
1315 .open(&filename)
1316 .map_err(|e| Error::IoError(filename.clone(), e, span.clone()))?;
1317
1318 let extractor = Extractor::compute_costs_from_rootsorts(
1319 None,
1320 self,
1321 TreeAdditiveCostModel::default(),
1322 );
1323 let mut termdag: TermDag = Default::default();
1324
1325 use std::io::Write;
1326 for expr in exprs {
1327 let value = self.eval_resolved_expr(span.clone(), &expr)?;
1328 let expr_type = expr.output_type();
1329
1330 let term = extractor
1331 .extract_best_with_sort(self, &mut termdag, value, expr_type)
1332 .unwrap()
1333 .1;
1334 writeln!(f, "{}", termdag.to_string(term))
1335 .map_err(|e| Error::IoError(filename.clone(), e, span.clone()))?;
1336 }
1337
1338 log::info!("Output to '{filename:?}'.")
1339 }
1340 ResolvedNCommand::UserDefined(_span, name, exprs) => {
1341 let command = self.commands.swap_remove(&name).unwrap_or_else(|| {
1342 panic!("Unrecognized user-defined command: {}", name);
1343 });
1344 let res = command.update(self, &exprs);
1345 self.commands.insert(name, command);
1346 return res;
1347 }
1348 ResolvedNCommand::ProveExists(span, resolved_call) => {
1349 let mut instrument = ProofInstrumentor { egraph: self };
1350 let (proof_store, proof_id) =
1351 instrument
1352 .prove_exists(&resolved_call)
1353 .map_err(|error| Error::ProofError {
1354 span: span.clone(),
1355 error,
1356 })?;
1357 return Ok(Some(CommandOutput::ProveExists {
1358 proof_store,
1359 proof_id,
1360 }));
1361 }
1362 };
1363
1364 Ok(None)
1365 }
1366
1367 fn input_file(&mut self, func_name: &str, file: String) -> Result<(), Error> {
1368 let function_type = self
1369 .type_info
1370 .get_func_type(func_name)
1371 .unwrap_or_else(|| panic!("Unrecognized function name {}", func_name));
1372 let func = self.functions.get_mut(func_name).unwrap();
1373
1374 let mut filename = self.fact_directory.clone().unwrap_or_default();
1375 filename.push(file.as_str());
1376
1377 for t in &func.schema.input {
1380 match t.name() {
1381 "i64" | "f64" | "String" => {}
1382 s => panic!("Unsupported type {} for input", s),
1383 }
1384 }
1385
1386 if function_type.subtype != FunctionSubtype::Constructor {
1387 match func.schema.output.name() {
1388 "i64" | "String" | "Unit" => {}
1389 s => panic!("Unsupported type {} for input", s),
1390 }
1391 }
1392
1393 log::info!("Opening file '{:?}'...", filename);
1394 let mut f = File::open(filename).unwrap();
1395 let mut contents = String::new();
1396 f.read_to_string(&mut contents).unwrap();
1397
1398 let mut parsed_contents: Vec<Vec<Value>> = Vec::with_capacity(contents.lines().count());
1400
1401 let mut row_schema = func.schema.input.clone();
1402 if function_type.subtype == FunctionSubtype::Custom {
1403 row_schema.push(func.schema.output.clone());
1404 }
1405
1406 log::debug!("{:?}", row_schema);
1407
1408 let unit_val = self.backend.base_values().get(());
1409
1410 for line in contents.lines() {
1411 let mut it = line.split('\t').map(|s| s.trim());
1412
1413 let mut row: Vec<Value> = Vec::with_capacity(row_schema.len());
1414
1415 for sort in row_schema.iter() {
1416 if let Some(raw) = it.next() {
1417 let val = match sort.name() {
1418 "i64" => {
1419 if let Ok(i) = raw.parse::<i64>() {
1420 self.backend.base_values().get(i)
1421 } else {
1422 return Err(Error::InputFileFormatError(file));
1423 }
1424 }
1425 "f64" => {
1426 if let Ok(f) = raw.parse::<f64>() {
1427 self.backend
1428 .base_values()
1429 .get::<F>(core_relations::Boxed::new(f.into()))
1430 } else {
1431 return Err(Error::InputFileFormatError(file));
1432 }
1433 }
1434 "String" => self.backend.base_values().get::<S>(raw.to_string().into()),
1435 "Unit" => unit_val,
1436 _ => panic!("Unreachable"),
1437 };
1438 row.push(val);
1439 } else {
1440 break;
1441 }
1442 }
1443
1444 if row.is_empty() {
1445 continue;
1446 }
1447
1448 if row.len() != row_schema.len() || it.next().is_some() {
1449 return Err(Error::InputFileFormatError(file));
1450 }
1451
1452 parsed_contents.push(row);
1453 }
1454
1455 log::debug!("Successfully loaded file.");
1456
1457 let num_facts = parsed_contents.len();
1458
1459 let mut table_action = egglog_bridge::TableAction::new(&self.backend, func.backend_id);
1460
1461 if function_type.subtype != FunctionSubtype::Constructor {
1462 self.backend.with_execution_state(|es| {
1463 for row in parsed_contents.iter() {
1464 table_action.insert(es, row.iter().copied());
1465 }
1466 Some(unit_val)
1467 });
1468 } else {
1469 self.backend.with_execution_state(|es| {
1470 for row in parsed_contents.iter() {
1471 table_action.lookup(es, row);
1472 }
1473 Some(unit_val)
1474 });
1475 }
1476
1477 self.backend.flush_updates();
1478
1479 log::info!("Read {num_facts} facts into {func_name} from '{file}'.");
1480 Ok(())
1481 }
1482
1483 fn resolve_command(&mut self, command: Command) -> Result<Vec<ResolvedNCommand>, Error> {
1486 let desugared = desugar_command(command, &mut self.parser, self.proof_state.proof_testing)?;
1487
1488 if let Some(original_typechecking) = self.proof_state.original_typechecking.as_mut() {
1490 let typechecked = original_typechecking.typecheck_program(&desugared)?;
1493
1494 for command in &typechecked {
1495 if !command_supports_proof_encoding(&command.to_command(), &self.type_info) {
1496 let command_text = format!("{}", command.to_command());
1497 return Err(Error::UnsupportedProofCommand {
1498 command: command_text,
1499 });
1500 }
1501 }
1502
1503 let normalized = proof_form(typechecked, &mut self.parser.symbol_gen);
1504
1505 self.desugared_commands.extend_from_slice(&normalized);
1507
1508 let typechecked_no_globals =
1510 proof_global_remover::remove_globals(normalized, &mut self.parser.symbol_gen);
1511 for command in &typechecked_no_globals {
1512 self.names.check_shadowing(command)?;
1513 }
1514
1515 let term_encoding_added =
1516 ProofInstrumentor::add_term_encoding(self, typechecked_no_globals);
1517 let mut new_typechecked = vec![];
1518 for new_cmd in term_encoding_added {
1519 let desugared =
1520 desugar_command(new_cmd, &mut self.parser, self.proof_state.proof_testing)?;
1521 for cmd in &desugared {
1522 log::debug!("Desugared term encoding: {}", cmd.to_command());
1523 }
1524
1525 let desugared_typechecked = self.typecheck_program(&desugared)?;
1527 let desugared_typechecked = remove_globals::remove_globals(
1529 desugared_typechecked,
1530 &mut self.parser.symbol_gen,
1531 );
1532
1533 new_typechecked.extend(desugared_typechecked);
1534 }
1535 Ok(new_typechecked)
1536 } else {
1537 let mut typechecked = self.typecheck_program(&desugared)?;
1538
1539 typechecked = remove_globals::remove_globals(typechecked, &mut self.parser.symbol_gen);
1540 for command in &typechecked {
1541 self.names.check_shadowing(command)?;
1542 }
1543 Ok(typechecked)
1544 }
1545 }
1546
1547 fn process_program_internal(
1550 &mut self,
1551 program: Vec<Command>,
1552 run_commands: bool,
1553 ) -> Result<(Vec<CommandOutput>, Vec<ResolvedCommand>), Error> {
1554 let mut outputs = Vec::new();
1555 let mut desugared_commands = Vec::new();
1556
1557 for before_expanded_command in program {
1558 let macro_expanded = self.command_macros.apply(
1561 before_expanded_command,
1562 &mut self.parser.symbol_gen,
1563 &self.type_info,
1564 )?;
1565
1566 for command in macro_expanded {
1567 if let Command::Include(span, file) = &command {
1569 let s = std::fs::read_to_string(file)
1570 .unwrap_or_else(|_| panic!("{span} Failed to read file {file}"));
1571 let included_program = self
1572 .parser
1573 .get_program_from_string(Some(file.clone()), &s)?;
1574 let (included_outputs, included_desugared) =
1576 self.process_program_internal(included_program, run_commands)?;
1577 outputs.extend(included_outputs);
1578 desugared_commands.extend(included_desugared);
1579 } else {
1580 for processed in self.resolve_command(command)? {
1581 desugared_commands.push(processed.to_command());
1582
1583 if run_commands
1585 || matches!(
1586 processed,
1587 ResolvedNCommand::Push(_) | ResolvedNCommand::Pop(_, _)
1588 )
1589 {
1590 let result = self.run_command(processed)?;
1591 if let Some(output) = result {
1592 outputs.push(output);
1593 }
1594 }
1595 }
1596 }
1597 }
1598 }
1599
1600 Ok((outputs, desugared_commands))
1601 }
1602
1603 pub fn run_program(&mut self, program: Vec<Command>) -> Result<Vec<CommandOutput>, Error> {
1606 let (outputs, _desugared_commands) = self.process_program_internal(program, true)?;
1607 Ok(outputs)
1608 }
1609
1610 pub fn desugar_program(
1613 &mut self,
1614 filename: Option<String>,
1615 input: &str,
1616 ) -> Result<Vec<ResolvedCommand>, Error> {
1617 let parsed = self.parser.get_program_from_string(filename, input)?;
1618 let (_outputs, desugared_commands) = self.process_program_internal(parsed, false)?;
1619 Ok(desugared_commands)
1620 }
1621
1622 pub fn parse_program(
1624 &mut self,
1625 filename: Option<String>,
1626 input: &str,
1627 ) -> Result<Vec<Command>, Error> {
1628 let parsed = self.parser.get_program_from_string(filename, input)?;
1629 Ok(parsed)
1630 }
1631
1632 pub fn parse_and_run_program(
1638 &mut self,
1639 filename: Option<String>,
1640 input: &str,
1641 ) -> Result<Vec<CommandOutput>, Error> {
1642 let parsed = self.parser.get_program_from_string(filename, input)?;
1643 self.run_program(parsed)
1644 }
1645
1646 pub fn num_tuples(&self) -> usize {
1649 self.functions
1650 .values()
1651 .map(|f| self.backend.table_size(f.backend_id))
1652 .sum()
1653 }
1654
1655 pub fn get_sort<S: Sort>(&self) -> Arc<S> {
1657 self.type_info.get_sort()
1658 }
1659
1660 pub fn get_sort_by<S: Sort>(&self, f: impl Fn(&Arc<S>) -> bool) -> Arc<S> {
1662 self.type_info.get_sort_by(f)
1663 }
1664
1665 pub fn get_sorts<S: Sort>(&self) -> Vec<Arc<S>> {
1667 self.type_info.get_sorts()
1668 }
1669
1670 pub fn get_sorts_by<S: Sort>(&self, f: impl Fn(&Arc<S>) -> bool) -> Vec<Arc<S>> {
1672 self.type_info.get_sorts_by(f)
1673 }
1674
1675 pub fn get_arcsort_by(&self, f: impl Fn(&ArcSort) -> bool) -> ArcSort {
1677 self.type_info.get_arcsort_by(f)
1678 }
1679
1680 pub fn get_arcsorts_by(&self, f: impl Fn(&ArcSort) -> bool) -> Vec<ArcSort> {
1682 self.type_info.get_arcsorts_by(f)
1683 }
1684
1685 pub fn get_sort_by_name(&self, sym: &str) -> Option<&ArcSort> {
1687 self.type_info.get_sort_by_name(sym)
1688 }
1689
1690 pub fn get_overall_run_report(&self) -> &RunReport {
1692 &self.overall_run_report
1693 }
1694
1695 pub fn value_to_base<T: BaseValue>(&self, x: Value) -> T {
1697 self.backend.base_values().unwrap::<T>(x)
1698 }
1699
1700 pub fn base_to_value<T: BaseValue>(&self, x: T) -> Value {
1702 self.backend.base_values().get::<T>(x)
1703 }
1704
1705 pub fn value_to_container<T: ContainerValue>(
1712 &self,
1713 x: Value,
1714 ) -> Option<impl Deref<Target = T>> {
1715 self.backend.container_values().get_val::<T>(x)
1716 }
1717
1718 pub fn container_to_value<T: ContainerValue>(&mut self, x: T) -> Value {
1720 self.backend.with_execution_state(|state| {
1721 self.backend.container_values().register_val::<T>(x, state)
1722 })
1723 }
1724
1725 pub fn get_size(&self, func: &str) -> usize {
1729 let function_id = self.functions.get(func).unwrap().backend_id;
1730 self.backend.table_size(function_id)
1731 }
1732
1733 pub fn lookup_function(&self, name: &str, key: &[Value]) -> Option<Value> {
1738 let func = self.functions.get(name).unwrap().backend_id;
1739 self.backend.lookup_id(func, key)
1740 }
1741
1742 pub fn get_function(&self, name: &str) -> Option<&Function> {
1746 self.functions.get(name)
1747 }
1748
1749 pub fn set_report_level(&mut self, level: ReportLevel) {
1750 self.backend.set_report_level(level);
1751 }
1752
1753 pub fn dump_debug_info(&self) {
1757 self.backend.dump_debug_info();
1758 }
1759
1760 pub fn get_canonical_value(&self, val: Value, sort: &ArcSort) -> Value {
1762 self.backend
1763 .get_canon_repr(val, sort.column_ty(&self.backend))
1764 }
1765}
1766
1767struct BackendRule<'a> {
1768 rb: egglog_bridge::RuleBuilder<'a>,
1769 entries: HashMap<core::ResolvedAtomTerm, QueryEntry>,
1770 functions: &'a IndexMap<String, Function>,
1771 type_info: &'a TypeInfo,
1772}
1773
1774impl<'a> BackendRule<'a> {
1775 fn new(
1776 rb: egglog_bridge::RuleBuilder<'a>,
1777 functions: &'a IndexMap<String, Function>,
1778 type_info: &'a TypeInfo,
1779 ) -> BackendRule<'a> {
1780 BackendRule {
1781 rb,
1782 functions,
1783 type_info,
1784 entries: Default::default(),
1785 }
1786 }
1787
1788 fn entry(&mut self, x: &core::ResolvedAtomTerm) -> QueryEntry {
1789 self.entries
1790 .entry(x.clone())
1791 .or_insert_with(|| match x {
1792 core::GenericAtomTerm::Var(_, v) => self
1793 .rb
1794 .new_var_named(v.sort.column_ty(self.rb.egraph()), &v.name),
1795 core::GenericAtomTerm::Literal(_, l) => literal_to_entry(self.rb.egraph(), l),
1796 core::GenericAtomTerm::Global(..) => {
1797 panic!("Globals should have been desugared")
1798 }
1799 })
1800 .clone()
1801 }
1802
1803 fn func(&self, f: &typechecking::FuncType) -> egglog_bridge::FunctionId {
1804 self.functions[&f.name].backend_id
1805 }
1806
1807 fn prim(
1808 &mut self,
1809 prim: &core::SpecializedPrimitive,
1810 args: &[core::ResolvedAtomTerm],
1811 ) -> (ExternalFunctionId, Vec<QueryEntry>, ColumnTy) {
1812 let mut qe_args = self.args(args);
1813
1814 if prim.name() == "unstable-fn" {
1815 let core::ResolvedAtomTerm::Literal(_, Literal::String(ref name)) = args[0] else {
1816 panic!("expected string literal after `unstable-fn`")
1817 };
1818 let id = if let Some(f) = self.type_info.get_func_type(name) {
1819 ResolvedFunctionId::Lookup(egglog_bridge::TableAction::new(
1820 self.rb.egraph(),
1821 self.func(f),
1822 ))
1823 } else if let Some(possible) = self.type_info.get_prims(name) {
1824 let mut ps: Vec<_> = possible.iter().collect();
1825 ps.retain(|p| {
1826 self.type_info
1827 .get_sorts::<FunctionSort>()
1828 .into_iter()
1829 .any(|f| {
1830 let types: Vec<_> = prim
1831 .input()
1832 .iter()
1833 .skip(1)
1834 .chain(f.inputs())
1835 .chain([&f.output()])
1836 .cloned()
1837 .collect();
1838 p.accept(&types, self.type_info)
1839 })
1840 });
1841 assert!(ps.len() == 1, "options for {name}: {ps:?}");
1842 ResolvedFunctionId::Prim(ps.into_iter().next().unwrap().id)
1843 } else {
1844 panic!("no callable for {name}");
1845 };
1846 let partial_arcsorts = prim.input().iter().skip(1).cloned().collect();
1847
1848 qe_args[0] = self.rb.egraph().base_value_constant(ResolvedFunction {
1849 id,
1850 partial_arcsorts,
1851 name: name.clone(),
1852 });
1853 }
1854
1855 (
1856 prim.external_id(),
1857 qe_args,
1858 prim.output().column_ty(self.rb.egraph()),
1859 )
1860 }
1861
1862 fn args<'b>(
1863 &mut self,
1864 args: impl IntoIterator<Item = &'b core::ResolvedAtomTerm>,
1865 ) -> Vec<QueryEntry> {
1866 args.into_iter().map(|x| self.entry(x)).collect()
1867 }
1868
1869 fn query(&mut self, query: &core::Query<ResolvedCall, ResolvedVar>, include_subsumed: bool) {
1870 for atom in &query.atoms {
1871 match &atom.head {
1872 ResolvedCall::Func(f) => {
1873 let f = self.func(f);
1874 let args = self.args(&atom.args);
1875 let is_subsumed = match include_subsumed {
1876 true => None,
1877 false => Some(false),
1878 };
1879 self.rb.query_table(f, &args, is_subsumed).unwrap();
1880 }
1881 ResolvedCall::Primitive(p) => {
1882 let (p, args, ty) = self.prim(p, &atom.args);
1883 self.rb.query_prim(p, &args, ty).unwrap()
1884 }
1885 }
1886 }
1887 }
1888
1889 fn actions(&mut self, actions: &core::ResolvedCoreActions) -> Result<(), Error> {
1890 for action in &actions.0 {
1891 match action {
1892 core::GenericCoreAction::Let(span, v, f, args) => {
1893 let v = core::GenericAtomTerm::Var(span.clone(), v.clone());
1894 let y = match f {
1895 ResolvedCall::Func(f) => {
1896 let name = f.name.clone();
1897 let f = self.func(f);
1898 let args = self.args(args);
1899 let span = span.clone();
1900 self.rb.lookup(f, &args, move || {
1901 format!("{span}: lookup of function {name} failed")
1902 })
1903 }
1904 ResolvedCall::Primitive(p) => {
1905 let name = p.name().to_owned();
1906 let (p, args, ty) = self.prim(p, args);
1907 let span = span.clone();
1908 self.rb.call_external_func(p, &args, ty, move || {
1909 format!("{span}: call of primitive {name} failed")
1910 })
1911 }
1912 };
1913 self.entries.insert(v, y.into());
1914 }
1915 core::GenericCoreAction::LetAtomTerm(span, v, x) => {
1916 let v = core::GenericAtomTerm::Var(span.clone(), v.clone());
1917 let x = self.entry(x);
1918 self.entries.insert(v, x);
1919 }
1920 core::GenericCoreAction::Set(_, f, xs, y) => match f {
1921 ResolvedCall::Primitive(..) => panic!("runtime primitive set!"),
1922 ResolvedCall::Func(f) => {
1923 let f = self.func(f);
1924 let args = self.args(xs.iter().chain([y]));
1925 self.rb.set(f, &args)
1926 }
1927 },
1928 core::GenericCoreAction::Change(span, change, f, args) => match f {
1929 ResolvedCall::Primitive(..) => panic!("runtime primitive change!"),
1930 ResolvedCall::Func(f) => {
1931 let name = f.name.clone();
1932 let can_subsume = self.functions[&f.name].can_subsume;
1933 let f = self.func(f);
1934 let args = self.args(args);
1935 match change {
1936 Change::Delete => self.rb.remove(f, &args),
1937 Change::Subsume if can_subsume => self.rb.subsume(f, &args),
1938 Change::Subsume => {
1939 return Err(Error::SubsumeMergeError(name, span.clone()));
1940 }
1941 }
1942 }
1943 },
1944 core::GenericCoreAction::Union(_, x, y) => {
1945 let x = self.entry(x);
1946 let y = self.entry(y);
1947 self.rb.union(x, y)
1948 }
1949 core::GenericCoreAction::Panic(_, message) => self.rb.panic(message.clone()),
1950 }
1951 }
1952 Ok(())
1953 }
1954
1955 fn build(self) -> egglog_bridge::RuleId {
1956 self.rb.build()
1957 }
1958}
1959
1960fn literal_to_entry(egraph: &egglog_bridge::EGraph, l: &Literal) -> QueryEntry {
1961 match l {
1962 Literal::Int(x) => egraph.base_value_constant::<i64>(*x),
1963 Literal::Float(x) => egraph.base_value_constant::<sort::F>(x.into()),
1964 Literal::String(x) => egraph.base_value_constant::<sort::S>(sort::S::new(x.clone())),
1965 Literal::Bool(x) => egraph.base_value_constant::<bool>(*x),
1966 Literal::Unit => egraph.base_value_constant::<()>(()),
1967 }
1968}
1969
1970fn literal_to_value(egraph: &egglog_bridge::EGraph, l: &Literal) -> Value {
1971 match l {
1972 Literal::Int(x) => egraph.base_values().get::<i64>(*x),
1973 Literal::Float(x) => egraph.base_values().get::<sort::F>(x.into()),
1974 Literal::String(x) => egraph.base_values().get::<sort::S>(sort::S::new(x.clone())),
1975 Literal::Bool(x) => egraph.base_values().get::<bool>(*x),
1976 Literal::Unit => egraph.base_values().get::<()>(()),
1977 }
1978}
1979
1980#[derive(Debug, Error)]
1981pub enum Error {
1982 #[error(transparent)]
1983 ParseError(#[from] ParseError),
1984 #[error(transparent)]
1985 NotFoundError(#[from] NotFoundError),
1986 #[error(transparent)]
1987 TypeError(#[from] TypeError),
1988 #[error("Errors:\n{}", ListDisplay(.0, "\n"))]
1989 TypeErrors(Vec<TypeError>),
1990 #[error("{}\nCheck failed: \n{}", .1, ListDisplay(.0, "\n"))]
1991 CheckError(Vec<Fact>, Span),
1992 #[error("{1}\nNo such ruleset: {0}")]
1993 NoSuchRuleset(String, Span),
1994 #[error(
1995 "{1}\nAttempted to add a rule to combined ruleset {0}. Combined rulesets may only depend on other rulesets."
1996 )]
1997 CombinedRulesetError(String, Span),
1998 #[error("{0}")]
1999 BackendError(String),
2000 #[error("{0}\nTried to pop too much")]
2001 Pop(Span),
2002 #[error("{0}\nCommand should have failed.")]
2003 ExpectFail(Span),
2004 #[error("{2}\nIO error: {0}: {1}")]
2005 IoError(PathBuf, std::io::Error, Span),
2006 #[error("{1}\nCannot subsume function with merge: {0}")]
2007 SubsumeMergeError(String, Span),
2008 #[error("extraction failure: {:?}", .0)]
2009 ExtractError(String),
2010 #[error("{span}\n{error}")]
2011 ProofError {
2012 span: Span,
2013 #[source]
2014 error: ProveExistsError,
2015 },
2016 #[error("{1}\n{2}\nShadowing is not allowed, but found {0}")]
2017 Shadowing(String, Span, Span),
2018 #[error("{1}\nCommand already exists: {0}")]
2019 CommandAlreadyExists(String, Span),
2020 #[error("Incorrect format in file '{0}'.")]
2021 InputFileFormatError(String),
2022 #[error(
2023 "Command is not supported by the current proof term encoding implementation.\n\
2024 This typically means the command uses constructs that cannot yet be represented as proof terms.\n\
2025 Consider disabling proof term encoding for this run or rewriting the command to avoid unsupported features.\n\
2026 Offending command: {command}"
2027 )]
2028 UnsupportedProofCommand { command: String },
2029}
2030
2031#[cfg(test)]
2032mod tests {
2033 use crate::constraint::SimpleTypeConstraint;
2034 use crate::sort::*;
2035 use crate::*;
2036
2037 #[derive(Clone)]
2038 struct InnerProduct {
2039 vec: ArcSort,
2040 }
2041
2042 impl Primitive for InnerProduct {
2043 fn name(&self) -> &str {
2044 "inner-product"
2045 }
2046
2047 fn get_type_constraints(&self, span: &Span) -> Box<dyn crate::constraint::TypeConstraint> {
2048 SimpleTypeConstraint::new(
2049 self.name(),
2050 vec![self.vec.clone(), self.vec.clone(), I64Sort.to_arcsort()],
2051 span.clone(),
2052 )
2053 .into_box()
2054 }
2055
2056 fn apply(&self, exec_state: &mut ExecutionState<'_>, args: &[Value]) -> Option<Value> {
2057 let mut sum = 0;
2058 let vec1 = exec_state
2059 .container_values()
2060 .get_val::<VecContainer>(args[0])
2061 .unwrap();
2062 let vec2 = exec_state
2063 .container_values()
2064 .get_val::<VecContainer>(args[1])
2065 .unwrap();
2066 assert_eq!(vec1.data.len(), vec2.data.len());
2067 for (a, b) in vec1.data.iter().zip(vec2.data.iter()) {
2068 let a = exec_state.base_values().unwrap::<i64>(*a);
2069 let b = exec_state.base_values().unwrap::<i64>(*b);
2070 sum += a * b;
2071 }
2072 Some(exec_state.base_values().get::<i64>(sum))
2073 }
2074 }
2075
2076 #[test]
2077 fn test_user_defined_primitive() {
2078 let mut egraph = EGraph::default();
2079 egraph
2080 .parse_and_run_program(None, "(sort IntVec (Vec i64))")
2081 .unwrap();
2082
2083 let int_vec_sort = egraph.get_arcsort_by(|s| {
2084 s.value_type() == Some(std::any::TypeId::of::<VecContainer>())
2085 && s.inner_sorts()[0].name() == I64Sort.name()
2086 });
2087
2088 egraph.add_primitive(InnerProduct { vec: int_vec_sort });
2089
2090 egraph
2091 .parse_and_run_program(
2092 None,
2093 "
2094 (let a (vec-of 1 2 3 4 5 6))
2095 (let b (vec-of 6 5 4 3 2 1))
2096 (check (= (inner-product a b) 56))
2097 ",
2098 )
2099 .unwrap();
2100 }
2101
2102 #[test]
2104 fn test_egraph_send_sync() {
2105 fn is_send<T: Send>(_t: &T) -> bool {
2106 true
2107 }
2108 fn is_sync<T: Sync>(_t: &T) -> bool {
2109 true
2110 }
2111 let egraph = EGraph::default();
2112 assert!(is_send(&egraph) && is_sync(&egraph));
2113 }
2114
2115 fn get_function(egraph: &EGraph, name: &str) -> Function {
2116 egraph.functions.get(name).unwrap().clone()
2117 }
2118
2119 fn get_value(egraph: &EGraph, name: &str) -> Value {
2120 let mut out = None;
2121 let id = get_function(egraph, name).backend_id;
2122 egraph.backend.for_each(id, |row| out = Some(row.vals[0]));
2123 out.unwrap()
2124 }
2125
2126 #[test]
2127 fn test_subsumed_unextractable_rebuild_arg() {
2128 let mut egraph = EGraph::default();
2130
2131 egraph
2132 .parse_and_run_program(
2133 None,
2134 r#"
2135 (datatype Math)
2136 (constructor container (Math) Math)
2137 (constructor exp () Math :cost 100)
2138 (constructor cheap () Math)
2139 (constructor cheap-1 () Math)
2140 ; we make the container cheap so that it will be extracted if possible, but then we mark it as subsumed
2141 ; so the (exp) expr should be extracted instead
2142 (let res (container (cheap)))
2143 (union res (exp))
2144 (cheap)
2145 (cheap-1)
2146 (subsume (container (cheap)))
2147 "#,
2148 ).unwrap();
2149 let orig_cheap_value = get_value(&egraph, "cheap");
2151 let orig_cheap_1_value = get_value(&egraph, "cheap-1");
2152 assert_ne!(orig_cheap_value, orig_cheap_1_value);
2153 egraph
2155 .parse_and_run_program(
2156 None,
2157 r#"
2158 (union (cheap-1) (cheap))
2159 "#,
2160 )
2161 .unwrap();
2162 let new_cheap_value = get_value(&egraph, "cheap");
2164 let new_cheap_1_value = get_value(&egraph, "cheap-1");
2165 assert_eq!(new_cheap_value, new_cheap_1_value);
2166 assert!(new_cheap_value != orig_cheap_value || new_cheap_1_value != orig_cheap_1_value);
2167 let outputs = egraph
2169 .parse_and_run_program(
2170 None,
2171 r#"
2172 (extract res)
2173 "#,
2174 )
2175 .unwrap();
2176 assert_eq!(outputs[0].to_string(), "(exp)\n");
2177 }
2178
2179 #[test]
2180 fn test_subsumed_unextractable_rebuild_self() {
2181 let mut egraph = EGraph::default();
2183
2184 egraph
2185 .parse_and_run_program(
2186 None,
2187 r#"
2188 (datatype Math)
2189 (constructor container (Math) Math)
2190 (constructor exp () Math :cost 100)
2191 (constructor cheap () Math)
2192 (exp)
2193 (let x (cheap))
2194 (subsume (cheap))
2195 "#,
2196 )
2197 .unwrap();
2198
2199 let orig_cheap_value = get_value(&egraph, "cheap");
2200 egraph
2202 .parse_and_run_program(
2203 None,
2204 r#"
2205 (union (exp) x)
2206 "#,
2207 )
2208 .unwrap();
2209 let new_cheap_value = get_value(&egraph, "cheap");
2211 assert_ne!(new_cheap_value, orig_cheap_value);
2212
2213 let res = egraph
2215 .parse_and_run_program(
2216 None,
2217 r#"
2218 (extract x)
2219 "#,
2220 )
2221 .unwrap();
2222 assert_eq!(res[0].to_string(), "(exp)\n");
2223 }
2224}