egglog/
constraint.rs

1use crate::{
2    core::{
3        Atom, CoreAction, CoreRule, GenericCoreActions, GenericCoreRule, HeadOrEq, Query,
4        StringOrEq,
5    },
6    *,
7};
8use std::{cmp, rc::Rc};
9// Use immutable hashmap for performance
10// cloning assignments is common and O(1) with immutable hashmap
11use egglog_ast::generic_ast::{GenericAction, GenericActions, GenericExpr, GenericFact};
12use egglog_ast::span::Span;
13use im_rc::HashMap;
14use std::{fmt::Debug, iter::once, mem::swap};
15
16/// Represents constraints that are logically impossible to satisfy.
17/// These are used to signal type errors during constraint solving.
18#[derive(Clone, Debug)]
19pub enum ImpossibleConstraint {
20    ArityMismatch {
21        atom: Atom<String>,
22        // The expected arity for this atom
23        expected: usize,
24    },
25    FunctionMismatch {
26        expected_output: ArcSort,
27        expected_input: Vec<ArcSort>,
28        actual_output: ArcSort,
29        actual_input: Vec<ArcSort>,
30    },
31}
32
33/// A constraint that can be applied to variable assignments.
34/// Constraints are used in type inference to represent relationships between variables and values.
35pub trait Constraint<Var, Value>: dyn_clone::DynClone {
36    /// Updates the assignment based on this constraint.
37    /// Returns Ok(true) if the assignment was modified, Ok(false) if no changes were made,
38    /// or Err if the constraint cannot be satisfied.
39    ///
40    /// `update` is allowed to modify the constraint itself, e.g. to convert a delayed constraint into an immediate one.
41    /// The `key` function gets a string representation of the value for display.
42    fn update(
43        &mut self,
44        assignment: &mut Assignment<Var, Value>,
45        key: fn(&Value) -> &str,
46    ) -> Result<bool, ConstraintError<Var, Value>>;
47
48    /// Returns a human-readable string representation of this constraint.
49    fn pretty(&self) -> String;
50}
51
52dyn_clone::clone_trait_object!(<Var, Value> Constraint<Var, Value>);
53
54/// Creates an equality constraint between two variables.
55/// If one of the variable has a known value, the constraint propagates value to the other variable.
56/// If both variables have known but different values, the constraint fails.
57pub fn eq<Var, Value>(x: Var, y: Var) -> Box<dyn Constraint<Var, Value>>
58where
59    Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
60    Value: Clone + Debug + 'static,
61{
62    Box::new(Eq(x, y))
63}
64
65/// Creates an assignment constraint that binds a variable to a specific value.
66/// The constraint fails if the variable is already assigned to a different value.
67pub fn assign<Var, Value>(x: Var, v: Value) -> Box<dyn Constraint<Var, Value>>
68where
69    Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
70    Value: Clone + Debug + 'static,
71{
72    Box::new(Assign(x, v))
73}
74
75/// Creates a conjunction constraint that requires all sub-constraints to be satisfied.
76pub fn and<Var, Value>(cs: Vec<Box<dyn Constraint<Var, Value>>>) -> Box<dyn Constraint<Var, Value>>
77where
78    Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
79    Value: Clone + Debug + 'static,
80{
81    Box::new(And(cs))
82}
83
84/// Creates an exclusive-or constraint that requires exactly one sub-constraint to be satisfied.
85/// The constraint proceeds if exactly one sub-constraint can be satisfied and all others lead to failure.
86/// The constraint fails if zero sub-constraints can be satisfied.
87pub fn xor<Var, Value>(cs: Vec<Box<dyn Constraint<Var, Value>>>) -> Box<dyn Constraint<Var, Value>>
88where
89    Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
90    Value: Clone + Debug + 'static,
91{
92    Box::new(Xor(cs))
93}
94
95/// Creates a constraint that always fails with the given impossible constraint.
96/// This is used to signal type errors during constraint solving.
97pub fn impossible<Var, Value>(constraint: ImpossibleConstraint) -> Box<dyn Constraint<Var, Value>>
98where
99    Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
100    Value: Clone + Debug + 'static,
101{
102    Box::new(Impossible { constraint })
103}
104
105/// Creates an implication constraint that activates when all watch variables are assigned.
106/// The constraint function is called with the values of the watch variables to generate the actual constraint.
107pub fn implies<Var, Value>(
108    name: String,
109    watch_vars: Vec<Var>,
110    constraint: DelayedConstraintFn<Var, Value>,
111) -> Box<dyn Constraint<Var, Value>>
112where
113    Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
114    Value: Clone + Debug + 'static,
115{
116    Box::new(Implies {
117        name,
118        watch_vars,
119        constraint: DelayedConstraint::Delayed(constraint),
120    })
121}
122
123pub type DelayedConstraintFn<Var, Value> = Rc<dyn Fn(&[&Value]) -> Box<dyn Constraint<Var, Value>>>;
124
125#[derive(Clone)]
126enum DelayedConstraint<Var, Value> {
127    Delayed(DelayedConstraintFn<Var, Value>),
128    Constraint(Box<dyn Constraint<Var, Value>>),
129}
130
131#[derive(Clone)]
132struct Implies<Var, Value> {
133    name: String,
134    watch_vars: Vec<Var>,
135    constraint: DelayedConstraint<Var, Value>,
136}
137
138impl<Var, Value> Constraint<Var, Value> for Implies<Var, Value>
139where
140    Var: cmp::Eq + PartialEq + Hash + Clone + Debug,
141    Value: Clone + Debug,
142{
143    fn update(
144        &mut self,
145        assignment: &mut Assignment<Var, Value>,
146        key: fn(&Value) -> &str,
147    ) -> Result<bool, ConstraintError<Var, Value>> {
148        let mut updated = false;
149        // If the constraint is delayed, either make it immediate or return.
150        if let DelayedConstraint::Delayed(delayed) = &self.constraint {
151            let watch_vals: Option<Vec<&Value>> =
152                self.watch_vars.iter().map(|v| assignment.get(v)).collect();
153            let Some(watch_vals) = watch_vals else {
154                return Ok(false);
155            };
156            let constraint = delayed(&watch_vals);
157            self.constraint = DelayedConstraint::Constraint(constraint);
158            updated = true;
159        };
160
161        // The constraint must be immediate now.
162        let DelayedConstraint::Constraint(constraint) = &mut self.constraint else {
163            unreachable!("update");
164        };
165        updated |= constraint.update(assignment, key)?;
166        Ok(updated)
167    }
168
169    fn pretty(&self) -> String {
170        let vars: String = self
171            .watch_vars
172            .iter()
173            .map(|v| format!("{:?}", v))
174            .collect::<Vec<_>>()
175            .join(", ");
176        format!("{} => {}({})", vars, self.name, vars)
177    }
178}
179
180#[derive(Clone)]
181struct Eq<Var>(Var, Var);
182
183impl<Var, Value> Constraint<Var, Value> for Eq<Var>
184where
185    Var: cmp::Eq + PartialEq + Hash + Clone + Debug,
186    Value: Clone + Debug,
187{
188    fn update(
189        &mut self,
190        assignment: &mut Assignment<Var, Value>,
191        key: fn(&Value) -> &str,
192    ) -> Result<bool, ConstraintError<Var, Value>> {
193        match (assignment.0.get(&self.0), assignment.0.get(&self.1)) {
194            (Some(value), None) => {
195                assignment.insert(self.1.clone(), value.clone());
196                Ok(true)
197            }
198            (None, Some(value)) => {
199                assignment.insert(self.0.clone(), value.clone());
200                Ok(true)
201            }
202            (Some(v1), Some(v2)) => {
203                if key(v1) == key(v2) {
204                    Ok(false)
205                } else {
206                    Err(ConstraintError::InconsistentConstraint(
207                        self.0.clone(),
208                        v1.clone(),
209                        v2.clone(),
210                    ))
211                }
212            }
213            (None, None) => Ok(false),
214        }
215    }
216
217    fn pretty(&self) -> String {
218        format!("{:?} = {:?}", self.0, self.1)
219    }
220}
221
222#[derive(Clone)]
223struct Assign<Var, Value>(Var, Value);
224
225impl<Var, Value> Constraint<Var, Value> for Assign<Var, Value>
226where
227    Var: cmp::Eq + PartialEq + Hash + Clone + Debug,
228    Value: Clone + Debug,
229{
230    fn update(
231        &mut self,
232        assignment: &mut Assignment<Var, Value>,
233        key: fn(&Value) -> &str,
234    ) -> Result<bool, ConstraintError<Var, Value>> {
235        match assignment.0.get(&self.0) {
236            None => {
237                assignment.insert(self.0.clone(), self.1.clone());
238                Ok(true)
239            }
240            Some(value) => {
241                if key(value) == key(&self.1) {
242                    Ok(false)
243                } else {
244                    Err(ConstraintError::InconsistentConstraint(
245                        self.0.clone(),
246                        self.1.clone(),
247                        value.clone(),
248                    ))
249                }
250            }
251        }
252    }
253
254    fn pretty(&self) -> String {
255        format!("{:?} = {:?}", self.0, self.1)
256    }
257}
258
259#[derive(Clone)]
260struct And<Var, Value>(Vec<Box<dyn Constraint<Var, Value>>>);
261
262impl<Var, Value> Constraint<Var, Value> for And<Var, Value>
263where
264    Var: cmp::Eq + PartialEq + Hash + Clone + Debug,
265    Value: Clone + Debug,
266{
267    fn update(
268        &mut self,
269        assignment: &mut Assignment<Var, Value>,
270        key: fn(&Value) -> &str,
271    ) -> Result<bool, ConstraintError<Var, Value>> {
272        let orig_assignment = assignment.clone();
273        let mut updated = false;
274        for c in self.0.iter_mut() {
275            match c.update(assignment, key) {
276                Ok(upd) => updated |= upd,
277                Err(error) => {
278                    // In the case of failure,
279                    // we need to restore the assignment
280                    *assignment = orig_assignment;
281                    return Err(error);
282                }
283            }
284        }
285        Ok(updated)
286    }
287
288    fn pretty(&self) -> String {
289        format!(
290            "({})",
291            self.0
292                .iter()
293                .map(|c| c.pretty())
294                .collect::<Vec<_>>()
295                .join(" /\\ ")
296        )
297    }
298}
299
300#[derive(Clone)]
301struct Xor<Var, Value>(Vec<Box<dyn Constraint<Var, Value>>>);
302
303impl<Var, Value> Constraint<Var, Value> for Xor<Var, Value>
304where
305    Var: cmp::Eq + PartialEq + Hash + Clone + Debug,
306    Value: Clone + Debug,
307{
308    fn update(
309        &mut self,
310        assignment: &mut Assignment<Var, Value>,
311        key: fn(&Value) -> &str,
312    ) -> Result<bool, ConstraintError<Var, Value>> {
313        let mut success_count = 0;
314        let orig_assignment = assignment.clone();
315        let orig_cs = self.0.clone();
316        let mut result_assignment = assignment.clone();
317        let mut assignment_updated = false;
318        let mut errors = vec![];
319        let mut result_constraint = None;
320
321        let cs = std::mem::take(&mut self.0);
322        for mut c in cs {
323            let result = c.update(assignment, key);
324            match result {
325                Ok(updated) => {
326                    success_count += 1;
327                    if success_count > 1 {
328                        break;
329                    }
330
331                    result_constraint = Some(c);
332                    if updated {
333                        swap(&mut result_assignment, assignment);
334                    }
335                    assignment_updated = updated;
336                }
337                Err(error) => errors.push(error),
338            }
339        }
340
341        // Success roughly means "the constraint is compatible with the current assignment".
342        //
343        // If update is successful for only one sub constraint, then we have nailed down the only true constraint.
344        // If update is successful for more than one constraint, then Xor succeeds with no updates.
345        // If update fails for every constraint, then Xor fails
346        match success_count.cmp(&1) {
347            std::cmp::Ordering::Equal => {
348                // Prune all other constraints. This is sound since the constraints are monotonic.
349                self.0 = vec![result_constraint.unwrap()];
350                *assignment = result_assignment;
351                Ok(assignment_updated)
352            }
353            std::cmp::Ordering::Greater => {
354                self.0 = orig_cs;
355                *assignment = orig_assignment;
356                Ok(false)
357            }
358            std::cmp::Ordering::Less => {
359                self.0 = orig_cs;
360                *assignment = orig_assignment;
361                Err(ConstraintError::NoConstraintSatisfied(errors))
362            }
363        }
364    }
365
366    fn pretty(&self) -> String {
367        format!(
368            "({})",
369            self.0
370                .iter()
371                .map(|c| c.pretty())
372                .collect::<Vec<_>>()
373                .join(" \\/ ")
374        )
375    }
376}
377
378#[derive(Clone)]
379struct Impossible {
380    constraint: ImpossibleConstraint,
381}
382
383impl<Var, Value> Constraint<Var, Value> for Impossible
384where
385    Var: cmp::Eq + PartialEq + Hash + Clone + Debug,
386    Value: Clone + Debug,
387{
388    fn update(
389        &mut self,
390        _assignment: &mut Assignment<Var, Value>,
391        _key: fn(&Value) -> &str,
392    ) -> Result<bool, ConstraintError<Var, Value>> {
393        Err(ConstraintError::ImpossibleCaseIdentified(
394            self.constraint.clone(),
395        ))
396    }
397
398    fn pretty(&self) -> String {
399        format!("{:?}", self.constraint)
400    }
401}
402
403/// Errors that can occur during constraint solving.
404/// These represent various ways that constraint satisfaction can fail.
405#[derive(Debug)]
406pub enum ConstraintError<Var, Value> {
407    /// A variable was assigned two different, incompatible values
408    InconsistentConstraint(Var, Value, Value),
409    /// A variable in the constraint range was not assigned any value
410    UnconstrainedVar(Var),
411    /// None of the alternative constraints in an XOR constraint could be satisfied
412    NoConstraintSatisfied(Vec<ConstraintError<Var, Value>>),
413    /// An impossible constraint was encountered during solving
414    ImpossibleCaseIdentified(ImpossibleConstraint),
415}
416
417impl ConstraintError<AtomTerm, ArcSort> {
418    /// Converts a [`ConstraintError`] produced by type checking into a type error.
419    pub fn to_type_error(&self) -> TypeError {
420        match &self {
421            ConstraintError::InconsistentConstraint(x, v1, v2) => TypeError::Mismatch {
422                expr: x.to_expr(),
423                expected: v1.clone(),
424                actual: v2.clone(),
425            },
426            ConstraintError::UnconstrainedVar(v) => TypeError::InferenceFailure(v.to_expr()),
427            ConstraintError::NoConstraintSatisfied(constraints) => TypeError::AllAlternativeFailed(
428                constraints.iter().map(|c| c.to_type_error()).collect(),
429            ),
430            ConstraintError::ImpossibleCaseIdentified(ImpossibleConstraint::ArityMismatch {
431                atom,
432                expected,
433            }) => TypeError::Arity {
434                expr: atom.to_expr(),
435                expected: *expected - 1,
436            },
437            ConstraintError::ImpossibleCaseIdentified(ImpossibleConstraint::FunctionMismatch {
438                expected_output,
439                expected_input,
440                actual_output,
441                actual_input,
442            }) => TypeError::FunctionTypeMismatch(
443                expected_output.clone(),
444                expected_input.clone(),
445                actual_output.clone(),
446                actual_input.clone(),
447            ),
448        }
449    }
450}
451
452/// A constraint satisfaction problem consisting of constraints and a range of variables to solve for.
453/// The problem is considered solved when *all* variables in the range are assigned.
454pub struct Problem<Var, Value> {
455    /// The list of constraints that must be satisfied
456    pub constraints: Vec<Box<dyn Constraint<Var, Value>>>,
457    /// The set of variables that must be assigned a value for the problem to be considered solved
458    pub range: HashSet<Var>,
459}
460
461impl Debug for Problem<AtomTerm, ArcSort> {
462    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
463        f.debug_struct("Problem")
464            .field(
465                "constraints",
466                &self
467                    .constraints
468                    .iter()
469                    .map(|c| c.pretty())
470                    .collect::<Vec<_>>(),
471            )
472            .field("range", &self.range)
473            .finish()
474    }
475}
476
477impl<Var, Value> Default for Problem<Var, Value> {
478    fn default() -> Self {
479        Self {
480            constraints: vec![],
481            range: HashSet::default(),
482        }
483    }
484}
485
486/// A mapping from variables to their assigned values.
487/// This is the result of constraint solving.
488/// Uses an immutable HashMap for efficient cloning during constraint solving.
489#[derive(Clone)]
490pub struct Assignment<Var, Value>(pub HashMap<Var, Value>);
491
492impl<Var, Value> Assignment<Var, Value>
493where
494    Var: Hash + cmp::Eq + PartialEq + Clone,
495    Value: Clone,
496{
497    /// Insert into the assignment.
498    pub fn insert(&mut self, var: Var, value: Value) -> Option<Value> {
499        self.0.insert(var, value)
500    }
501
502    /// Get the value from the assignment.
503    pub fn get(&self, var: &Var) -> Option<&Value> {
504        self.0.get(var)
505    }
506}
507
508impl Assignment<AtomTerm, ArcSort> {
509    pub(crate) fn annotate_expr(
510        &self,
511        expr: &GenericExpr<CorrespondingVar<String, String>, String>,
512        typeinfo: &TypeInfo,
513    ) -> ResolvedExpr {
514        match &expr {
515            GenericExpr::Lit(span, literal) => ResolvedExpr::Lit(span.clone(), literal.clone()),
516            GenericExpr::Var(span, var) => {
517                let global_sort = typeinfo.get_global_sort(var);
518                let ty = global_sort
519                    // Span is ignored when looking up atom_terms
520                    .or_else(|| self.get(&AtomTerm::Var(Span::Panic, var.clone())))
521                    .expect("All variables should be assigned before annotation");
522                ResolvedExpr::Var(
523                    span.clone(),
524                    ResolvedVar {
525                        name: var.clone(),
526                        sort: ty.clone(),
527                        is_global_ref: global_sort.is_some(),
528                    },
529                )
530            }
531            GenericExpr::Call(
532                span,
533                CorrespondingVar {
534                    head,
535                    to: corresponding_var,
536                },
537                args,
538            ) => {
539                // get the resolved call using resolve_rule
540                let args: Vec<_> = args
541                    .iter()
542                    .map(|arg| self.annotate_expr(arg, typeinfo))
543                    .collect();
544                let types: Vec<_> = args
545                    .iter()
546                    .map(|arg| arg.output_type())
547                    .chain(once(
548                        self.get(&AtomTerm::Var(span.clone(), corresponding_var.clone()))
549                            .unwrap()
550                            .clone(),
551                    ))
552                    .collect();
553                let resolved_call = ResolvedCall::from_resolution(head, &types, typeinfo);
554                GenericExpr::Call(span.clone(), resolved_call, args)
555            }
556        }
557    }
558
559    pub(crate) fn annotate_fact(
560        &self,
561        facts: &GenericFact<CorrespondingVar<String, String>, String>,
562        typeinfo: &TypeInfo,
563    ) -> ResolvedFact {
564        match facts {
565            GenericFact::Eq(span, e1, e2) => ResolvedFact::Eq(
566                span.clone(),
567                self.annotate_expr(e1, typeinfo),
568                self.annotate_expr(e2, typeinfo),
569            ),
570            GenericFact::Fact(expr) => ResolvedFact::Fact(self.annotate_expr(expr, typeinfo)),
571        }
572    }
573
574    pub(crate) fn annotate_facts(
575        &self,
576        mapped_facts: &[GenericFact<CorrespondingVar<String, String>, String>],
577        typeinfo: &TypeInfo,
578    ) -> Vec<ResolvedFact> {
579        mapped_facts
580            .iter()
581            .map(|fact| self.annotate_fact(fact, typeinfo))
582            .collect()
583    }
584
585    pub(crate) fn annotate_action(
586        &self,
587        action: &MappedAction,
588        typeinfo: &TypeInfo,
589    ) -> Result<ResolvedAction, TypeError> {
590        match action {
591            GenericAction::Let(span, var, expr) => {
592                let ty = self
593                    .get(&AtomTerm::Var(span.clone(), var.clone()))
594                    .expect("All variables should be assigned before annotation");
595                Ok(ResolvedAction::Let(
596                    span.clone(),
597                    ResolvedVar {
598                        name: var.clone(),
599                        sort: ty.clone(),
600                        is_global_ref: false,
601                    },
602                    self.annotate_expr(expr, typeinfo),
603                ))
604            }
605            // Note mapped_var for set is a dummy variable that does not mean anything
606            GenericAction::Set(
607                span,
608                CorrespondingVar {
609                    head,
610                    to: _mapped_var,
611                },
612                children,
613                rhs,
614            ) => {
615                let children: Vec<_> = children
616                    .iter()
617                    .map(|child| self.annotate_expr(child, typeinfo))
618                    .collect();
619                let rhs = self.annotate_expr(rhs, typeinfo);
620                let types: Vec<_> = children
621                    .iter()
622                    .map(|child| child.output_type())
623                    .chain(once(rhs.output_type()))
624                    .collect();
625                let resolved_call = ResolvedCall::from_resolution(head, &types, typeinfo);
626                if !matches!(resolved_call, ResolvedCall::Func(_)) {
627                    return Err(TypeError::UnboundFunction(head.clone(), span.clone()));
628                }
629                Ok(ResolvedAction::Set(
630                    span.clone(),
631                    resolved_call,
632                    children,
633                    rhs,
634                ))
635            }
636            // Note mapped_var for delete is a dummy variable that does not mean anything
637            GenericAction::Change(
638                span,
639                change,
640                CorrespondingVar {
641                    head,
642                    to: _mapped_var,
643                },
644                children,
645            ) => {
646                let children: Vec<_> = children
647                    .iter()
648                    .map(|child| self.annotate_expr(child, typeinfo))
649                    .collect();
650                let types: Vec<_> = children.iter().map(|child| child.output_type()).collect();
651                let resolved_call =
652                    ResolvedCall::from_resolution_func_types(head, &types, typeinfo)
653                        .ok_or_else(|| TypeError::UnboundFunction(head.clone(), span.clone()))?;
654                Ok(ResolvedAction::Change(
655                    span.clone(),
656                    *change,
657                    resolved_call,
658                    children.clone(),
659                ))
660            }
661            GenericAction::Union(span, lhs, rhs) => {
662                let lhs = self.annotate_expr(lhs, typeinfo);
663                let rhs = self.annotate_expr(rhs, typeinfo);
664
665                let sort = lhs.output_type();
666                assert_eq!(sort.name(), rhs.output_type().name());
667                if !sort.is_eq_sort() {
668                    return Err(TypeError::NonEqsortUnion(sort, span.clone()));
669                }
670
671                Ok(ResolvedAction::Union(span.clone(), lhs, rhs))
672            }
673            GenericAction::Panic(span, msg) => Ok(ResolvedAction::Panic(span.clone(), msg.clone())),
674            GenericAction::Expr(span, expr) => Ok(ResolvedAction::Expr(
675                span.clone(),
676                self.annotate_expr(expr, typeinfo),
677            )),
678        }
679    }
680
681    pub(crate) fn annotate_actions(
682        &self,
683        mapped_actions: &GenericActions<CorrespondingVar<String, String>, String>,
684        typeinfo: &TypeInfo,
685    ) -> Result<ResolvedActions, TypeError> {
686        let actions = mapped_actions
687            .iter()
688            .map(|action| self.annotate_action(action, typeinfo))
689            .collect::<Result<_, _>>()?;
690
691        Ok(ResolvedActions::new(actions))
692    }
693}
694
695impl<Var, Value> Problem<Var, Value>
696where
697    Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
698    Value: Clone + Debug + 'static,
699{
700    pub(crate) fn solve(
701        mut self,
702        key: fn(&Value) -> &str,
703    ) -> Result<Assignment<Var, Value>, ConstraintError<Var, Value>> {
704        let mut assignment = Assignment(HashMap::default());
705        let mut changed = true;
706        while changed {
707            changed = false;
708            for constraint in self.constraints.iter_mut() {
709                changed |= constraint.update(&mut assignment, key)?;
710            }
711        }
712
713        for v in self.range.iter() {
714            if !assignment.0.contains_key(v) {
715                return Err(ConstraintError::UnconstrainedVar(v.clone()));
716            }
717        }
718        Ok(assignment)
719    }
720
721    pub(crate) fn add_binding(&mut self, var: Var, clone: Value) {
722        self.constraints.push(constraint::assign(var, clone));
723    }
724}
725
726impl Problem<AtomTerm, ArcSort> {
727    pub(crate) fn add_query(
728        &mut self,
729        query: &Query<StringOrEq, String>,
730        typeinfo: &TypeInfo,
731    ) -> Result<(), TypeError> {
732        self.constraints.extend(query.get_constraints(typeinfo)?);
733        self.range.extend(query.atom_terms());
734        Ok(())
735    }
736
737    pub(crate) fn add_actions(
738        &mut self,
739        actions: &GenericCoreActions<String, String>,
740        typeinfo: &TypeInfo,
741        symbol_gen: &mut SymbolGen,
742    ) -> Result<(), TypeError> {
743        for action in actions.0.iter() {
744            self.constraints
745                .extend(action.get_constraints(typeinfo, symbol_gen)?);
746
747            // bound vars are added to range
748            match action {
749                CoreAction::Let(span, var, _, _) => {
750                    self.range.insert(AtomTerm::Var(span.clone(), var.clone()));
751                }
752                CoreAction::LetAtomTerm(span, v, _) => {
753                    self.range.insert(AtomTerm::Var(span.clone(), v.clone()));
754                }
755                _ => (),
756            }
757        }
758        Ok(())
759    }
760
761    pub(crate) fn add_rule(
762        &mut self,
763        rule: &CoreRule,
764        typeinfo: &TypeInfo,
765        symbol_gen: &mut SymbolGen,
766    ) -> Result<(), TypeError> {
767        let CoreRule {
768            span: _,
769            head,
770            body,
771        } = rule;
772        self.add_query(body, typeinfo)?;
773        self.add_actions(head, typeinfo, symbol_gen)?;
774        Ok(())
775    }
776
777    pub(crate) fn assign_local_var_type(
778        &mut self,
779        var: &str,
780        span: Span,
781        sort: ArcSort,
782    ) -> Result<(), TypeError> {
783        self.add_binding(AtomTerm::Var(span.clone(), var.to_owned()), sort);
784        self.range.insert(AtomTerm::Var(span, var.to_owned()));
785        Ok(())
786    }
787}
788
789impl CoreAction {
790    pub(crate) fn get_constraints(
791        &self,
792        typeinfo: &TypeInfo,
793        symbol_gen: &mut SymbolGen,
794    ) -> Result<Vec<Box<dyn Constraint<AtomTerm, ArcSort>>>, TypeError> {
795        match self {
796            CoreAction::Let(span, symbol, f, args) => {
797                let mut args = args.clone();
798                args.push(AtomTerm::Var(span.clone(), symbol.clone()));
799
800                Ok(get_literal_and_global_constraints(&args, typeinfo)
801                    .chain(get_atom_application_constraints(f, &args, span, typeinfo)?)
802                    .collect())
803            }
804            CoreAction::Set(span, head, args, rhs) => {
805                let mut args = args.clone();
806                args.push(rhs.clone());
807
808                Ok(get_literal_and_global_constraints(&args, typeinfo)
809                    .chain(get_atom_application_constraints(
810                        head, &args, span, typeinfo,
811                    )?)
812                    .collect())
813            }
814            CoreAction::Change(span, _change, head, args) => {
815                let mut args = args.clone();
816                // Add a dummy last output argument
817                let var = symbol_gen.fresh(head);
818                args.push(AtomTerm::Var(span.clone(), var));
819
820                Ok(get_literal_and_global_constraints(&args, typeinfo)
821                    .chain(get_atom_application_constraints(
822                        head, &args, span, typeinfo,
823                    )?)
824                    .collect())
825            }
826            CoreAction::Union(_ann, lhs, rhs) => Ok(get_literal_and_global_constraints(
827                &[lhs.clone(), rhs.clone()],
828                typeinfo,
829            )
830            .chain(once(constraint::eq(lhs.clone(), rhs.clone())))
831            .collect()),
832            CoreAction::Panic(_ann, _) => Ok(vec![]),
833            CoreAction::LetAtomTerm(span, v, at) => {
834                Ok(get_literal_and_global_constraints(&[at.clone()], typeinfo)
835                    .chain(once(constraint::eq(
836                        AtomTerm::Var(span.clone(), v.clone()),
837                        at.clone(),
838                    )))
839                    .collect())
840            }
841        }
842    }
843}
844
845impl Atom<StringOrEq> {
846    pub(crate) fn get_constraints(
847        &self,
848        type_info: &TypeInfo,
849    ) -> Result<Vec<Box<dyn Constraint<AtomTerm, ArcSort>>>, TypeError> {
850        let literal_constraints = get_literal_and_global_constraints(&self.args, type_info);
851        match &self.head {
852            StringOrEq::Eq => {
853                assert_eq!(self.args.len(), 2);
854                let constraints = literal_constraints
855                    .chain(once(constraint::eq(
856                        self.args[0].clone(),
857                        self.args[1].clone(),
858                    )))
859                    .collect();
860                Ok(constraints)
861            }
862            StringOrEq::Head(head) => Ok(literal_constraints
863                .chain(get_atom_application_constraints(
864                    head, &self.args, &self.span, type_info,
865                )?)
866                .collect()),
867        }
868    }
869}
870
871fn get_atom_application_constraints(
872    head: &str,
873    args: &[AtomTerm],
874    span: &Span,
875    type_info: &TypeInfo,
876) -> Result<Vec<Box<dyn Constraint<AtomTerm, ArcSort>>>, TypeError> {
877    // An atom can have potentially different semantics due to polymorphism
878    // e.g. (set-empty) can mean any empty set with some element type.
879    // To handle this, we collect each possible instantiations of an atom
880    // (where each instantiation is a vec of constraints, thus vec of vec)
881    // into `xor_constraints`.
882    // `constraint::xor` means one and only one of the instantiation can hold.
883    let mut xor_constraints: Vec<Vec<Box<dyn Constraint<AtomTerm, ArcSort>>>> = vec![];
884
885    // function atom constraints
886    if let Some(typ) = type_info.get_func_type(head) {
887        let mut constraints = vec![];
888        // arity mismatch
889        if typ.input.len() + 1 != args.len() {
890            constraints.push(constraint::impossible(
891                ImpossibleConstraint::ArityMismatch {
892                    atom: Atom {
893                        span: span.clone(),
894                        head: head.to_owned(),
895                        args: args.to_vec(),
896                    },
897                    expected: typ.input.len() + 1,
898                },
899            ));
900        } else {
901            for (arg_typ, arg) in typ
902                .input
903                .iter()
904                .cloned()
905                .chain(once(typ.output.clone()))
906                .zip(args.iter().cloned())
907            {
908                constraints.push(constraint::assign(arg, arg_typ));
909            }
910        }
911        xor_constraints.push(constraints);
912    }
913
914    // primitive atom constraints
915    if let Some(primitives) = type_info.get_prims(head) {
916        for p in primitives {
917            let constraints = p.0.get_type_constraints(span).get(args, type_info);
918            xor_constraints.push(constraints);
919        }
920    }
921
922    // do literal and global variable constraints first
923    // as they are the most "informative"
924    match xor_constraints.len() {
925        0 => Err(TypeError::UnboundFunction(head.to_owned(), span.clone())),
926        1 => Ok(xor_constraints.pop().unwrap()),
927        _ => Ok(vec![constraint::xor(
928            xor_constraints.into_iter().map(constraint::and).collect(),
929        )]),
930    }
931}
932
933fn get_literal_and_global_constraints<'a>(
934    args: &'a [AtomTerm],
935    type_info: &'a TypeInfo,
936) -> impl Iterator<Item = Box<dyn Constraint<AtomTerm, ArcSort>>> + 'a {
937    args.iter().filter_map(|arg| {
938        match arg {
939            AtomTerm::Var(_, _) => None,
940            // Literal to type constraint
941            AtomTerm::Literal(_, lit) => {
942                let typ = crate::sort::literal_sort(lit);
943                Some(constraint::assign(arg.clone(), typ) as Box<dyn Constraint<AtomTerm, ArcSort>>)
944            }
945            AtomTerm::Global(_, v) => {
946                if let Some(typ) = type_info.get_global_sort(v) {
947                    Some(constraint::assign(arg.clone(), typ.clone()))
948                } else {
949                    panic!("All global variables should be bound before type checking")
950                }
951            }
952        }
953    })
954}
955
956/// A trait for generating type constraints from atom applications.
957/// This is used to create constraints that ensure proper typing of function/primitive applications.
958pub trait TypeConstraint {
959    /// Generates constraints for the given arguments based on this type constraint.
960    /// The constraints ensure that the arguments have compatible types.
961    fn get(
962        &self,
963        arguments: &[AtomTerm],
964        typeinfo: &TypeInfo,
965    ) -> Vec<Box<dyn Constraint<AtomTerm, ArcSort>>>;
966}
967
968/// A type constraint that assigns specific sorts to each argument position.
969/// Constructs a set of `Assign` constraints that fully constrain the type of arguments.
970pub struct SimpleTypeConstraint {
971    name: String,
972    sorts: Vec<ArcSort>,
973    span: Span,
974}
975
976impl SimpleTypeConstraint {
977    /// Constructs a `SimpleTypeConstraint`
978    pub fn new(name: &str, sorts: Vec<ArcSort>, span: Span) -> SimpleTypeConstraint {
979        let name = name.to_owned();
980        SimpleTypeConstraint { name, sorts, span }
981    }
982
983    /// Converts self to a boxed type constraint.
984    pub fn into_box(self) -> Box<dyn TypeConstraint> {
985        Box::new(self)
986    }
987}
988
989impl TypeConstraint for SimpleTypeConstraint {
990    fn get(
991        &self,
992        arguments: &[AtomTerm],
993        _typeinfo: &TypeInfo,
994    ) -> Vec<Box<dyn Constraint<AtomTerm, ArcSort>>> {
995        if arguments.len() != self.sorts.len() {
996            vec![constraint::impossible(
997                ImpossibleConstraint::ArityMismatch {
998                    atom: Atom {
999                        span: self.span.clone(),
1000                        head: self.name.clone(),
1001                        args: arguments.to_vec(),
1002                    },
1003                    expected: self.sorts.len(),
1004                },
1005            )]
1006        } else {
1007            arguments
1008                .iter()
1009                .cloned()
1010                .zip(self.sorts.iter().cloned())
1011                .map(|(arg, sort)| constraint::assign(arg, sort))
1012                .collect()
1013        }
1014    }
1015}
1016
1017/// A type constraint that requires all or some arguments to have the same type.
1018///
1019/// See the `with_all_arguments_sort`, `with_exact_length`, and `with_output_sort` methods
1020/// for configuring the constraint.
1021pub struct AllEqualTypeConstraint {
1022    name: String,
1023    sort: Option<ArcSort>,
1024    exact_length: Option<usize>,
1025    output: Option<ArcSort>,
1026    span: Span,
1027}
1028
1029impl AllEqualTypeConstraint {
1030    /// Creates the `AllEqualTypeConstraint`.
1031    pub fn new(name: &str, span: Span) -> AllEqualTypeConstraint {
1032        AllEqualTypeConstraint {
1033            name: name.to_owned(),
1034            sort: None,
1035            exact_length: None,
1036            output: None,
1037            span,
1038        }
1039    }
1040
1041    /// Converts self into a boxed type constraint.
1042    pub fn into_box(self) -> Box<dyn TypeConstraint> {
1043        Box::new(self)
1044    }
1045
1046    /// Requires all arguments to have the given sort.
1047    /// If `with_output_sort` is not specified, this requirement
1048    /// also applies to the output argument.
1049    pub fn with_all_arguments_sort(mut self, sort: ArcSort) -> Self {
1050        self.sort = Some(sort);
1051        self
1052    }
1053
1054    /// Requires the length of arguments to be `exact_length`.
1055    /// Note this includes both input arguments and output argument.
1056    pub fn with_exact_length(mut self, exact_length: usize) -> Self {
1057        self.exact_length = Some(exact_length);
1058        self
1059    }
1060
1061    /// Requires the output argument to have the given sort.
1062    pub fn with_output_sort(mut self, output_sort: ArcSort) -> Self {
1063        self.output = Some(output_sort);
1064        self
1065    }
1066}
1067
1068impl TypeConstraint for AllEqualTypeConstraint {
1069    fn get(
1070        &self,
1071        mut arguments: &[AtomTerm],
1072        _typeinfo: &TypeInfo,
1073    ) -> Vec<Box<dyn Constraint<AtomTerm, ArcSort>>> {
1074        if arguments.is_empty() {
1075            panic!("all arguments should have length > 0")
1076        }
1077
1078        match self.exact_length {
1079            Some(exact_length) if exact_length != arguments.len() => {
1080                return vec![constraint::impossible(
1081                    ImpossibleConstraint::ArityMismatch {
1082                        atom: Atom {
1083                            span: self.span.clone(),
1084                            head: self.name.clone(),
1085                            args: arguments.to_vec(),
1086                        },
1087                        expected: exact_length,
1088                    },
1089                )];
1090            }
1091            _ => (),
1092        }
1093
1094        let mut constraints = vec![];
1095        if let Some(output) = self.output.clone() {
1096            let (out, inputs) = arguments.split_last().unwrap();
1097            constraints.push(constraint::assign(out.clone(), output));
1098            arguments = inputs;
1099        }
1100
1101        if let Some(sort) = self.sort.clone() {
1102            constraints.extend(
1103                arguments
1104                    .iter()
1105                    .cloned()
1106                    .map(|arg| constraint::assign(arg, sort.clone())),
1107            )
1108        } else if let Some((first, rest)) = arguments.split_first() {
1109            constraints.extend(
1110                rest.iter()
1111                    .cloned()
1112                    .map(|arg| constraint::eq(arg, first.clone())),
1113            );
1114        }
1115        constraints
1116    }
1117}
1118
1119/// Checks that all variables in a rule's body are properly grounded.
1120/// A variable is grounded if it appears in a function call or is equal to a grounded variable.
1121/// This pass happens after type resolution and lowering to core rules, but before canonicalization.
1122pub(crate) fn grounded_check(
1123    rule: &GenericCoreRule<HeadOrEq<ResolvedCall>, ResolvedCall, ResolvedVar>,
1124) -> Result<(), TypeError> {
1125    use crate::core::ResolvedAtomTerm;
1126    let body = &rule.body;
1127
1128    let range = rule
1129        .body
1130        .get_vars()
1131        .into_iter()
1132        .map(|v| ResolvedAtomTerm::Var(rule.span.clone(), v))
1133        .collect();
1134    let mut problem: Problem<ResolvedAtomTerm, ()> = Problem {
1135        constraints: vec![],
1136        range,
1137    };
1138
1139    for atom in body.atoms.iter() {
1140        let mut add_global_and_literal = false;
1141        match &atom.head {
1142            HeadOrEq::Head(ResolvedCall::Func(_)) => {
1143                for arg in atom.args.iter() {
1144                    problem.constraints.push(assign(arg.clone(), ()));
1145                }
1146            }
1147            HeadOrEq::Head(ResolvedCall::Primitive(_)) => {
1148                let (out, inp) = atom.args.split_last().unwrap();
1149                let out = out.clone();
1150                problem.constraints.push(implies(
1151                    format!("grounded_{:?}", out),
1152                    inp.to_vec(),
1153                    Rc::new(move |_| assign(out.clone(), ())),
1154                ));
1155                add_global_and_literal = true;
1156            }
1157            HeadOrEq::Eq => {
1158                assert_eq!(atom.args.len(), 2);
1159                problem
1160                    .constraints
1161                    .push(eq(atom.args[0].clone(), atom.args[1].clone()));
1162                add_global_and_literal = true;
1163            }
1164        }
1165        if add_global_and_literal {
1166            for arg in atom.args.iter() {
1167                match arg {
1168                    ResolvedAtomTerm::Global(..) | ResolvedAtomTerm::Literal(..) => {
1169                        problem.constraints.push(assign(arg.clone(), ()));
1170                    }
1171                    ResolvedAtomTerm::Var(..) => {}
1172                }
1173            }
1174        }
1175    }
1176
1177    let _assignment = problem.solve(|_| "grounded").map_err(|err| match err {
1178        ConstraintError::UnconstrainedVar(ResolvedAtomTerm::Var(span, v)) => {
1179            TypeError::Ungrounded(v.to_string(), span)
1180        }
1181        _ => panic!(
1182            "unexpected constraint error in groundedness check {:?}",
1183            err
1184        ),
1185    })?;
1186
1187    Ok(())
1188}