egglog/
typechecking.rs

1use crate::{
2    core::{CoreRule, GenericActionsExt},
3    *,
4};
5use ast::{ResolvedAction, ResolvedExpr, ResolvedFact, ResolvedRule, ResolvedVar, Rule};
6use core_relations::ExternalFunction;
7use egglog_ast::generic_ast::GenericAction;
8
9#[derive(Clone, Debug)]
10pub struct FuncType {
11    pub name: String,
12    pub subtype: FunctionSubtype,
13    pub input: Vec<ArcSort>,
14    pub output: ArcSort,
15}
16
17#[derive(Clone)]
18pub struct PrimitiveWithId(pub Arc<dyn Primitive + Send + Sync>, pub ExternalFunctionId);
19
20impl PrimitiveWithId {
21    /// Takes the full signature of a primitive (both input and output types).
22    /// Returns whether the primitive is compatible with this signature.
23    pub fn accept(&self, tys: &[Arc<dyn Sort>], typeinfo: &TypeInfo) -> bool {
24        let mut constraints = vec![];
25        let lits: Vec<_> = (0..tys.len())
26            .map(|i| AtomTerm::Literal(Span::Panic, Literal::Int(i as i64)))
27            .collect();
28        for (lit, ty) in lits.iter().zip(tys.iter()) {
29            constraints.push(constraint::assign(lit.clone(), ty.clone()))
30        }
31        constraints.extend(
32            self.0
33                .get_type_constraints(&Span::Panic)
34                .get(&lits, typeinfo),
35        );
36        let problem = Problem {
37            constraints,
38            range: HashSet::default(),
39        };
40        problem.solve(|sort| sort.name()).is_ok()
41    }
42}
43
44impl Debug for PrimitiveWithId {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        write!(f, "Prim({})", self.0.name())
47    }
48}
49
50/// Stores resolved typechecking information.
51#[derive(Clone, Default)]
52pub struct TypeInfo {
53    mksorts: HashMap<String, MkSort>,
54    // TODO(yz): I want to get rid of this as now we have user-defined primitives and constraint based type checking
55    reserved_primitives: HashSet<&'static str>,
56    sorts: HashMap<String, Arc<dyn Sort>>,
57    primitives: HashMap<String, Vec<PrimitiveWithId>>,
58    func_types: HashMap<String, FuncType>,
59    global_sorts: HashMap<String, ArcSort>,
60}
61
62// These methods need to be on the `EGraph` in order to
63// register sorts and primitives with the backend.
64impl EGraph {
65    /// Add a user-defined sort to the e-graph.
66    ///
67    /// Also look at [`prelude::add_base_sort`] for a convenience method for adding user-defined sorts
68    pub fn add_sort<S: Sort + 'static>(&mut self, sort: S, span: Span) -> Result<(), TypeError> {
69        self.add_arcsort(Arc::new(sort), span)
70    }
71
72    /// Declare a sort. This corresponds to the `sort` keyword in egglog.
73    /// It can either declares a new [`EqSort`] if `presort_and_args` is not provided,
74    /// or an instantiation of a presort (e.g., containers like `Vec`).
75    pub fn declare_sort(
76        &mut self,
77        name: impl Into<String>,
78        presort_and_args: &Option<(String, Vec<Expr>)>,
79        span: Span,
80    ) -> Result<(), TypeError> {
81        let name = name.into();
82        if self.type_info.func_types.contains_key(&name) {
83            return Err(TypeError::FunctionAlreadyBound(name, span));
84        }
85
86        let sort = match presort_and_args {
87            None => Arc::new(EqSort { name }),
88            Some((presort, args)) => {
89                if let Some(mksort) = self.type_info.mksorts.get(presort) {
90                    mksort(&mut self.type_info, name, args)?
91                } else {
92                    return Err(TypeError::PresortNotFound(presort.clone(), span));
93                }
94            }
95        };
96
97        self.add_arcsort(sort, span)
98    }
99
100    /// Add a user-defined sort to the e-graph.
101    pub fn add_arcsort(&mut self, sort: ArcSort, span: Span) -> Result<(), TypeError> {
102        sort.register_type(&mut self.backend);
103
104        let name = sort.name();
105        match self.type_info.sorts.entry(name.to_owned()) {
106            HEntry::Occupied(_) => Err(TypeError::SortAlreadyBound(name.to_owned(), span)),
107            HEntry::Vacant(e) => {
108                e.insert(sort.clone());
109                sort.register_primitives(self);
110                Ok(())
111            }
112        }
113    }
114
115    /// Add a user-defined primitive
116    pub fn add_primitive<T>(&mut self, x: T)
117    where
118        T: Clone + Primitive + Send + Sync + 'static,
119    {
120        // We need to use a wrapper because of the orphan rule.
121        // If we just try to implement `ExternalFunction` directly on
122        // all `PrimitiveLike`s then it would be possible for a
123        // downstream crate to create a conflict.
124        #[derive(Clone)]
125        struct Wrapper<T>(T);
126        impl<T: Clone + Primitive + Send + Sync> ExternalFunction for Wrapper<T> {
127            fn invoke(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
128                self.0.apply(exec_state, args)
129            }
130        }
131
132        let prim = Arc::new(x.clone());
133        let ext = self.backend.register_external_func(Wrapper(x));
134        self.type_info
135            .primitives
136            .entry(prim.name().to_owned())
137            .or_default()
138            .push(PrimitiveWithId(prim, ext));
139    }
140
141    pub(crate) fn typecheck_program(
142        &mut self,
143        program: &Vec<NCommand>,
144    ) -> Result<Vec<ResolvedNCommand>, TypeError> {
145        let mut result = vec![];
146        for command in program {
147            result.push(self.typecheck_command(command)?);
148        }
149        Ok(result)
150    }
151
152    fn typecheck_command(&mut self, command: &NCommand) -> Result<ResolvedNCommand, TypeError> {
153        let symbol_gen = &mut self.parser.symbol_gen;
154
155        let command: ResolvedNCommand = match command {
156            NCommand::Function(fdecl) => {
157                ResolvedNCommand::Function(self.type_info.typecheck_function(symbol_gen, fdecl)?)
158            }
159            NCommand::NormRule { rule } => ResolvedNCommand::NormRule {
160                rule: self.type_info.typecheck_rule(symbol_gen, rule)?,
161            },
162            NCommand::Sort(span, sort, presort_and_args) => {
163                // Note this is bad since typechecking should be pure and idempotent
164                // Otherwise typechecking the same program twice will fail
165                self.declare_sort(sort.clone(), presort_and_args, span.clone())?;
166                ResolvedNCommand::Sort(span.clone(), sort.clone(), presort_and_args.clone())
167            }
168            NCommand::CoreAction(Action::Let(span, var, expr)) => {
169                let expr = self
170                    .type_info
171                    .typecheck_expr(symbol_gen, expr, &Default::default())?;
172                let output_type = expr.output_type();
173                self.ensure_global_name_prefix(span, var)?;
174                self.type_info
175                    .global_sorts
176                    .insert(var.clone(), output_type.clone());
177                let var = ResolvedVar {
178                    name: var.clone(),
179                    sort: output_type,
180                    // not a global reference, but a global binding
181                    is_global_ref: false,
182                };
183                ResolvedNCommand::CoreAction(ResolvedAction::Let(span.clone(), var, expr))
184            }
185            NCommand::CoreAction(action) => ResolvedNCommand::CoreAction(
186                self.type_info
187                    .typecheck_action(symbol_gen, action, &Default::default())?,
188            ),
189            NCommand::Extract(span, expr, variants) => {
190                let res_expr =
191                    self.type_info
192                        .typecheck_expr(symbol_gen, expr, &Default::default())?;
193
194                let res_variants =
195                    self.type_info
196                        .typecheck_expr(symbol_gen, variants, &Default::default())?;
197                if res_variants.output_type().name() != I64Sort.name() {
198                    return Err(TypeError::Mismatch {
199                        expr: variants.clone(),
200                        expected: I64Sort.to_arcsort(),
201                        actual: res_variants.output_type(),
202                    });
203                }
204
205                ResolvedNCommand::Extract(span.clone(), res_expr, res_variants)
206            }
207            NCommand::Check(span, facts) => ResolvedNCommand::Check(
208                span.clone(),
209                self.type_info.typecheck_facts(symbol_gen, facts)?,
210            ),
211            NCommand::Fail(span, cmd) => {
212                ResolvedNCommand::Fail(span.clone(), Box::new(self.typecheck_command(cmd)?))
213            }
214            NCommand::RunSchedule(schedule) => ResolvedNCommand::RunSchedule(
215                self.type_info.typecheck_schedule(symbol_gen, schedule)?,
216            ),
217            NCommand::Pop(span, n) => ResolvedNCommand::Pop(span.clone(), *n),
218            NCommand::Push(n) => ResolvedNCommand::Push(*n),
219            NCommand::AddRuleset(span, ruleset) => {
220                ResolvedNCommand::AddRuleset(span.clone(), ruleset.clone())
221            }
222            NCommand::UnstableCombinedRuleset(span, name, sub_rulesets) => {
223                ResolvedNCommand::UnstableCombinedRuleset(
224                    span.clone(),
225                    name.clone(),
226                    sub_rulesets.clone(),
227                )
228            }
229            NCommand::PrintOverallStatistics(span, file) => {
230                ResolvedNCommand::PrintOverallStatistics(span.clone(), file.clone())
231            }
232            NCommand::PrintFunction(span, table, size, file, mode) => {
233                ResolvedNCommand::PrintFunction(
234                    span.clone(),
235                    table.clone(),
236                    *size,
237                    file.clone(),
238                    *mode,
239                )
240            }
241            NCommand::PrintSize(span, n) => {
242                // Should probably also resolve the function symbol here
243                ResolvedNCommand::PrintSize(span.clone(), n.clone())
244            }
245            NCommand::Output { span, file, exprs } => {
246                let exprs = exprs
247                    .iter()
248                    .map(|expr| {
249                        self.type_info
250                            .typecheck_expr(symbol_gen, expr, &Default::default())
251                    })
252                    .collect::<Result<Vec<_>, _>>()?;
253                ResolvedNCommand::Output {
254                    span: span.clone(),
255                    file: file.clone(),
256                    exprs,
257                }
258            }
259            NCommand::Input { span, name, file } => ResolvedNCommand::Input {
260                span: span.clone(),
261                name: name.clone(),
262                file: file.clone(),
263            },
264            NCommand::UserDefined(span, name, exprs) => {
265                ResolvedNCommand::UserDefined(span.clone(), name.clone(), exprs.clone())
266            }
267        };
268        if let ResolvedNCommand::NormRule { rule } = &command {
269            self.warn_for_prefixed_non_globals_in_rule(rule)?;
270        }
271        Ok(command)
272    }
273
274    fn warn_for_prefixed_non_globals_in_var(
275        &mut self,
276        span: &Span,
277        var: &ResolvedVar,
278    ) -> Result<(), TypeError> {
279        if var.is_global_ref {
280            return Ok(());
281        }
282        if let Some(stripped) = var.name.strip_prefix(crate::GLOBAL_NAME_PREFIX) {
283            self.warn_missing_global_prefix(span, stripped)?;
284        }
285        Ok(())
286    }
287
288    fn warn_for_prefixed_non_globals_in_rule(
289        &mut self,
290        rule: &ResolvedRule,
291    ) -> Result<(), TypeError> {
292        let mut res: Result<(), TypeError> = Ok(());
293
294        for fact in &rule.body {
295            fact.visit_vars(&mut |span, var| {
296                if res.is_ok() {
297                    res = self.warn_for_prefixed_non_globals_in_var(span, var);
298                }
299            });
300        }
301
302        rule.head.visit_vars(&mut |span, var| {
303            if res.is_ok() {
304                res = self.warn_for_prefixed_non_globals_in_var(span, var);
305            }
306        });
307        res
308    }
309}
310
311impl TypeInfo {
312    /// Adds a sort constructor to the typechecker's known set of types.
313    pub fn add_presort<S: Presort>(&mut self, span: Span) -> Result<(), TypeError> {
314        let name = S::presort_name();
315        match self.mksorts.entry(name.to_owned()) {
316            HEntry::Occupied(_) => Err(TypeError::SortAlreadyBound(name.to_owned(), span)),
317            HEntry::Vacant(e) => {
318                e.insert(S::make_sort);
319                self.reserved_primitives.extend(S::reserved_primitives());
320                Ok(())
321            }
322        }
323    }
324
325    /// Returns all sorts that satisfy the type and predicate.
326    pub fn get_sorts_by<S: Sort>(&self, pred: impl Fn(&Arc<S>) -> bool) -> Vec<Arc<S>> {
327        let mut results = Vec::new();
328        for sort in self.sorts.values() {
329            let sort = sort.clone().as_arc_any();
330            if let Ok(sort) = Arc::downcast(sort) {
331                if pred(&sort) {
332                    results.push(sort);
333                }
334            }
335        }
336        results
337    }
338
339    /// Returns all sorts based on the type.
340    pub fn get_sorts<S: Sort>(&self) -> Vec<Arc<S>> {
341        self.get_sorts_by(|_| true)
342    }
343
344    /// Returns a sort that satisfies the type and predicate.
345    pub fn get_sort_by<S: Sort>(&self, pred: impl Fn(&Arc<S>) -> bool) -> Arc<S> {
346        let results = self.get_sorts_by(pred);
347        assert_eq!(
348            results.len(),
349            1,
350            "Expected exactly one sort for type {}",
351            std::any::type_name::<S>()
352        );
353        results.into_iter().next().unwrap()
354    }
355
356    /// Returns a sort based on the type.
357    pub fn get_sort<S: Sort>(&self) -> Arc<S> {
358        self.get_sort_by(|_| true)
359    }
360
361    /// Returns all sorts that satisfy the predicate.
362    pub fn get_arcsorts_by(&self, f: impl Fn(&ArcSort) -> bool) -> Vec<ArcSort> {
363        self.sorts.values().filter(|&x| f(x)).cloned().collect()
364    }
365
366    /// Returns a sort based on the predicate.
367    pub fn get_arcsort_by(&self, f: impl Fn(&ArcSort) -> bool) -> ArcSort {
368        let results = self.get_arcsorts_by(f);
369        assert_eq!(
370            results.len(),
371            1,
372            "Expected exactly one sort for type {}",
373            std::any::type_name::<S>()
374        );
375        results.into_iter().next().unwrap()
376    }
377
378    fn function_to_functype(&self, func: &FunctionDecl) -> Result<FuncType, TypeError> {
379        let input = func
380            .schema
381            .input
382            .iter()
383            .map(|name| {
384                if let Some(sort) = self.sorts.get(name) {
385                    Ok(sort.clone())
386                } else {
387                    Err(TypeError::UndefinedSort(name.clone(), func.span.clone()))
388                }
389            })
390            .collect::<Result<Vec<_>, _>>()?;
391        let output = if let Some(sort) = self.sorts.get(&func.schema.output) {
392            Ok(sort.clone())
393        } else {
394            Err(TypeError::UndefinedSort(
395                func.schema.output.clone(),
396                func.span.clone(),
397            ))
398        }?;
399
400        Ok(FuncType {
401            name: func.name.clone(),
402            subtype: func.subtype,
403            input,
404            output: output.clone(),
405        })
406    }
407
408    fn typecheck_function(
409        &mut self,
410        symbol_gen: &mut SymbolGen,
411        fdecl: &FunctionDecl,
412    ) -> Result<ResolvedFunctionDecl, TypeError> {
413        if self.sorts.contains_key(&fdecl.name) {
414            return Err(TypeError::SortAlreadyBound(
415                fdecl.name.clone(),
416                fdecl.span.clone(),
417            ));
418        }
419        if self.is_primitive(&fdecl.name) {
420            return Err(TypeError::PrimitiveAlreadyBound(
421                fdecl.name.clone(),
422                fdecl.span.clone(),
423            ));
424        }
425        let ftype = self.function_to_functype(fdecl)?;
426        if self.func_types.insert(fdecl.name.clone(), ftype).is_some() {
427            return Err(TypeError::FunctionAlreadyBound(
428                fdecl.name.clone(),
429                fdecl.span.clone(),
430            ));
431        }
432        let mut bound_vars = IndexMap::default();
433        let output_type = self.sorts.get(&fdecl.schema.output).unwrap();
434        if fdecl.subtype == FunctionSubtype::Constructor && !output_type.is_eq_sort() {
435            return Err(TypeError::ConstructorOutputNotSort(
436                fdecl.name.clone(),
437                fdecl.span.clone(),
438            ));
439        }
440        bound_vars.insert("old", (fdecl.span.clone(), output_type.clone()));
441        bound_vars.insert("new", (fdecl.span.clone(), output_type.clone()));
442
443        Ok(ResolvedFunctionDecl {
444            name: fdecl.name.clone(),
445            subtype: fdecl.subtype,
446            schema: fdecl.schema.clone(),
447            merge: match &fdecl.merge {
448                Some(merge) => Some(self.typecheck_expr(symbol_gen, merge, &bound_vars)?),
449                None => None,
450            },
451            cost: fdecl.cost,
452            unextractable: fdecl.unextractable,
453            let_binding: fdecl.let_binding,
454            span: fdecl.span.clone(),
455        })
456    }
457
458    fn typecheck_schedule(
459        &self,
460        symbol_gen: &mut SymbolGen,
461        schedule: &Schedule,
462    ) -> Result<ResolvedSchedule, TypeError> {
463        let schedule = match schedule {
464            Schedule::Repeat(span, times, schedule) => ResolvedSchedule::Repeat(
465                span.clone(),
466                *times,
467                Box::new(self.typecheck_schedule(symbol_gen, schedule)?),
468            ),
469            Schedule::Sequence(span, schedules) => {
470                let schedules = schedules
471                    .iter()
472                    .map(|schedule| self.typecheck_schedule(symbol_gen, schedule))
473                    .collect::<Result<Vec<_>, _>>()?;
474                ResolvedSchedule::Sequence(span.clone(), schedules)
475            }
476            Schedule::Saturate(span, schedule) => ResolvedSchedule::Saturate(
477                span.clone(),
478                Box::new(self.typecheck_schedule(symbol_gen, schedule)?),
479            ),
480            Schedule::Run(span, RunConfig { ruleset, until }) => {
481                let until = until
482                    .as_ref()
483                    .map(|facts| self.typecheck_facts(symbol_gen, facts))
484                    .transpose()?;
485                ResolvedSchedule::Run(
486                    span.clone(),
487                    ResolvedRunConfig {
488                        ruleset: ruleset.clone(),
489                        until,
490                    },
491                )
492            }
493        };
494
495        Result::Ok(schedule)
496    }
497
498    fn typecheck_rule(
499        &self,
500        symbol_gen: &mut SymbolGen,
501        rule: &Rule,
502    ) -> Result<ResolvedRule, TypeError> {
503        let Rule {
504            span,
505            head,
506            body,
507            name,
508            ruleset,
509        } = rule;
510        let mut constraints = vec![];
511
512        let (query, mapped_query) = Facts(body.clone()).to_query(self, symbol_gen);
513        constraints.extend(query.get_constraints(self)?);
514
515        let mut binding = query.get_vars();
516        let (actions, mapped_action) = head.to_core_actions(self, &mut binding, symbol_gen)?;
517
518        let mut problem = Problem::default();
519        problem.add_rule(
520            &CoreRule {
521                span: span.clone(),
522                body: query,
523                head: actions,
524            },
525            self,
526            symbol_gen,
527        )?;
528
529        let assignment = problem
530            .solve(|sort: &ArcSort| sort.name())
531            .map_err(|e| e.to_type_error())?;
532
533        let body: Vec<ResolvedFact> = assignment.annotate_facts(&mapped_query, self);
534        let actions: ResolvedActions = assignment.annotate_actions(&mapped_action, self)?;
535
536        Self::check_lookup_actions(&actions)?;
537
538        Ok(ResolvedRule {
539            span: span.clone(),
540            body,
541            head: actions,
542            name: name.clone(),
543            ruleset: ruleset.clone(),
544        })
545    }
546
547    fn check_lookup_expr(expr: &GenericExpr<ResolvedCall, ResolvedVar>) -> Result<(), TypeError> {
548        match expr {
549            GenericExpr::Call(span, head, args) => {
550                match head {
551                    ResolvedCall::Func(t) => {
552                        // Only allowed to lookup constructor or relation
553                        if t.subtype != FunctionSubtype::Constructor
554                            && t.subtype != FunctionSubtype::Relation
555                        {
556                            Err(TypeError::LookupInRuleDisallowed(
557                                head.to_string(),
558                                span.clone(),
559                            ))
560                        } else {
561                            Ok(())
562                        }
563                    }
564                    ResolvedCall::Primitive(_) => Ok(()),
565                }?;
566                for arg in args.iter() {
567                    Self::check_lookup_expr(arg)?
568                }
569                Ok(())
570            }
571            _ => Ok(()),
572        }
573    }
574
575    fn check_lookup_actions(actions: &ResolvedActions) -> Result<(), TypeError> {
576        for action in actions.iter() {
577            match action {
578                GenericAction::Let(_, _, rhs) => Self::check_lookup_expr(rhs),
579                GenericAction::Set(_, _, args, rhs) => {
580                    for arg in args.iter() {
581                        Self::check_lookup_expr(arg)?
582                    }
583                    Self::check_lookup_expr(rhs)
584                }
585                GenericAction::Union(_, lhs, rhs) => {
586                    Self::check_lookup_expr(lhs)?;
587                    Self::check_lookup_expr(rhs)
588                }
589                GenericAction::Change(_, _, _, args) => {
590                    for arg in args.iter() {
591                        Self::check_lookup_expr(arg)?
592                    }
593                    Ok(())
594                }
595                GenericAction::Panic(..) => Ok(()),
596                GenericAction::Expr(_, expr) => Self::check_lookup_expr(expr),
597            }?
598        }
599        Ok(())
600    }
601
602    fn typecheck_facts(
603        &self,
604        symbol_gen: &mut SymbolGen,
605        facts: &[Fact],
606    ) -> Result<Vec<ResolvedFact>, TypeError> {
607        let (query, mapped_facts) = Facts(facts.to_vec()).to_query(self, symbol_gen);
608        let mut problem = Problem::default();
609        problem.add_query(&query, self)?;
610        let assignment = problem
611            .solve(|sort: &ArcSort| sort.name())
612            .map_err(|e| e.to_type_error())?;
613        let annotated_facts = assignment.annotate_facts(&mapped_facts, self);
614        Ok(annotated_facts)
615    }
616
617    fn typecheck_actions(
618        &self,
619        symbol_gen: &mut SymbolGen,
620        actions: &Actions,
621        binding: &IndexMap<&str, (Span, ArcSort)>,
622    ) -> Result<ResolvedActions, TypeError> {
623        let mut binding_set: IndexSet<String> =
624            binding.keys().copied().map(str::to_string).collect();
625        let (actions, mapped_action) =
626            actions.to_core_actions(self, &mut binding_set, symbol_gen)?;
627        let mut problem = Problem::default();
628
629        // add actions to problem
630        problem.add_actions(&actions, self, symbol_gen)?;
631
632        // add bindings from the context
633        for (var, (span, sort)) in binding {
634            problem.assign_local_var_type(var, span.clone(), sort.clone())?;
635        }
636
637        let assignment = problem
638            .solve(|sort: &ArcSort| sort.name())
639            .map_err(|e| e.to_type_error())?;
640
641        let annotated_actions = assignment.annotate_actions(&mapped_action, self)?;
642        Ok(annotated_actions)
643    }
644
645    fn typecheck_expr(
646        &self,
647        symbol_gen: &mut SymbolGen,
648        expr: &Expr,
649        binding: &IndexMap<&str, (Span, ArcSort)>,
650    ) -> Result<ResolvedExpr, TypeError> {
651        let action = Action::Expr(expr.span(), expr.clone());
652        let typechecked_action = self.typecheck_action(symbol_gen, &action, binding)?;
653        match typechecked_action {
654            ResolvedAction::Expr(_, expr) => Ok(expr),
655            _ => unreachable!(),
656        }
657    }
658
659    fn typecheck_action(
660        &self,
661        symbol_gen: &mut SymbolGen,
662        action: &Action,
663        binding: &IndexMap<&str, (Span, ArcSort)>,
664    ) -> Result<ResolvedAction, TypeError> {
665        self.typecheck_actions(symbol_gen, &Actions::singleton(action.clone()), binding)
666            .map(|v| {
667                assert_eq!(v.len(), 1);
668                v.0.into_iter().next().unwrap()
669            })
670    }
671
672    pub fn get_sort_by_name(&self, sym: &str) -> Option<&ArcSort> {
673        self.sorts.get(sym)
674    }
675
676    pub fn get_prims(&self, sym: &str) -> Option<&[PrimitiveWithId]> {
677        self.primitives.get(sym).map(Vec::as_slice)
678    }
679
680    pub fn is_primitive(&self, sym: &str) -> bool {
681        self.primitives.contains_key(sym) || self.reserved_primitives.contains(sym)
682    }
683
684    pub fn get_func_type(&self, sym: &str) -> Option<&FuncType> {
685        self.func_types.get(sym)
686    }
687
688    pub(crate) fn is_constructor(&self, sym: &str) -> bool {
689        self.func_types
690            .get(sym)
691            .is_some_and(|f| f.subtype == FunctionSubtype::Constructor)
692    }
693
694    pub fn get_global_sort(&self, sym: &str) -> Option<&ArcSort> {
695        self.global_sorts.get(sym)
696    }
697
698    pub fn is_global(&self, sym: &str) -> bool {
699        self.global_sorts.contains_key(sym)
700    }
701}
702
703#[derive(Debug, Clone, Error)]
704pub enum TypeError {
705    #[error("{}\nArity mismatch, expected {expected} args: {expr}", .expr.span())]
706    Arity { expr: Expr, expected: usize },
707    #[error(
708        "{}\n Expect expression {expr} to have type {}, but get type {}",
709        .expr.span(), .expected.name(), .actual.name(),
710    )]
711    Mismatch {
712        expr: Expr,
713        expected: ArcSort,
714        actual: ArcSort,
715    },
716    #[error("{1}\nUnbound symbol {0}")]
717    Unbound(String, Span),
718    #[error("{1}\nVariable {0} is ungrounded")]
719    Ungrounded(String, Span),
720    #[error("{1}\nUndefined sort {0}")]
721    UndefinedSort(String, Span),
722    #[error("{2}\nSort {0} definition is disallowed: {1}")]
723    DisallowedSort(String, String, Span),
724    #[error("{1}\nUnbound function {0}")]
725    UnboundFunction(String, Span),
726    #[error("{1}\nFunction already bound {0}")]
727    FunctionAlreadyBound(String, Span),
728    #[error("{1}\nSort {0} already declared.")]
729    SortAlreadyBound(String, Span),
730    #[error("{1}\nPrimitive {0} already declared.")]
731    PrimitiveAlreadyBound(String, Span),
732    #[error("Function type mismatch: expected {} => {}, actual {} => {}", .1.iter().map(|s| s.name().to_string()).collect::<Vec<_>>().join(", "), .0.name(), .3.iter().map(|s| s.name().to_string()).collect::<Vec<_>>().join(", "), .2.name())]
733    FunctionTypeMismatch(ArcSort, Vec<ArcSort>, ArcSort, Vec<ArcSort>),
734    #[error("{1}\nPresort {0} not found.")]
735    PresortNotFound(String, Span),
736    #[error("{}\nFailed to infer a type for: {}", .0.span(), .0)]
737    InferenceFailure(Expr),
738    #[error("{1}\nVariable {0} was already defined")]
739    AlreadyDefined(String, Span),
740    #[error("{1}\nThe output type of constructor function {0} must be sort")]
741    ConstructorOutputNotSort(String, Span),
742    #[error("{1}\nValue lookup of non-constructor function {0} in rule is disallowed.")]
743    LookupInRuleDisallowed(String, Span),
744    #[error("All alternative definitions considered failed\n{}", .0.iter().map(|e| format!("  {e}\n")).collect::<Vec<_>>().join(""))]
745    AllAlternativeFailed(Vec<TypeError>),
746    #[error("{}\nCannot union values of sort {}", .1, .0.name())]
747    NonEqsortUnion(ArcSort, Span),
748    #[error(
749        "{span}\nNon-global variable `{name}` must not start with `{}`.",
750        crate::GLOBAL_NAME_PREFIX
751    )]
752    NonGlobalPrefixed { name: String, span: Span },
753    #[error(
754        "{span}\nGlobal `{name}` must start with `{}`.",
755        crate::GLOBAL_NAME_PREFIX
756    )]
757    GlobalMissingPrefix { name: String, span: Span },
758}
759
760#[cfg(test)]
761mod test {
762    use crate::{EGraph, Error, typechecking::TypeError};
763
764    #[test]
765    fn test_arity_mismatch() {
766        let mut egraph = EGraph::default();
767
768        let prog = "
769            (relation f (i64 i64))
770            (rule ((f a b c)) ())
771       ";
772        let res = egraph.parse_and_run_program(None, prog);
773        match res {
774            Err(Error::TypeError(TypeError::Arity {
775                expected: 2,
776                expr: e,
777            })) => {
778                assert_eq!(e.span().string(), "(f a b c)");
779            }
780            _ => panic!("Expected arity mismatch, got: {:?}", res),
781        }
782    }
783}