egglog/
typechecking.rs

1use std::hash::Hasher;
2
3use crate::{
4    core::{CoreActionContext, CoreRule, GenericActionsExt, ResolvedCall},
5    *,
6};
7use ast::{ResolvedAction, ResolvedExpr, ResolvedFact, ResolvedRule, ResolvedVar, Rule};
8use core_relations::ExternalFunction;
9use egglog_ast::generic_ast::GenericAction;
10
11#[derive(Clone, Debug)]
12pub struct FuncType {
13    pub name: String,
14    pub subtype: FunctionSubtype,
15    pub input: Vec<ArcSort>,
16    pub output: ArcSort,
17}
18
19impl PartialEq for FuncType {
20    fn eq(&self, other: &Self) -> bool {
21        if self.name == other.name
22            && self.subtype == other.subtype
23            && self.output.name() == other.output.name()
24        {
25            if self.input.len() != other.input.len() {
26                return false;
27            }
28            for (a, b) in self.input.iter().zip(other.input.iter()) {
29                if a.name() != b.name() {
30                    return false;
31                }
32            }
33            true
34        } else {
35            false
36        }
37    }
38}
39
40impl Eq for FuncType {}
41
42impl Hash for FuncType {
43    fn hash<H: Hasher>(&self, state: &mut H) {
44        self.name.hash(state);
45        self.subtype.hash(state);
46        self.output.name().hash(state);
47        for inp in &self.input {
48            inp.name().hash(state);
49        }
50    }
51}
52/// Validators take a termdag and arguments (as TermIds) and return
53/// a newly computed TermId if the primitive application is valid,
54/// or None if it is invalid.
55pub type PrimitiveValidator = Arc<dyn Fn(&mut TermDag, &[TermId]) -> Option<TermId> + Send + Sync>;
56
57#[derive(Clone)]
58pub struct PrimitiveWithId {
59    pub(crate) primitive: Arc<dyn Primitive + Send + Sync>,
60    pub(crate) id: ExternalFunctionId,
61    pub(crate) validator: Option<PrimitiveValidator>,
62}
63
64impl PrimitiveWithId {
65    /// Takes the full signature of a primitive (both input and output types).
66    /// Returns whether the primitive is compatible with this signature.
67    pub fn accept(&self, tys: &[Arc<dyn Sort>], typeinfo: &TypeInfo) -> bool {
68        let mut constraints = vec![];
69        let lits: Vec<_> = (0..tys.len())
70            .map(|i| AtomTerm::Literal(Span::Panic, Literal::Int(i as i64)))
71            .collect();
72        for (lit, ty) in lits.iter().zip(tys.iter()) {
73            constraints.push(constraint::assign(lit.clone(), ty.clone()))
74        }
75        constraints.extend(
76            self.primitive
77                .get_type_constraints(&Span::Panic)
78                .get(&lits, typeinfo),
79        );
80        let problem = Problem {
81            constraints,
82            range: HashSet::default(),
83        };
84        problem.solve(|sort| sort.name()).is_ok()
85    }
86}
87
88impl Debug for PrimitiveWithId {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        write!(f, "Prim({})", self.primitive.name())
91    }
92}
93
94/// Stores resolved typechecking information.
95#[derive(Clone, Default)]
96pub struct TypeInfo {
97    mksorts: HashMap<String, MkSort>,
98    // TODO(yz): I want to get rid of this as now we have user-defined primitives and constraint based type checking
99    reserved_primitives: HashSet<&'static str>,
100    pub(crate) sorts: HashMap<String, Arc<dyn Sort>>,
101    primitives: HashMap<String, Vec<PrimitiveWithId>>,
102    func_types: HashMap<String, FuncType>,
103    pub(crate) global_sorts: HashMap<String, ArcSort>,
104    /// Sorts that do not allow union (e.g., from `:no-union` sorts or relations).
105    pub(crate) non_unionable_sorts: HashSet<String>,
106}
107
108// These methods need to be on the `EGraph` in order to
109// register sorts and primitives with the backend.
110impl EGraph {
111    /// Add a user-defined sort to the e-graph.
112    ///
113    /// Also look at [`prelude::add_base_sort`] for a convenience method for adding user-defined sorts
114    pub fn add_sort<S: Sort + 'static>(&mut self, sort: S, span: Span) -> Result<(), TypeError> {
115        self.add_arcsort(Arc::new(sort), span)
116    }
117
118    /// Declare a sort. This corresponds to the `sort` keyword in egglog.
119    /// It can either declares a new [`EqSort`] if `presort_and_args` is not provided,
120    /// or an instantiation of a presort (e.g., containers like `Vec`).
121    pub fn declare_sort(
122        &mut self,
123        name: impl Into<String>,
124        presort_and_args: &Option<(String, Vec<Expr>)>,
125        span: Span,
126    ) -> Result<(), TypeError> {
127        let name = name.into();
128        if self.type_info.func_types.contains_key(&name) {
129            return Err(TypeError::FunctionAlreadyBound(name, span));
130        }
131
132        let sort = match presort_and_args {
133            None => Arc::new(EqSort { name }),
134            Some((presort, args)) => {
135                if let Some(mksort) = self.type_info.mksorts.get(presort) {
136                    mksort(&mut self.type_info, name, args)?
137                } else {
138                    return Err(TypeError::PresortNotFound(presort.clone(), span));
139                }
140            }
141        };
142
143        self.add_arcsort(sort, span)
144    }
145
146    /// Add a user-defined sort to the e-graph.
147    pub fn add_arcsort(&mut self, sort: ArcSort, span: Span) -> Result<(), TypeError> {
148        sort.register_type(&mut self.backend);
149
150        let name = sort.name();
151        match self.type_info.sorts.entry(name.to_owned()) {
152            HEntry::Occupied(_) => Err(TypeError::SortAlreadyBound(name.to_owned(), span)),
153            HEntry::Vacant(e) => {
154                e.insert(sort.clone());
155                sort.register_primitives(self);
156                Ok(())
157            }
158        }
159    }
160
161    /// Add a user-defined primitive
162    pub fn add_primitive<T>(&mut self, x: T)
163    where
164        T: Clone + Primitive + Send + Sync + 'static,
165    {
166        self.add_primitive_with_validator(x, None)
167    }
168
169    /// Add a user-defined primitive with an optional validator
170    pub fn add_primitive_with_validator<T>(&mut self, x: T, validator: Option<PrimitiveValidator>)
171    where
172        T: Clone + Primitive + Send + Sync + 'static,
173    {
174        // We need to use a wrapper because of the orphan rule.
175        // If we just try to implement `ExternalFunction` directly on
176        // all `PrimitiveLike`s then it would be possible for a
177        // downstream crate to create a conflict.
178        #[derive(Clone)]
179        struct Wrapper<T>(T);
180        impl<T: Clone + Primitive + Send + Sync> ExternalFunction for Wrapper<T> {
181            fn invoke(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
182                self.0.apply(exec_state, args)
183            }
184        }
185
186        let primitive = Arc::new(x.clone());
187        let id = self.backend.register_external_func(Box::new(Wrapper(x)));
188        self.type_info
189            .primitives
190            .entry(primitive.name().to_owned())
191            .or_default()
192            .push(PrimitiveWithId {
193                primitive,
194                id,
195                validator,
196            });
197    }
198
199    pub(crate) fn typecheck_program(
200        &mut self,
201        program: &Vec<NCommand>,
202    ) -> Result<Vec<ResolvedNCommand>, TypeError> {
203        let mut result = vec![];
204        for command in program {
205            result.push(self.typecheck_command(command)?);
206        }
207        Ok(result)
208    }
209
210    fn typecheck_command(&mut self, command: &NCommand) -> Result<ResolvedNCommand, TypeError> {
211        let symbol_gen = &mut self.parser.symbol_gen;
212
213        let command: ResolvedNCommand = match command {
214            NCommand::Function(fdecl) => {
215                let resolved = self.type_info.typecheck_function(symbol_gen, fdecl)?;
216                // If this is a let binding, add it to global_sorts
217                // This preserves bahavior for lets after desugaring
218                if resolved.internal_let {
219                    let output_sort = self.type_info.sorts.get(&fdecl.schema.output).unwrap();
220                    self.type_info
221                        .global_sorts
222                        .insert(fdecl.name.clone(), output_sort.clone());
223                }
224                ResolvedNCommand::Function(resolved)
225            }
226            NCommand::NormRule { rule } => ResolvedNCommand::NormRule {
227                rule: self.type_info.typecheck_rule(symbol_gen, rule)?,
228            },
229            NCommand::Sort {
230                span,
231                name,
232                presort_and_args,
233                uf,
234                proof_func,
235                unionable,
236            } => {
237                // Note this is bad since typechecking should be pure and idempotent
238                // Otherwise typechecking the same program twice will fail
239                self.declare_sort(name.clone(), presort_and_args, span.clone())?;
240                // Mark as non-unionable if the sort declaration says so
241                if !unionable {
242                    self.type_info.non_unionable_sorts.insert(name.clone());
243                }
244                ResolvedNCommand::Sort {
245                    span: span.clone(),
246                    name: name.clone(),
247                    presort_and_args: presort_and_args.clone(),
248                    uf: uf.clone(),
249                    proof_func: proof_func.clone(),
250                    unionable: *unionable,
251                }
252            }
253            NCommand::CoreAction(Action::Let(span, var, expr)) => {
254                let expr = self
255                    .type_info
256                    .typecheck_expr(symbol_gen, expr, &Default::default())?;
257                let output_type = expr.output_type();
258                self.ensure_global_name_prefix(span, var)?;
259                self.type_info
260                    .global_sorts
261                    .insert(var.clone(), output_type.clone());
262                let var = ResolvedVar {
263                    name: var.clone(),
264                    sort: output_type,
265                    // not a global reference, but a global binding
266                    is_global_ref: false,
267                };
268                ResolvedNCommand::CoreAction(ResolvedAction::Let(span.clone(), var, expr))
269            }
270            NCommand::CoreAction(action) => ResolvedNCommand::CoreAction(
271                self.type_info
272                    .typecheck_action(symbol_gen, action, &Default::default())?,
273            ),
274            NCommand::Extract(span, expr, variants) => {
275                let res_expr =
276                    self.type_info
277                        .typecheck_expr(symbol_gen, expr, &Default::default())?;
278
279                let res_variants =
280                    self.type_info
281                        .typecheck_expr(symbol_gen, variants, &Default::default())?;
282                if res_variants.output_type().name() != I64Sort.name() {
283                    return Err(TypeError::Mismatch {
284                        expr: variants.clone(),
285                        expected: I64Sort.to_arcsort(),
286                        actual: res_variants.output_type(),
287                    });
288                }
289
290                ResolvedNCommand::Extract(span.clone(), res_expr, res_variants)
291            }
292            NCommand::Check(span, facts) => ResolvedNCommand::Check(
293                span.clone(),
294                self.type_info.typecheck_facts(symbol_gen, facts)?,
295            ),
296            NCommand::Fail(span, cmd) => {
297                ResolvedNCommand::Fail(span.clone(), Box::new(self.typecheck_command(cmd)?))
298            }
299            NCommand::RunSchedule(schedule) => ResolvedNCommand::RunSchedule(
300                self.type_info.typecheck_schedule(symbol_gen, schedule)?,
301            ),
302            NCommand::Pop(span, n) => ResolvedNCommand::Pop(span.clone(), *n),
303            NCommand::Push(n) => ResolvedNCommand::Push(*n),
304            NCommand::AddRuleset(span, ruleset) => {
305                ResolvedNCommand::AddRuleset(span.clone(), ruleset.clone())
306            }
307            NCommand::UnstableCombinedRuleset(span, name, sub_rulesets) => {
308                ResolvedNCommand::UnstableCombinedRuleset(
309                    span.clone(),
310                    name.clone(),
311                    sub_rulesets.clone(),
312                )
313            }
314            NCommand::PrintOverallStatistics(span, file) => {
315                ResolvedNCommand::PrintOverallStatistics(span.clone(), file.clone())
316            }
317            NCommand::PrintFunction(span, table, size, file, mode) => {
318                ResolvedNCommand::PrintFunction(
319                    span.clone(),
320                    table.clone(),
321                    *size,
322                    file.clone(),
323                    *mode,
324                )
325            }
326            NCommand::PrintSize(span, n) => {
327                // Should probably also resolve the function symbol here
328                ResolvedNCommand::PrintSize(span.clone(), n.clone())
329            }
330            NCommand::ProveExists(span, constructor) => {
331                let func_type = self
332                    .type_info
333                    .get_func_type(constructor)
334                    .ok_or_else(|| TypeError::UnboundFunction(constructor.clone(), span.clone()))?;
335                if func_type.subtype != FunctionSubtype::Constructor {
336                    return Err(TypeError::ProveExistsRequiresConstructor(
337                        constructor.clone(),
338                        span.clone(),
339                    ));
340                }
341                ResolvedNCommand::ProveExists(span.clone(), ResolvedCall::Func(func_type.clone()))
342            }
343            NCommand::Output { span, file, exprs } => {
344                let exprs = exprs
345                    .iter()
346                    .map(|expr| {
347                        self.type_info
348                            .typecheck_expr(symbol_gen, expr, &Default::default())
349                    })
350                    .collect::<Result<Vec<_>, _>>()?;
351                ResolvedNCommand::Output {
352                    span: span.clone(),
353                    file: file.clone(),
354                    exprs,
355                }
356            }
357            NCommand::Input { span, name, file } => ResolvedNCommand::Input {
358                span: span.clone(),
359                name: name.clone(),
360                file: file.clone(),
361            },
362            NCommand::UserDefined(span, name, exprs) => {
363                ResolvedNCommand::UserDefined(span.clone(), name.clone(), exprs.clone())
364            }
365        };
366        if let ResolvedNCommand::NormRule { rule } = &command {
367            self.warn_for_prefixed_non_globals_in_rule(rule)?;
368        }
369        Ok(command)
370    }
371
372    fn warn_for_prefixed_non_globals_in_var(
373        &mut self,
374        span: &Span,
375        var: &ResolvedVar,
376    ) -> Result<(), TypeError> {
377        if var.is_global_ref {
378            return Ok(());
379        }
380        if var.name.starts_with(crate::GLOBAL_NAME_PREFIX) {
381            self.warn_prefixed_non_globals(span, &var.name)?;
382        }
383        Ok(())
384    }
385
386    fn warn_for_prefixed_non_globals_in_rule(
387        &mut self,
388        rule: &ResolvedRule,
389    ) -> Result<(), TypeError> {
390        let mut res: Result<(), TypeError> = Ok(());
391
392        for fact in &rule.body {
393            fact.visit_vars(&mut |span, var| {
394                if res.is_ok() {
395                    res = self.warn_for_prefixed_non_globals_in_var(span, var);
396                }
397            });
398        }
399
400        rule.head.visit_vars(&mut |span, var| {
401            if res.is_ok() {
402                res = self.warn_for_prefixed_non_globals_in_var(span, var);
403            }
404        });
405        res
406    }
407}
408
409impl TypeInfo {
410    /// Adds a sort constructor to the typechecker's known set of types.
411    pub fn add_presort<S: Presort>(&mut self, span: Span) -> Result<(), TypeError> {
412        let name = S::presort_name();
413        match self.mksorts.entry(name.to_owned()) {
414            HEntry::Occupied(_) => Err(TypeError::SortAlreadyBound(name.to_owned(), span)),
415            HEntry::Vacant(e) => {
416                e.insert(S::make_sort);
417                self.reserved_primitives.extend(S::reserved_primitives());
418                Ok(())
419            }
420        }
421    }
422
423    /// Returns all sorts that satisfy the type and predicate.
424    pub fn get_sorts_by<S: Sort>(&self, pred: impl Fn(&Arc<S>) -> bool) -> Vec<Arc<S>> {
425        let mut results = Vec::new();
426        for sort in self.sorts.values() {
427            let sort = sort.clone().as_arc_any();
428            if let Ok(sort) = Arc::downcast(sort) {
429                if pred(&sort) {
430                    results.push(sort);
431                }
432            }
433        }
434        results
435    }
436
437    /// Returns all sorts based on the type.
438    pub fn get_sorts<S: Sort>(&self) -> Vec<Arc<S>> {
439        self.get_sorts_by(|_| true)
440    }
441
442    /// Returns a sort that satisfies the type and predicate.
443    pub fn get_sort_by<S: Sort>(&self, pred: impl Fn(&Arc<S>) -> bool) -> Arc<S> {
444        let results = self.get_sorts_by(pred);
445        assert_eq!(
446            results.len(),
447            1,
448            "Expected exactly one sort for type {}",
449            std::any::type_name::<S>()
450        );
451        results.into_iter().next().unwrap()
452    }
453
454    /// Returns a sort based on the type.
455    pub fn get_sort<S: Sort>(&self) -> Arc<S> {
456        self.get_sort_by(|_| true)
457    }
458
459    /// Returns all sorts that satisfy the predicate.
460    pub fn get_arcsorts_by(&self, f: impl Fn(&ArcSort) -> bool) -> Vec<ArcSort> {
461        self.sorts.values().filter(|&x| f(x)).cloned().collect()
462    }
463
464    /// Returns a sort based on the predicate.
465    pub fn get_arcsort_by(&self, f: impl Fn(&ArcSort) -> bool) -> ArcSort {
466        let results = self.get_arcsorts_by(f);
467        assert_eq!(
468            results.len(),
469            1,
470            "Expected exactly one sort for type {}",
471            std::any::type_name::<S>()
472        );
473        results.into_iter().next().unwrap()
474    }
475
476    /// Check if a sort allows union operations.
477    /// A sort is unionable if it's an eq_sort and not marked as non-unionable
478    /// (e.g., from `(sort Foo :no-union)` or relation desugaring).
479    pub fn is_sort_unionable(&self, sort: &ArcSort) -> bool {
480        sort.is_eq_sort() && !self.non_unionable_sorts.contains(sort.name())
481    }
482
483    fn function_to_functype(&self, func: &FunctionDecl) -> Result<FuncType, TypeError> {
484        let input = func
485            .schema
486            .input
487            .iter()
488            .map(|name| {
489                if let Some(sort) = self.sorts.get(name) {
490                    Ok(sort.clone())
491                } else {
492                    Err(TypeError::UndefinedSort(name.clone(), func.span.clone()))
493                }
494            })
495            .collect::<Result<Vec<_>, _>>()?;
496        let output = if let Some(sort) = self.sorts.get(&func.schema.output) {
497            Ok(sort.clone())
498        } else {
499            Err(TypeError::UndefinedSort(
500                func.schema.output.clone(),
501                func.span.clone(),
502            ))
503        }?;
504
505        Ok(FuncType {
506            name: func.name.clone(),
507            subtype: func.subtype,
508            input,
509            output: output.clone(),
510        })
511    }
512
513    fn typecheck_function(
514        &mut self,
515        symbol_gen: &mut SymbolGen,
516        fdecl: &FunctionDecl,
517    ) -> Result<ResolvedFunctionDecl, TypeError> {
518        if self.sorts.contains_key(&fdecl.name) {
519            return Err(TypeError::SortAlreadyBound(
520                fdecl.name.clone(),
521                fdecl.span.clone(),
522            ));
523        }
524        if self.is_primitive(&fdecl.name) {
525            return Err(TypeError::PrimitiveAlreadyBound(
526                fdecl.name.clone(),
527                fdecl.span.clone(),
528            ));
529        }
530        // View tables (with term_constructor) must have at least one input (the e-class)
531        if fdecl.term_constructor.is_some() && fdecl.schema.input.is_empty() {
532            return Err(TypeError::TermConstructorNoInputs(
533                fdecl.name.clone(),
534                fdecl.span.clone(),
535            ));
536        }
537        let ftype = self.function_to_functype(fdecl)?;
538        if self.func_types.insert(fdecl.name.clone(), ftype).is_some() {
539            return Err(TypeError::FunctionAlreadyBound(
540                fdecl.name.clone(),
541                fdecl.span.clone(),
542            ));
543        }
544        let mut bound_vars = IndexMap::default();
545        let output_type = self.sorts.get(&fdecl.schema.output).unwrap();
546        if fdecl.subtype == FunctionSubtype::Constructor && !output_type.is_eq_sort() {
547            return Err(TypeError::ConstructorOutputNotSort(
548                fdecl.name.clone(),
549                fdecl.span.clone(),
550            ));
551        }
552        bound_vars.insert("old", (fdecl.span.clone(), output_type.clone()));
553        bound_vars.insert("new", (fdecl.span.clone(), output_type.clone()));
554
555        Ok(ResolvedFunctionDecl {
556            name: fdecl.name.clone(),
557            subtype: fdecl.subtype,
558            schema: fdecl.schema.clone(),
559            resolved_schema: ResolvedCall::Func(self.func_types.get(&fdecl.name).unwrap().clone()),
560            merge: match &fdecl.merge {
561                Some(merge) => Some(self.typecheck_expr(symbol_gen, merge, &bound_vars)?),
562                None => None,
563            },
564            cost: fdecl.cost,
565            unextractable: fdecl.unextractable,
566            internal_hidden: fdecl.internal_hidden,
567            internal_let: fdecl.internal_let,
568            span: fdecl.span.clone(),
569            term_constructor: fdecl.term_constructor.clone(),
570        })
571    }
572
573    fn typecheck_schedule(
574        &self,
575        symbol_gen: &mut SymbolGen,
576        schedule: &Schedule,
577    ) -> Result<ResolvedSchedule, TypeError> {
578        let schedule = match schedule {
579            Schedule::Repeat(span, times, schedule) => ResolvedSchedule::Repeat(
580                span.clone(),
581                *times,
582                Box::new(self.typecheck_schedule(symbol_gen, schedule)?),
583            ),
584            Schedule::Sequence(span, schedules) => {
585                let schedules = schedules
586                    .iter()
587                    .map(|schedule| self.typecheck_schedule(symbol_gen, schedule))
588                    .collect::<Result<Vec<_>, _>>()?;
589                ResolvedSchedule::Sequence(span.clone(), schedules)
590            }
591            Schedule::Saturate(span, schedule) => ResolvedSchedule::Saturate(
592                span.clone(),
593                Box::new(self.typecheck_schedule(symbol_gen, schedule)?),
594            ),
595            Schedule::Run(span, RunConfig { ruleset, until }) => {
596                let until = until
597                    .as_ref()
598                    .map(|facts| self.typecheck_facts(symbol_gen, facts))
599                    .transpose()?;
600                ResolvedSchedule::Run(
601                    span.clone(),
602                    ResolvedRunConfig {
603                        ruleset: ruleset.clone(),
604                        until,
605                    },
606                )
607            }
608        };
609
610        Result::Ok(schedule)
611    }
612
613    fn typecheck_rule(
614        &self,
615        symbol_gen: &mut SymbolGen,
616        rule: &Rule,
617    ) -> Result<ResolvedRule, TypeError> {
618        let Rule {
619            span,
620            head,
621            body,
622            name,
623            ruleset,
624        } = rule;
625        let mut constraints = vec![];
626
627        let (query, mapped_query) = Facts(body.clone()).to_query(self, symbol_gen);
628        constraints.extend(query.get_constraints(self)?);
629
630        let mut binding = query.get_vars();
631        // We lower to core actions with `union_to_set_optimization`
632        // later in the pipeline. For typechecking we do not need it.
633        let mut ctx = CoreActionContext::new(self, &mut binding, symbol_gen, false);
634        let (actions, mapped_action) = head.to_core_actions(&mut ctx)?;
635
636        let mut problem = Problem::default();
637        problem.add_rule(
638            &CoreRule {
639                span: span.clone(),
640                body: query,
641                head: actions,
642            },
643            self,
644            symbol_gen,
645        )?;
646
647        let assignment = problem
648            .solve(|sort: &ArcSort| sort.name())
649            .map_err(|e| e.to_type_error())?;
650
651        let body: Vec<ResolvedFact> = assignment.annotate_facts(&mapped_query, self);
652        let actions: ResolvedActions = assignment.annotate_actions(&mapped_action, self)?;
653
654        self.check_lookup_actions(&actions)?;
655
656        Ok(ResolvedRule {
657            span: span.clone(),
658            body,
659            head: actions,
660            name: name.clone(),
661            ruleset: ruleset.clone(),
662        })
663    }
664
665    fn check_lookup_expr(&self, expr: &ResolvedExpr) -> Result<(), TypeError> {
666        if let Some(span) = self.expr_has_function_lookup(expr) {
667            return Err(TypeError::LookupInRuleDisallowed(
668                "function".to_string(),
669                span,
670            ));
671        }
672        Ok(())
673    }
674
675    fn check_lookup_actions(&self, actions: &ResolvedActions) -> Result<(), TypeError> {
676        for action in actions.iter() {
677            match action {
678                GenericAction::Let(_, _, rhs) => self.check_lookup_expr(rhs)?,
679                GenericAction::Set(_, _, args, rhs) => {
680                    for arg in args.iter() {
681                        self.check_lookup_expr(arg)?;
682                    }
683                    self.check_lookup_expr(rhs)?;
684                }
685                GenericAction::Union(_, lhs, rhs) => {
686                    self.check_lookup_expr(lhs)?;
687                    self.check_lookup_expr(rhs)?;
688                }
689                GenericAction::Change(_, _, _, args) => {
690                    for arg in args.iter() {
691                        self.check_lookup_expr(arg)?;
692                    }
693                }
694                GenericAction::Panic(..) => {}
695                GenericAction::Expr(_, expr) => self.check_lookup_expr(expr)?,
696            }
697        }
698        Ok(())
699    }
700
701    pub fn typecheck_facts(
702        &self,
703        symbol_gen: &mut SymbolGen,
704        facts: &[Fact],
705    ) -> Result<Vec<ResolvedFact>, TypeError> {
706        let (query, mapped_facts) = Facts(facts.to_vec()).to_query(self, symbol_gen);
707        let mut problem = Problem::default();
708        problem.add_query(&query, self)?;
709        let assignment = problem
710            .solve(|sort: &ArcSort| sort.name())
711            .map_err(|e| e.to_type_error())?;
712        let annotated_facts = assignment.annotate_facts(&mapped_facts, self);
713        Ok(annotated_facts)
714    }
715
716    fn typecheck_actions(
717        &self,
718        symbol_gen: &mut SymbolGen,
719        actions: &Actions,
720        binding: &IndexMap<&str, (Span, ArcSort)>,
721    ) -> Result<ResolvedActions, TypeError> {
722        let mut binding_set: IndexSet<String> =
723            binding.keys().copied().map(str::to_string).collect();
724        // We lower to core actions with `union_to_set_optimization`
725        // later in the pipeline. For typechecking we do not need it.
726        let mut ctx = CoreActionContext::new(self, &mut binding_set, symbol_gen, false);
727        let (actions, mapped_action) = actions.to_core_actions(&mut ctx)?;
728        let mut problem = Problem::default();
729
730        // add actions to problem
731        problem.add_actions(&actions, self, symbol_gen)?;
732
733        // add bindings from the context
734        for (var, (span, sort)) in binding {
735            problem.assign_local_var_type(var, span.clone(), sort.clone())?;
736        }
737
738        let assignment = problem
739            .solve(|sort: &ArcSort| sort.name())
740            .map_err(|e| e.to_type_error())?;
741
742        let annotated_actions = assignment.annotate_actions(&mapped_action, self)?;
743        Ok(annotated_actions)
744    }
745
746    fn typecheck_expr(
747        &self,
748        symbol_gen: &mut SymbolGen,
749        expr: &Expr,
750        binding: &IndexMap<&str, (Span, ArcSort)>,
751    ) -> Result<ResolvedExpr, TypeError> {
752        let action = Action::Expr(expr.span(), expr.clone());
753        let typechecked_action = self.typecheck_action(symbol_gen, &action, binding)?;
754        match typechecked_action {
755            ResolvedAction::Expr(_, expr) => Ok(expr),
756            _ => unreachable!(),
757        }
758    }
759
760    fn typecheck_action(
761        &self,
762        symbol_gen: &mut SymbolGen,
763        action: &Action,
764        binding: &IndexMap<&str, (Span, ArcSort)>,
765    ) -> Result<ResolvedAction, TypeError> {
766        self.typecheck_actions(symbol_gen, &Actions::singleton(action.clone()), binding)
767            .map(|v| {
768                assert_eq!(v.len(), 1);
769                v.0.into_iter().next().unwrap()
770            })
771    }
772
773    pub fn get_sort_by_name(&self, sym: &str) -> Option<&ArcSort> {
774        self.sorts.get(sym)
775    }
776
777    pub fn get_prims(&self, sym: &str) -> Option<&[PrimitiveWithId]> {
778        self.primitives.get(sym).map(Vec::as_slice)
779    }
780
781    pub fn is_primitive(&self, sym: &str) -> bool {
782        self.primitives.contains_key(sym) || self.reserved_primitives.contains(sym)
783    }
784
785    pub fn primitive_has_validator(&self, id: ExternalFunctionId) -> bool {
786        self.primitives
787            .values()
788            .flat_map(|v| v.iter())
789            .any(|p| p.id == id && p.validator.is_some())
790    }
791
792    pub fn get_func_type(&self, sym: &str) -> Option<&FuncType> {
793        self.func_types.get(sym)
794    }
795
796    pub fn is_constructor(&self, sym: &str) -> bool {
797        self.func_types
798            .get(sym)
799            .is_some_and(|f| f.subtype == FunctionSubtype::Constructor)
800    }
801
802    pub fn get_global_sort(&self, sym: &str) -> Option<&ArcSort> {
803        self.global_sorts.get(sym)
804    }
805
806    pub fn is_global(&self, sym: &str) -> bool {
807        self.global_sorts.contains_key(sym)
808    }
809
810    /// Check if an expression contains non-global function lookups (FunctionSubtype::Custom calls).
811    /// Global function calls are allowed since they get desugared to constructors.
812    /// Returns Some(span) if a lookup is found, None otherwise.
813    pub fn expr_has_function_lookup(&self, expr: &ResolvedExpr) -> Option<Span> {
814        use ast::GenericExpr;
815
816        expr.find(&mut |e| {
817            if let GenericExpr::Call(span, ResolvedCall::Func(func_type), _) = e {
818                if func_type.subtype == FunctionSubtype::Custom {
819                    // Skip global functions - they get desugared to constructors
820                    if !self.is_global(&func_type.name) {
821                        return Some(span.clone());
822                    }
823                }
824            }
825            None
826        })
827    }
828}
829
830#[derive(Debug, Clone, Error)]
831pub enum TypeError {
832    #[error("{}\nArity mismatch, expected {expected} args: {expr}", .expr.span())]
833    Arity { expr: Expr, expected: usize },
834    #[error(
835        "{}\n Expect expression {expr} to have type {}, but get type {}",
836        .expr.span(), .expected.name(), .actual.name(),
837    )]
838    Mismatch {
839        expr: Expr,
840        expected: ArcSort,
841        actual: ArcSort,
842    },
843    #[error("{1}\nUnbound symbol {0}")]
844    Unbound(String, Span),
845    #[error("{1}\nVariable {0} is ungrounded")]
846    Ungrounded(String, Span),
847    #[error("{1}\nUndefined sort {0}")]
848    UndefinedSort(String, Span),
849    #[error("{2}\nSort {0} definition is disallowed: {1}")]
850    DisallowedSort(String, String, Span),
851    #[error("{1}\nUnbound function {0}")]
852    UnboundFunction(String, Span),
853    #[error("{1}\nprove-exists requires constructor function, but {0} is not a constructor")]
854    ProveExistsRequiresConstructor(String, Span),
855    #[error("{1}\nFunction already bound {0}")]
856    FunctionAlreadyBound(String, Span),
857    #[error("{1}\nSort {0} already declared.")]
858    SortAlreadyBound(String, Span),
859    #[error("{1}\nPrimitive {0} already declared.")]
860    PrimitiveAlreadyBound(String, Span),
861    #[error("Function type mismatch: expected {} => {}, actual {} => {}", .1.iter().map(|s| s.name().to_string()).collect::<Vec<_>>().join(", "), .0.name(), .3.iter().map(|s| s.name().to_string()).collect::<Vec<_>>().join(", "), .2.name())]
862    FunctionTypeMismatch(ArcSort, Vec<ArcSort>, ArcSort, Vec<ArcSort>),
863    #[error("{1}\nPresort {0} not found.")]
864    PresortNotFound(String, Span),
865    #[error("{}\nFailed to infer a type for: {}", .0.span(), .0)]
866    InferenceFailure(Expr),
867    #[error("{1}\nVariable {0} was already defined")]
868    AlreadyDefined(String, Span),
869    #[error("{1}\nThe output type of constructor function {0} must be sort")]
870    ConstructorOutputNotSort(String, Span),
871    #[error("{1}\nValue lookup of non-constructor function {0} in rule is disallowed.")]
872    LookupInRuleDisallowed(String, Span),
873    #[error("{1}\nCannot set constructor {0}. Use `union` instead or declare {0} as a function.")]
874    SetConstructorDisallowed(String, Span),
875    #[error("All alternative definitions considered failed\n{}", .0.iter().map(|e| format!("  {e}\n")).collect::<Vec<_>>().join(""))]
876    AllAlternativeFailed(Vec<TypeError>),
877    #[error("{}\nCannot union values of sort {}", .1, .0.name())]
878    NonEqsortUnion(ArcSort, Span),
879    #[error("{}\nCannot union values of sort {} because it is marked as non-unionable (e.g. from a relation)", .1, .0.name())]
880    NonUnionableSort(ArcSort, Span),
881    #[error(
882        "{1}\nView table {0} with :term-constructor must have at least one input (the e-class)."
883    )]
884    TermConstructorNoInputs(String, Span),
885    #[error(
886        "{span}\nNon-global variable `{name}` must not start with `{}`.",
887        crate::GLOBAL_NAME_PREFIX
888    )]
889    NonGlobalPrefixed { name: String, span: Span },
890    #[error(
891        "{span}\nGlobal `{name}` must start with `{}`.",
892        crate::GLOBAL_NAME_PREFIX
893    )]
894    GlobalMissingPrefix { name: String, span: Span },
895}
896
897#[cfg(test)]
898mod test {
899    use crate::{EGraph, Error, typechecking::TypeError};
900
901    #[test]
902    fn test_arity_mismatch() {
903        let mut egraph = EGraph::default();
904
905        let prog = "
906            (relation f (i64 i64))
907            (rule ((f a b c)) ())
908       ";
909        let res = egraph.parse_and_run_program(None, prog);
910        match res {
911            Err(Error::TypeError(TypeError::Arity {
912                expected: 2,
913                expr: e,
914            })) => {
915                assert_eq!(e.span().string(), "(f a b c)");
916            }
917            _ => panic!("Expected arity mismatch, got: {res:?}"),
918        }
919    }
920}