1use crate::{
2 core::{
3 Atom, CoreAction, CoreRule, GenericCoreActions, GenericCoreRule, HeadOrEq, Query,
4 StringOrEq,
5 },
6 *,
7};
8use std::{cmp, rc::Rc};
9use egglog_ast::generic_ast::{GenericAction, GenericActions, GenericExpr, GenericFact};
12use egglog_ast::span::Span;
13use im_rc::HashMap;
14use std::{fmt::Debug, iter::once, mem::swap};
15
16#[derive(Clone, Debug)]
19pub enum ImpossibleConstraint {
20 ArityMismatch {
21 atom: Atom<String>,
22 expected: usize,
24 },
25 FunctionMismatch {
26 expected_output: ArcSort,
27 expected_input: Vec<ArcSort>,
28 actual_output: ArcSort,
29 actual_input: Vec<ArcSort>,
30 },
31}
32
33pub trait Constraint<Var, Value>: dyn_clone::DynClone {
36 fn update(
43 &mut self,
44 assignment: &mut Assignment<Var, Value>,
45 key: fn(&Value) -> &str,
46 ) -> Result<bool, ConstraintError<Var, Value>>;
47
48 fn pretty(&self) -> String;
50}
51
52dyn_clone::clone_trait_object!(<Var, Value> Constraint<Var, Value>);
53
54pub fn eq<Var, Value>(x: Var, y: Var) -> Box<dyn Constraint<Var, Value>>
58where
59 Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
60 Value: Clone + Debug + 'static,
61{
62 Box::new(Eq(x, y))
63}
64
65pub fn assign<Var, Value>(x: Var, v: Value) -> Box<dyn Constraint<Var, Value>>
68where
69 Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
70 Value: Clone + Debug + 'static,
71{
72 Box::new(Assign(x, v))
73}
74
75pub fn and<Var, Value>(cs: Vec<Box<dyn Constraint<Var, Value>>>) -> Box<dyn Constraint<Var, Value>>
77where
78 Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
79 Value: Clone + Debug + 'static,
80{
81 Box::new(And(cs))
82}
83
84pub fn xor<Var, Value>(cs: Vec<Box<dyn Constraint<Var, Value>>>) -> Box<dyn Constraint<Var, Value>>
88where
89 Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
90 Value: Clone + Debug + 'static,
91{
92 Box::new(Xor(cs))
93}
94
95pub fn impossible<Var, Value>(constraint: ImpossibleConstraint) -> Box<dyn Constraint<Var, Value>>
98where
99 Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
100 Value: Clone + Debug + 'static,
101{
102 Box::new(Impossible { constraint })
103}
104
105pub fn implies<Var, Value>(
108 name: String,
109 watch_vars: Vec<Var>,
110 constraint: DelayedConstraintFn<Var, Value>,
111) -> Box<dyn Constraint<Var, Value>>
112where
113 Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
114 Value: Clone + Debug + 'static,
115{
116 Box::new(Implies {
117 name,
118 watch_vars,
119 constraint: DelayedConstraint::Delayed(constraint),
120 })
121}
122
123pub type DelayedConstraintFn<Var, Value> = Rc<dyn Fn(&[&Value]) -> Box<dyn Constraint<Var, Value>>>;
124
125#[derive(Clone)]
126enum DelayedConstraint<Var, Value> {
127 Delayed(DelayedConstraintFn<Var, Value>),
128 Constraint(Box<dyn Constraint<Var, Value>>),
129}
130
131#[derive(Clone)]
132struct Implies<Var, Value> {
133 name: String,
134 watch_vars: Vec<Var>,
135 constraint: DelayedConstraint<Var, Value>,
136}
137
138impl<Var, Value> Constraint<Var, Value> for Implies<Var, Value>
139where
140 Var: cmp::Eq + PartialEq + Hash + Clone + Debug,
141 Value: Clone + Debug,
142{
143 fn update(
144 &mut self,
145 assignment: &mut Assignment<Var, Value>,
146 key: fn(&Value) -> &str,
147 ) -> Result<bool, ConstraintError<Var, Value>> {
148 let mut updated = false;
149 if let DelayedConstraint::Delayed(delayed) = &self.constraint {
151 let watch_vals: Option<Vec<&Value>> =
152 self.watch_vars.iter().map(|v| assignment.get(v)).collect();
153 let Some(watch_vals) = watch_vals else {
154 return Ok(false);
155 };
156 let constraint = delayed(&watch_vals);
157 self.constraint = DelayedConstraint::Constraint(constraint);
158 updated = true;
159 };
160
161 let DelayedConstraint::Constraint(constraint) = &mut self.constraint else {
163 unreachable!("update");
164 };
165 updated |= constraint.update(assignment, key)?;
166 Ok(updated)
167 }
168
169 fn pretty(&self) -> String {
170 let vars: String = self
171 .watch_vars
172 .iter()
173 .map(|v| format!("{:?}", v))
174 .collect::<Vec<_>>()
175 .join(", ");
176 format!("{} => {}({})", vars, self.name, vars)
177 }
178}
179
180#[derive(Clone)]
181struct Eq<Var>(Var, Var);
182
183impl<Var, Value> Constraint<Var, Value> for Eq<Var>
184where
185 Var: cmp::Eq + PartialEq + Hash + Clone + Debug,
186 Value: Clone + Debug,
187{
188 fn update(
189 &mut self,
190 assignment: &mut Assignment<Var, Value>,
191 key: fn(&Value) -> &str,
192 ) -> Result<bool, ConstraintError<Var, Value>> {
193 match (assignment.0.get(&self.0), assignment.0.get(&self.1)) {
194 (Some(value), None) => {
195 assignment.insert(self.1.clone(), value.clone());
196 Ok(true)
197 }
198 (None, Some(value)) => {
199 assignment.insert(self.0.clone(), value.clone());
200 Ok(true)
201 }
202 (Some(v1), Some(v2)) => {
203 if key(v1) == key(v2) {
204 Ok(false)
205 } else {
206 Err(ConstraintError::InconsistentConstraint(
207 self.0.clone(),
208 v1.clone(),
209 v2.clone(),
210 ))
211 }
212 }
213 (None, None) => Ok(false),
214 }
215 }
216
217 fn pretty(&self) -> String {
218 format!("{:?} = {:?}", self.0, self.1)
219 }
220}
221
222#[derive(Clone)]
223struct Assign<Var, Value>(Var, Value);
224
225impl<Var, Value> Constraint<Var, Value> for Assign<Var, Value>
226where
227 Var: cmp::Eq + PartialEq + Hash + Clone + Debug,
228 Value: Clone + Debug,
229{
230 fn update(
231 &mut self,
232 assignment: &mut Assignment<Var, Value>,
233 key: fn(&Value) -> &str,
234 ) -> Result<bool, ConstraintError<Var, Value>> {
235 match assignment.0.get(&self.0) {
236 None => {
237 assignment.insert(self.0.clone(), self.1.clone());
238 Ok(true)
239 }
240 Some(value) => {
241 if key(value) == key(&self.1) {
242 Ok(false)
243 } else {
244 Err(ConstraintError::InconsistentConstraint(
245 self.0.clone(),
246 self.1.clone(),
247 value.clone(),
248 ))
249 }
250 }
251 }
252 }
253
254 fn pretty(&self) -> String {
255 format!("{:?} = {:?}", self.0, self.1)
256 }
257}
258
259#[derive(Clone)]
260struct And<Var, Value>(Vec<Box<dyn Constraint<Var, Value>>>);
261
262impl<Var, Value> Constraint<Var, Value> for And<Var, Value>
263where
264 Var: cmp::Eq + PartialEq + Hash + Clone + Debug,
265 Value: Clone + Debug,
266{
267 fn update(
268 &mut self,
269 assignment: &mut Assignment<Var, Value>,
270 key: fn(&Value) -> &str,
271 ) -> Result<bool, ConstraintError<Var, Value>> {
272 let orig_assignment = assignment.clone();
273 let mut updated = false;
274 for c in self.0.iter_mut() {
275 match c.update(assignment, key) {
276 Ok(upd) => updated |= upd,
277 Err(error) => {
278 *assignment = orig_assignment;
281 return Err(error);
282 }
283 }
284 }
285 Ok(updated)
286 }
287
288 fn pretty(&self) -> String {
289 format!(
290 "({})",
291 self.0
292 .iter()
293 .map(|c| c.pretty())
294 .collect::<Vec<_>>()
295 .join(" /\\ ")
296 )
297 }
298}
299
300#[derive(Clone)]
301struct Xor<Var, Value>(Vec<Box<dyn Constraint<Var, Value>>>);
302
303impl<Var, Value> Constraint<Var, Value> for Xor<Var, Value>
304where
305 Var: cmp::Eq + PartialEq + Hash + Clone + Debug,
306 Value: Clone + Debug,
307{
308 fn update(
309 &mut self,
310 assignment: &mut Assignment<Var, Value>,
311 key: fn(&Value) -> &str,
312 ) -> Result<bool, ConstraintError<Var, Value>> {
313 let mut success_count = 0;
314 let orig_assignment = assignment.clone();
315 let orig_cs = self.0.clone();
316 let mut result_assignment = assignment.clone();
317 let mut assignment_updated = false;
318 let mut errors = vec![];
319 let mut result_constraint = None;
320
321 let cs = std::mem::take(&mut self.0);
322 for mut c in cs {
323 let result = c.update(assignment, key);
324 match result {
325 Ok(updated) => {
326 success_count += 1;
327 if success_count > 1 {
328 break;
329 }
330
331 result_constraint = Some(c);
332 if updated {
333 swap(&mut result_assignment, assignment);
334 }
335 assignment_updated = updated;
336 }
337 Err(error) => errors.push(error),
338 }
339 }
340
341 match success_count.cmp(&1) {
347 std::cmp::Ordering::Equal => {
348 self.0 = vec![result_constraint.unwrap()];
350 *assignment = result_assignment;
351 Ok(assignment_updated)
352 }
353 std::cmp::Ordering::Greater => {
354 self.0 = orig_cs;
355 *assignment = orig_assignment;
356 Ok(false)
357 }
358 std::cmp::Ordering::Less => {
359 self.0 = orig_cs;
360 *assignment = orig_assignment;
361 Err(ConstraintError::NoConstraintSatisfied(errors))
362 }
363 }
364 }
365
366 fn pretty(&self) -> String {
367 format!(
368 "({})",
369 self.0
370 .iter()
371 .map(|c| c.pretty())
372 .collect::<Vec<_>>()
373 .join(" \\/ ")
374 )
375 }
376}
377
378#[derive(Clone)]
379struct Impossible {
380 constraint: ImpossibleConstraint,
381}
382
383impl<Var, Value> Constraint<Var, Value> for Impossible
384where
385 Var: cmp::Eq + PartialEq + Hash + Clone + Debug,
386 Value: Clone + Debug,
387{
388 fn update(
389 &mut self,
390 _assignment: &mut Assignment<Var, Value>,
391 _key: fn(&Value) -> &str,
392 ) -> Result<bool, ConstraintError<Var, Value>> {
393 Err(ConstraintError::ImpossibleCaseIdentified(
394 self.constraint.clone(),
395 ))
396 }
397
398 fn pretty(&self) -> String {
399 format!("{:?}", self.constraint)
400 }
401}
402
403#[derive(Debug)]
406pub enum ConstraintError<Var, Value> {
407 InconsistentConstraint(Var, Value, Value),
409 UnconstrainedVar(Var),
411 NoConstraintSatisfied(Vec<ConstraintError<Var, Value>>),
413 ImpossibleCaseIdentified(ImpossibleConstraint),
415}
416
417impl ConstraintError<AtomTerm, ArcSort> {
418 pub fn to_type_error(&self) -> TypeError {
420 match &self {
421 ConstraintError::InconsistentConstraint(x, v1, v2) => TypeError::Mismatch {
422 expr: x.to_expr(),
423 expected: v1.clone(),
424 actual: v2.clone(),
425 },
426 ConstraintError::UnconstrainedVar(v) => TypeError::InferenceFailure(v.to_expr()),
427 ConstraintError::NoConstraintSatisfied(constraints) => TypeError::AllAlternativeFailed(
428 constraints.iter().map(|c| c.to_type_error()).collect(),
429 ),
430 ConstraintError::ImpossibleCaseIdentified(ImpossibleConstraint::ArityMismatch {
431 atom,
432 expected,
433 }) => TypeError::Arity {
434 expr: atom.to_expr(),
435 expected: *expected - 1,
436 },
437 ConstraintError::ImpossibleCaseIdentified(ImpossibleConstraint::FunctionMismatch {
438 expected_output,
439 expected_input,
440 actual_output,
441 actual_input,
442 }) => TypeError::FunctionTypeMismatch(
443 expected_output.clone(),
444 expected_input.clone(),
445 actual_output.clone(),
446 actual_input.clone(),
447 ),
448 }
449 }
450}
451
452pub struct Problem<Var, Value> {
455 pub constraints: Vec<Box<dyn Constraint<Var, Value>>>,
457 pub range: HashSet<Var>,
459}
460
461impl Debug for Problem<AtomTerm, ArcSort> {
462 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
463 f.debug_struct("Problem")
464 .field(
465 "constraints",
466 &self
467 .constraints
468 .iter()
469 .map(|c| c.pretty())
470 .collect::<Vec<_>>(),
471 )
472 .field("range", &self.range)
473 .finish()
474 }
475}
476
477impl<Var, Value> Default for Problem<Var, Value> {
478 fn default() -> Self {
479 Self {
480 constraints: vec![],
481 range: HashSet::default(),
482 }
483 }
484}
485
486#[derive(Clone)]
490pub struct Assignment<Var, Value>(pub HashMap<Var, Value>);
491
492impl<Var, Value> Assignment<Var, Value>
493where
494 Var: Hash + cmp::Eq + PartialEq + Clone,
495 Value: Clone,
496{
497 pub fn insert(&mut self, var: Var, value: Value) -> Option<Value> {
499 self.0.insert(var, value)
500 }
501
502 pub fn get(&self, var: &Var) -> Option<&Value> {
504 self.0.get(var)
505 }
506}
507
508impl Assignment<AtomTerm, ArcSort> {
509 pub(crate) fn annotate_expr(
510 &self,
511 expr: &GenericExpr<CorrespondingVar<String, String>, String>,
512 typeinfo: &TypeInfo,
513 ) -> ResolvedExpr {
514 match &expr {
515 GenericExpr::Lit(span, literal) => ResolvedExpr::Lit(span.clone(), literal.clone()),
516 GenericExpr::Var(span, var) => {
517 let global_sort = typeinfo.get_global_sort(var);
518 let ty = global_sort
519 .or_else(|| self.get(&AtomTerm::Var(Span::Panic, var.clone())))
521 .expect("All variables should be assigned before annotation");
522 ResolvedExpr::Var(
523 span.clone(),
524 ResolvedVar {
525 name: var.clone(),
526 sort: ty.clone(),
527 is_global_ref: global_sort.is_some(),
528 },
529 )
530 }
531 GenericExpr::Call(
532 span,
533 CorrespondingVar {
534 head,
535 to: corresponding_var,
536 },
537 args,
538 ) => {
539 let args: Vec<_> = args
541 .iter()
542 .map(|arg| self.annotate_expr(arg, typeinfo))
543 .collect();
544 let types: Vec<_> = args
545 .iter()
546 .map(|arg| arg.output_type())
547 .chain(once(
548 self.get(&AtomTerm::Var(span.clone(), corresponding_var.clone()))
549 .unwrap()
550 .clone(),
551 ))
552 .collect();
553 let resolved_call = ResolvedCall::from_resolution(head, &types, typeinfo);
554 GenericExpr::Call(span.clone(), resolved_call, args)
555 }
556 }
557 }
558
559 pub(crate) fn annotate_fact(
560 &self,
561 facts: &GenericFact<CorrespondingVar<String, String>, String>,
562 typeinfo: &TypeInfo,
563 ) -> ResolvedFact {
564 match facts {
565 GenericFact::Eq(span, e1, e2) => ResolvedFact::Eq(
566 span.clone(),
567 self.annotate_expr(e1, typeinfo),
568 self.annotate_expr(e2, typeinfo),
569 ),
570 GenericFact::Fact(expr) => ResolvedFact::Fact(self.annotate_expr(expr, typeinfo)),
571 }
572 }
573
574 pub(crate) fn annotate_facts(
575 &self,
576 mapped_facts: &[GenericFact<CorrespondingVar<String, String>, String>],
577 typeinfo: &TypeInfo,
578 ) -> Vec<ResolvedFact> {
579 mapped_facts
580 .iter()
581 .map(|fact| self.annotate_fact(fact, typeinfo))
582 .collect()
583 }
584
585 pub(crate) fn annotate_action(
586 &self,
587 action: &MappedAction,
588 typeinfo: &TypeInfo,
589 ) -> Result<ResolvedAction, TypeError> {
590 match action {
591 GenericAction::Let(span, var, expr) => {
592 let ty = self
593 .get(&AtomTerm::Var(span.clone(), var.clone()))
594 .expect("All variables should be assigned before annotation");
595 Ok(ResolvedAction::Let(
596 span.clone(),
597 ResolvedVar {
598 name: var.clone(),
599 sort: ty.clone(),
600 is_global_ref: false,
601 },
602 self.annotate_expr(expr, typeinfo),
603 ))
604 }
605 GenericAction::Set(
607 span,
608 CorrespondingVar {
609 head,
610 to: _mapped_var,
611 },
612 children,
613 rhs,
614 ) => {
615 let children: Vec<_> = children
616 .iter()
617 .map(|child| self.annotate_expr(child, typeinfo))
618 .collect();
619 let rhs = self.annotate_expr(rhs, typeinfo);
620 let types: Vec<_> = children
621 .iter()
622 .map(|child| child.output_type())
623 .chain(once(rhs.output_type()))
624 .collect();
625 let resolved_call = ResolvedCall::from_resolution(head, &types, typeinfo);
626 if !matches!(resolved_call, ResolvedCall::Func(_)) {
627 return Err(TypeError::UnboundFunction(head.clone(), span.clone()));
628 }
629 Ok(ResolvedAction::Set(
630 span.clone(),
631 resolved_call,
632 children,
633 rhs,
634 ))
635 }
636 GenericAction::Change(
638 span,
639 change,
640 CorrespondingVar {
641 head,
642 to: _mapped_var,
643 },
644 children,
645 ) => {
646 let children: Vec<_> = children
647 .iter()
648 .map(|child| self.annotate_expr(child, typeinfo))
649 .collect();
650 let types: Vec<_> = children.iter().map(|child| child.output_type()).collect();
651 let resolved_call =
652 ResolvedCall::from_resolution_func_types(head, &types, typeinfo)
653 .ok_or_else(|| TypeError::UnboundFunction(head.clone(), span.clone()))?;
654 Ok(ResolvedAction::Change(
655 span.clone(),
656 *change,
657 resolved_call,
658 children.clone(),
659 ))
660 }
661 GenericAction::Union(span, lhs, rhs) => {
662 let lhs = self.annotate_expr(lhs, typeinfo);
663 let rhs = self.annotate_expr(rhs, typeinfo);
664
665 let sort = lhs.output_type();
666 assert_eq!(sort.name(), rhs.output_type().name());
667 if !sort.is_eq_sort() {
668 return Err(TypeError::NonEqsortUnion(sort, span.clone()));
669 }
670
671 Ok(ResolvedAction::Union(span.clone(), lhs, rhs))
672 }
673 GenericAction::Panic(span, msg) => Ok(ResolvedAction::Panic(span.clone(), msg.clone())),
674 GenericAction::Expr(span, expr) => Ok(ResolvedAction::Expr(
675 span.clone(),
676 self.annotate_expr(expr, typeinfo),
677 )),
678 }
679 }
680
681 pub(crate) fn annotate_actions(
682 &self,
683 mapped_actions: &GenericActions<CorrespondingVar<String, String>, String>,
684 typeinfo: &TypeInfo,
685 ) -> Result<ResolvedActions, TypeError> {
686 let actions = mapped_actions
687 .iter()
688 .map(|action| self.annotate_action(action, typeinfo))
689 .collect::<Result<_, _>>()?;
690
691 Ok(ResolvedActions::new(actions))
692 }
693}
694
695impl<Var, Value> Problem<Var, Value>
696where
697 Var: cmp::Eq + PartialEq + Hash + Clone + Debug + 'static,
698 Value: Clone + Debug + 'static,
699{
700 pub(crate) fn solve(
701 mut self,
702 key: fn(&Value) -> &str,
703 ) -> Result<Assignment<Var, Value>, ConstraintError<Var, Value>> {
704 let mut assignment = Assignment(HashMap::default());
705 let mut changed = true;
706 while changed {
707 changed = false;
708 for constraint in self.constraints.iter_mut() {
709 changed |= constraint.update(&mut assignment, key)?;
710 }
711 }
712
713 for v in self.range.iter() {
714 if !assignment.0.contains_key(v) {
715 return Err(ConstraintError::UnconstrainedVar(v.clone()));
716 }
717 }
718 Ok(assignment)
719 }
720
721 pub(crate) fn add_binding(&mut self, var: Var, clone: Value) {
722 self.constraints.push(constraint::assign(var, clone));
723 }
724}
725
726impl Problem<AtomTerm, ArcSort> {
727 pub(crate) fn add_query(
728 &mut self,
729 query: &Query<StringOrEq, String>,
730 typeinfo: &TypeInfo,
731 ) -> Result<(), TypeError> {
732 self.constraints.extend(query.get_constraints(typeinfo)?);
733 self.range.extend(query.atom_terms());
734 Ok(())
735 }
736
737 pub(crate) fn add_actions(
738 &mut self,
739 actions: &GenericCoreActions<String, String>,
740 typeinfo: &TypeInfo,
741 symbol_gen: &mut SymbolGen,
742 ) -> Result<(), TypeError> {
743 for action in actions.0.iter() {
744 self.constraints
745 .extend(action.get_constraints(typeinfo, symbol_gen)?);
746
747 match action {
749 CoreAction::Let(span, var, _, _) => {
750 self.range.insert(AtomTerm::Var(span.clone(), var.clone()));
751 }
752 CoreAction::LetAtomTerm(span, v, _) => {
753 self.range.insert(AtomTerm::Var(span.clone(), v.clone()));
754 }
755 _ => (),
756 }
757 }
758 Ok(())
759 }
760
761 pub(crate) fn add_rule(
762 &mut self,
763 rule: &CoreRule,
764 typeinfo: &TypeInfo,
765 symbol_gen: &mut SymbolGen,
766 ) -> Result<(), TypeError> {
767 let CoreRule {
768 span: _,
769 head,
770 body,
771 } = rule;
772 self.add_query(body, typeinfo)?;
773 self.add_actions(head, typeinfo, symbol_gen)?;
774 Ok(())
775 }
776
777 pub(crate) fn assign_local_var_type(
778 &mut self,
779 var: &str,
780 span: Span,
781 sort: ArcSort,
782 ) -> Result<(), TypeError> {
783 self.add_binding(AtomTerm::Var(span.clone(), var.to_owned()), sort);
784 self.range.insert(AtomTerm::Var(span, var.to_owned()));
785 Ok(())
786 }
787}
788
789impl CoreAction {
790 pub(crate) fn get_constraints(
791 &self,
792 typeinfo: &TypeInfo,
793 symbol_gen: &mut SymbolGen,
794 ) -> Result<Vec<Box<dyn Constraint<AtomTerm, ArcSort>>>, TypeError> {
795 match self {
796 CoreAction::Let(span, symbol, f, args) => {
797 let mut args = args.clone();
798 args.push(AtomTerm::Var(span.clone(), symbol.clone()));
799
800 Ok(get_literal_and_global_constraints(&args, typeinfo)
801 .chain(get_atom_application_constraints(f, &args, span, typeinfo)?)
802 .collect())
803 }
804 CoreAction::Set(span, head, args, rhs) => {
805 let mut args = args.clone();
806 args.push(rhs.clone());
807
808 Ok(get_literal_and_global_constraints(&args, typeinfo)
809 .chain(get_atom_application_constraints(
810 head, &args, span, typeinfo,
811 )?)
812 .collect())
813 }
814 CoreAction::Change(span, _change, head, args) => {
815 let mut args = args.clone();
816 let var = symbol_gen.fresh(head);
818 args.push(AtomTerm::Var(span.clone(), var));
819
820 Ok(get_literal_and_global_constraints(&args, typeinfo)
821 .chain(get_atom_application_constraints(
822 head, &args, span, typeinfo,
823 )?)
824 .collect())
825 }
826 CoreAction::Union(_ann, lhs, rhs) => Ok(get_literal_and_global_constraints(
827 &[lhs.clone(), rhs.clone()],
828 typeinfo,
829 )
830 .chain(once(constraint::eq(lhs.clone(), rhs.clone())))
831 .collect()),
832 CoreAction::Panic(_ann, _) => Ok(vec![]),
833 CoreAction::LetAtomTerm(span, v, at) => {
834 Ok(get_literal_and_global_constraints(&[at.clone()], typeinfo)
835 .chain(once(constraint::eq(
836 AtomTerm::Var(span.clone(), v.clone()),
837 at.clone(),
838 )))
839 .collect())
840 }
841 }
842 }
843}
844
845impl Atom<StringOrEq> {
846 pub(crate) fn get_constraints(
847 &self,
848 type_info: &TypeInfo,
849 ) -> Result<Vec<Box<dyn Constraint<AtomTerm, ArcSort>>>, TypeError> {
850 let literal_constraints = get_literal_and_global_constraints(&self.args, type_info);
851 match &self.head {
852 StringOrEq::Eq => {
853 assert_eq!(self.args.len(), 2);
854 let constraints = literal_constraints
855 .chain(once(constraint::eq(
856 self.args[0].clone(),
857 self.args[1].clone(),
858 )))
859 .collect();
860 Ok(constraints)
861 }
862 StringOrEq::Head(head) => Ok(literal_constraints
863 .chain(get_atom_application_constraints(
864 head, &self.args, &self.span, type_info,
865 )?)
866 .collect()),
867 }
868 }
869}
870
871fn get_atom_application_constraints(
872 head: &str,
873 args: &[AtomTerm],
874 span: &Span,
875 type_info: &TypeInfo,
876) -> Result<Vec<Box<dyn Constraint<AtomTerm, ArcSort>>>, TypeError> {
877 let mut xor_constraints: Vec<Vec<Box<dyn Constraint<AtomTerm, ArcSort>>>> = vec![];
884
885 if let Some(typ) = type_info.get_func_type(head) {
887 let mut constraints = vec![];
888 if typ.input.len() + 1 != args.len() {
890 constraints.push(constraint::impossible(
891 ImpossibleConstraint::ArityMismatch {
892 atom: Atom {
893 span: span.clone(),
894 head: head.to_owned(),
895 args: args.to_vec(),
896 },
897 expected: typ.input.len() + 1,
898 },
899 ));
900 } else {
901 for (arg_typ, arg) in typ
902 .input
903 .iter()
904 .cloned()
905 .chain(once(typ.output.clone()))
906 .zip(args.iter().cloned())
907 {
908 constraints.push(constraint::assign(arg, arg_typ));
909 }
910 }
911 xor_constraints.push(constraints);
912 }
913
914 if let Some(primitives) = type_info.get_prims(head) {
916 for p in primitives {
917 let constraints = p.0.get_type_constraints(span).get(args, type_info);
918 xor_constraints.push(constraints);
919 }
920 }
921
922 match xor_constraints.len() {
925 0 => Err(TypeError::UnboundFunction(head.to_owned(), span.clone())),
926 1 => Ok(xor_constraints.pop().unwrap()),
927 _ => Ok(vec![constraint::xor(
928 xor_constraints.into_iter().map(constraint::and).collect(),
929 )]),
930 }
931}
932
933fn get_literal_and_global_constraints<'a>(
934 args: &'a [AtomTerm],
935 type_info: &'a TypeInfo,
936) -> impl Iterator<Item = Box<dyn Constraint<AtomTerm, ArcSort>>> + 'a {
937 args.iter().filter_map(|arg| {
938 match arg {
939 AtomTerm::Var(_, _) => None,
940 AtomTerm::Literal(_, lit) => {
942 let typ = crate::sort::literal_sort(lit);
943 Some(constraint::assign(arg.clone(), typ) as Box<dyn Constraint<AtomTerm, ArcSort>>)
944 }
945 AtomTerm::Global(_, v) => {
946 if let Some(typ) = type_info.get_global_sort(v) {
947 Some(constraint::assign(arg.clone(), typ.clone()))
948 } else {
949 panic!("All global variables should be bound before type checking")
950 }
951 }
952 }
953 })
954}
955
956pub trait TypeConstraint {
959 fn get(
962 &self,
963 arguments: &[AtomTerm],
964 typeinfo: &TypeInfo,
965 ) -> Vec<Box<dyn Constraint<AtomTerm, ArcSort>>>;
966}
967
968pub struct SimpleTypeConstraint {
971 name: String,
972 sorts: Vec<ArcSort>,
973 span: Span,
974}
975
976impl SimpleTypeConstraint {
977 pub fn new(name: &str, sorts: Vec<ArcSort>, span: Span) -> SimpleTypeConstraint {
979 let name = name.to_owned();
980 SimpleTypeConstraint { name, sorts, span }
981 }
982
983 pub fn into_box(self) -> Box<dyn TypeConstraint> {
985 Box::new(self)
986 }
987}
988
989impl TypeConstraint for SimpleTypeConstraint {
990 fn get(
991 &self,
992 arguments: &[AtomTerm],
993 _typeinfo: &TypeInfo,
994 ) -> Vec<Box<dyn Constraint<AtomTerm, ArcSort>>> {
995 if arguments.len() != self.sorts.len() {
996 vec![constraint::impossible(
997 ImpossibleConstraint::ArityMismatch {
998 atom: Atom {
999 span: self.span.clone(),
1000 head: self.name.clone(),
1001 args: arguments.to_vec(),
1002 },
1003 expected: self.sorts.len(),
1004 },
1005 )]
1006 } else {
1007 arguments
1008 .iter()
1009 .cloned()
1010 .zip(self.sorts.iter().cloned())
1011 .map(|(arg, sort)| constraint::assign(arg, sort))
1012 .collect()
1013 }
1014 }
1015}
1016
1017pub struct AllEqualTypeConstraint {
1022 name: String,
1023 sort: Option<ArcSort>,
1024 exact_length: Option<usize>,
1025 output: Option<ArcSort>,
1026 span: Span,
1027}
1028
1029impl AllEqualTypeConstraint {
1030 pub fn new(name: &str, span: Span) -> AllEqualTypeConstraint {
1032 AllEqualTypeConstraint {
1033 name: name.to_owned(),
1034 sort: None,
1035 exact_length: None,
1036 output: None,
1037 span,
1038 }
1039 }
1040
1041 pub fn into_box(self) -> Box<dyn TypeConstraint> {
1043 Box::new(self)
1044 }
1045
1046 pub fn with_all_arguments_sort(mut self, sort: ArcSort) -> Self {
1050 self.sort = Some(sort);
1051 self
1052 }
1053
1054 pub fn with_exact_length(mut self, exact_length: usize) -> Self {
1057 self.exact_length = Some(exact_length);
1058 self
1059 }
1060
1061 pub fn with_output_sort(mut self, output_sort: ArcSort) -> Self {
1063 self.output = Some(output_sort);
1064 self
1065 }
1066}
1067
1068impl TypeConstraint for AllEqualTypeConstraint {
1069 fn get(
1070 &self,
1071 mut arguments: &[AtomTerm],
1072 _typeinfo: &TypeInfo,
1073 ) -> Vec<Box<dyn Constraint<AtomTerm, ArcSort>>> {
1074 if arguments.is_empty() {
1075 panic!("all arguments should have length > 0")
1076 }
1077
1078 match self.exact_length {
1079 Some(exact_length) if exact_length != arguments.len() => {
1080 return vec![constraint::impossible(
1081 ImpossibleConstraint::ArityMismatch {
1082 atom: Atom {
1083 span: self.span.clone(),
1084 head: self.name.clone(),
1085 args: arguments.to_vec(),
1086 },
1087 expected: exact_length,
1088 },
1089 )];
1090 }
1091 _ => (),
1092 }
1093
1094 let mut constraints = vec![];
1095 if let Some(output) = self.output.clone() {
1096 let (out, inputs) = arguments.split_last().unwrap();
1097 constraints.push(constraint::assign(out.clone(), output));
1098 arguments = inputs;
1099 }
1100
1101 if let Some(sort) = self.sort.clone() {
1102 constraints.extend(
1103 arguments
1104 .iter()
1105 .cloned()
1106 .map(|arg| constraint::assign(arg, sort.clone())),
1107 )
1108 } else if let Some((first, rest)) = arguments.split_first() {
1109 constraints.extend(
1110 rest.iter()
1111 .cloned()
1112 .map(|arg| constraint::eq(arg, first.clone())),
1113 );
1114 }
1115 constraints
1116 }
1117}
1118
1119pub(crate) fn grounded_check(
1123 rule: &GenericCoreRule<HeadOrEq<ResolvedCall>, ResolvedCall, ResolvedVar>,
1124) -> Result<(), TypeError> {
1125 use crate::core::ResolvedAtomTerm;
1126 let body = &rule.body;
1127
1128 let range = rule
1129 .body
1130 .get_vars()
1131 .into_iter()
1132 .map(|v| ResolvedAtomTerm::Var(rule.span.clone(), v))
1133 .collect();
1134 let mut problem: Problem<ResolvedAtomTerm, ()> = Problem {
1135 constraints: vec![],
1136 range,
1137 };
1138
1139 for atom in body.atoms.iter() {
1140 let mut add_global_and_literal = false;
1141 match &atom.head {
1142 HeadOrEq::Head(ResolvedCall::Func(_)) => {
1143 for arg in atom.args.iter() {
1144 problem.constraints.push(assign(arg.clone(), ()));
1145 }
1146 }
1147 HeadOrEq::Head(ResolvedCall::Primitive(_)) => {
1148 let (out, inp) = atom.args.split_last().unwrap();
1149 let out = out.clone();
1150 problem.constraints.push(implies(
1151 format!("grounded_{:?}", out),
1152 inp.to_vec(),
1153 Rc::new(move |_| assign(out.clone(), ())),
1154 ));
1155 add_global_and_literal = true;
1156 }
1157 HeadOrEq::Eq => {
1158 assert_eq!(atom.args.len(), 2);
1159 problem
1160 .constraints
1161 .push(eq(atom.args[0].clone(), atom.args[1].clone()));
1162 add_global_and_literal = true;
1163 }
1164 }
1165 if add_global_and_literal {
1166 for arg in atom.args.iter() {
1167 match arg {
1168 ResolvedAtomTerm::Global(..) | ResolvedAtomTerm::Literal(..) => {
1169 problem.constraints.push(assign(arg.clone(), ()));
1170 }
1171 ResolvedAtomTerm::Var(..) => {}
1172 }
1173 }
1174 }
1175 }
1176
1177 let _assignment = problem.solve(|_| "grounded").map_err(|err| match err {
1178 ConstraintError::UnconstrainedVar(ResolvedAtomTerm::Var(span, v)) => {
1179 TypeError::Ungrounded(v.to_string(), span)
1180 }
1181 _ => panic!(
1182 "unexpected constraint error in groundedness check {:?}",
1183 err
1184 ),
1185 })?;
1186
1187 Ok(())
1188}