1use std::sync::Mutex;
14
15use super::*;
16
17#[derive(Clone, Debug)]
18pub struct FunctionContainer(ResolvedFunctionId, pub Vec<(ArcSort, Value)>, pub String);
19
20impl PartialEq for FunctionContainer {
24 fn eq(&self, other: &Self) -> bool {
25 self.0 == other.0
26 && self.1.iter().map(|(_, v)| *v).collect::<Vec<_>>()
27 == other.1.iter().map(|(_, v)| *v).collect::<Vec<_>>()
28 && self.2 == other.2
29 }
30}
31
32impl Eq for FunctionContainer {}
33
34impl Hash for FunctionContainer {
35 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
36 self.0.hash(state);
37 for (_, v) in &self.1 {
38 v.hash(state);
39 }
40 self.2.hash(state);
41 }
42}
43
44impl ContainerValue for FunctionContainer {
45 fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool {
46 let mut changed = false;
47 for (s, old) in &mut self.1 {
48 if s.is_eq_sort() || s.is_eq_container_sort() {
49 let new = rebuilder.rebuild_val(*old);
50 changed |= *old != new;
51 *old = new;
52 }
53 }
54 changed
55 }
56 fn iter(&self) -> impl Iterator<Item = Value> + '_ {
57 self.1.iter().map(|(_, v)| v).copied()
58 }
59}
60
61#[derive(Debug)]
62pub struct FunctionSort {
63 name: String,
64 inputs: Vec<ArcSort>,
65 output: ArcSort,
66 partial_arcsorts: Arc<Mutex<Vec<ArcSort>>>,
69}
70
71impl FunctionSort {
72 pub fn name(&self) -> &str {
73 &self.name
74 }
75
76 pub fn inputs(&self) -> &[ArcSort] {
77 &self.inputs
78 }
79
80 pub fn output(&self) -> ArcSort {
81 self.output.clone()
82 }
83}
84
85impl Presort for FunctionSort {
86 fn presort_name() -> &'static str {
87 "UnstableFn"
88 }
89
90 fn reserved_primitives() -> Vec<&'static str> {
91 vec!["unstable-fn", "unstable-app"]
92 }
93
94 fn make_sort(
95 typeinfo: &mut TypeInfo,
96 name: String,
97 args: &[Expr],
98 ) -> Result<ArcSort, TypeError> {
99 if let [inputs, Expr::Var(span, output)] = args {
100 let output_sort = typeinfo
101 .get_sort_by_name(output)
102 .ok_or(TypeError::UndefinedSort(output.clone(), span.clone()))?;
103
104 let input_sorts = match inputs {
105 Expr::Call(_, first, rest_args) => {
106 let all_args = once(first).chain(rest_args.iter().map(|arg| {
107 if let Expr::Var(_, arg) = arg {
108 arg
109 } else {
110 panic!("function sort must be called with list of input sorts");
111 }
112 }));
113 all_args
114 .map(|arg| {
115 typeinfo
116 .get_sort_by_name(arg)
117 .ok_or(TypeError::UndefinedSort(arg.clone(), span.clone()))
118 .cloned()
119 })
120 .collect::<Result<Vec<_>, _>>()?
121 }
122 Expr::Lit(_, Literal::Unit) => vec![],
124 _ => panic!("function sort must be called with list of input sorts"),
125 };
126
127 Ok(Arc::new(Self {
128 name,
129 inputs: input_sorts,
130 output: output_sort.clone(),
131 partial_arcsorts: Arc::new(Mutex::new(vec![])),
132 }))
133 } else {
134 panic!("function sort must be called with list of input args and output sort");
135 }
136 }
137}
138
139impl Sort for FunctionSort {
140 fn name(&self) -> &str {
141 &self.name
142 }
143
144 fn column_ty(&self, _backend: &egglog_bridge::EGraph) -> ColumnTy {
145 ColumnTy::Id
146 }
147
148 fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
149 backend.register_container_ty::<FunctionContainer>();
150 backend
151 .base_values_mut()
152 .register_type::<ResolvedFunction>();
153 }
154
155 fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
156 self
157 }
158
159 fn is_container_sort(&self) -> bool {
160 true
161 }
162
163 fn is_eq_container_sort(&self) -> bool {
164 self.inputs.iter().any(|s| s.is_eq_sort())
165 }
166
167 fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String {
168 let val = container_values
169 .get_val::<FunctionContainer>(value)
170 .unwrap();
171 val.2.clone()
172 }
173
174 fn inner_sorts(&self) -> Vec<ArcSort> {
175 self.partial_arcsorts.lock().unwrap().clone()
176 }
177
178 fn inner_values(
179 &self,
180 container_values: &ContainerValues,
181 value: Value,
182 ) -> Vec<(ArcSort, Value)> {
183 let val = container_values
184 .get_val::<FunctionContainer>(value)
185 .unwrap();
186 val.1.clone()
187 }
188
189 fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
190 eg.add_primitive(Ctor {
191 name: "unstable-fn".into(),
192 function: self.clone(),
193 });
194 eg.add_primitive(Apply {
195 name: "unstable-app".into(),
196 function: self.clone(),
197 });
198 }
199
200 fn value_type(&self) -> Option<TypeId> {
201 Some(TypeId::of::<FunctionContainer>())
202 }
203
204 fn reconstruct_termdag_container(
205 &self,
206 container_values: &ContainerValues,
207 value: Value,
208 termdag: &mut TermDag,
209 mut element_terms: Vec<Term>,
210 ) -> Term {
211 let name = &container_values
212 .get_val::<FunctionContainer>(value)
213 .unwrap()
214 .2;
215 let head = termdag.lit(Literal::String(name.clone()));
216 element_terms.insert(0, head);
217 termdag.app("unstable-fn".to_owned(), element_terms)
218 }
219}
220
221struct FunctionCTorTypeConstraint {
223 name: String,
224 function: Arc<FunctionSort>,
225 span: Span,
226}
227
228impl TypeConstraint for FunctionCTorTypeConstraint {
229 fn get(
230 &self,
231 arguments: &[AtomTerm],
232 typeinfo: &TypeInfo,
233 ) -> Vec<Box<dyn Constraint<AtomTerm, ArcSort>>> {
234 if arguments.len() < 2 {
236 return vec![constraint::impossible(
237 constraint::ImpossibleConstraint::ArityMismatch {
238 atom: core::Atom {
239 span: self.span.clone(),
240 head: self.name.clone(),
241 args: arguments.to_vec(),
242 },
243 expected: 2,
244 },
245 )];
246 }
247 let output_sort_constraint: Box<dyn Constraint<_, ArcSort>> = constraint::assign(
248 arguments[arguments.len() - 1].clone(),
249 self.function.clone(),
250 );
251 if let AtomTerm::Literal(_, Literal::String(ref name)) = arguments[0] {
254 if let Some(func_type) = typeinfo.get_func_type(name) {
255 let n_partial_args = arguments.len() - 2;
257 if self.function.inputs.len() + n_partial_args != func_type.input.len() {
260 return vec![constraint::impossible(
261 constraint::ImpossibleConstraint::ArityMismatch {
262 atom: core::Atom {
263 span: self.span.clone(),
264 head: self.name.clone(),
265 args: arguments.to_vec(),
266 },
267 expected: self.function.inputs.len() + func_type.input.len() + 1,
268 },
269 )];
270 }
271 let expected_output = self.function.output.clone();
273 let expected_input = self.function.inputs.clone();
274 let actual_output = func_type.output.clone();
275 let actual_input: Vec<ArcSort> = func_type
276 .input
277 .iter()
278 .skip(n_partial_args)
279 .cloned()
280 .collect();
281 if expected_output.name() != actual_output.name()
282 || expected_input
283 .iter()
284 .map(|s| s.name())
285 .ne(actual_input.iter().map(|s| s.name()))
286 {
287 return vec![constraint::impossible(
288 constraint::ImpossibleConstraint::FunctionMismatch {
289 expected_output,
290 expected_input,
291 actual_output,
292 actual_input,
293 },
294 )];
295 }
296 return func_type
298 .input
299 .iter()
300 .take(n_partial_args)
301 .zip(arguments.iter().skip(1))
302 .map(|(expected_sort, actual_term)| {
303 constraint::assign(actual_term.clone(), expected_sort.clone())
304 })
305 .chain(once(output_sort_constraint))
306 .collect();
307 }
308 }
309
310 vec![
312 constraint::assign(arguments[0].clone(), StringSort.to_arcsort()),
313 output_sort_constraint,
314 ]
315 }
316}
317
318#[derive(Clone)]
320struct Ctor {
321 name: String,
322 function: Arc<FunctionSort>,
323}
324
325impl Primitive for Ctor {
326 fn name(&self) -> &str {
327 &self.name
328 }
329
330 fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
331 Box::new(FunctionCTorTypeConstraint {
332 name: self.name.clone(),
333 function: self.function.clone(),
334 span: span.clone(),
335 })
336 }
337
338 fn apply(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
339 let (rf, args) = args.split_first().unwrap();
340 let ResolvedFunction {
341 id,
342 partial_arcsorts,
343 name,
344 } = exec_state.base_values().unwrap(*rf);
345 self.function
346 .partial_arcsorts
347 .lock()
348 .unwrap()
349 .extend(partial_arcsorts.iter().cloned());
350 let args = partial_arcsorts
351 .iter()
352 .zip(args)
353 .map(|(b, x)| (b.clone(), *x))
354 .collect();
355 let y = FunctionContainer(id, args, name);
356 Some(
357 exec_state
358 .clone()
359 .container_values()
360 .register_val(y, exec_state),
361 )
362 }
363}
364
365#[derive(Clone, Debug)]
366pub struct ResolvedFunction {
367 pub id: ResolvedFunctionId,
368 pub partial_arcsorts: Vec<ArcSort>,
369 pub name: String,
370}
371impl PartialEq for ResolvedFunction {
374 fn eq(&self, other: &Self) -> bool {
375 self.id == other.id
376 && self
377 .partial_arcsorts
378 .iter()
379 .map(|s| s.name())
380 .collect::<Vec<_>>()
381 == other
382 .partial_arcsorts
383 .iter()
384 .map(|s| s.name())
385 .collect::<Vec<_>>()
386 }
387}
388
389impl Eq for ResolvedFunction {}
390
391impl Hash for ResolvedFunction {
392 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
393 self.id.hash(state);
394 for s in &self.partial_arcsorts {
395 s.name().hash(state);
396 }
397 }
398}
399
400impl BaseValue for ResolvedFunction {}
401
402#[derive(Clone, Debug, PartialEq, Eq, Hash)]
403pub enum ResolvedFunctionId {
404 Lookup(egglog_bridge::TableAction),
405 Prim(ExternalFunctionId),
406}
407
408#[derive(Clone)]
410struct Apply {
411 name: String,
412 function: Arc<FunctionSort>,
413}
414
415impl Primitive for Apply {
416 fn name(&self) -> &str {
417 &self.name
418 }
419
420 fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
421 let mut sorts: Vec<ArcSort> = vec![self.function.clone()];
422 sorts.extend(self.function.inputs.clone());
423 sorts.push(self.function.output.clone());
424 SimpleTypeConstraint::new(self.name(), sorts, span.clone()).into_box()
425 }
426
427 fn apply(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
428 let (fc, args) = args.split_first().unwrap();
429 let fc = exec_state
430 .container_values()
431 .get_val::<FunctionContainer>(*fc)
432 .unwrap()
433 .clone();
434 fc.apply(exec_state, args)
435 }
436}
437
438impl FunctionContainer {
439 pub fn apply(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
443 let args: Vec<_> = self.1.iter().map(|(_, x)| x).chain(args).copied().collect();
444 match &self.0 {
445 ResolvedFunctionId::Lookup(action) => action.lookup(exec_state, &args),
446 ResolvedFunctionId::Prim(prim) => exec_state.call_external_func(*prim, &args),
447 }
448 }
449}