1use std::hash::Hasher;
2
3use crate::{
4 core::{CoreActionContext, CoreRule, GenericActionsExt, ResolvedCall},
5 *,
6};
7use ast::{ResolvedAction, ResolvedExpr, ResolvedFact, ResolvedRule, ResolvedVar, Rule};
8use core_relations::ExternalFunction;
9use egglog_ast::generic_ast::GenericAction;
10
11#[derive(Clone, Debug)]
12pub struct FuncType {
13 pub name: String,
14 pub subtype: FunctionSubtype,
15 pub input: Vec<ArcSort>,
16 pub output: ArcSort,
17}
18
19impl PartialEq for FuncType {
20 fn eq(&self, other: &Self) -> bool {
21 if self.name == other.name
22 && self.subtype == other.subtype
23 && self.output.name() == other.output.name()
24 {
25 if self.input.len() != other.input.len() {
26 return false;
27 }
28 for (a, b) in self.input.iter().zip(other.input.iter()) {
29 if a.name() != b.name() {
30 return false;
31 }
32 }
33 true
34 } else {
35 false
36 }
37 }
38}
39
40impl Eq for FuncType {}
41
42impl Hash for FuncType {
43 fn hash<H: Hasher>(&self, state: &mut H) {
44 self.name.hash(state);
45 self.subtype.hash(state);
46 self.output.name().hash(state);
47 for inp in &self.input {
48 inp.name().hash(state);
49 }
50 }
51}
52pub type PrimitiveValidator = Arc<dyn Fn(&mut TermDag, &[TermId]) -> Option<TermId> + Send + Sync>;
56
57#[derive(Clone)]
58pub struct PrimitiveWithId {
59 pub(crate) primitive: Arc<dyn Primitive + Send + Sync>,
60 pub(crate) id: ExternalFunctionId,
61 pub(crate) validator: Option<PrimitiveValidator>,
62}
63
64impl PrimitiveWithId {
65 pub fn accept(&self, tys: &[Arc<dyn Sort>], typeinfo: &TypeInfo) -> bool {
68 let mut constraints = vec![];
69 let lits: Vec<_> = (0..tys.len())
70 .map(|i| AtomTerm::Literal(Span::Panic, Literal::Int(i as i64)))
71 .collect();
72 for (lit, ty) in lits.iter().zip(tys.iter()) {
73 constraints.push(constraint::assign(lit.clone(), ty.clone()))
74 }
75 constraints.extend(
76 self.primitive
77 .get_type_constraints(&Span::Panic)
78 .get(&lits, typeinfo),
79 );
80 let problem = Problem {
81 constraints,
82 range: HashSet::default(),
83 };
84 problem.solve(|sort| sort.name()).is_ok()
85 }
86}
87
88impl Debug for PrimitiveWithId {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 write!(f, "Prim({})", self.primitive.name())
91 }
92}
93
94#[derive(Clone, Default)]
96pub struct TypeInfo {
97 mksorts: HashMap<String, MkSort>,
98 reserved_primitives: HashSet<&'static str>,
100 pub(crate) sorts: HashMap<String, Arc<dyn Sort>>,
101 primitives: HashMap<String, Vec<PrimitiveWithId>>,
102 func_types: HashMap<String, FuncType>,
103 pub(crate) global_sorts: HashMap<String, ArcSort>,
104 pub(crate) non_unionable_sorts: HashSet<String>,
106}
107
108impl EGraph {
111 pub fn add_sort<S: Sort + 'static>(&mut self, sort: S, span: Span) -> Result<(), TypeError> {
115 self.add_arcsort(Arc::new(sort), span)
116 }
117
118 pub fn declare_sort(
122 &mut self,
123 name: impl Into<String>,
124 presort_and_args: &Option<(String, Vec<Expr>)>,
125 span: Span,
126 ) -> Result<(), TypeError> {
127 let name = name.into();
128 if self.type_info.func_types.contains_key(&name) {
129 return Err(TypeError::FunctionAlreadyBound(name, span));
130 }
131
132 let sort = match presort_and_args {
133 None => Arc::new(EqSort { name }),
134 Some((presort, args)) => {
135 if let Some(mksort) = self.type_info.mksorts.get(presort) {
136 mksort(&mut self.type_info, name, args)?
137 } else {
138 return Err(TypeError::PresortNotFound(presort.clone(), span));
139 }
140 }
141 };
142
143 self.add_arcsort(sort, span)
144 }
145
146 pub fn add_arcsort(&mut self, sort: ArcSort, span: Span) -> Result<(), TypeError> {
148 sort.register_type(&mut self.backend);
149
150 let name = sort.name();
151 match self.type_info.sorts.entry(name.to_owned()) {
152 HEntry::Occupied(_) => Err(TypeError::SortAlreadyBound(name.to_owned(), span)),
153 HEntry::Vacant(e) => {
154 e.insert(sort.clone());
155 sort.register_primitives(self);
156 Ok(())
157 }
158 }
159 }
160
161 pub fn add_primitive<T>(&mut self, x: T)
163 where
164 T: Clone + Primitive + Send + Sync + 'static,
165 {
166 self.add_primitive_with_validator(x, None)
167 }
168
169 pub fn add_primitive_with_validator<T>(&mut self, x: T, validator: Option<PrimitiveValidator>)
171 where
172 T: Clone + Primitive + Send + Sync + 'static,
173 {
174 #[derive(Clone)]
179 struct Wrapper<T>(T);
180 impl<T: Clone + Primitive + Send + Sync> ExternalFunction for Wrapper<T> {
181 fn invoke(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
182 self.0.apply(exec_state, args)
183 }
184 }
185
186 let primitive = Arc::new(x.clone());
187 let id = self.backend.register_external_func(Box::new(Wrapper(x)));
188 self.type_info
189 .primitives
190 .entry(primitive.name().to_owned())
191 .or_default()
192 .push(PrimitiveWithId {
193 primitive,
194 id,
195 validator,
196 });
197 }
198
199 pub(crate) fn typecheck_program(
200 &mut self,
201 program: &Vec<NCommand>,
202 ) -> Result<Vec<ResolvedNCommand>, TypeError> {
203 let mut result = vec![];
204 for command in program {
205 result.push(self.typecheck_command(command)?);
206 }
207 Ok(result)
208 }
209
210 fn typecheck_command(&mut self, command: &NCommand) -> Result<ResolvedNCommand, TypeError> {
211 let symbol_gen = &mut self.parser.symbol_gen;
212
213 let command: ResolvedNCommand = match command {
214 NCommand::Function(fdecl) => {
215 let resolved = self.type_info.typecheck_function(symbol_gen, fdecl)?;
216 if resolved.internal_let {
219 let output_sort = self.type_info.sorts.get(&fdecl.schema.output).unwrap();
220 self.type_info
221 .global_sorts
222 .insert(fdecl.name.clone(), output_sort.clone());
223 }
224 ResolvedNCommand::Function(resolved)
225 }
226 NCommand::NormRule { rule } => ResolvedNCommand::NormRule {
227 rule: self.type_info.typecheck_rule(symbol_gen, rule)?,
228 },
229 NCommand::Sort {
230 span,
231 name,
232 presort_and_args,
233 uf,
234 proof_func,
235 unionable,
236 } => {
237 self.declare_sort(name.clone(), presort_and_args, span.clone())?;
240 if !unionable {
242 self.type_info.non_unionable_sorts.insert(name.clone());
243 }
244 ResolvedNCommand::Sort {
245 span: span.clone(),
246 name: name.clone(),
247 presort_and_args: presort_and_args.clone(),
248 uf: uf.clone(),
249 proof_func: proof_func.clone(),
250 unionable: *unionable,
251 }
252 }
253 NCommand::CoreAction(Action::Let(span, var, expr)) => {
254 let expr = self
255 .type_info
256 .typecheck_expr(symbol_gen, expr, &Default::default())?;
257 let output_type = expr.output_type();
258 self.ensure_global_name_prefix(span, var)?;
259 self.type_info
260 .global_sorts
261 .insert(var.clone(), output_type.clone());
262 let var = ResolvedVar {
263 name: var.clone(),
264 sort: output_type,
265 is_global_ref: false,
267 };
268 ResolvedNCommand::CoreAction(ResolvedAction::Let(span.clone(), var, expr))
269 }
270 NCommand::CoreAction(action) => ResolvedNCommand::CoreAction(
271 self.type_info
272 .typecheck_action(symbol_gen, action, &Default::default())?,
273 ),
274 NCommand::Extract(span, expr, variants) => {
275 let res_expr =
276 self.type_info
277 .typecheck_expr(symbol_gen, expr, &Default::default())?;
278
279 let res_variants =
280 self.type_info
281 .typecheck_expr(symbol_gen, variants, &Default::default())?;
282 if res_variants.output_type().name() != I64Sort.name() {
283 return Err(TypeError::Mismatch {
284 expr: variants.clone(),
285 expected: I64Sort.to_arcsort(),
286 actual: res_variants.output_type(),
287 });
288 }
289
290 ResolvedNCommand::Extract(span.clone(), res_expr, res_variants)
291 }
292 NCommand::Check(span, facts) => ResolvedNCommand::Check(
293 span.clone(),
294 self.type_info.typecheck_facts(symbol_gen, facts)?,
295 ),
296 NCommand::Fail(span, cmd) => {
297 ResolvedNCommand::Fail(span.clone(), Box::new(self.typecheck_command(cmd)?))
298 }
299 NCommand::RunSchedule(schedule) => ResolvedNCommand::RunSchedule(
300 self.type_info.typecheck_schedule(symbol_gen, schedule)?,
301 ),
302 NCommand::Pop(span, n) => ResolvedNCommand::Pop(span.clone(), *n),
303 NCommand::Push(n) => ResolvedNCommand::Push(*n),
304 NCommand::AddRuleset(span, ruleset) => {
305 ResolvedNCommand::AddRuleset(span.clone(), ruleset.clone())
306 }
307 NCommand::UnstableCombinedRuleset(span, name, sub_rulesets) => {
308 ResolvedNCommand::UnstableCombinedRuleset(
309 span.clone(),
310 name.clone(),
311 sub_rulesets.clone(),
312 )
313 }
314 NCommand::PrintOverallStatistics(span, file) => {
315 ResolvedNCommand::PrintOverallStatistics(span.clone(), file.clone())
316 }
317 NCommand::PrintFunction(span, table, size, file, mode) => {
318 ResolvedNCommand::PrintFunction(
319 span.clone(),
320 table.clone(),
321 *size,
322 file.clone(),
323 *mode,
324 )
325 }
326 NCommand::PrintSize(span, n) => {
327 ResolvedNCommand::PrintSize(span.clone(), n.clone())
329 }
330 NCommand::ProveExists(span, constructor) => {
331 let func_type = self
332 .type_info
333 .get_func_type(constructor)
334 .ok_or_else(|| TypeError::UnboundFunction(constructor.clone(), span.clone()))?;
335 if func_type.subtype != FunctionSubtype::Constructor {
336 return Err(TypeError::ProveExistsRequiresConstructor(
337 constructor.clone(),
338 span.clone(),
339 ));
340 }
341 ResolvedNCommand::ProveExists(span.clone(), ResolvedCall::Func(func_type.clone()))
342 }
343 NCommand::Output { span, file, exprs } => {
344 let exprs = exprs
345 .iter()
346 .map(|expr| {
347 self.type_info
348 .typecheck_expr(symbol_gen, expr, &Default::default())
349 })
350 .collect::<Result<Vec<_>, _>>()?;
351 ResolvedNCommand::Output {
352 span: span.clone(),
353 file: file.clone(),
354 exprs,
355 }
356 }
357 NCommand::Input { span, name, file } => ResolvedNCommand::Input {
358 span: span.clone(),
359 name: name.clone(),
360 file: file.clone(),
361 },
362 NCommand::UserDefined(span, name, exprs) => {
363 ResolvedNCommand::UserDefined(span.clone(), name.clone(), exprs.clone())
364 }
365 };
366 if let ResolvedNCommand::NormRule { rule } = &command {
367 self.warn_for_prefixed_non_globals_in_rule(rule)?;
368 }
369 Ok(command)
370 }
371
372 fn warn_for_prefixed_non_globals_in_var(
373 &mut self,
374 span: &Span,
375 var: &ResolvedVar,
376 ) -> Result<(), TypeError> {
377 if var.is_global_ref {
378 return Ok(());
379 }
380 if var.name.starts_with(crate::GLOBAL_NAME_PREFIX) {
381 self.warn_prefixed_non_globals(span, &var.name)?;
382 }
383 Ok(())
384 }
385
386 fn warn_for_prefixed_non_globals_in_rule(
387 &mut self,
388 rule: &ResolvedRule,
389 ) -> Result<(), TypeError> {
390 let mut res: Result<(), TypeError> = Ok(());
391
392 for fact in &rule.body {
393 fact.visit_vars(&mut |span, var| {
394 if res.is_ok() {
395 res = self.warn_for_prefixed_non_globals_in_var(span, var);
396 }
397 });
398 }
399
400 rule.head.visit_vars(&mut |span, var| {
401 if res.is_ok() {
402 res = self.warn_for_prefixed_non_globals_in_var(span, var);
403 }
404 });
405 res
406 }
407}
408
409impl TypeInfo {
410 pub fn add_presort<S: Presort>(&mut self, span: Span) -> Result<(), TypeError> {
412 let name = S::presort_name();
413 match self.mksorts.entry(name.to_owned()) {
414 HEntry::Occupied(_) => Err(TypeError::SortAlreadyBound(name.to_owned(), span)),
415 HEntry::Vacant(e) => {
416 e.insert(S::make_sort);
417 self.reserved_primitives.extend(S::reserved_primitives());
418 Ok(())
419 }
420 }
421 }
422
423 pub fn get_sorts_by<S: Sort>(&self, pred: impl Fn(&Arc<S>) -> bool) -> Vec<Arc<S>> {
425 let mut results = Vec::new();
426 for sort in self.sorts.values() {
427 let sort = sort.clone().as_arc_any();
428 if let Ok(sort) = Arc::downcast(sort) {
429 if pred(&sort) {
430 results.push(sort);
431 }
432 }
433 }
434 results
435 }
436
437 pub fn get_sorts<S: Sort>(&self) -> Vec<Arc<S>> {
439 self.get_sorts_by(|_| true)
440 }
441
442 pub fn get_sort_by<S: Sort>(&self, pred: impl Fn(&Arc<S>) -> bool) -> Arc<S> {
444 let results = self.get_sorts_by(pred);
445 assert_eq!(
446 results.len(),
447 1,
448 "Expected exactly one sort for type {}",
449 std::any::type_name::<S>()
450 );
451 results.into_iter().next().unwrap()
452 }
453
454 pub fn get_sort<S: Sort>(&self) -> Arc<S> {
456 self.get_sort_by(|_| true)
457 }
458
459 pub fn get_arcsorts_by(&self, f: impl Fn(&ArcSort) -> bool) -> Vec<ArcSort> {
461 self.sorts.values().filter(|&x| f(x)).cloned().collect()
462 }
463
464 pub fn get_arcsort_by(&self, f: impl Fn(&ArcSort) -> bool) -> ArcSort {
466 let results = self.get_arcsorts_by(f);
467 assert_eq!(
468 results.len(),
469 1,
470 "Expected exactly one sort for type {}",
471 std::any::type_name::<S>()
472 );
473 results.into_iter().next().unwrap()
474 }
475
476 pub fn is_sort_unionable(&self, sort: &ArcSort) -> bool {
480 sort.is_eq_sort() && !self.non_unionable_sorts.contains(sort.name())
481 }
482
483 fn function_to_functype(&self, func: &FunctionDecl) -> Result<FuncType, TypeError> {
484 let input = func
485 .schema
486 .input
487 .iter()
488 .map(|name| {
489 if let Some(sort) = self.sorts.get(name) {
490 Ok(sort.clone())
491 } else {
492 Err(TypeError::UndefinedSort(name.clone(), func.span.clone()))
493 }
494 })
495 .collect::<Result<Vec<_>, _>>()?;
496 let output = if let Some(sort) = self.sorts.get(&func.schema.output) {
497 Ok(sort.clone())
498 } else {
499 Err(TypeError::UndefinedSort(
500 func.schema.output.clone(),
501 func.span.clone(),
502 ))
503 }?;
504
505 Ok(FuncType {
506 name: func.name.clone(),
507 subtype: func.subtype,
508 input,
509 output: output.clone(),
510 })
511 }
512
513 fn typecheck_function(
514 &mut self,
515 symbol_gen: &mut SymbolGen,
516 fdecl: &FunctionDecl,
517 ) -> Result<ResolvedFunctionDecl, TypeError> {
518 if self.sorts.contains_key(&fdecl.name) {
519 return Err(TypeError::SortAlreadyBound(
520 fdecl.name.clone(),
521 fdecl.span.clone(),
522 ));
523 }
524 if self.is_primitive(&fdecl.name) {
525 return Err(TypeError::PrimitiveAlreadyBound(
526 fdecl.name.clone(),
527 fdecl.span.clone(),
528 ));
529 }
530 if fdecl.term_constructor.is_some() && fdecl.schema.input.is_empty() {
532 return Err(TypeError::TermConstructorNoInputs(
533 fdecl.name.clone(),
534 fdecl.span.clone(),
535 ));
536 }
537 let ftype = self.function_to_functype(fdecl)?;
538 if self.func_types.insert(fdecl.name.clone(), ftype).is_some() {
539 return Err(TypeError::FunctionAlreadyBound(
540 fdecl.name.clone(),
541 fdecl.span.clone(),
542 ));
543 }
544 let mut bound_vars = IndexMap::default();
545 let output_type = self.sorts.get(&fdecl.schema.output).unwrap();
546 if fdecl.subtype == FunctionSubtype::Constructor && !output_type.is_eq_sort() {
547 return Err(TypeError::ConstructorOutputNotSort(
548 fdecl.name.clone(),
549 fdecl.span.clone(),
550 ));
551 }
552 bound_vars.insert("old", (fdecl.span.clone(), output_type.clone()));
553 bound_vars.insert("new", (fdecl.span.clone(), output_type.clone()));
554
555 Ok(ResolvedFunctionDecl {
556 name: fdecl.name.clone(),
557 subtype: fdecl.subtype,
558 schema: fdecl.schema.clone(),
559 resolved_schema: ResolvedCall::Func(self.func_types.get(&fdecl.name).unwrap().clone()),
560 merge: match &fdecl.merge {
561 Some(merge) => Some(self.typecheck_expr(symbol_gen, merge, &bound_vars)?),
562 None => None,
563 },
564 cost: fdecl.cost,
565 unextractable: fdecl.unextractable,
566 internal_hidden: fdecl.internal_hidden,
567 internal_let: fdecl.internal_let,
568 span: fdecl.span.clone(),
569 term_constructor: fdecl.term_constructor.clone(),
570 })
571 }
572
573 fn typecheck_schedule(
574 &self,
575 symbol_gen: &mut SymbolGen,
576 schedule: &Schedule,
577 ) -> Result<ResolvedSchedule, TypeError> {
578 let schedule = match schedule {
579 Schedule::Repeat(span, times, schedule) => ResolvedSchedule::Repeat(
580 span.clone(),
581 *times,
582 Box::new(self.typecheck_schedule(symbol_gen, schedule)?),
583 ),
584 Schedule::Sequence(span, schedules) => {
585 let schedules = schedules
586 .iter()
587 .map(|schedule| self.typecheck_schedule(symbol_gen, schedule))
588 .collect::<Result<Vec<_>, _>>()?;
589 ResolvedSchedule::Sequence(span.clone(), schedules)
590 }
591 Schedule::Saturate(span, schedule) => ResolvedSchedule::Saturate(
592 span.clone(),
593 Box::new(self.typecheck_schedule(symbol_gen, schedule)?),
594 ),
595 Schedule::Run(span, RunConfig { ruleset, until }) => {
596 let until = until
597 .as_ref()
598 .map(|facts| self.typecheck_facts(symbol_gen, facts))
599 .transpose()?;
600 ResolvedSchedule::Run(
601 span.clone(),
602 ResolvedRunConfig {
603 ruleset: ruleset.clone(),
604 until,
605 },
606 )
607 }
608 };
609
610 Result::Ok(schedule)
611 }
612
613 fn typecheck_rule(
614 &self,
615 symbol_gen: &mut SymbolGen,
616 rule: &Rule,
617 ) -> Result<ResolvedRule, TypeError> {
618 let Rule {
619 span,
620 head,
621 body,
622 name,
623 ruleset,
624 } = rule;
625 let mut constraints = vec![];
626
627 let (query, mapped_query) = Facts(body.clone()).to_query(self, symbol_gen);
628 constraints.extend(query.get_constraints(self)?);
629
630 let mut binding = query.get_vars();
631 let mut ctx = CoreActionContext::new(self, &mut binding, symbol_gen, false);
634 let (actions, mapped_action) = head.to_core_actions(&mut ctx)?;
635
636 let mut problem = Problem::default();
637 problem.add_rule(
638 &CoreRule {
639 span: span.clone(),
640 body: query,
641 head: actions,
642 },
643 self,
644 symbol_gen,
645 )?;
646
647 let assignment = problem
648 .solve(|sort: &ArcSort| sort.name())
649 .map_err(|e| e.to_type_error())?;
650
651 let body: Vec<ResolvedFact> = assignment.annotate_facts(&mapped_query, self);
652 let actions: ResolvedActions = assignment.annotate_actions(&mapped_action, self)?;
653
654 self.check_lookup_actions(&actions)?;
655
656 Ok(ResolvedRule {
657 span: span.clone(),
658 body,
659 head: actions,
660 name: name.clone(),
661 ruleset: ruleset.clone(),
662 })
663 }
664
665 fn check_lookup_expr(&self, expr: &ResolvedExpr) -> Result<(), TypeError> {
666 if let Some(span) = self.expr_has_function_lookup(expr) {
667 return Err(TypeError::LookupInRuleDisallowed(
668 "function".to_string(),
669 span,
670 ));
671 }
672 Ok(())
673 }
674
675 fn check_lookup_actions(&self, actions: &ResolvedActions) -> Result<(), TypeError> {
676 for action in actions.iter() {
677 match action {
678 GenericAction::Let(_, _, rhs) => self.check_lookup_expr(rhs)?,
679 GenericAction::Set(_, _, args, rhs) => {
680 for arg in args.iter() {
681 self.check_lookup_expr(arg)?;
682 }
683 self.check_lookup_expr(rhs)?;
684 }
685 GenericAction::Union(_, lhs, rhs) => {
686 self.check_lookup_expr(lhs)?;
687 self.check_lookup_expr(rhs)?;
688 }
689 GenericAction::Change(_, _, _, args) => {
690 for arg in args.iter() {
691 self.check_lookup_expr(arg)?;
692 }
693 }
694 GenericAction::Panic(..) => {}
695 GenericAction::Expr(_, expr) => self.check_lookup_expr(expr)?,
696 }
697 }
698 Ok(())
699 }
700
701 pub fn typecheck_facts(
702 &self,
703 symbol_gen: &mut SymbolGen,
704 facts: &[Fact],
705 ) -> Result<Vec<ResolvedFact>, TypeError> {
706 let (query, mapped_facts) = Facts(facts.to_vec()).to_query(self, symbol_gen);
707 let mut problem = Problem::default();
708 problem.add_query(&query, self)?;
709 let assignment = problem
710 .solve(|sort: &ArcSort| sort.name())
711 .map_err(|e| e.to_type_error())?;
712 let annotated_facts = assignment.annotate_facts(&mapped_facts, self);
713 Ok(annotated_facts)
714 }
715
716 fn typecheck_actions(
717 &self,
718 symbol_gen: &mut SymbolGen,
719 actions: &Actions,
720 binding: &IndexMap<&str, (Span, ArcSort)>,
721 ) -> Result<ResolvedActions, TypeError> {
722 let mut binding_set: IndexSet<String> =
723 binding.keys().copied().map(str::to_string).collect();
724 let mut ctx = CoreActionContext::new(self, &mut binding_set, symbol_gen, false);
727 let (actions, mapped_action) = actions.to_core_actions(&mut ctx)?;
728 let mut problem = Problem::default();
729
730 problem.add_actions(&actions, self, symbol_gen)?;
732
733 for (var, (span, sort)) in binding {
735 problem.assign_local_var_type(var, span.clone(), sort.clone())?;
736 }
737
738 let assignment = problem
739 .solve(|sort: &ArcSort| sort.name())
740 .map_err(|e| e.to_type_error())?;
741
742 let annotated_actions = assignment.annotate_actions(&mapped_action, self)?;
743 Ok(annotated_actions)
744 }
745
746 fn typecheck_expr(
747 &self,
748 symbol_gen: &mut SymbolGen,
749 expr: &Expr,
750 binding: &IndexMap<&str, (Span, ArcSort)>,
751 ) -> Result<ResolvedExpr, TypeError> {
752 let action = Action::Expr(expr.span(), expr.clone());
753 let typechecked_action = self.typecheck_action(symbol_gen, &action, binding)?;
754 match typechecked_action {
755 ResolvedAction::Expr(_, expr) => Ok(expr),
756 _ => unreachable!(),
757 }
758 }
759
760 fn typecheck_action(
761 &self,
762 symbol_gen: &mut SymbolGen,
763 action: &Action,
764 binding: &IndexMap<&str, (Span, ArcSort)>,
765 ) -> Result<ResolvedAction, TypeError> {
766 self.typecheck_actions(symbol_gen, &Actions::singleton(action.clone()), binding)
767 .map(|v| {
768 assert_eq!(v.len(), 1);
769 v.0.into_iter().next().unwrap()
770 })
771 }
772
773 pub fn get_sort_by_name(&self, sym: &str) -> Option<&ArcSort> {
774 self.sorts.get(sym)
775 }
776
777 pub fn get_prims(&self, sym: &str) -> Option<&[PrimitiveWithId]> {
778 self.primitives.get(sym).map(Vec::as_slice)
779 }
780
781 pub fn is_primitive(&self, sym: &str) -> bool {
782 self.primitives.contains_key(sym) || self.reserved_primitives.contains(sym)
783 }
784
785 pub fn primitive_has_validator(&self, id: ExternalFunctionId) -> bool {
786 self.primitives
787 .values()
788 .flat_map(|v| v.iter())
789 .any(|p| p.id == id && p.validator.is_some())
790 }
791
792 pub fn get_func_type(&self, sym: &str) -> Option<&FuncType> {
793 self.func_types.get(sym)
794 }
795
796 pub fn is_constructor(&self, sym: &str) -> bool {
797 self.func_types
798 .get(sym)
799 .is_some_and(|f| f.subtype == FunctionSubtype::Constructor)
800 }
801
802 pub fn get_global_sort(&self, sym: &str) -> Option<&ArcSort> {
803 self.global_sorts.get(sym)
804 }
805
806 pub fn is_global(&self, sym: &str) -> bool {
807 self.global_sorts.contains_key(sym)
808 }
809
810 pub fn expr_has_function_lookup(&self, expr: &ResolvedExpr) -> Option<Span> {
814 use ast::GenericExpr;
815
816 expr.find(&mut |e| {
817 if let GenericExpr::Call(span, ResolvedCall::Func(func_type), _) = e {
818 if func_type.subtype == FunctionSubtype::Custom {
819 if !self.is_global(&func_type.name) {
821 return Some(span.clone());
822 }
823 }
824 }
825 None
826 })
827 }
828}
829
830#[derive(Debug, Clone, Error)]
831pub enum TypeError {
832 #[error("{}\nArity mismatch, expected {expected} args: {expr}", .expr.span())]
833 Arity { expr: Expr, expected: usize },
834 #[error(
835 "{}\n Expect expression {expr} to have type {}, but get type {}",
836 .expr.span(), .expected.name(), .actual.name(),
837 )]
838 Mismatch {
839 expr: Expr,
840 expected: ArcSort,
841 actual: ArcSort,
842 },
843 #[error("{1}\nUnbound symbol {0}")]
844 Unbound(String, Span),
845 #[error("{1}\nVariable {0} is ungrounded")]
846 Ungrounded(String, Span),
847 #[error("{1}\nUndefined sort {0}")]
848 UndefinedSort(String, Span),
849 #[error("{2}\nSort {0} definition is disallowed: {1}")]
850 DisallowedSort(String, String, Span),
851 #[error("{1}\nUnbound function {0}")]
852 UnboundFunction(String, Span),
853 #[error("{1}\nprove-exists requires constructor function, but {0} is not a constructor")]
854 ProveExistsRequiresConstructor(String, Span),
855 #[error("{1}\nFunction already bound {0}")]
856 FunctionAlreadyBound(String, Span),
857 #[error("{1}\nSort {0} already declared.")]
858 SortAlreadyBound(String, Span),
859 #[error("{1}\nPrimitive {0} already declared.")]
860 PrimitiveAlreadyBound(String, Span),
861 #[error("Function type mismatch: expected {} => {}, actual {} => {}", .1.iter().map(|s| s.name().to_string()).collect::<Vec<_>>().join(", "), .0.name(), .3.iter().map(|s| s.name().to_string()).collect::<Vec<_>>().join(", "), .2.name())]
862 FunctionTypeMismatch(ArcSort, Vec<ArcSort>, ArcSort, Vec<ArcSort>),
863 #[error("{1}\nPresort {0} not found.")]
864 PresortNotFound(String, Span),
865 #[error("{}\nFailed to infer a type for: {}", .0.span(), .0)]
866 InferenceFailure(Expr),
867 #[error("{1}\nVariable {0} was already defined")]
868 AlreadyDefined(String, Span),
869 #[error("{1}\nThe output type of constructor function {0} must be sort")]
870 ConstructorOutputNotSort(String, Span),
871 #[error("{1}\nValue lookup of non-constructor function {0} in rule is disallowed.")]
872 LookupInRuleDisallowed(String, Span),
873 #[error("{1}\nCannot set constructor {0}. Use `union` instead or declare {0} as a function.")]
874 SetConstructorDisallowed(String, Span),
875 #[error("All alternative definitions considered failed\n{}", .0.iter().map(|e| format!(" {e}\n")).collect::<Vec<_>>().join(""))]
876 AllAlternativeFailed(Vec<TypeError>),
877 #[error("{}\nCannot union values of sort {}", .1, .0.name())]
878 NonEqsortUnion(ArcSort, Span),
879 #[error("{}\nCannot union values of sort {} because it is marked as non-unionable (e.g. from a relation)", .1, .0.name())]
880 NonUnionableSort(ArcSort, Span),
881 #[error(
882 "{1}\nView table {0} with :term-constructor must have at least one input (the e-class)."
883 )]
884 TermConstructorNoInputs(String, Span),
885 #[error(
886 "{span}\nNon-global variable `{name}` must not start with `{}`.",
887 crate::GLOBAL_NAME_PREFIX
888 )]
889 NonGlobalPrefixed { name: String, span: Span },
890 #[error(
891 "{span}\nGlobal `{name}` must start with `{}`.",
892 crate::GLOBAL_NAME_PREFIX
893 )]
894 GlobalMissingPrefix { name: String, span: Span },
895}
896
897#[cfg(test)]
898mod test {
899 use crate::{EGraph, Error, typechecking::TypeError};
900
901 #[test]
902 fn test_arity_mismatch() {
903 let mut egraph = EGraph::default();
904
905 let prog = "
906 (relation f (i64 i64))
907 (rule ((f a b c)) ())
908 ";
909 let res = egraph.parse_and_run_program(None, prog);
910 match res {
911 Err(Error::TypeError(TypeError::Arity {
912 expected: 2,
913 expr: e,
914 })) => {
915 assert_eq!(e.span().string(), "(f a b c)");
916 }
917 _ => panic!("Expected arity mismatch, got: {res:?}"),
918 }
919 }
920}