egglog/
typechecking.rs

1use std::hash::Hasher;
2
3use crate::Context;
4use crate::{
5    core::{CoreActionContext, CoreRule, GenericActionsExt, ResolvedCall},
6    *,
7};
8use ast::{
9    MappedExprExt, ResolvedAction, ResolvedExpr, ResolvedFact, ResolvedRule, ResolvedVar, Rule,
10    RuleEvalMode,
11};
12use core_relations::ExternalFunction;
13use egglog_ast::generic_ast::GenericAction;
14use egglog_bridge::ActionRegistry;
15use enum_map::EnumMap;
16use std::sync::{Arc, RwLock};
17
18// `ExternalFunction` wrapper for `PurePrim`. Holds the primitive
19// directly so the dispatch chain `external_funcs[id].invoke(...)` →
20// `T::apply(...)` is just one vtable hop plus a direct call — no
21// closure indirection that defeats inlining.
22#[derive(Clone)]
23struct PurePrimWrapper<T> {
24    prim: T,
25    /// The call-site [`Context`] this wrapper stamps onto the
26    /// `PureState` before dispatching. `register_per_context` commits
27    /// one wrapper per valid `Context` for the trait, so the
28    /// typechecker's pick at each call site is encoded directly here.
29    ctx: Context,
30}
31
32impl<T: PurePrim + Clone> ExternalFunction for PurePrimWrapper<T> {
33    fn invoke(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
34        self.prim.apply(PureState::wrap(exec_state, self.ctx), args)
35    }
36}
37
38// `ExternalFunction` wrapper for primitives that need the
39// `ActionRegistry` (`ReadPrim`, `WritePrim`, `FullPrim`). One generic
40// over the `Wrap` strategy that knows how to construct the right
41// state type and dispatch to the primitive's `apply`.
42#[derive(Clone)]
43struct RegistryPrimWrapper<T, S> {
44    prim: T,
45    registry: Arc<RwLock<ActionRegistry>>,
46    /// Stamped onto the state wrapper.
47    ctx: Context,
48    _wrap: std::marker::PhantomData<fn() -> S>,
49}
50
51trait RegistryWrap<T>: Clone + Send + Sync {
52    fn invoke(
53        prim: &T,
54        exec_state: &mut ExecutionState,
55        ctx: Context,
56        args: &[Value],
57        registry: &ActionRegistry,
58    ) -> Option<Value>;
59}
60
61#[derive(Clone)]
62struct WrapRead;
63impl<T: ReadPrim> RegistryWrap<T> for WrapRead {
64    #[inline]
65    fn invoke(
66        prim: &T,
67        exec_state: &mut ExecutionState,
68        ctx: Context,
69        args: &[Value],
70        registry: &ActionRegistry,
71    ) -> Option<Value> {
72        prim.apply(ReadState::wrap(exec_state, registry, ctx), args)
73    }
74}
75#[derive(Clone)]
76struct WrapWrite;
77impl<T: WritePrim> RegistryWrap<T> for WrapWrite {
78    #[inline]
79    fn invoke(
80        prim: &T,
81        exec_state: &mut ExecutionState,
82        ctx: Context,
83        args: &[Value],
84        registry: &ActionRegistry,
85    ) -> Option<Value> {
86        prim.apply(WriteState::wrap(exec_state, registry, ctx), args)
87    }
88}
89#[derive(Clone)]
90struct WrapFull;
91impl<T: FullPrim> RegistryWrap<T> for WrapFull {
92    #[inline]
93    fn invoke(
94        prim: &T,
95        exec_state: &mut ExecutionState,
96        ctx: Context,
97        args: &[Value],
98        registry: &ActionRegistry,
99    ) -> Option<Value> {
100        prim.apply(FullState::wrap(exec_state, registry, ctx), args)
101    }
102}
103
104impl<T: Clone + Send + Sync + 'static, S: RegistryWrap<T> + 'static> ExternalFunction
105    for RegistryPrimWrapper<T, S>
106{
107    fn invoke(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
108        let registry = self.registry.read().unwrap();
109        S::invoke(&self.prim, exec_state, self.ctx, args, &registry)
110    }
111}
112
113#[derive(Clone, Debug)]
114pub struct FuncType {
115    pub name: String,
116    pub subtype: FunctionSubtype,
117    pub input: Vec<ArcSort>,
118    pub output: ArcSort,
119}
120
121impl PartialEq for FuncType {
122    fn eq(&self, other: &Self) -> bool {
123        if self.name == other.name
124            && self.subtype == other.subtype
125            && self.output.name() == other.output.name()
126        {
127            if self.input.len() != other.input.len() {
128                return false;
129            }
130            for (a, b) in self.input.iter().zip(other.input.iter()) {
131                if a.name() != b.name() {
132                    return false;
133                }
134            }
135            true
136        } else {
137            false
138        }
139    }
140}
141
142impl Eq for FuncType {}
143
144impl Hash for FuncType {
145    fn hash<H: Hasher>(&self, state: &mut H) {
146        self.name.hash(state);
147        self.subtype.hash(state);
148        self.output.name().hash(state);
149        for inp in &self.input {
150            inp.name().hash(state);
151        }
152    }
153}
154/// Validators take a termdag and arguments (as TermIds) and return
155/// a newly computed TermId if the primitive application is valid,
156/// or None if it is invalid.
157pub type PrimitiveValidator = Arc<dyn Fn(&mut TermDag, &[TermId]) -> Option<TermId> + Send + Sync>;
158
159#[derive(Clone)]
160pub struct PrimitiveWithId {
161    pub(crate) primitive: Arc<dyn Primitive>,
162    pub(crate) validator: Option<PrimitiveValidator>,
163    /// Runtime entrypoints for the contexts this primitive is valid in.
164    /// The primitive definition is stored once, while each context keeps
165    /// its own backend id so higher-order dispatch can still recover the
166    /// application context at runtime.
167    pub(crate) context_ids: EnumMap<Context, Option<ExternalFunctionId>>,
168}
169
170impl PrimitiveWithId {
171    /// Takes the full signature of a primitive (both input and output types).
172    /// Returns whether the primitive is compatible with this signature.
173    pub fn accept(&self, tys: &[Arc<dyn Sort>], typeinfo: &TypeInfo) -> bool {
174        let mut constraints = vec![];
175        let lits: Vec<_> = (0..tys.len())
176            .map(|i| AtomTerm::Literal(Span::Panic, Literal::Int(i as i64)))
177            .collect();
178        for (lit, ty) in lits.iter().zip(tys.iter()) {
179            constraints.push(constraint::assign(lit.clone(), ty.clone()))
180        }
181        constraints.extend(
182            self.primitive
183                .get_type_constraints(&Span::Panic)
184                .get(&lits, typeinfo),
185        );
186        let problem = Problem {
187            constraints,
188            range: HashSet::default(),
189        };
190        problem.solve(|sort| sort.name()).is_ok()
191    }
192
193    /// Returns whether this primitive has a runtime entrypoint for `context`.
194    pub fn is_valid_in_context(&self, context: Context) -> bool {
195        self.context_ids[context].is_some()
196    }
197}
198
199impl Debug for PrimitiveWithId {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        write!(f, "Prim({})", self.primitive.name())
202    }
203}
204
205/// Stores resolved typechecking information.
206#[derive(Clone, Default)]
207pub struct TypeInfo {
208    mksorts: HashMap<String, MkSort>,
209    // TODO(yz): I want to get rid of this as now we have user-defined primitives and constraint based type checking
210    reserved_primitives: HashSet<&'static str>,
211    pub(crate) sorts: HashMap<String, Arc<dyn Sort>>,
212    primitives: HashMap<String, Vec<PrimitiveWithId>>,
213    func_types: HashMap<String, FuncType>,
214    pub(crate) global_sorts: HashMap<String, ArcSort>,
215    /// Sorts that do not allow union (e.g., from `:no-union` sorts or relations).
216    pub(crate) non_unionable_sorts: HashSet<String>,
217}
218
219// These methods need to be on the `EGraph` in order to
220// register sorts and primitives with the backend.
221impl EGraph {
222    /// Add a user-defined sort to the e-graph.
223    ///
224    /// Also look at [`prelude::add_base_sort`] for a convenience method for adding user-defined sorts
225    pub fn add_sort<S: Sort + 'static>(&mut self, sort: S, span: Span) -> Result<(), TypeError> {
226        self.add_arcsort(Arc::new(sort), span)
227    }
228
229    /// Declare a sort. This corresponds to the `sort` keyword in egglog.
230    /// It can either declares a new [`EqSort`] if `presort_and_args` is not provided,
231    /// or an instantiation of a presort (e.g., containers like `Vec`).
232    pub fn declare_sort(
233        &mut self,
234        name: impl Into<String>,
235        presort_and_args: &Option<(String, Vec<Expr>)>,
236        span: Span,
237    ) -> Result<(), TypeError> {
238        let name = name.into();
239        if self.type_info.func_types.contains_key(&name) {
240            return Err(TypeError::FunctionAlreadyBound(name, span));
241        }
242
243        let sort = match presort_and_args {
244            None => Arc::new(EqSort { name }),
245            Some((presort, args)) => {
246                if let Some(mksort) = self.type_info.mksorts.get(presort) {
247                    mksort(&mut self.type_info, name, args)?
248                } else {
249                    return Err(TypeError::PresortNotFound(presort.clone(), span));
250                }
251            }
252        };
253
254        self.add_arcsort(sort, span)
255    }
256
257    /// Add a user-defined sort to the e-graph.
258    pub fn add_arcsort(&mut self, sort: ArcSort, span: Span) -> Result<(), TypeError> {
259        sort.register_type(&mut self.backend);
260
261        let name = sort.name();
262        match self.type_info.sorts.entry(name.to_owned()) {
263            HEntry::Occupied(_) => Err(TypeError::SortAlreadyBound(name.to_owned(), span)),
264            HEntry::Vacant(e) => {
265                e.insert(sort.clone());
266                sort.register_primitives(self);
267                Ok(())
268            }
269        }
270    }
271
272    /// Register a [`PurePrim`]. Pass `None` for the validator if not
273    /// using the proof checker.
274    ///
275    /// Pick the trait whose state wrapper matches the body's needs:
276    /// [`PurePrim`] for pure ops, [`WritePrim`] for writes,
277    /// [`ReadPrim`] for table reads, [`FullPrim`] for both. The Rust
278    /// type checker enforces the body only uses methods the chosen
279    /// state allows.
280    pub fn add_pure_primitive<T>(&mut self, x: T, validator: Option<PrimitiveValidator>)
281    where
282        T: PurePrim + Clone,
283    {
284        self.register_per_context(x, validator, PureState::valid_contexts(), |x, ctx| {
285            Box::new(PurePrimWrapper { prim: x, ctx })
286        });
287    }
288
289    /// Register a [`WritePrim`]. Pass `None` for the validator if not
290    /// using the proof checker.
291    pub fn add_write_primitive<T>(&mut self, x: T, validator: Option<PrimitiveValidator>)
292    where
293        T: WritePrim + Clone,
294    {
295        self.register_registry_primitive::<T, WrapWrite>(
296            x,
297            validator,
298            WriteState::valid_contexts(),
299        );
300    }
301
302    /// Register a [`ReadPrim`]. Pass `None` for the validator if not
303    /// using the proof checker.
304    pub fn add_read_primitive<T>(&mut self, x: T, validator: Option<PrimitiveValidator>)
305    where
306        T: ReadPrim + Clone,
307    {
308        self.register_registry_primitive::<T, WrapRead>(x, validator, ReadState::valid_contexts());
309    }
310
311    /// Register a [`FullPrim`]. Pass `None` for the validator if not
312    /// using the proof checker.
313    pub fn add_full_primitive<T>(&mut self, x: T, validator: Option<PrimitiveValidator>)
314    where
315        T: FullPrim + Clone,
316    {
317        self.register_registry_primitive::<T, WrapFull>(x, validator, FullState::valid_contexts());
318    }
319
320    fn register_registry_primitive<T, S>(
321        &mut self,
322        x: T,
323        validator: Option<PrimitiveValidator>,
324        valid_ctxs: &[Context],
325    ) where
326        T: Primitive + Clone,
327        S: RegistryWrap<T> + 'static,
328    {
329        let registry = self.backend.action_registry().clone();
330        self.register_per_context(x, validator, valid_ctxs, move |x, ctx| {
331            Box::new(RegistryPrimWrapper::<T, S> {
332                prim: x,
333                registry: registry.clone(),
334                ctx,
335                _wrap: std::marker::PhantomData,
336            })
337        });
338    }
339
340    /// Shared registration engine. Stores one primitive definition, plus
341    /// one runtime id per valid [`Context`]. Each wrapper carries its
342    /// specific context stamped onto the state wrapper at invoke time.
343    ///
344    /// The typechecker filters by the context-id mask at each call site;
345    /// an `unstable-fn` value built around the primitive bakes *all*
346    /// signature-matching context ids, and `FunctionContainer::apply`
347    /// picks the one whose context matches the application ctx — so
348    /// values flow freely across contexts.
349    fn register_per_context<T, F>(
350        &mut self,
351        x: T,
352        validator: Option<PrimitiveValidator>,
353        valid_ctxs: &[Context],
354        mut build_wrapper: F,
355    ) where
356        T: Primitive + Clone,
357        F: FnMut(T, Context) -> Box<dyn ExternalFunction>,
358    {
359        let primitive: Arc<dyn Primitive> = Arc::new(x.clone());
360        let name = primitive.name().to_owned();
361        let context_ids = EnumMap::from_fn(|ctx| {
362            valid_ctxs.contains(&ctx).then(|| {
363                self.backend
364                    .register_external_func(build_wrapper(x.clone(), ctx))
365            })
366        });
367        self.type_info
368            .primitives
369            .entry(name)
370            .or_default()
371            .push(PrimitiveWithId {
372                primitive,
373                validator,
374                context_ids,
375            });
376    }
377}
378
379impl EGraph {
380    pub(crate) fn typecheck_program(
381        &mut self,
382        program: &Vec<NCommand>,
383    ) -> Result<Vec<ResolvedNCommand>, TypeError> {
384        let mut result = vec![];
385        for command in program {
386            result.push(self.typecheck_command(command)?);
387        }
388        Ok(result)
389    }
390
391    fn typecheck_command(&mut self, command: &NCommand) -> Result<ResolvedNCommand, TypeError> {
392        let symbol_gen = &mut self.parser.symbol_gen;
393
394        let command: ResolvedNCommand = match command {
395            NCommand::Function(fdecl) => {
396                let resolved = self.type_info.typecheck_function(symbol_gen, fdecl)?;
397                // If this is a let binding, add it to global_sorts
398                // This preserves bahavior for lets after desugaring
399                if resolved.internal_let {
400                    let output_sort = self.type_info.sorts.get(&fdecl.schema.output).unwrap();
401                    self.type_info
402                        .global_sorts
403                        .insert(fdecl.name.clone(), output_sort.clone());
404                }
405                ResolvedNCommand::Function(resolved)
406            }
407            NCommand::NormRule { rule } => ResolvedNCommand::NormRule {
408                rule: self
409                    .type_info
410                    .typecheck_rule(symbol_gen, rule, self.seminaive)?,
411            },
412            NCommand::Sort {
413                span,
414                name,
415                presort_and_args,
416                uf,
417                proof_func,
418                unionable,
419            } => {
420                // Note this is bad since typechecking should be pure and idempotent
421                // Otherwise typechecking the same program twice will fail
422                self.declare_sort(name.clone(), presort_and_args, span.clone())?;
423                // Mark as non-unionable if the sort declaration says so
424                if !unionable {
425                    self.type_info.non_unionable_sorts.insert(name.clone());
426                }
427                ResolvedNCommand::Sort {
428                    span: span.clone(),
429                    name: name.clone(),
430                    presort_and_args: presort_and_args.clone(),
431                    uf: uf.clone(),
432                    proof_func: proof_func.clone(),
433                    unionable: *unionable,
434                }
435            }
436            NCommand::CoreAction(action @ Action::Let(span, var, _)) => {
437                let action = self.type_info.typecheck_standalone_action(
438                    symbol_gen,
439                    action,
440                    &Default::default(),
441                    Context::Full,
442                )?;
443                self.ensure_global_name_prefix(span, var)?;
444                let ResolvedAction::Let(_, resolved_var, _) = &action else {
445                    unreachable!("typechecking an Action::Let should return ResolvedAction::Let")
446                };
447                self.type_info
448                    .global_sorts
449                    .insert(resolved_var.name.clone(), resolved_var.sort.clone());
450                ResolvedNCommand::CoreAction(action)
451            }
452            NCommand::CoreAction(action) => {
453                ResolvedNCommand::CoreAction(self.type_info.typecheck_standalone_action(
454                    symbol_gen,
455                    action,
456                    &Default::default(),
457                    Context::Full,
458                )?)
459            }
460            NCommand::Extract(span, expr, variants) => {
461                let res_expr = self.type_info.typecheck_standalone_expr(
462                    symbol_gen,
463                    expr,
464                    &Default::default(),
465                    Context::Full,
466                )?;
467
468                let res_variants = self.type_info.typecheck_standalone_expr(
469                    symbol_gen,
470                    variants,
471                    &Default::default(),
472                    Context::Full,
473                )?;
474                if res_variants.output_type().name() != I64Sort.name() {
475                    return Err(TypeError::Mismatch {
476                        expr: variants.clone(),
477                        expected: I64Sort.to_arcsort(),
478                        actual: res_variants.output_type(),
479                    });
480                }
481
482                ResolvedNCommand::Extract(span.clone(), res_expr, res_variants)
483            }
484            NCommand::Check(span, facts) => ResolvedNCommand::Check(
485                span.clone(),
486                self.type_info.typecheck_facts(symbol_gen, facts)?,
487            ),
488            NCommand::Fail(span, cmd) => {
489                ResolvedNCommand::Fail(span.clone(), Box::new(self.typecheck_command(cmd)?))
490            }
491            NCommand::RunSchedule(schedule) => ResolvedNCommand::RunSchedule(
492                self.type_info.typecheck_schedule(symbol_gen, schedule)?,
493            ),
494            NCommand::Pop(span, n) => ResolvedNCommand::Pop(span.clone(), *n),
495            NCommand::Push(n) => ResolvedNCommand::Push(*n),
496            NCommand::AddRuleset(span, ruleset) => {
497                ResolvedNCommand::AddRuleset(span.clone(), ruleset.clone())
498            }
499            NCommand::UnstableCombinedRuleset(span, name, sub_rulesets) => {
500                ResolvedNCommand::UnstableCombinedRuleset(
501                    span.clone(),
502                    name.clone(),
503                    sub_rulesets.clone(),
504                )
505            }
506            NCommand::PrintOverallStatistics(span, file) => {
507                ResolvedNCommand::PrintOverallStatistics(span.clone(), file.clone())
508            }
509            NCommand::PrintFunction(span, table, size, file, mode) => {
510                ResolvedNCommand::PrintFunction(
511                    span.clone(),
512                    table.clone(),
513                    *size,
514                    file.clone(),
515                    *mode,
516                )
517            }
518            NCommand::PrintSize(span, n) => {
519                // Should probably also resolve the function symbol here
520                ResolvedNCommand::PrintSize(span.clone(), n.clone())
521            }
522            NCommand::ProveExists(span, constructor) => {
523                let func_type = self
524                    .type_info
525                    .get_func_type(constructor)
526                    .ok_or_else(|| TypeError::UnboundFunction(constructor.clone(), span.clone()))?;
527                if func_type.subtype != FunctionSubtype::Constructor {
528                    return Err(TypeError::ProveExistsRequiresConstructor(
529                        constructor.clone(),
530                        span.clone(),
531                    ));
532                }
533                ResolvedNCommand::ProveExists(span.clone(), ResolvedCall::Func(func_type.clone()))
534            }
535            NCommand::Output { span, file, exprs } => {
536                let exprs = exprs
537                    .iter()
538                    .map(|expr| {
539                        self.type_info.typecheck_standalone_expr(
540                            symbol_gen,
541                            expr,
542                            &Default::default(),
543                            Context::Full,
544                        )
545                    })
546                    .collect::<Result<Vec<_>, _>>()?;
547                ResolvedNCommand::Output {
548                    span: span.clone(),
549                    file: file.clone(),
550                    exprs,
551                }
552            }
553            NCommand::Input { span, name, file } => ResolvedNCommand::Input {
554                span: span.clone(),
555                name: name.clone(),
556                file: file.clone(),
557            },
558            NCommand::UserDefined(span, name, exprs) => {
559                ResolvedNCommand::UserDefined(span.clone(), name.clone(), exprs.clone())
560            }
561        };
562        if let ResolvedNCommand::NormRule { rule } = &command {
563            self.warn_for_prefixed_non_globals_in_rule(rule)?;
564        }
565        Ok(command)
566    }
567
568    fn warn_for_prefixed_non_globals_in_var(
569        &mut self,
570        span: &Span,
571        var: &ResolvedVar,
572    ) -> Result<(), TypeError> {
573        if var.is_global_ref {
574            return Ok(());
575        }
576        if var.name.starts_with(crate::GLOBAL_NAME_PREFIX) {
577            self.warn_prefixed_non_globals(span, &var.name)?;
578        }
579        Ok(())
580    }
581
582    fn warn_for_prefixed_non_globals_in_rule(
583        &mut self,
584        rule: &ResolvedRule,
585    ) -> Result<(), TypeError> {
586        let mut res: Result<(), TypeError> = Ok(());
587
588        for fact in &rule.body {
589            fact.visit_vars(&mut |span, var| {
590                if res.is_ok() {
591                    res = self.warn_for_prefixed_non_globals_in_var(span, var);
592                }
593            });
594        }
595
596        rule.head.visit_vars(&mut |span, var| {
597            if res.is_ok() {
598                res = self.warn_for_prefixed_non_globals_in_var(span, var);
599            }
600        });
601        res
602    }
603}
604
605impl TypeInfo {
606    /// Adds a sort constructor to the typechecker's known set of types.
607    pub fn add_presort<S: Presort>(&mut self, span: Span) -> Result<(), TypeError> {
608        let name = S::presort_name();
609        match self.mksorts.entry(name.to_owned()) {
610            HEntry::Occupied(_) => Err(TypeError::SortAlreadyBound(name.to_owned(), span)),
611            HEntry::Vacant(e) => {
612                e.insert(S::make_sort);
613                self.reserved_primitives.extend(S::reserved_primitives());
614                Ok(())
615            }
616        }
617    }
618
619    /// Returns all sorts that satisfy the type and predicate.
620    pub fn get_sorts_by<S: Sort>(&self, pred: impl Fn(&Arc<S>) -> bool) -> Vec<Arc<S>> {
621        let mut results = Vec::new();
622        for sort in self.sorts.values() {
623            let sort = sort.clone().as_arc_any();
624            if let Ok(sort) = Arc::downcast(sort)
625                && pred(&sort)
626            {
627                results.push(sort);
628            }
629        }
630        results
631    }
632
633    /// Returns all sorts based on the type.
634    pub fn get_sorts<S: Sort>(&self) -> Vec<Arc<S>> {
635        self.get_sorts_by(|_| true)
636    }
637
638    /// Returns a sort that satisfies the type and predicate.
639    pub fn get_sort_by<S: Sort>(&self, pred: impl Fn(&Arc<S>) -> bool) -> Arc<S> {
640        let results = self.get_sorts_by(pred);
641        assert_eq!(
642            results.len(),
643            1,
644            "Expected exactly one sort for type {}",
645            std::any::type_name::<S>()
646        );
647        results.into_iter().next().unwrap()
648    }
649
650    /// Returns a sort based on the type.
651    pub fn get_sort<S: Sort>(&self) -> Arc<S> {
652        self.get_sort_by(|_| true)
653    }
654
655    /// Returns all sorts that satisfy the predicate.
656    pub fn get_arcsorts_by(&self, f: impl Fn(&ArcSort) -> bool) -> Vec<ArcSort> {
657        self.sorts.values().filter(|&x| f(x)).cloned().collect()
658    }
659
660    /// Returns a sort based on the predicate.
661    pub fn get_arcsort_by(&self, f: impl Fn(&ArcSort) -> bool) -> ArcSort {
662        let results = self.get_arcsorts_by(f);
663        assert_eq!(
664            results.len(),
665            1,
666            "Expected exactly one sort matching the given predicate"
667        );
668        results.into_iter().next().unwrap()
669    }
670
671    /// Returns the unique sort whose runtime values have Rust type `T`.
672    pub fn get_arcsort_for_value_type<T: 'static>(&self) -> ArcSort {
673        let results = self.get_arcsorts_by(|s| s.value_type() == Some(std::any::TypeId::of::<T>()));
674        assert_eq!(
675            results.len(),
676            1,
677            "Expected exactly one sort for type `{}`",
678            std::any::type_name::<T>()
679        );
680        results.into_iter().next().unwrap()
681    }
682
683    /// Check if a sort allows union operations.
684    /// A sort is unionable if it's an eq_sort and not marked as non-unionable
685    /// (e.g., from `(sort Foo :no-union)` or relation desugaring).
686    pub fn is_sort_unionable(&self, sort: &ArcSort) -> bool {
687        sort.is_eq_sort() && !self.non_unionable_sorts.contains(sort.name())
688    }
689
690    fn function_to_functype(&self, func: &FunctionDecl) -> Result<FuncType, TypeError> {
691        let input = func
692            .schema
693            .input
694            .iter()
695            .map(|name| {
696                if let Some(sort) = self.sorts.get(name) {
697                    Ok(sort.clone())
698                } else {
699                    Err(TypeError::UndefinedSort(name.clone(), func.span.clone()))
700                }
701            })
702            .collect::<Result<Vec<_>, _>>()?;
703        let output = if let Some(sort) = self.sorts.get(&func.schema.output) {
704            Ok(sort.clone())
705        } else {
706            Err(TypeError::UndefinedSort(
707                func.schema.output.clone(),
708                func.span.clone(),
709            ))
710        }?;
711
712        Ok(FuncType {
713            name: func.name.clone(),
714            subtype: func.subtype,
715            input,
716            output: output.clone(),
717        })
718    }
719
720    fn typecheck_function(
721        &mut self,
722        symbol_gen: &mut SymbolGen,
723        fdecl: &FunctionDecl,
724    ) -> Result<ResolvedFunctionDecl, TypeError> {
725        if self.sorts.contains_key(&fdecl.name) {
726            return Err(TypeError::SortAlreadyBound(
727                fdecl.name.clone(),
728                fdecl.span.clone(),
729            ));
730        }
731        if self.is_primitive(&fdecl.name) {
732            return Err(TypeError::PrimitiveAlreadyBound(
733                fdecl.name.clone(),
734                fdecl.span.clone(),
735            ));
736        }
737        // View tables (with term_constructor) must have at least one input (the e-class)
738        if fdecl.term_constructor.is_some() && fdecl.schema.input.is_empty() {
739            return Err(TypeError::TermConstructorNoInputs(
740                fdecl.name.clone(),
741                fdecl.span.clone(),
742            ));
743        }
744        let ftype = self.function_to_functype(fdecl)?;
745        if self.func_types.insert(fdecl.name.clone(), ftype).is_some() {
746            return Err(TypeError::FunctionAlreadyBound(
747                fdecl.name.clone(),
748                fdecl.span.clone(),
749            ));
750        }
751        let mut bound_vars = IndexMap::default();
752        let output_type = self.sorts.get(&fdecl.schema.output).unwrap();
753        if fdecl.subtype == FunctionSubtype::Constructor && !output_type.is_eq_sort() {
754            return Err(TypeError::ConstructorOutputNotSort(
755                fdecl.name.clone(),
756                fdecl.span.clone(),
757            ));
758        }
759        bound_vars.insert("old", (fdecl.span.clone(), output_type.clone()));
760        bound_vars.insert("new", (fdecl.span.clone(), output_type.clone()));
761
762        Ok(ResolvedFunctionDecl {
763            name: fdecl.name.clone(),
764            subtype: fdecl.subtype,
765            schema: fdecl.schema.clone(),
766            resolved_schema: ResolvedCall::Func(self.func_types.get(&fdecl.name).unwrap().clone()),
767            merge: match &fdecl.merge {
768                // Merge expressions run as part of action-side table updates:
769                // writes are allowed, but live DB reads would be untracked by
770                // seminaive rule execution.
771                Some(merge) => Some(self.typecheck_standalone_expr(
772                    symbol_gen,
773                    merge,
774                    &bound_vars,
775                    Context::Write,
776                )?),
777                None => None,
778            },
779            cost: fdecl.cost,
780            unextractable: fdecl.unextractable,
781            internal_hidden: fdecl.internal_hidden,
782            internal_let: fdecl.internal_let,
783            span: fdecl.span.clone(),
784            term_constructor: fdecl.term_constructor.clone(),
785        })
786    }
787
788    fn typecheck_schedule(
789        &self,
790        symbol_gen: &mut SymbolGen,
791        schedule: &Schedule,
792    ) -> Result<ResolvedSchedule, TypeError> {
793        let schedule = match schedule {
794            Schedule::Repeat(span, times, schedule) => ResolvedSchedule::Repeat(
795                span.clone(),
796                *times,
797                Box::new(self.typecheck_schedule(symbol_gen, schedule)?),
798            ),
799            Schedule::Sequence(span, schedules) => {
800                let schedules = schedules
801                    .iter()
802                    .map(|schedule| self.typecheck_schedule(symbol_gen, schedule))
803                    .collect::<Result<Vec<_>, _>>()?;
804                ResolvedSchedule::Sequence(span.clone(), schedules)
805            }
806            Schedule::Saturate(span, schedule) => ResolvedSchedule::Saturate(
807                span.clone(),
808                Box::new(self.typecheck_schedule(symbol_gen, schedule)?),
809            ),
810            Schedule::Run(span, RunConfig { ruleset, until }) => {
811                let until = until
812                    .as_ref()
813                    .map(|facts| self.typecheck_facts(symbol_gen, facts))
814                    .transpose()?;
815                ResolvedSchedule::Run(
816                    span.clone(),
817                    ResolvedRunConfig {
818                        ruleset: ruleset.clone(),
819                        until,
820                    },
821                )
822            }
823        };
824
825        Result::Ok(schedule)
826    }
827
828    fn typecheck_rule(
829        &self,
830        symbol_gen: &mut SymbolGen,
831        rule: &Rule,
832        global_seminaive: bool,
833    ) -> Result<ResolvedRule, TypeError> {
834        let Rule {
835            span,
836            head,
837            body,
838            name,
839            ruleset,
840            eval_mode,
841            no_decomp,
842            include_subsumed,
843        } = rule;
844        let mut constraints = vec![];
845
846        // Compile with the permissive Read/Full primitive contexts (so the RHS
847        // can read the database) when the whole EGraph is non-seminaive, or the
848        // rule's own mode requires it (`:naive` / `:unsafe-seminaive`).
849        let read_contexts = !global_seminaive
850            || matches!(
851                eval_mode,
852                RuleEvalMode::Naive | RuleEvalMode::UnsafeSeminaive
853            );
854        let (query_ctx, action_ctx) = if read_contexts {
855            (Context::Read, Context::Full)
856        } else {
857            (Context::Pure, Context::Write)
858        };
859
860        let (query, mapped_query) = Facts(body.clone()).to_query(self, symbol_gen);
861        constraints.extend(query.get_constraints(self, query_ctx)?);
862
863        let mut binding = query.get_vars();
864        // We lower to core actions with `union_to_set_optimization`
865        // later in the pipeline. For typechecking we do not need it.
866        let mut ctx = CoreActionContext::new(self, &mut binding, symbol_gen, false);
867        let (actions, mapped_action) = head.to_core_actions(&mut ctx)?;
868
869        let mut problem = Problem::default();
870        problem.add_rule(
871            &CoreRule {
872                span: span.clone(),
873                body: query,
874                head: actions,
875            },
876            self,
877            symbol_gen,
878            query_ctx,
879            action_ctx,
880        )?;
881
882        let assignment = problem
883            .solve(|sort: &ArcSort| sort.name())
884            .map_err(|e| e.to_type_error())?;
885
886        let body: Vec<ResolvedFact> = assignment.annotate_facts(&mapped_query, self, query_ctx);
887        let actions: ResolvedActions =
888            assignment.annotate_actions(&mapped_action, self, action_ctx)?;
889
890        // Function lookups in actions need the `Full` action context; the
891        // `Write` context (`!read_contexts`) can't express them.
892        if !read_contexts {
893            self.check_no_function_lookups_in_actions(&actions)?;
894        }
895
896        Ok(ResolvedRule {
897            span: span.clone(),
898            body,
899            head: actions,
900            name: name.clone(),
901            ruleset: ruleset.clone(),
902            eval_mode: *eval_mode,
903            no_decomp: *no_decomp,
904            include_subsumed: *include_subsumed,
905        })
906    }
907
908    fn check_lookup_expr(&self, expr: &ResolvedExpr) -> Result<(), TypeError> {
909        if let Some(span) = self.expr_has_function_lookup(expr) {
910            return Err(TypeError::LookupInRuleDisallowed(
911                "function".to_string(),
912                span,
913            ));
914        }
915        Ok(())
916    }
917
918    fn check_no_function_lookups_in_actions(
919        &self,
920        actions: &ResolvedActions,
921    ) -> Result<(), TypeError> {
922        for action in actions.iter() {
923            match action {
924                GenericAction::Let(_, _, rhs) => self.check_lookup_expr(rhs)?,
925                GenericAction::Set(_, _, args, rhs) => {
926                    for arg in args.iter() {
927                        self.check_lookup_expr(arg)?;
928                    }
929                    self.check_lookup_expr(rhs)?;
930                }
931                GenericAction::Union(_, lhs, rhs) => {
932                    self.check_lookup_expr(lhs)?;
933                    self.check_lookup_expr(rhs)?;
934                }
935                GenericAction::Change(_, _, _, args) => {
936                    for arg in args.iter() {
937                        self.check_lookup_expr(arg)?;
938                    }
939                }
940                GenericAction::Panic(..) => {}
941                GenericAction::Expr(_, expr) => self.check_lookup_expr(expr)?,
942            }
943        }
944        Ok(())
945    }
946
947    pub fn typecheck_facts(
948        &self,
949        symbol_gen: &mut SymbolGen,
950        facts: &[Fact],
951    ) -> Result<Vec<ResolvedFact>, TypeError> {
952        let (query, mapped_facts) = Facts(facts.to_vec()).to_query(self, symbol_gen);
953        let mut problem = Problem::default();
954        // Top-level query-shaped commands (e.g. `check`) are read-only:
955        // primitives may inspect the database but not write to it.
956        problem.add_query(&query, self, Context::Read)?;
957        let assignment = problem
958            .solve(|sort: &ArcSort| sort.name())
959            .map_err(|e| e.to_type_error())?;
960        let annotated_facts = assignment.annotate_facts(&mapped_facts, self, Context::Read);
961        Ok(annotated_facts)
962    }
963
964    // Standalone expressions/actions use action lowering. Top-level commands
965    // pass `Full`; function `:merge` reuses this path with `Write` because
966    // merge expressions run during table updates.
967    fn typecheck_standalone_actions(
968        &self,
969        symbol_gen: &mut SymbolGen,
970        actions: &Actions,
971        binding: &IndexMap<&str, (Span, ArcSort)>,
972        context: Context,
973    ) -> Result<ResolvedActions, TypeError> {
974        let mut binding_set: IndexSet<String> =
975            binding.keys().copied().map(str::to_string).collect();
976        // We lower to core actions with `union_to_set_optimization`
977        // later in the pipeline. For typechecking we do not need it.
978        let mut ctx = CoreActionContext::new(self, &mut binding_set, symbol_gen, false);
979        let (actions, mapped_action) = actions.to_core_actions(&mut ctx)?;
980        let mut problem = Problem::default();
981
982        problem.add_actions(&actions, self, symbol_gen, context)?;
983
984        // add bindings from the context
985        for (var, (span, sort)) in binding {
986            problem.assign_local_var_type(var, span.clone(), sort.clone())?;
987        }
988
989        let assignment = problem
990            .solve(|sort: &ArcSort| sort.name())
991            .map_err(|e| e.to_type_error())?;
992
993        let annotated_actions = assignment.annotate_actions(&mapped_action, self, context)?;
994        Ok(annotated_actions)
995    }
996
997    fn typecheck_standalone_expr(
998        &self,
999        symbol_gen: &mut SymbolGen,
1000        expr: &Expr,
1001        binding: &IndexMap<&str, (Span, ArcSort)>,
1002        context: Context,
1003    ) -> Result<ResolvedExpr, TypeError> {
1004        let action = Action::Expr(expr.span(), expr.clone());
1005        let typechecked_action =
1006            self.typecheck_standalone_action(symbol_gen, &action, binding, context)?;
1007        match typechecked_action {
1008            ResolvedAction::Expr(_, expr) => Ok(expr),
1009            _ => unreachable!(),
1010        }
1011    }
1012
1013    pub(crate) fn typecheck_expr_with_output(
1014        &self,
1015        symbol_gen: &mut SymbolGen,
1016        expr: &Expr,
1017        binding: &IndexMap<&str, (Span, ArcSort)>,
1018        output_sort: ArcSort,
1019        context: Context,
1020    ) -> Result<ResolvedExpr, TypeError> {
1021        let action = Action::Expr(expr.span(), expr.clone());
1022        let mut binding_set: IndexSet<String> =
1023            binding.keys().copied().map(str::to_string).collect();
1024        let mut ctx = CoreActionContext::new(self, &mut binding_set, symbol_gen, false);
1025        let (actions, mapped_action) = Actions::singleton(action).to_core_actions(&mut ctx)?;
1026        let mut problem = Problem::default();
1027
1028        problem.add_actions(&actions, self, symbol_gen, context)?;
1029
1030        for (var, (span, sort)) in binding {
1031            problem.assign_local_var_type(var, span.clone(), sort.clone())?;
1032        }
1033
1034        let [GenericAction::Expr(_, mapped_expr)] = mapped_action.0.as_slice() else {
1035            unreachable!("typechecking an expression should produce one expression action")
1036        };
1037        let output_atom = mapped_expr.get_corresponding_var_or_lit(self);
1038        problem.add_binding(output_atom, output_sort.clone());
1039
1040        let assignment = problem
1041            .solve(|sort: &ArcSort| sort.name())
1042            .map_err(|e| e.to_type_error())?;
1043
1044        let annotated_actions = assignment.annotate_actions(&mapped_action, self, context)?;
1045        match annotated_actions.0.into_iter().next().unwrap() {
1046            ResolvedAction::Expr(_, resolved_expr) => {
1047                let actual = resolved_expr.output_type();
1048                if actual.name() != output_sort.name() {
1049                    return Err(TypeError::Mismatch {
1050                        expr: expr.clone(),
1051                        expected: output_sort,
1052                        actual,
1053                    });
1054                }
1055                Ok(resolved_expr)
1056            }
1057            _ => unreachable!(),
1058        }
1059    }
1060
1061    fn typecheck_standalone_action(
1062        &self,
1063        symbol_gen: &mut SymbolGen,
1064        action: &Action,
1065        binding: &IndexMap<&str, (Span, ArcSort)>,
1066        context: Context,
1067    ) -> Result<ResolvedAction, TypeError> {
1068        self.typecheck_standalone_actions(
1069            symbol_gen,
1070            &Actions::singleton(action.clone()),
1071            binding,
1072            context,
1073        )
1074        .map(|v| {
1075            assert_eq!(v.len(), 1);
1076            v.0.into_iter().next().unwrap()
1077        })
1078    }
1079
1080    pub fn get_sort_by_name(&self, sym: &str) -> Option<&ArcSort> {
1081        self.sorts.get(sym)
1082    }
1083
1084    pub fn get_prims(&self, sym: &str) -> Option<&[PrimitiveWithId]> {
1085        self.primitives.get(sym).map(Vec::as_slice)
1086    }
1087
1088    pub fn is_primitive(&self, sym: &str) -> bool {
1089        self.primitives.contains_key(sym) || self.reserved_primitives.contains(sym)
1090    }
1091
1092    pub fn primitive_has_validator(&self, id: ExternalFunctionId) -> bool {
1093        self.primitives
1094            .values()
1095            .flat_map(|v| v.iter())
1096            .any(|p| p.context_ids.iter().any(|(_, pid)| *pid == Some(id)) && p.validator.is_some())
1097    }
1098
1099    pub fn get_func_type(&self, sym: &str) -> Option<&FuncType> {
1100        self.func_types.get(sym)
1101    }
1102
1103    pub fn is_constructor(&self, sym: &str) -> bool {
1104        self.func_types
1105            .get(sym)
1106            .is_some_and(|f| f.subtype == FunctionSubtype::Constructor)
1107    }
1108
1109    pub fn get_global_sort(&self, sym: &str) -> Option<&ArcSort> {
1110        self.global_sorts.get(sym)
1111    }
1112
1113    pub fn is_global(&self, sym: &str) -> bool {
1114        self.global_sorts.contains_key(sym)
1115    }
1116
1117    /// Check if an expression contains non-global function lookups (FunctionSubtype::Custom calls).
1118    /// Global function calls are allowed since they get desugared to constructors.
1119    /// Returns Some(span) if a lookup is found, None otherwise.
1120    pub fn expr_has_function_lookup(&self, expr: &ResolvedExpr) -> Option<Span> {
1121        use ast::GenericExpr;
1122
1123        expr.find(&mut |e| {
1124            if let GenericExpr::Call(span, ResolvedCall::Func(func_type), _) = e
1125                && func_type.subtype == FunctionSubtype::Custom
1126                && !self.is_global(&func_type.name)
1127            {
1128                return Some(span.clone());
1129            }
1130            None
1131        })
1132    }
1133}
1134
1135#[derive(Debug, Clone, Error)]
1136pub enum TypeError {
1137    #[error("{}\nArity mismatch, expected {expected} args: {expr}", .expr.span())]
1138    Arity { expr: Expr, expected: usize },
1139    #[error(
1140        "{}\n Expect expression {expr} to have type {}, but get type {}",
1141        .expr.span(), .expected.name(), .actual.name(),
1142    )]
1143    Mismatch {
1144        expr: Expr,
1145        expected: ArcSort,
1146        actual: ArcSort,
1147    },
1148    #[error("{1}\nUnbound symbol {0}")]
1149    Unbound(String, Span),
1150    #[error(
1151        "{1}\nVariable {0} is ungrounded. A variable is grounded when it appears as an argument to a constructor or function in the query, not just under primitives or equalities."
1152    )]
1153    Ungrounded(String, Span),
1154    #[error("{1}\nUndefined sort {0}")]
1155    UndefinedSort(String, Span),
1156    #[error("{1}\nUnbound function {0}")]
1157    UnboundFunction(String, Span),
1158    #[error("{1}\nprove-exists requires constructor function, but {0} is not a constructor")]
1159    ProveExistsRequiresConstructor(String, Span),
1160    #[error("{1}\nFunction already bound {0}")]
1161    FunctionAlreadyBound(String, Span),
1162    #[error("{1}\nSort {0} already declared.")]
1163    SortAlreadyBound(String, Span),
1164    #[error("{1}\nPrimitive {0} already declared.")]
1165    PrimitiveAlreadyBound(String, Span),
1166    #[error("Function type mismatch: expected {} => {}, actual {} => {}", .1.iter().map(|s| s.name().to_string()).collect::<Vec<_>>().join(", "), .0.name(), .3.iter().map(|s| s.name().to_string()).collect::<Vec<_>>().join(", "), .2.name())]
1167    FunctionTypeMismatch(ArcSort, Vec<ArcSort>, ArcSort, Vec<ArcSort>),
1168    #[error("{1}\nPresort {0} not found.")]
1169    PresortNotFound(String, Span),
1170    #[error("{}\nFailed to infer a type for: {}", .0.span(), .0)]
1171    InferenceFailure(Expr),
1172    #[error("{1}\nVariable {0} was already defined")]
1173    AlreadyDefined(String, Span),
1174    #[error("{1}\nThe output type of constructor function {0} must be sort")]
1175    ConstructorOutputNotSort(String, Span),
1176    #[error("{1}\nValue lookup of non-constructor function {0} in rule is disallowed.")]
1177    LookupInRuleDisallowed(String, Span),
1178    #[error("{1}\nCannot set constructor {0}. Use `union` instead or declare {0} as a function.")]
1179    SetConstructorDisallowed(String, Span),
1180    #[error("All alternative definitions considered failed\n{}", .0.iter().map(|e| format!("  {e}\n")).collect::<Vec<_>>().join(""))]
1181    AllAlternativeFailed(Vec<TypeError>),
1182    #[error("{}\nCannot union values of sort {}", .1, .0.name())]
1183    NonEqsortUnion(ArcSort, Span),
1184    #[error("{}\nCannot union values of sort {} because it is marked as non-unionable (e.g. from a relation)", .1, .0.name())]
1185    NonUnionableSort(ArcSort, Span),
1186    #[error(
1187        "{1}\nView table {0} with :internal-term-constructor must have at least one input (the e-class)."
1188    )]
1189    TermConstructorNoInputs(String, Span),
1190    #[error(
1191        "{span}\nNon-global variable `{name}` must not start with `{}`.",
1192        crate::GLOBAL_NAME_PREFIX
1193    )]
1194    NonGlobalPrefixed { name: String, span: Span },
1195    #[error(
1196        "{span}\nGlobal `{name}` must start with `{}`.",
1197        crate::GLOBAL_NAME_PREFIX
1198    )]
1199    GlobalMissingPrefix { name: String, span: Span },
1200}
1201
1202#[cfg(test)]
1203mod test {
1204    use crate::{EGraph, Error, typechecking::TypeError};
1205
1206    #[test]
1207    fn test_arity_mismatch() {
1208        let mut egraph = EGraph::default();
1209
1210        let prog = "
1211            (relation f (i64 i64))
1212            (rule ((f a b c)) ())
1213       ";
1214        let res = egraph.parse_and_run_program(None, prog);
1215        match res {
1216            Err(Error::TypeError(TypeError::Arity {
1217                expected: 2,
1218                expr: e,
1219            })) => {
1220                assert_eq!(e.span().string(), "(f a b c)");
1221            }
1222            _ => panic!("Expected arity mismatch, got: {res:?}"),
1223        }
1224    }
1225}