egglog/sort/
fn.rs

1//! Sort to represent functions as values.
2//!
3//! To declare the sort, you must specify the exact number of arguments and the sort of each, followed by the output sort:
4//! `(sort IntToString (UnstableFn (i64) String))`
5//!
6//! To create a function value, use the `(unstable-fn "name" [<partial args>])` primitive and to apply it use the `(unstable-app function arg1 arg2 ...)` primitive.
7//! The number of args must match the number of arguments in the function sort.
8//!
9//!
10//! The value is stored similar to the `vec` sort, as an index into a set, where each item in
11//! the set is a `(Symbol, Vec<(Sort, Value)>)` pairs. The Symbol is the function name, and the `Vec<(Sort, Value)>` is
12//! the list of partially applied arguments.
13use std::sync::Mutex;
14
15use super::*;
16
17#[derive(Clone, Debug)]
18pub struct FunctionContainer(ResolvedFunctionId, pub Vec<(ArcSort, Value)>, pub String);
19
20// implement hash and equality based on values only not arcsorts, since
21// arcsorts are not comparable and any two values that are equal must have the same sort
22
23impl PartialEq for FunctionContainer {
24    fn eq(&self, other: &Self) -> bool {
25        self.0 == other.0
26            && self.1.iter().map(|(_, v)| *v).collect::<Vec<_>>()
27                == other.1.iter().map(|(_, v)| *v).collect::<Vec<_>>()
28            && self.2 == other.2
29    }
30}
31
32impl Eq for FunctionContainer {}
33
34impl Hash for FunctionContainer {
35    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
36        self.0.hash(state);
37        for (_, v) in &self.1 {
38            v.hash(state);
39        }
40        self.2.hash(state);
41    }
42}
43
44impl ContainerValue for FunctionContainer {
45    fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool {
46        let mut changed = false;
47        for (s, old) in &mut self.1 {
48            if s.is_eq_sort() || s.is_eq_container_sort() {
49                let new = rebuilder.rebuild_val(*old);
50                changed |= *old != new;
51                *old = new;
52            }
53        }
54        changed
55    }
56    fn iter(&self) -> impl Iterator<Item = Value> + '_ {
57        self.1.iter().map(|(_, v)| v).copied()
58    }
59}
60
61#[derive(Debug)]
62pub struct FunctionSort {
63    name: String,
64    inputs: Vec<ArcSort>,
65    output: ArcSort,
66    // store all the arcsorts for functions that were added as partial args to this function sort
67    // so that we can retrieve them during extraction
68    partial_arcsorts: Arc<Mutex<Vec<ArcSort>>>,
69}
70
71impl FunctionSort {
72    pub fn name(&self) -> &str {
73        &self.name
74    }
75
76    pub fn inputs(&self) -> &[ArcSort] {
77        &self.inputs
78    }
79
80    pub fn output(&self) -> ArcSort {
81        self.output.clone()
82    }
83}
84
85impl Presort for FunctionSort {
86    fn presort_name() -> &'static str {
87        "UnstableFn"
88    }
89
90    fn reserved_primitives() -> Vec<&'static str> {
91        vec!["unstable-fn", "unstable-app"]
92    }
93
94    fn make_sort(
95        typeinfo: &mut TypeInfo,
96        name: String,
97        args: &[Expr],
98    ) -> Result<ArcSort, TypeError> {
99        if let [inputs, Expr::Var(span, output)] = args {
100            let output_sort = typeinfo
101                .get_sort_by_name(output)
102                .ok_or(TypeError::UndefinedSort(output.clone(), span.clone()))?;
103
104            let input_sorts = match inputs {
105                Expr::Call(_, first, rest_args) => {
106                    let all_args = once(first).chain(rest_args.iter().map(|arg| {
107                        if let Expr::Var(_, arg) = arg {
108                            arg
109                        } else {
110                            panic!("function sort must be called with list of input sorts");
111                        }
112                    }));
113                    all_args
114                        .map(|arg| {
115                            typeinfo
116                                .get_sort_by_name(arg)
117                                .ok_or(TypeError::UndefinedSort(arg.clone(), span.clone()))
118                                .cloned()
119                        })
120                        .collect::<Result<Vec<_>, _>>()?
121                }
122                // an empty list of inputs args is parsed as a unit literal
123                Expr::Lit(_, Literal::Unit) => vec![],
124                _ => panic!("function sort must be called with list of input sorts"),
125            };
126
127            Ok(Arc::new(Self {
128                name,
129                inputs: input_sorts,
130                output: output_sort.clone(),
131                partial_arcsorts: Arc::new(Mutex::new(vec![])),
132            }))
133        } else {
134            panic!("function sort must be called with list of input args and output sort");
135        }
136    }
137}
138
139impl Sort for FunctionSort {
140    fn name(&self) -> &str {
141        &self.name
142    }
143
144    fn column_ty(&self, _backend: &egglog_bridge::EGraph) -> ColumnTy {
145        ColumnTy::Id
146    }
147
148    fn register_type(&self, backend: &mut egglog_bridge::EGraph) {
149        backend.register_container_ty::<FunctionContainer>();
150        backend
151            .base_values_mut()
152            .register_type::<ResolvedFunction>();
153    }
154
155    fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
156        self
157    }
158
159    fn is_container_sort(&self) -> bool {
160        true
161    }
162
163    fn is_eq_container_sort(&self) -> bool {
164        self.inputs.iter().any(|s| s.is_eq_sort())
165    }
166
167    fn serialized_name(&self, container_values: &ContainerValues, value: Value) -> String {
168        let val = container_values
169            .get_val::<FunctionContainer>(value)
170            .unwrap();
171        val.2.clone()
172    }
173
174    fn inner_sorts(&self) -> Vec<ArcSort> {
175        self.partial_arcsorts.lock().unwrap().clone()
176    }
177
178    fn inner_values(
179        &self,
180        container_values: &ContainerValues,
181        value: Value,
182    ) -> Vec<(ArcSort, Value)> {
183        let val = container_values
184            .get_val::<FunctionContainer>(value)
185            .unwrap();
186        val.1.clone()
187    }
188
189    fn register_primitives(self: Arc<Self>, eg: &mut EGraph) {
190        eg.add_primitive(Ctor {
191            name: "unstable-fn".into(),
192            function: self.clone(),
193        });
194        eg.add_primitive(Apply {
195            name: "unstable-app".into(),
196            function: self.clone(),
197        });
198    }
199
200    fn value_type(&self) -> Option<TypeId> {
201        Some(TypeId::of::<FunctionContainer>())
202    }
203
204    fn reconstruct_termdag_container(
205        &self,
206        container_values: &ContainerValues,
207        value: Value,
208        termdag: &mut TermDag,
209        mut element_terms: Vec<Term>,
210    ) -> Term {
211        let name = &container_values
212            .get_val::<FunctionContainer>(value)
213            .unwrap()
214            .2;
215        let head = termdag.lit(Literal::String(name.clone()));
216        element_terms.insert(0, head);
217        termdag.app("unstable-fn".to_owned(), element_terms)
218    }
219}
220
221/// Takes a string and any number of partially applied args of any sort and returns a function
222struct FunctionCTorTypeConstraint {
223    name: String,
224    function: Arc<FunctionSort>,
225    span: Span,
226}
227
228impl TypeConstraint for FunctionCTorTypeConstraint {
229    fn get(
230        &self,
231        arguments: &[AtomTerm],
232        typeinfo: &TypeInfo,
233    ) -> Vec<Box<dyn Constraint<AtomTerm, ArcSort>>> {
234        // Must have at least one arg (plus the return value)
235        if arguments.len() < 2 {
236            return vec![constraint::impossible(
237                constraint::ImpossibleConstraint::ArityMismatch {
238                    atom: core::Atom {
239                        span: self.span.clone(),
240                        head: self.name.clone(),
241                        args: arguments.to_vec(),
242                    },
243                    expected: 2,
244                },
245            )];
246        }
247        let output_sort_constraint: Box<dyn Constraint<_, ArcSort>> = constraint::assign(
248            arguments[arguments.len() - 1].clone(),
249            self.function.clone(),
250        );
251        // If first arg is a literal string and we know the name of the function and can use that to know what
252        // types to expect
253        if let AtomTerm::Literal(_, Literal::String(ref name)) = arguments[0] {
254            if let Some(func_type) = typeinfo.get_func_type(name) {
255                // The arguments contains the return sort as well as the function name
256                let n_partial_args = arguments.len() - 2;
257                // the number of partial args must match the number of inputs from the func type minus the number from
258                // this function sort
259                if self.function.inputs.len() + n_partial_args != func_type.input.len() {
260                    return vec![constraint::impossible(
261                        constraint::ImpossibleConstraint::ArityMismatch {
262                            atom: core::Atom {
263                                span: self.span.clone(),
264                                head: self.name.clone(),
265                                args: arguments.to_vec(),
266                            },
267                            expected: self.function.inputs.len() + func_type.input.len() + 1,
268                        },
269                    )];
270                }
271                // the output type and input types (starting after the partial args) must match between these functions
272                let expected_output = self.function.output.clone();
273                let expected_input = self.function.inputs.clone();
274                let actual_output = func_type.output.clone();
275                let actual_input: Vec<ArcSort> = func_type
276                    .input
277                    .iter()
278                    .skip(n_partial_args)
279                    .cloned()
280                    .collect();
281                if expected_output.name() != actual_output.name()
282                    || expected_input
283                        .iter()
284                        .map(|s| s.name())
285                        .ne(actual_input.iter().map(|s| s.name()))
286                {
287                    return vec![constraint::impossible(
288                        constraint::ImpossibleConstraint::FunctionMismatch {
289                            expected_output,
290                            expected_input,
291                            actual_output,
292                            actual_input,
293                        },
294                    )];
295                }
296                // if they match, then just make sure the partial args match as well
297                return func_type
298                    .input
299                    .iter()
300                    .take(n_partial_args)
301                    .zip(arguments.iter().skip(1))
302                    .map(|(expected_sort, actual_term)| {
303                        constraint::assign(actual_term.clone(), expected_sort.clone())
304                    })
305                    .chain(once(output_sort_constraint))
306                    .collect();
307            }
308        }
309
310        // Otherwise we just try assuming it's this function, we don't know if it is or not
311        vec![
312            constraint::assign(arguments[0].clone(), StringSort.to_arcsort()),
313            output_sort_constraint,
314        ]
315    }
316}
317
318// (unstable-fn "name" [<arg1>, <arg2>, ...])
319#[derive(Clone)]
320struct Ctor {
321    name: String,
322    function: Arc<FunctionSort>,
323}
324
325impl Primitive for Ctor {
326    fn name(&self) -> &str {
327        &self.name
328    }
329
330    fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
331        Box::new(FunctionCTorTypeConstraint {
332            name: self.name.clone(),
333            function: self.function.clone(),
334            span: span.clone(),
335        })
336    }
337
338    fn apply(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
339        let (rf, args) = args.split_first().unwrap();
340        let ResolvedFunction {
341            id,
342            partial_arcsorts,
343            name,
344        } = exec_state.base_values().unwrap(*rf);
345        self.function
346            .partial_arcsorts
347            .lock()
348            .unwrap()
349            .extend(partial_arcsorts.iter().cloned());
350        let args = partial_arcsorts
351            .iter()
352            .zip(args)
353            .map(|(b, x)| (b.clone(), *x))
354            .collect();
355        let y = FunctionContainer(id, args, name);
356        Some(
357            exec_state
358                .clone()
359                .container_values()
360                .register_val(y, exec_state),
361        )
362    }
363}
364
365#[derive(Clone, Debug)]
366pub struct ResolvedFunction {
367    pub id: ResolvedFunctionId,
368    pub partial_arcsorts: Vec<ArcSort>,
369    pub name: String,
370}
371// implement equality and hash based on id and  arcsort names, since arcsorts are not comparable
372
373impl PartialEq for ResolvedFunction {
374    fn eq(&self, other: &Self) -> bool {
375        self.id == other.id
376            && self
377                .partial_arcsorts
378                .iter()
379                .map(|s| s.name())
380                .collect::<Vec<_>>()
381                == other
382                    .partial_arcsorts
383                    .iter()
384                    .map(|s| s.name())
385                    .collect::<Vec<_>>()
386    }
387}
388
389impl Eq for ResolvedFunction {}
390
391impl Hash for ResolvedFunction {
392    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
393        self.id.hash(state);
394        for s in &self.partial_arcsorts {
395            s.name().hash(state);
396        }
397    }
398}
399
400impl BaseValue for ResolvedFunction {}
401
402#[derive(Clone, Debug, PartialEq, Eq, Hash)]
403pub enum ResolvedFunctionId {
404    Lookup(egglog_bridge::TableAction),
405    Prim(ExternalFunctionId),
406}
407
408// (unstable-app <function> [<arg1>, <arg2>, ...])
409#[derive(Clone)]
410struct Apply {
411    name: String,
412    function: Arc<FunctionSort>,
413}
414
415impl Primitive for Apply {
416    fn name(&self) -> &str {
417        &self.name
418    }
419
420    fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
421        let mut sorts: Vec<ArcSort> = vec![self.function.clone()];
422        sorts.extend(self.function.inputs.clone());
423        sorts.push(self.function.output.clone());
424        SimpleTypeConstraint::new(self.name(), sorts, span.clone()).into_box()
425    }
426
427    fn apply(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
428        let (fc, args) = args.split_first().unwrap();
429        let fc = exec_state
430            .container_values()
431            .get_val::<FunctionContainer>(*fc)
432            .unwrap()
433            .clone();
434        fc.apply(exec_state, args)
435    }
436}
437
438impl FunctionContainer {
439    /// Call function (primitive or table) `name` with value args `args` and return the value.
440    ///
441    /// Public so that other primitive sorts (external or internal) have access.
442    pub fn apply(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
443        let args: Vec<_> = self.1.iter().map(|(_, x)| x).chain(args).copied().collect();
444        match &self.0 {
445            ResolvedFunctionId::Lookup(action) => action.lookup(exec_state, &args),
446            ResolvedFunctionId::Prim(prim) => exec_state.call_external_func(*prim, &args),
447        }
448    }
449}