1use std::hash::Hasher;
2
3use crate::Context;
4use crate::{
5 core::{CoreActionContext, CoreRule, GenericActionsExt, ResolvedCall},
6 *,
7};
8use ast::{
9 MappedExprExt, ResolvedAction, ResolvedExpr, ResolvedFact, ResolvedRule, ResolvedVar, Rule,
10 RuleEvalMode,
11};
12use core_relations::ExternalFunction;
13use egglog_ast::generic_ast::GenericAction;
14use egglog_bridge::ActionRegistry;
15use enum_map::EnumMap;
16use std::sync::{Arc, RwLock};
17
18#[derive(Clone)]
23struct PurePrimWrapper<T> {
24 prim: T,
25 ctx: Context,
30}
31
32impl<T: PurePrim + Clone> ExternalFunction for PurePrimWrapper<T> {
33 fn invoke(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
34 self.prim.apply(PureState::wrap(exec_state, self.ctx), args)
35 }
36}
37
38#[derive(Clone)]
43struct RegistryPrimWrapper<T, S> {
44 prim: T,
45 registry: Arc<RwLock<ActionRegistry>>,
46 ctx: Context,
48 _wrap: std::marker::PhantomData<fn() -> S>,
49}
50
51trait RegistryWrap<T>: Clone + Send + Sync {
52 fn invoke(
53 prim: &T,
54 exec_state: &mut ExecutionState,
55 ctx: Context,
56 args: &[Value],
57 registry: &ActionRegistry,
58 ) -> Option<Value>;
59}
60
61#[derive(Clone)]
62struct WrapRead;
63impl<T: ReadPrim> RegistryWrap<T> for WrapRead {
64 #[inline]
65 fn invoke(
66 prim: &T,
67 exec_state: &mut ExecutionState,
68 ctx: Context,
69 args: &[Value],
70 registry: &ActionRegistry,
71 ) -> Option<Value> {
72 prim.apply(ReadState::wrap(exec_state, registry, ctx), args)
73 }
74}
75#[derive(Clone)]
76struct WrapWrite;
77impl<T: WritePrim> RegistryWrap<T> for WrapWrite {
78 #[inline]
79 fn invoke(
80 prim: &T,
81 exec_state: &mut ExecutionState,
82 ctx: Context,
83 args: &[Value],
84 registry: &ActionRegistry,
85 ) -> Option<Value> {
86 prim.apply(WriteState::wrap(exec_state, registry, ctx), args)
87 }
88}
89#[derive(Clone)]
90struct WrapFull;
91impl<T: FullPrim> RegistryWrap<T> for WrapFull {
92 #[inline]
93 fn invoke(
94 prim: &T,
95 exec_state: &mut ExecutionState,
96 ctx: Context,
97 args: &[Value],
98 registry: &ActionRegistry,
99 ) -> Option<Value> {
100 prim.apply(FullState::wrap(exec_state, registry, ctx), args)
101 }
102}
103
104impl<T: Clone + Send + Sync + 'static, S: RegistryWrap<T> + 'static> ExternalFunction
105 for RegistryPrimWrapper<T, S>
106{
107 fn invoke(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
108 let registry = self.registry.read().unwrap();
109 S::invoke(&self.prim, exec_state, self.ctx, args, ®istry)
110 }
111}
112
113#[derive(Clone, Debug)]
114pub struct FuncType {
115 pub name: String,
116 pub subtype: FunctionSubtype,
117 pub input: Vec<ArcSort>,
118 pub output: ArcSort,
119}
120
121impl PartialEq for FuncType {
122 fn eq(&self, other: &Self) -> bool {
123 if self.name == other.name
124 && self.subtype == other.subtype
125 && self.output.name() == other.output.name()
126 {
127 if self.input.len() != other.input.len() {
128 return false;
129 }
130 for (a, b) in self.input.iter().zip(other.input.iter()) {
131 if a.name() != b.name() {
132 return false;
133 }
134 }
135 true
136 } else {
137 false
138 }
139 }
140}
141
142impl Eq for FuncType {}
143
144impl Hash for FuncType {
145 fn hash<H: Hasher>(&self, state: &mut H) {
146 self.name.hash(state);
147 self.subtype.hash(state);
148 self.output.name().hash(state);
149 for inp in &self.input {
150 inp.name().hash(state);
151 }
152 }
153}
154pub type PrimitiveValidator = Arc<dyn Fn(&mut TermDag, &[TermId]) -> Option<TermId> + Send + Sync>;
158
159#[derive(Clone)]
160pub struct PrimitiveWithId {
161 pub(crate) primitive: Arc<dyn Primitive>,
162 pub(crate) validator: Option<PrimitiveValidator>,
163 pub(crate) context_ids: EnumMap<Context, Option<ExternalFunctionId>>,
168}
169
170impl PrimitiveWithId {
171 pub fn accept(&self, tys: &[Arc<dyn Sort>], typeinfo: &TypeInfo) -> bool {
174 let mut constraints = vec![];
175 let lits: Vec<_> = (0..tys.len())
176 .map(|i| AtomTerm::Literal(Span::Panic, Literal::Int(i as i64)))
177 .collect();
178 for (lit, ty) in lits.iter().zip(tys.iter()) {
179 constraints.push(constraint::assign(lit.clone(), ty.clone()))
180 }
181 constraints.extend(
182 self.primitive
183 .get_type_constraints(&Span::Panic)
184 .get(&lits, typeinfo),
185 );
186 let problem = Problem {
187 constraints,
188 range: HashSet::default(),
189 };
190 problem.solve(|sort| sort.name()).is_ok()
191 }
192
193 pub fn is_valid_in_context(&self, context: Context) -> bool {
195 self.context_ids[context].is_some()
196 }
197}
198
199impl Debug for PrimitiveWithId {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 write!(f, "Prim({})", self.primitive.name())
202 }
203}
204
205#[derive(Clone, Default)]
207pub struct TypeInfo {
208 mksorts: HashMap<String, MkSort>,
209 reserved_primitives: HashSet<&'static str>,
211 pub(crate) sorts: HashMap<String, Arc<dyn Sort>>,
212 primitives: HashMap<String, Vec<PrimitiveWithId>>,
213 func_types: HashMap<String, FuncType>,
214 pub(crate) global_sorts: HashMap<String, ArcSort>,
215 pub(crate) non_unionable_sorts: HashSet<String>,
217}
218
219impl EGraph {
222 pub fn add_sort<S: Sort + 'static>(&mut self, sort: S, span: Span) -> Result<(), TypeError> {
226 self.add_arcsort(Arc::new(sort), span)
227 }
228
229 pub fn declare_sort(
233 &mut self,
234 name: impl Into<String>,
235 presort_and_args: &Option<(String, Vec<Expr>)>,
236 span: Span,
237 ) -> Result<(), TypeError> {
238 let name = name.into();
239 if self.type_info.func_types.contains_key(&name) {
240 return Err(TypeError::FunctionAlreadyBound(name, span));
241 }
242
243 let sort = match presort_and_args {
244 None => Arc::new(EqSort { name }),
245 Some((presort, args)) => {
246 if let Some(mksort) = self.type_info.mksorts.get(presort) {
247 mksort(&mut self.type_info, name, args)?
248 } else {
249 return Err(TypeError::PresortNotFound(presort.clone(), span));
250 }
251 }
252 };
253
254 self.add_arcsort(sort, span)
255 }
256
257 pub fn add_arcsort(&mut self, sort: ArcSort, span: Span) -> Result<(), TypeError> {
259 sort.register_type(&mut self.backend);
260
261 let name = sort.name();
262 match self.type_info.sorts.entry(name.to_owned()) {
263 HEntry::Occupied(_) => Err(TypeError::SortAlreadyBound(name.to_owned(), span)),
264 HEntry::Vacant(e) => {
265 e.insert(sort.clone());
266 sort.register_primitives(self);
267 Ok(())
268 }
269 }
270 }
271
272 pub fn add_pure_primitive<T>(&mut self, x: T, validator: Option<PrimitiveValidator>)
281 where
282 T: PurePrim + Clone,
283 {
284 self.register_per_context(x, validator, PureState::valid_contexts(), |x, ctx| {
285 Box::new(PurePrimWrapper { prim: x, ctx })
286 });
287 }
288
289 pub fn add_write_primitive<T>(&mut self, x: T, validator: Option<PrimitiveValidator>)
292 where
293 T: WritePrim + Clone,
294 {
295 self.register_registry_primitive::<T, WrapWrite>(
296 x,
297 validator,
298 WriteState::valid_contexts(),
299 );
300 }
301
302 pub fn add_read_primitive<T>(&mut self, x: T, validator: Option<PrimitiveValidator>)
305 where
306 T: ReadPrim + Clone,
307 {
308 self.register_registry_primitive::<T, WrapRead>(x, validator, ReadState::valid_contexts());
309 }
310
311 pub fn add_full_primitive<T>(&mut self, x: T, validator: Option<PrimitiveValidator>)
314 where
315 T: FullPrim + Clone,
316 {
317 self.register_registry_primitive::<T, WrapFull>(x, validator, FullState::valid_contexts());
318 }
319
320 fn register_registry_primitive<T, S>(
321 &mut self,
322 x: T,
323 validator: Option<PrimitiveValidator>,
324 valid_ctxs: &[Context],
325 ) where
326 T: Primitive + Clone,
327 S: RegistryWrap<T> + 'static,
328 {
329 let registry = self.backend.action_registry().clone();
330 self.register_per_context(x, validator, valid_ctxs, move |x, ctx| {
331 Box::new(RegistryPrimWrapper::<T, S> {
332 prim: x,
333 registry: registry.clone(),
334 ctx,
335 _wrap: std::marker::PhantomData,
336 })
337 });
338 }
339
340 fn register_per_context<T, F>(
350 &mut self,
351 x: T,
352 validator: Option<PrimitiveValidator>,
353 valid_ctxs: &[Context],
354 mut build_wrapper: F,
355 ) where
356 T: Primitive + Clone,
357 F: FnMut(T, Context) -> Box<dyn ExternalFunction>,
358 {
359 let primitive: Arc<dyn Primitive> = Arc::new(x.clone());
360 let name = primitive.name().to_owned();
361 let context_ids = EnumMap::from_fn(|ctx| {
362 valid_ctxs.contains(&ctx).then(|| {
363 self.backend
364 .register_external_func(build_wrapper(x.clone(), ctx))
365 })
366 });
367 self.type_info
368 .primitives
369 .entry(name)
370 .or_default()
371 .push(PrimitiveWithId {
372 primitive,
373 validator,
374 context_ids,
375 });
376 }
377}
378
379impl EGraph {
380 pub(crate) fn typecheck_program(
381 &mut self,
382 program: &Vec<NCommand>,
383 ) -> Result<Vec<ResolvedNCommand>, TypeError> {
384 let mut result = vec![];
385 for command in program {
386 result.push(self.typecheck_command(command)?);
387 }
388 Ok(result)
389 }
390
391 fn typecheck_command(&mut self, command: &NCommand) -> Result<ResolvedNCommand, TypeError> {
392 let symbol_gen = &mut self.parser.symbol_gen;
393
394 let command: ResolvedNCommand = match command {
395 NCommand::Function(fdecl) => {
396 let resolved = self.type_info.typecheck_function(symbol_gen, fdecl)?;
397 if resolved.internal_let {
400 let output_sort = self.type_info.sorts.get(&fdecl.schema.output).unwrap();
401 self.type_info
402 .global_sorts
403 .insert(fdecl.name.clone(), output_sort.clone());
404 }
405 ResolvedNCommand::Function(resolved)
406 }
407 NCommand::NormRule { rule } => ResolvedNCommand::NormRule {
408 rule: self
409 .type_info
410 .typecheck_rule(symbol_gen, rule, self.seminaive)?,
411 },
412 NCommand::Sort {
413 span,
414 name,
415 presort_and_args,
416 uf,
417 proof_func,
418 unionable,
419 } => {
420 self.declare_sort(name.clone(), presort_and_args, span.clone())?;
423 if !unionable {
425 self.type_info.non_unionable_sorts.insert(name.clone());
426 }
427 ResolvedNCommand::Sort {
428 span: span.clone(),
429 name: name.clone(),
430 presort_and_args: presort_and_args.clone(),
431 uf: uf.clone(),
432 proof_func: proof_func.clone(),
433 unionable: *unionable,
434 }
435 }
436 NCommand::CoreAction(action @ Action::Let(span, var, _)) => {
437 let action = self.type_info.typecheck_standalone_action(
438 symbol_gen,
439 action,
440 &Default::default(),
441 Context::Full,
442 )?;
443 self.ensure_global_name_prefix(span, var)?;
444 let ResolvedAction::Let(_, resolved_var, _) = &action else {
445 unreachable!("typechecking an Action::Let should return ResolvedAction::Let")
446 };
447 self.type_info
448 .global_sorts
449 .insert(resolved_var.name.clone(), resolved_var.sort.clone());
450 ResolvedNCommand::CoreAction(action)
451 }
452 NCommand::CoreAction(action) => {
453 ResolvedNCommand::CoreAction(self.type_info.typecheck_standalone_action(
454 symbol_gen,
455 action,
456 &Default::default(),
457 Context::Full,
458 )?)
459 }
460 NCommand::Extract(span, expr, variants) => {
461 let res_expr = self.type_info.typecheck_standalone_expr(
462 symbol_gen,
463 expr,
464 &Default::default(),
465 Context::Full,
466 )?;
467
468 let res_variants = self.type_info.typecheck_standalone_expr(
469 symbol_gen,
470 variants,
471 &Default::default(),
472 Context::Full,
473 )?;
474 if res_variants.output_type().name() != I64Sort.name() {
475 return Err(TypeError::Mismatch {
476 expr: variants.clone(),
477 expected: I64Sort.to_arcsort(),
478 actual: res_variants.output_type(),
479 });
480 }
481
482 ResolvedNCommand::Extract(span.clone(), res_expr, res_variants)
483 }
484 NCommand::Check(span, facts) => ResolvedNCommand::Check(
485 span.clone(),
486 self.type_info.typecheck_facts(symbol_gen, facts)?,
487 ),
488 NCommand::Fail(span, cmd) => {
489 ResolvedNCommand::Fail(span.clone(), Box::new(self.typecheck_command(cmd)?))
490 }
491 NCommand::RunSchedule(schedule) => ResolvedNCommand::RunSchedule(
492 self.type_info.typecheck_schedule(symbol_gen, schedule)?,
493 ),
494 NCommand::Pop(span, n) => ResolvedNCommand::Pop(span.clone(), *n),
495 NCommand::Push(n) => ResolvedNCommand::Push(*n),
496 NCommand::AddRuleset(span, ruleset) => {
497 ResolvedNCommand::AddRuleset(span.clone(), ruleset.clone())
498 }
499 NCommand::UnstableCombinedRuleset(span, name, sub_rulesets) => {
500 ResolvedNCommand::UnstableCombinedRuleset(
501 span.clone(),
502 name.clone(),
503 sub_rulesets.clone(),
504 )
505 }
506 NCommand::PrintOverallStatistics(span, file) => {
507 ResolvedNCommand::PrintOverallStatistics(span.clone(), file.clone())
508 }
509 NCommand::PrintFunction(span, table, size, file, mode) => {
510 ResolvedNCommand::PrintFunction(
511 span.clone(),
512 table.clone(),
513 *size,
514 file.clone(),
515 *mode,
516 )
517 }
518 NCommand::PrintSize(span, n) => {
519 ResolvedNCommand::PrintSize(span.clone(), n.clone())
521 }
522 NCommand::ProveExists(span, constructor) => {
523 let func_type = self
524 .type_info
525 .get_func_type(constructor)
526 .ok_or_else(|| TypeError::UnboundFunction(constructor.clone(), span.clone()))?;
527 if func_type.subtype != FunctionSubtype::Constructor {
528 return Err(TypeError::ProveExistsRequiresConstructor(
529 constructor.clone(),
530 span.clone(),
531 ));
532 }
533 ResolvedNCommand::ProveExists(span.clone(), ResolvedCall::Func(func_type.clone()))
534 }
535 NCommand::Output { span, file, exprs } => {
536 let exprs = exprs
537 .iter()
538 .map(|expr| {
539 self.type_info.typecheck_standalone_expr(
540 symbol_gen,
541 expr,
542 &Default::default(),
543 Context::Full,
544 )
545 })
546 .collect::<Result<Vec<_>, _>>()?;
547 ResolvedNCommand::Output {
548 span: span.clone(),
549 file: file.clone(),
550 exprs,
551 }
552 }
553 NCommand::Input { span, name, file } => ResolvedNCommand::Input {
554 span: span.clone(),
555 name: name.clone(),
556 file: file.clone(),
557 },
558 NCommand::UserDefined(span, name, exprs) => {
559 ResolvedNCommand::UserDefined(span.clone(), name.clone(), exprs.clone())
560 }
561 };
562 if let ResolvedNCommand::NormRule { rule } = &command {
563 self.warn_for_prefixed_non_globals_in_rule(rule)?;
564 }
565 Ok(command)
566 }
567
568 fn warn_for_prefixed_non_globals_in_var(
569 &mut self,
570 span: &Span,
571 var: &ResolvedVar,
572 ) -> Result<(), TypeError> {
573 if var.is_global_ref {
574 return Ok(());
575 }
576 if var.name.starts_with(crate::GLOBAL_NAME_PREFIX) {
577 self.warn_prefixed_non_globals(span, &var.name)?;
578 }
579 Ok(())
580 }
581
582 fn warn_for_prefixed_non_globals_in_rule(
583 &mut self,
584 rule: &ResolvedRule,
585 ) -> Result<(), TypeError> {
586 let mut res: Result<(), TypeError> = Ok(());
587
588 for fact in &rule.body {
589 fact.visit_vars(&mut |span, var| {
590 if res.is_ok() {
591 res = self.warn_for_prefixed_non_globals_in_var(span, var);
592 }
593 });
594 }
595
596 rule.head.visit_vars(&mut |span, var| {
597 if res.is_ok() {
598 res = self.warn_for_prefixed_non_globals_in_var(span, var);
599 }
600 });
601 res
602 }
603}
604
605impl TypeInfo {
606 pub fn add_presort<S: Presort>(&mut self, span: Span) -> Result<(), TypeError> {
608 let name = S::presort_name();
609 match self.mksorts.entry(name.to_owned()) {
610 HEntry::Occupied(_) => Err(TypeError::SortAlreadyBound(name.to_owned(), span)),
611 HEntry::Vacant(e) => {
612 e.insert(S::make_sort);
613 self.reserved_primitives.extend(S::reserved_primitives());
614 Ok(())
615 }
616 }
617 }
618
619 pub fn get_sorts_by<S: Sort>(&self, pred: impl Fn(&Arc<S>) -> bool) -> Vec<Arc<S>> {
621 let mut results = Vec::new();
622 for sort in self.sorts.values() {
623 let sort = sort.clone().as_arc_any();
624 if let Ok(sort) = Arc::downcast(sort)
625 && pred(&sort)
626 {
627 results.push(sort);
628 }
629 }
630 results
631 }
632
633 pub fn get_sorts<S: Sort>(&self) -> Vec<Arc<S>> {
635 self.get_sorts_by(|_| true)
636 }
637
638 pub fn get_sort_by<S: Sort>(&self, pred: impl Fn(&Arc<S>) -> bool) -> Arc<S> {
640 let results = self.get_sorts_by(pred);
641 assert_eq!(
642 results.len(),
643 1,
644 "Expected exactly one sort for type {}",
645 std::any::type_name::<S>()
646 );
647 results.into_iter().next().unwrap()
648 }
649
650 pub fn get_sort<S: Sort>(&self) -> Arc<S> {
652 self.get_sort_by(|_| true)
653 }
654
655 pub fn get_arcsorts_by(&self, f: impl Fn(&ArcSort) -> bool) -> Vec<ArcSort> {
657 self.sorts.values().filter(|&x| f(x)).cloned().collect()
658 }
659
660 pub fn get_arcsort_by(&self, f: impl Fn(&ArcSort) -> bool) -> ArcSort {
662 let results = self.get_arcsorts_by(f);
663 assert_eq!(
664 results.len(),
665 1,
666 "Expected exactly one sort matching the given predicate"
667 );
668 results.into_iter().next().unwrap()
669 }
670
671 pub fn get_arcsort_for_value_type<T: 'static>(&self) -> ArcSort {
673 let results = self.get_arcsorts_by(|s| s.value_type() == Some(std::any::TypeId::of::<T>()));
674 assert_eq!(
675 results.len(),
676 1,
677 "Expected exactly one sort for type `{}`",
678 std::any::type_name::<T>()
679 );
680 results.into_iter().next().unwrap()
681 }
682
683 pub fn is_sort_unionable(&self, sort: &ArcSort) -> bool {
687 sort.is_eq_sort() && !self.non_unionable_sorts.contains(sort.name())
688 }
689
690 fn function_to_functype(&self, func: &FunctionDecl) -> Result<FuncType, TypeError> {
691 let input = func
692 .schema
693 .input
694 .iter()
695 .map(|name| {
696 if let Some(sort) = self.sorts.get(name) {
697 Ok(sort.clone())
698 } else {
699 Err(TypeError::UndefinedSort(name.clone(), func.span.clone()))
700 }
701 })
702 .collect::<Result<Vec<_>, _>>()?;
703 let output = if let Some(sort) = self.sorts.get(&func.schema.output) {
704 Ok(sort.clone())
705 } else {
706 Err(TypeError::UndefinedSort(
707 func.schema.output.clone(),
708 func.span.clone(),
709 ))
710 }?;
711
712 Ok(FuncType {
713 name: func.name.clone(),
714 subtype: func.subtype,
715 input,
716 output: output.clone(),
717 })
718 }
719
720 fn typecheck_function(
721 &mut self,
722 symbol_gen: &mut SymbolGen,
723 fdecl: &FunctionDecl,
724 ) -> Result<ResolvedFunctionDecl, TypeError> {
725 if self.sorts.contains_key(&fdecl.name) {
726 return Err(TypeError::SortAlreadyBound(
727 fdecl.name.clone(),
728 fdecl.span.clone(),
729 ));
730 }
731 if self.is_primitive(&fdecl.name) {
732 return Err(TypeError::PrimitiveAlreadyBound(
733 fdecl.name.clone(),
734 fdecl.span.clone(),
735 ));
736 }
737 if fdecl.term_constructor.is_some() && fdecl.schema.input.is_empty() {
739 return Err(TypeError::TermConstructorNoInputs(
740 fdecl.name.clone(),
741 fdecl.span.clone(),
742 ));
743 }
744 let ftype = self.function_to_functype(fdecl)?;
745 if self.func_types.insert(fdecl.name.clone(), ftype).is_some() {
746 return Err(TypeError::FunctionAlreadyBound(
747 fdecl.name.clone(),
748 fdecl.span.clone(),
749 ));
750 }
751 let mut bound_vars = IndexMap::default();
752 let output_type = self.sorts.get(&fdecl.schema.output).unwrap();
753 if fdecl.subtype == FunctionSubtype::Constructor && !output_type.is_eq_sort() {
754 return Err(TypeError::ConstructorOutputNotSort(
755 fdecl.name.clone(),
756 fdecl.span.clone(),
757 ));
758 }
759 bound_vars.insert("old", (fdecl.span.clone(), output_type.clone()));
760 bound_vars.insert("new", (fdecl.span.clone(), output_type.clone()));
761
762 Ok(ResolvedFunctionDecl {
763 name: fdecl.name.clone(),
764 subtype: fdecl.subtype,
765 schema: fdecl.schema.clone(),
766 resolved_schema: ResolvedCall::Func(self.func_types.get(&fdecl.name).unwrap().clone()),
767 merge: match &fdecl.merge {
768 Some(merge) => Some(self.typecheck_standalone_expr(
772 symbol_gen,
773 merge,
774 &bound_vars,
775 Context::Write,
776 )?),
777 None => None,
778 },
779 cost: fdecl.cost,
780 unextractable: fdecl.unextractable,
781 internal_hidden: fdecl.internal_hidden,
782 internal_let: fdecl.internal_let,
783 span: fdecl.span.clone(),
784 term_constructor: fdecl.term_constructor.clone(),
785 })
786 }
787
788 fn typecheck_schedule(
789 &self,
790 symbol_gen: &mut SymbolGen,
791 schedule: &Schedule,
792 ) -> Result<ResolvedSchedule, TypeError> {
793 let schedule = match schedule {
794 Schedule::Repeat(span, times, schedule) => ResolvedSchedule::Repeat(
795 span.clone(),
796 *times,
797 Box::new(self.typecheck_schedule(symbol_gen, schedule)?),
798 ),
799 Schedule::Sequence(span, schedules) => {
800 let schedules = schedules
801 .iter()
802 .map(|schedule| self.typecheck_schedule(symbol_gen, schedule))
803 .collect::<Result<Vec<_>, _>>()?;
804 ResolvedSchedule::Sequence(span.clone(), schedules)
805 }
806 Schedule::Saturate(span, schedule) => ResolvedSchedule::Saturate(
807 span.clone(),
808 Box::new(self.typecheck_schedule(symbol_gen, schedule)?),
809 ),
810 Schedule::Run(span, RunConfig { ruleset, until }) => {
811 let until = until
812 .as_ref()
813 .map(|facts| self.typecheck_facts(symbol_gen, facts))
814 .transpose()?;
815 ResolvedSchedule::Run(
816 span.clone(),
817 ResolvedRunConfig {
818 ruleset: ruleset.clone(),
819 until,
820 },
821 )
822 }
823 };
824
825 Result::Ok(schedule)
826 }
827
828 fn typecheck_rule(
829 &self,
830 symbol_gen: &mut SymbolGen,
831 rule: &Rule,
832 global_seminaive: bool,
833 ) -> Result<ResolvedRule, TypeError> {
834 let Rule {
835 span,
836 head,
837 body,
838 name,
839 ruleset,
840 eval_mode,
841 no_decomp,
842 include_subsumed,
843 } = rule;
844 let mut constraints = vec![];
845
846 let read_contexts = !global_seminaive
850 || matches!(
851 eval_mode,
852 RuleEvalMode::Naive | RuleEvalMode::UnsafeSeminaive
853 );
854 let (query_ctx, action_ctx) = if read_contexts {
855 (Context::Read, Context::Full)
856 } else {
857 (Context::Pure, Context::Write)
858 };
859
860 let (query, mapped_query) = Facts(body.clone()).to_query(self, symbol_gen);
861 constraints.extend(query.get_constraints(self, query_ctx)?);
862
863 let mut binding = query.get_vars();
864 let mut ctx = CoreActionContext::new(self, &mut binding, symbol_gen, false);
867 let (actions, mapped_action) = head.to_core_actions(&mut ctx)?;
868
869 let mut problem = Problem::default();
870 problem.add_rule(
871 &CoreRule {
872 span: span.clone(),
873 body: query,
874 head: actions,
875 },
876 self,
877 symbol_gen,
878 query_ctx,
879 action_ctx,
880 )?;
881
882 let assignment = problem
883 .solve(|sort: &ArcSort| sort.name())
884 .map_err(|e| e.to_type_error())?;
885
886 let body: Vec<ResolvedFact> = assignment.annotate_facts(&mapped_query, self, query_ctx);
887 let actions: ResolvedActions =
888 assignment.annotate_actions(&mapped_action, self, action_ctx)?;
889
890 if !read_contexts {
893 self.check_no_function_lookups_in_actions(&actions)?;
894 }
895
896 Ok(ResolvedRule {
897 span: span.clone(),
898 body,
899 head: actions,
900 name: name.clone(),
901 ruleset: ruleset.clone(),
902 eval_mode: *eval_mode,
903 no_decomp: *no_decomp,
904 include_subsumed: *include_subsumed,
905 })
906 }
907
908 fn check_lookup_expr(&self, expr: &ResolvedExpr) -> Result<(), TypeError> {
909 if let Some(span) = self.expr_has_function_lookup(expr) {
910 return Err(TypeError::LookupInRuleDisallowed(
911 "function".to_string(),
912 span,
913 ));
914 }
915 Ok(())
916 }
917
918 fn check_no_function_lookups_in_actions(
919 &self,
920 actions: &ResolvedActions,
921 ) -> Result<(), TypeError> {
922 for action in actions.iter() {
923 match action {
924 GenericAction::Let(_, _, rhs) => self.check_lookup_expr(rhs)?,
925 GenericAction::Set(_, _, args, rhs) => {
926 for arg in args.iter() {
927 self.check_lookup_expr(arg)?;
928 }
929 self.check_lookup_expr(rhs)?;
930 }
931 GenericAction::Union(_, lhs, rhs) => {
932 self.check_lookup_expr(lhs)?;
933 self.check_lookup_expr(rhs)?;
934 }
935 GenericAction::Change(_, _, _, args) => {
936 for arg in args.iter() {
937 self.check_lookup_expr(arg)?;
938 }
939 }
940 GenericAction::Panic(..) => {}
941 GenericAction::Expr(_, expr) => self.check_lookup_expr(expr)?,
942 }
943 }
944 Ok(())
945 }
946
947 pub fn typecheck_facts(
948 &self,
949 symbol_gen: &mut SymbolGen,
950 facts: &[Fact],
951 ) -> Result<Vec<ResolvedFact>, TypeError> {
952 let (query, mapped_facts) = Facts(facts.to_vec()).to_query(self, symbol_gen);
953 let mut problem = Problem::default();
954 problem.add_query(&query, self, Context::Read)?;
957 let assignment = problem
958 .solve(|sort: &ArcSort| sort.name())
959 .map_err(|e| e.to_type_error())?;
960 let annotated_facts = assignment.annotate_facts(&mapped_facts, self, Context::Read);
961 Ok(annotated_facts)
962 }
963
964 fn typecheck_standalone_actions(
968 &self,
969 symbol_gen: &mut SymbolGen,
970 actions: &Actions,
971 binding: &IndexMap<&str, (Span, ArcSort)>,
972 context: Context,
973 ) -> Result<ResolvedActions, TypeError> {
974 let mut binding_set: IndexSet<String> =
975 binding.keys().copied().map(str::to_string).collect();
976 let mut ctx = CoreActionContext::new(self, &mut binding_set, symbol_gen, false);
979 let (actions, mapped_action) = actions.to_core_actions(&mut ctx)?;
980 let mut problem = Problem::default();
981
982 problem.add_actions(&actions, self, symbol_gen, context)?;
983
984 for (var, (span, sort)) in binding {
986 problem.assign_local_var_type(var, span.clone(), sort.clone())?;
987 }
988
989 let assignment = problem
990 .solve(|sort: &ArcSort| sort.name())
991 .map_err(|e| e.to_type_error())?;
992
993 let annotated_actions = assignment.annotate_actions(&mapped_action, self, context)?;
994 Ok(annotated_actions)
995 }
996
997 fn typecheck_standalone_expr(
998 &self,
999 symbol_gen: &mut SymbolGen,
1000 expr: &Expr,
1001 binding: &IndexMap<&str, (Span, ArcSort)>,
1002 context: Context,
1003 ) -> Result<ResolvedExpr, TypeError> {
1004 let action = Action::Expr(expr.span(), expr.clone());
1005 let typechecked_action =
1006 self.typecheck_standalone_action(symbol_gen, &action, binding, context)?;
1007 match typechecked_action {
1008 ResolvedAction::Expr(_, expr) => Ok(expr),
1009 _ => unreachable!(),
1010 }
1011 }
1012
1013 pub(crate) fn typecheck_expr_with_output(
1014 &self,
1015 symbol_gen: &mut SymbolGen,
1016 expr: &Expr,
1017 binding: &IndexMap<&str, (Span, ArcSort)>,
1018 output_sort: ArcSort,
1019 context: Context,
1020 ) -> Result<ResolvedExpr, TypeError> {
1021 let action = Action::Expr(expr.span(), expr.clone());
1022 let mut binding_set: IndexSet<String> =
1023 binding.keys().copied().map(str::to_string).collect();
1024 let mut ctx = CoreActionContext::new(self, &mut binding_set, symbol_gen, false);
1025 let (actions, mapped_action) = Actions::singleton(action).to_core_actions(&mut ctx)?;
1026 let mut problem = Problem::default();
1027
1028 problem.add_actions(&actions, self, symbol_gen, context)?;
1029
1030 for (var, (span, sort)) in binding {
1031 problem.assign_local_var_type(var, span.clone(), sort.clone())?;
1032 }
1033
1034 let [GenericAction::Expr(_, mapped_expr)] = mapped_action.0.as_slice() else {
1035 unreachable!("typechecking an expression should produce one expression action")
1036 };
1037 let output_atom = mapped_expr.get_corresponding_var_or_lit(self);
1038 problem.add_binding(output_atom, output_sort.clone());
1039
1040 let assignment = problem
1041 .solve(|sort: &ArcSort| sort.name())
1042 .map_err(|e| e.to_type_error())?;
1043
1044 let annotated_actions = assignment.annotate_actions(&mapped_action, self, context)?;
1045 match annotated_actions.0.into_iter().next().unwrap() {
1046 ResolvedAction::Expr(_, resolved_expr) => {
1047 let actual = resolved_expr.output_type();
1048 if actual.name() != output_sort.name() {
1049 return Err(TypeError::Mismatch {
1050 expr: expr.clone(),
1051 expected: output_sort,
1052 actual,
1053 });
1054 }
1055 Ok(resolved_expr)
1056 }
1057 _ => unreachable!(),
1058 }
1059 }
1060
1061 fn typecheck_standalone_action(
1062 &self,
1063 symbol_gen: &mut SymbolGen,
1064 action: &Action,
1065 binding: &IndexMap<&str, (Span, ArcSort)>,
1066 context: Context,
1067 ) -> Result<ResolvedAction, TypeError> {
1068 self.typecheck_standalone_actions(
1069 symbol_gen,
1070 &Actions::singleton(action.clone()),
1071 binding,
1072 context,
1073 )
1074 .map(|v| {
1075 assert_eq!(v.len(), 1);
1076 v.0.into_iter().next().unwrap()
1077 })
1078 }
1079
1080 pub fn get_sort_by_name(&self, sym: &str) -> Option<&ArcSort> {
1081 self.sorts.get(sym)
1082 }
1083
1084 pub fn get_prims(&self, sym: &str) -> Option<&[PrimitiveWithId]> {
1085 self.primitives.get(sym).map(Vec::as_slice)
1086 }
1087
1088 pub fn is_primitive(&self, sym: &str) -> bool {
1089 self.primitives.contains_key(sym) || self.reserved_primitives.contains(sym)
1090 }
1091
1092 pub fn primitive_has_validator(&self, id: ExternalFunctionId) -> bool {
1093 self.primitives
1094 .values()
1095 .flat_map(|v| v.iter())
1096 .any(|p| p.context_ids.iter().any(|(_, pid)| *pid == Some(id)) && p.validator.is_some())
1097 }
1098
1099 pub fn get_func_type(&self, sym: &str) -> Option<&FuncType> {
1100 self.func_types.get(sym)
1101 }
1102
1103 pub fn is_constructor(&self, sym: &str) -> bool {
1104 self.func_types
1105 .get(sym)
1106 .is_some_and(|f| f.subtype == FunctionSubtype::Constructor)
1107 }
1108
1109 pub fn get_global_sort(&self, sym: &str) -> Option<&ArcSort> {
1110 self.global_sorts.get(sym)
1111 }
1112
1113 pub fn is_global(&self, sym: &str) -> bool {
1114 self.global_sorts.contains_key(sym)
1115 }
1116
1117 pub fn expr_has_function_lookup(&self, expr: &ResolvedExpr) -> Option<Span> {
1121 use ast::GenericExpr;
1122
1123 expr.find(&mut |e| {
1124 if let GenericExpr::Call(span, ResolvedCall::Func(func_type), _) = e
1125 && func_type.subtype == FunctionSubtype::Custom
1126 && !self.is_global(&func_type.name)
1127 {
1128 return Some(span.clone());
1129 }
1130 None
1131 })
1132 }
1133}
1134
1135#[derive(Debug, Clone, Error)]
1136pub enum TypeError {
1137 #[error("{}\nArity mismatch, expected {expected} args: {expr}", .expr.span())]
1138 Arity { expr: Expr, expected: usize },
1139 #[error(
1140 "{}\n Expect expression {expr} to have type {}, but get type {}",
1141 .expr.span(), .expected.name(), .actual.name(),
1142 )]
1143 Mismatch {
1144 expr: Expr,
1145 expected: ArcSort,
1146 actual: ArcSort,
1147 },
1148 #[error("{1}\nUnbound symbol {0}")]
1149 Unbound(String, Span),
1150 #[error(
1151 "{1}\nVariable {0} is ungrounded. A variable is grounded when it appears as an argument to a constructor or function in the query, not just under primitives or equalities."
1152 )]
1153 Ungrounded(String, Span),
1154 #[error("{1}\nUndefined sort {0}")]
1155 UndefinedSort(String, Span),
1156 #[error("{1}\nUnbound function {0}")]
1157 UnboundFunction(String, Span),
1158 #[error("{1}\nprove-exists requires constructor function, but {0} is not a constructor")]
1159 ProveExistsRequiresConstructor(String, Span),
1160 #[error("{1}\nFunction already bound {0}")]
1161 FunctionAlreadyBound(String, Span),
1162 #[error("{1}\nSort {0} already declared.")]
1163 SortAlreadyBound(String, Span),
1164 #[error("{1}\nPrimitive {0} already declared.")]
1165 PrimitiveAlreadyBound(String, Span),
1166 #[error("Function type mismatch: expected {} => {}, actual {} => {}", .1.iter().map(|s| s.name().to_string()).collect::<Vec<_>>().join(", "), .0.name(), .3.iter().map(|s| s.name().to_string()).collect::<Vec<_>>().join(", "), .2.name())]
1167 FunctionTypeMismatch(ArcSort, Vec<ArcSort>, ArcSort, Vec<ArcSort>),
1168 #[error("{1}\nPresort {0} not found.")]
1169 PresortNotFound(String, Span),
1170 #[error("{}\nFailed to infer a type for: {}", .0.span(), .0)]
1171 InferenceFailure(Expr),
1172 #[error("{1}\nVariable {0} was already defined")]
1173 AlreadyDefined(String, Span),
1174 #[error("{1}\nThe output type of constructor function {0} must be sort")]
1175 ConstructorOutputNotSort(String, Span),
1176 #[error("{1}\nValue lookup of non-constructor function {0} in rule is disallowed.")]
1177 LookupInRuleDisallowed(String, Span),
1178 #[error("{1}\nCannot set constructor {0}. Use `union` instead or declare {0} as a function.")]
1179 SetConstructorDisallowed(String, Span),
1180 #[error("All alternative definitions considered failed\n{}", .0.iter().map(|e| format!(" {e}\n")).collect::<Vec<_>>().join(""))]
1181 AllAlternativeFailed(Vec<TypeError>),
1182 #[error("{}\nCannot union values of sort {}", .1, .0.name())]
1183 NonEqsortUnion(ArcSort, Span),
1184 #[error("{}\nCannot union values of sort {} because it is marked as non-unionable (e.g. from a relation)", .1, .0.name())]
1185 NonUnionableSort(ArcSort, Span),
1186 #[error(
1187 "{1}\nView table {0} with :internal-term-constructor must have at least one input (the e-class)."
1188 )]
1189 TermConstructorNoInputs(String, Span),
1190 #[error(
1191 "{span}\nNon-global variable `{name}` must not start with `{}`.",
1192 crate::GLOBAL_NAME_PREFIX
1193 )]
1194 NonGlobalPrefixed { name: String, span: Span },
1195 #[error(
1196 "{span}\nGlobal `{name}` must start with `{}`.",
1197 crate::GLOBAL_NAME_PREFIX
1198 )]
1199 GlobalMissingPrefix { name: String, span: Span },
1200}
1201
1202#[cfg(test)]
1203mod test {
1204 use crate::{EGraph, Error, typechecking::TypeError};
1205
1206 #[test]
1207 fn test_arity_mismatch() {
1208 let mut egraph = EGraph::default();
1209
1210 let prog = "
1211 (relation f (i64 i64))
1212 (rule ((f a b c)) ())
1213 ";
1214 let res = egraph.parse_and_run_program(None, prog);
1215 match res {
1216 Err(Error::TypeError(TypeError::Arity {
1217 expected: 2,
1218 expr: e,
1219 })) => {
1220 assert_eq!(e.span().string(), "(f a b c)");
1221 }
1222 _ => panic!("Expected arity mismatch, got: {res:?}"),
1223 }
1224 }
1225}