1use super::*;
2use inner::MultiSet;
3
4#[derive(Clone, Debug, PartialEq, Eq, Hash)]
5pub struct MultiSetContainer {
6 pub do_rebuild: bool,
7 pub data: MultiSet<Value>,
8}
9
10impl ContainerValue for MultiSetContainer {
11 fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool {
12 if self.do_rebuild {
13 let mut xs: Vec<_> = self.data.iter().copied().collect();
14 let changed = rebuilder.rebuild_slice(&mut xs);
15 self.data = xs.into_iter().collect();
16 changed
17 } else {
18 false
19 }
20 }
21 fn iter(&self) -> impl Iterator<Item = Value> + '_ {
22 self.data.iter().copied()
23 }
24}
25
26#[derive(Clone, Debug)]
27pub struct MultiSetSort {
28 name: String,
29 element: ArcSort,
30}
31
32impl MultiSetSort {
33 pub fn element(&self) -> ArcSort {
34 self.element.clone()
35 }
36}
37
38impl Presort for MultiSetSort {
39 fn presort_name() -> &'static str {
40 "MultiSet"
41 }
42
43 fn reserved_primitives() -> Vec<&'static str> {
44 vec![
45 "multiset-of",
46 "multiset-insert",
47 "multiset-contains",
48 "multiset-not-contains",
49 "multiset-remove",
50 "multiset-length",
51 "multiset-sum",
52 "unstable-multiset-map",
53 ]
54 }
55
56 fn make_sort(
57 typeinfo: &mut TypeInfo,
58 name: String,
59 args: &[Expr],
60 ) -> Result<ArcSort, TypeError> {
61 if let [Expr::Var(span, e)] = args {
62 let e = typeinfo
63 .get_sort_by_name(e)
64 .ok_or(TypeError::UndefinedSort(e.clone(), span.clone()))?;
65
66 if e.is_eq_container_sort() {
67 return Err(TypeError::DisallowedSort(
68 name,
69 "Multisets nested with other EqSort containers are not allowed".into(),
70 span.clone(),
71 ));
72 }
73
74 let out = Self {
75 name,
76 element: e.clone(),
77 };
78 Ok(out.to_arcsort())
79 } else {
80 panic!()
81 }
82 }
83}
84
85impl ContainerSort for MultiSetSort {
86 type Container = MultiSetContainer;
87
88 fn name(&self) -> &str {
89 &self.name
90 }
91
92 fn inner_sorts(&self) -> Vec<ArcSort> {
93 vec![self.element.clone()]
94 }
95
96 fn is_eq_container_sort(&self) -> bool {
97 self.element.is_eq_sort()
98 }
99
100 fn inner_values(
101 &self,
102 container_values: &ContainerValues,
103 value: Value,
104 ) -> Vec<(ArcSort, Value)> {
105 let val = container_values
106 .get_val::<MultiSetContainer>(value)
107 .unwrap()
108 .clone();
109 val.data
110 .iter()
111 .map(|k| (self.element.clone(), *k))
112 .collect()
113 }
114
115 fn register_primitives(&self, eg: &mut EGraph) {
116 let arc = self.clone().to_arcsort();
117
118 add_primitive!(eg, "multiset-of" = {self.clone(): MultiSetSort} [xs: # (self.element())] -> @MultiSetContainer (arc) { MultiSetContainer {
119 do_rebuild: self.ctx.element.is_eq_sort(),
120 data: xs.collect()
121 } });
122
123 add_primitive!(eg, "multiset-pick" = |xs: @MultiSetContainer (arc)| -> # (self.element()) { *xs.data.pick().expect("Cannot pick from an empty multiset") });
124 add_primitive!(eg, "multiset-insert" = |mut xs: @MultiSetContainer (arc), x: # (self.element())| -> @MultiSetContainer (arc) { MultiSetContainer { data: xs.data.insert( x) , ..xs } });
125 add_primitive!(eg, "multiset-remove" = |mut xs: @MultiSetContainer (arc), x: # (self.element())| -> @MultiSetContainer (arc) { MultiSetContainer { data: xs.data.remove(&x)?, ..xs } });
126
127 add_primitive!(eg, "multiset-length" = |xs: @MultiSetContainer (arc)| -> i64 { xs.data.len() as i64 });
128 add_primitive!(eg, "multiset-contains" = |xs: @MultiSetContainer (arc), x: # (self.element())| -?> () { ( xs.data.contains(&x)).then_some(()) });
129 add_primitive!(eg, "multiset-not-contains" = |xs: @MultiSetContainer (arc), x: # (self.element())| -?> () { (!xs.data.contains(&x)).then_some(()) });
130
131 add_primitive!(eg, "multiset-sum" = |xs: @MultiSetContainer (arc), ys: @MultiSetContainer (arc)| -> @MultiSetContainer (arc) { MultiSetContainer { data: xs.data.sum(ys.data), ..xs } });
132
133 let fn_sorts = eg.type_info.get_sorts_by(|s: &Arc<FunctionSort>| {
135 (s.inputs().len() == 1)
136 && (s.inputs()[0].name() == self.element.name())
137 && (s.output().name() == self.element.name())
138 });
139 match fn_sorts.len() {
140 0 => {}
141 1 => eg.add_primitive(Map {
142 name: "unstable-multiset-map".into(),
143 multiset: arc,
144 fn_: fn_sorts.into_iter().next().unwrap(),
145 }),
146 _ => panic!("too many applicable function sorts"),
147 }
148 }
149
150 fn reconstruct_termdag(
151 &self,
152 _container_values: &ContainerValues,
153 _value: Value,
154 termdag: &mut TermDag,
155 element_terms: Vec<Term>,
156 ) -> Term {
157 termdag.app("multiset-of".into(), element_terms)
158 }
159
160 fn serialized_name(&self, _container_values: &ContainerValues, _: Value) -> String {
161 "multiset-of".to_owned()
162 }
163}
164
165#[derive(Clone)]
166struct Map {
167 name: String,
168 multiset: ArcSort,
169 fn_: Arc<FunctionSort>,
170}
171
172impl Primitive for Map {
173 fn name(&self) -> &str {
174 &self.name
175 }
176
177 fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
178 SimpleTypeConstraint::new(
179 self.name(),
180 vec![
181 self.fn_.clone(),
182 self.multiset.clone(),
183 self.multiset.clone(),
184 ],
185 span.clone(),
186 )
187 .into_box()
188 }
189
190 fn apply(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value> {
191 let fc = exec_state
192 .container_values()
193 .get_val::<FunctionContainer>(args[0])
194 .unwrap()
195 .clone();
196 let multiset = exec_state
197 .container_values()
198 .get_val::<MultiSetContainer>(args[1])
199 .unwrap()
200 .clone();
201 let multiset = MultiSetContainer {
202 data: multiset
203 .data
204 .iter()
205 .map(|e| fc.apply(exec_state, &[*e]))
206 .collect::<Option<_>>()?,
207 ..multiset
208 };
209 Some(
210 exec_state
211 .clone()
212 .container_values()
213 .register_val(multiset, exec_state),
214 )
215 }
216}
217
218mod inner {
220 use std::collections::BTreeMap;
221 use std::hash::Hash;
222 #[derive(Debug, Default, Hash, Eq, PartialEq, Clone)]
226 pub struct MultiSet<T: Clone + Hash + Ord>(
227 BTreeMap<T, usize>,
229 usize,
231 );
232
233 impl<T: Clone + Hash + Ord> MultiSet<T> {
234 pub fn new() -> Self {
236 MultiSet(BTreeMap::new(), 0)
237 }
238
239 pub fn contains(&self, value: &T) -> bool {
241 self.0.contains_key(value)
242 }
243
244 pub fn len(&self) -> usize {
246 self.1
247 }
248
249 pub fn iter(&self) -> impl Iterator<Item = &T> {
251 self.0.iter().flat_map(|(k, v)| std::iter::repeat_n(k, *v))
252 }
253
254 pub fn pick(&self) -> Option<&T> {
256 self.0.keys().next()
257 }
258
259 pub fn insert(mut self, value: T) -> MultiSet<T> {
261 self.insert_multiple_mut(value, 1);
262 self
263 }
264
265 pub fn remove(mut self, value: &T) -> Option<MultiSet<T>> {
267 if let Some(v) = self.0.get(value) {
268 self.1 -= 1;
269 if *v == 1 {
270 self.0.remove(value);
271 } else {
272 self.0.insert(value.clone(), v - 1);
273 }
274 Some(self)
275 } else {
276 None
277 }
278 }
279
280 fn insert_multiple_mut(&mut self, value: T, n: usize) {
281 self.1 += n;
282 if let Some(v) = self.0.get(&value) {
283 self.0.insert(value, v + n);
284 } else {
285 self.0.insert(value, n);
286 }
287 }
288
289 pub fn sum(mut self, MultiSet(other_map, other_count): Self) -> Self {
291 let target_count = self.1 + other_count;
292 for (k, v) in other_map {
293 self.insert_multiple_mut(k, v);
294 }
295 assert_eq!(self.1, target_count);
296 self
297 }
298 }
299
300 impl<T: Clone + Hash + Ord> FromIterator<T> for MultiSet<T> {
301 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
302 let mut multiset = MultiSet::new();
303 for value in iter {
304 multiset.insert_multiple_mut(value, 1);
305 }
306 multiset
307 }
308 }
309}