egglog_core_relations/containers/
mod.rs

1//! Support for containers
2//!
3//! Containers behave a lot like base values. They are implemented differently because
4//! their ids share a space with other Ids in the egraph and as a result, their ids need to be
5//! sparse.
6//!
7//! This is a relatively "eagler" implementation of containers, reflecting egglog's current
8//! semantics. One could imagine a variant of containers in which they behave more like egglog
9//! functions than base values.
10
11use std::{
12    any::{Any, TypeId},
13    hash::{Hash, Hasher},
14    ops::Deref,
15};
16
17use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
18use crossbeam_queue::SegQueue;
19use dashmap::SharedValue;
20use rayon::{
21    iter::{ParallelBridge, ParallelIterator},
22    prelude::*,
23};
24use rustc_hash::FxHasher;
25
26use crate::{
27    ColumnId, CounterId, ExecutionState, Offset, SubsetRef, TableId, TaggedRowBuffer, Value,
28    WrappedTable,
29    common::{DashMap, IndexSet, SubsetTracker},
30    parallel_heuristics::{parallelize_inter_container_op, parallelize_intra_container_op},
31    table_spec::Rebuilder,
32};
33
34#[cfg(test)]
35mod tests;
36
37define_id!(pub ContainerValueId, u32, "an identifier for containers");
38
39pub trait MergeFn:
40    Fn(&mut ExecutionState, Value, Value) -> Value + dyn_clone::DynClone + Send + Sync
41{
42}
43impl<T: Fn(&mut ExecutionState, Value, Value) -> Value + Clone + Send + Sync> MergeFn for T {}
44
45// Implements `Clone` for `Box<dyn MergeFn>`.
46dyn_clone::clone_trait_object!(MergeFn);
47
48#[derive(Clone, Default)]
49struct ContainerIds {
50    ids: IndexSet<TypeId>,
51}
52
53impl ContainerIds {
54    fn insert(&mut self, ty: TypeId) -> ContainerValueId {
55        if let Some(idx) = self.ids.get_index_of(&ty) {
56            ContainerValueId::from_usize(idx)
57        } else {
58            let idx = self.ids.len();
59            self.ids.insert(ty);
60            ContainerValueId::from_usize(idx)
61        }
62    }
63
64    fn get(&self, ty: &TypeId) -> Option<ContainerValueId> {
65        self.ids.get_index_of(ty).map(ContainerValueId::from_usize)
66    }
67}
68
69#[derive(Clone, Default)]
70pub struct ContainerValues {
71    subset_tracker: SubsetTracker,
72    container_ids: ContainerIds,
73    data: DenseIdMap<ContainerValueId, Box<dyn DynamicContainerEnv + Send + Sync>>,
74}
75
76impl ContainerValues {
77    pub fn new() -> Self {
78        Default::default()
79    }
80
81    fn get<C: ContainerValue>(&self) -> Option<&ContainerEnv<C>> {
82        let id = self.container_ids.get(&TypeId::of::<C>())?;
83        let res = self.data.get(id)?.as_any();
84        Some(res.downcast_ref::<ContainerEnv<C>>().unwrap())
85    }
86
87    /// Iterate over the containers of the given type.
88    pub fn for_each<C: ContainerValue>(&self, mut f: impl FnMut(&C, Value)) {
89        let Some(env) = self.get::<C>() else {
90            return;
91        };
92        for ent in env.to_id.iter() {
93            f(ent.key(), *ent.value());
94        }
95    }
96
97    /// Get the container associated with the value `val` in the database. The caller must know the
98    /// type of the container.
99    ///
100    /// The return type of this function may contain lock guards. Attempts to modify the contents
101    /// of the containers database may deadlock if the given guard has not been dropped.
102    pub fn get_val<C: ContainerValue>(&self, val: Value) -> Option<impl Deref<Target = C> + '_> {
103        self.get::<C>()?.get_container(val)
104    }
105
106    pub fn register_val<C: ContainerValue>(
107        &self,
108        container: C,
109        exec_state: &mut ExecutionState,
110    ) -> Value {
111        let env = self
112            .get::<C>()
113            .expect("must register container type before registering a value");
114        env.get_or_insert(&container, exec_state)
115    }
116
117    /// Apply the given rebuild to the contents of each container.
118    pub fn rebuild_all(
119        &mut self,
120        table_id: TableId,
121        table: &WrappedTable,
122        exec_state: &mut ExecutionState,
123    ) -> bool {
124        let Some(rebuilder) = table.rebuilder(&[]) else {
125            return false;
126        };
127        let to_scan = rebuilder.hint_col().map(|_| {
128            // We may attempt an incremental rebuild.
129            self.subset_tracker.recent_updates(table_id, table)
130        });
131        if parallelize_inter_container_op(self.data.next_id().index()) {
132            self.data
133                .iter_mut()
134                .zip(std::iter::repeat_with(|| exec_state.clone()))
135                .par_bridge()
136                .map(|((_, env), mut exec_state)| {
137                    env.apply_rebuild(
138                        table,
139                        &*rebuilder,
140                        to_scan.as_ref().map(|x| x.as_ref()),
141                        &mut exec_state,
142                    )
143                })
144                .max()
145                .unwrap_or(false)
146        } else {
147            let mut changed = false;
148            for (_, env) in self.data.iter_mut() {
149                changed |= env.apply_rebuild(
150                    table,
151                    &*rebuilder,
152                    to_scan.as_ref().map(|x| x.as_ref()),
153                    exec_state,
154                );
155            }
156            changed
157        }
158    }
159
160    /// Add a new container type to the given [`ContainerValue`] instance.
161    ///
162    /// Container types need a meaans of generating fresh ids (`id_counter`) along with a means of
163    /// merging conflicting ids (`merge_fn`).
164    pub fn register_type<C: ContainerValue>(
165        &mut self,
166        id_counter: CounterId,
167        merge_fn: impl MergeFn + 'static,
168    ) -> ContainerValueId {
169        let id = self.container_ids.insert(TypeId::of::<C>());
170        self.data.get_or_insert(id, || {
171            Box::new(ContainerEnv::<C>::new(Box::new(merge_fn), id_counter))
172        });
173        id
174    }
175}
176
177/// A trait implemented by container types.
178///
179/// Containers behave a lot like base values, but they include extra trait methods to support
180/// rebuilding of container contents and merging containers that become equal after a rebuild pass
181/// has taken place.
182pub trait ContainerValue: Hash + Eq + Clone + Send + Sync + 'static {
183    /// Rebuild an additional container in place according the the given [`Rebuilder`].
184    ///
185    /// If this method returns `false` then the container must not have been modified (i.e. it must
186    /// hash to the same value, and compare equal to a copy of itself before the call).
187    fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool;
188
189    /// Iterate over the contents of the container.
190    ///
191    /// Note that containers can be more structured than just a sequence of values. This iterator
192    /// is used to populate an index that in turn is used to speed up rebuilds. If a value in the
193    /// container is eligible for a rebuild and it is not mentioned by this iterator, the outer
194    /// container registry may skip rebuilding this container.
195    fn iter(&self) -> impl Iterator<Item = Value> + '_;
196}
197
198pub trait DynamicContainerEnv: Any + dyn_clone::DynClone + Send + Sync {
199    fn as_any(&self) -> &dyn Any;
200    fn apply_rebuild(
201        &mut self,
202        table: &WrappedTable,
203        rebuilder: &dyn Rebuilder,
204        subset: Option<SubsetRef>,
205        exec_state: &mut ExecutionState,
206    ) -> bool;
207}
208
209// Implements `Clone` for `Box<dyn DynamicContainerEnv>`.
210dyn_clone::clone_trait_object!(DynamicContainerEnv);
211
212fn hash_container(container: &impl ContainerValue) -> u64 {
213    let mut hasher = FxHasher::default();
214    container.hash(&mut hasher);
215    hasher.finish()
216}
217
218#[derive(Clone)]
219struct ContainerEnv<C: Eq + Hash> {
220    merge_fn: Box<dyn MergeFn>,
221    counter: CounterId,
222    to_id: DashMap<C, Value>,
223    to_container: DashMap<Value, (usize /* hash code */, usize /* map */)>,
224    /// Map from a Value to the set of ids of containers that contain that value.
225    val_index: DashMap<Value, IndexSet<Value>>,
226}
227
228impl<C: ContainerValue> DynamicContainerEnv for ContainerEnv<C> {
229    fn as_any(&self) -> &dyn Any {
230        self
231    }
232
233    fn apply_rebuild(
234        &mut self,
235        table: &WrappedTable,
236        rebuilder: &dyn Rebuilder,
237        subset: Option<SubsetRef>,
238        exec_state: &mut ExecutionState,
239    ) -> bool {
240        if let Some(subset) = subset {
241            if incremental_rebuild(
242                subset.size(),
243                self.to_id.len(),
244                parallelize_intra_container_op(self.to_id.len()),
245            ) {
246                return self.apply_rebuild_incremental(
247                    table,
248                    rebuilder,
249                    exec_state,
250                    subset,
251                    rebuilder.hint_col().unwrap(),
252                );
253            }
254        }
255        self.apply_rebuild_nonincremental(rebuilder, exec_state)
256    }
257}
258
259impl<C: ContainerValue> ContainerEnv<C> {
260    pub fn new(merge_fn: Box<dyn MergeFn>, counter: CounterId) -> Self {
261        Self {
262            merge_fn,
263            counter,
264            to_id: DashMap::default(),
265            to_container: DashMap::default(),
266            val_index: DashMap::default(),
267        }
268    }
269
270    fn get_or_insert(&self, container: &C, exec_state: &mut ExecutionState) -> Value {
271        if let Some(value) = self.to_id.get(container) {
272            return *value;
273        }
274
275        // Time to insert a new mapping. First, insert into `to_container`: the moment that we
276        // insert a new value into `to_id`, someone else can return it from another call to
277        // `get_or_insert` and then feed that value to `get_container`.
278
279        let value = Value::from_usize(exec_state.inc_counter(self.counter));
280        let target_map = self.to_id.determine_map(container);
281        // This assertion is here because in parallel rebuilding we use `to_container` to
282        // compute the intended shard for to_id, because we have a mutable borrow of
283        // `to_container` that means we cannot call `determine_map` on `to_id`.
284        debug_assert_eq!(
285            target_map,
286            self.to_container
287                .determine_shard(hash_container(container) as usize)
288        );
289        self.to_container
290            .insert(value, (hash_container(container) as usize, target_map));
291
292        // Now insert into `to_id`, handling the case where a different thread is doing the same
293        // thing.
294        match self.to_id.entry(container.clone()) {
295            dashmap::Entry::Vacant(vac) => {
296                // Common case: insert the mapping in to_id and update the index.
297                vac.insert(value);
298                for val in container.iter() {
299                    self.val_index.entry(val).or_default().insert(value);
300                }
301                value
302            }
303            dashmap::Entry::Occupied(occ) => {
304                // Someone inserted `container` into the mapping since we looked it up. Remove the
305                // mapping that we inserted into `to_container` (we won't use it), and instead
306                // return the "winning" value.
307                let res = *occ.get();
308                std::mem::drop(occ); // drop the lock.
309                self.to_container.remove(&value);
310                res
311            }
312        }
313    }
314
315    fn insert_owned(&self, container: C, value: Value, exec_state: &mut ExecutionState) {
316        let hc = hash_container(&container);
317        let target_map = self.to_id.determine_map(&container);
318        match self.to_id.entry(container) {
319            dashmap::Entry::Occupied(mut occ) => {
320                let result = (self.merge_fn)(exec_state, *occ.get(), value);
321                let old_val = *occ.get();
322                if result != old_val {
323                    self.to_container.remove(&old_val);
324                    self.to_container.insert(result, (hc as usize, target_map));
325                    *occ.get_mut() = result;
326                    for val in occ.key().iter() {
327                        let mut index = self.val_index.entry(val).or_default();
328                        index.swap_remove(&old_val);
329                        index.insert(result);
330                    }
331                }
332            }
333            dashmap::Entry::Vacant(vacant_entry) => {
334                self.to_container.insert(value, (hc as usize, target_map));
335                for val in vacant_entry.key().iter() {
336                    self.val_index.entry(val).or_default().insert(value);
337                }
338                vacant_entry.insert(value);
339            }
340        }
341    }
342    fn apply_rebuild_incremental(
343        &mut self,
344        table: &WrappedTable,
345        rebuilder: &dyn Rebuilder,
346        exec_state: &mut ExecutionState,
347        to_scan: SubsetRef,
348        search_col: ColumnId,
349    ) -> bool {
350        // NB: there is no parallel implementation as of now.
351        //
352        // Implementing one should be straightforward, but we should wait for a real benchmark that
353        // requires it. It's possible that incremental rebuilding will only be profitable when the
354        // total number of ids to rebuild is small, in which case the overhead of parallelism may
355        // not be worth it in the first place.
356        let mut changed = false;
357        let mut buf = TaggedRowBuffer::new(1);
358        table.scan_project(
359            to_scan,
360            &[search_col],
361            Offset::new(0),
362            usize::MAX,
363            &[],
364            &mut buf,
365        );
366        // For each value in the buffer, rebuild all containers that mention it.
367        let mut to_rebuild = IndexSet::<Value>::default();
368        for (_, row) in buf.iter() {
369            to_rebuild.insert(row[0]);
370            let Some(ids) = self.val_index.get(&row[0]) else {
371                continue;
372            };
373            to_rebuild.extend(&*ids);
374        }
375        for id in to_rebuild {
376            let Some((hc, target_map)) = self.to_container.get(&id).map(|x| *x) else {
377                continue;
378            };
379            let shard_mut = self.to_id.shards_mut()[target_map].get_mut();
380            let Some((mut container, _)) =
381                shard_mut.remove_entry(hc as u64, |(_, v)| *v.get() == id)
382            else {
383                continue;
384            };
385            changed |= container.rebuild_contents(rebuilder);
386            self.insert_owned(container, id, exec_state);
387        }
388        changed
389    }
390
391    fn apply_rebuild_nonincremental(
392        &mut self,
393        rebuilder: &dyn Rebuilder,
394        exec_state: &mut ExecutionState,
395    ) -> bool {
396        if parallelize_inter_container_op(self.to_id.len()) {
397            return self.apply_rebuild_nonincremental_parallel(rebuilder, exec_state);
398        }
399        let mut changed = false;
400        let mut to_reinsert = Vec::new();
401        let shards = self.to_id.shards_mut();
402        for shard in shards.iter_mut() {
403            let shard = shard.get_mut();
404            // SAFETY: the iterator does not outlive `shard`.
405            for bucket in unsafe { shard.iter() } {
406                // SAFETY: the bucket is valid; we just got it from the iterator.
407                let (container, val) = unsafe { bucket.as_mut() };
408                let old_val = *val.get();
409                let new_val = rebuilder.rebuild_val(old_val);
410                let container_changed = container.rebuild_contents(rebuilder);
411                if !container_changed && new_val == old_val {
412                    // Nothing changed about this entry. Leave it in place.
413                    continue;
414                }
415                changed = true;
416                if container_changed {
417                    // The container changed. Remove both map entries then reinsert.
418                    // SAFETY: This is a valid bucket. Furthermore, iterators remain valid if
419                    // buckets they have already yielded have been removed.
420                    let ((container, _), _) = unsafe { shard.remove(bucket) };
421                    self.to_container.remove(&old_val);
422                    to_reinsert.push((container, new_val));
423                } else {
424                    // Just the value changed. Leave the container in place.
425                    *val.get_mut() = new_val;
426                    let prev = self.to_container.remove(&old_val).unwrap().1;
427                    self.to_container.insert(new_val, prev);
428                }
429            }
430        }
431        for (container, val) in to_reinsert {
432            self.insert_owned(container, val, exec_state);
433        }
434        changed
435    }
436
437    fn apply_rebuild_nonincremental_parallel(
438        &mut self,
439        rebuilder: &dyn Rebuilder,
440        exec_state: &mut ExecutionState,
441    ) -> bool {
442        // This is very similar to the serial variant. The main difference is that
443        // `to_reinsert` isn't a flat vector. It's instead a vector of queues - one per
444        // destination map shard. This lets us do a bulk insertion in parallel without having
445        // to grab a lock per container.
446        let mut to_reinsert = IdVec::<usize /* to_id shard */, SegQueue<(C, Value)>>::default();
447        to_reinsert.resize_with(self.to_id.shards().len(), Default::default);
448
449        let shards = self.to_id.shards_mut();
450        let changed = shards
451            .par_iter_mut()
452            .map(|shard| {
453                let mut changed = false;
454                let shard = shard.get_mut();
455                // SAFETY: the iterator does not outlive `shard`.
456                for bucket in unsafe { shard.iter() } {
457                    // SAFETY: the bucket is valid; we just got it from the iterator.
458                    let (container, val) = unsafe { bucket.as_mut() };
459                    let old_val = *val.get();
460                    let new_val = rebuilder.rebuild_val(old_val);
461                    let container_changed = container.rebuild_contents(rebuilder);
462                    if !container_changed && new_val == old_val {
463                        // Nothing changed about this entry. Leave it in place.
464                        continue;
465                    }
466                    changed = true;
467                    if container_changed {
468                        // The container changed. Remove both map entries then reinsert.
469                        // SAFETY: This is a valid bucket. Furthermore, iterators remain valid if
470                        // buckets they have already yielded have been removed.
471                        let ((container, _), _) = unsafe { shard.remove(bucket) };
472                        self.to_container.remove(&old_val);
473                        // Spooky: we're using `to_container` to determine the shard for
474                        // `to_id`. We are assuming that the # shards determination is
475                        // deterministic here. There is a debug assertion in `get_or_insert`
476                        // that attempts to verify this.
477                        let shard = self
478                            .to_container
479                            .determine_shard(hash_container(&container) as usize);
480                        to_reinsert[shard].push((container, new_val));
481                    } else {
482                        // Just the value changed. Leave the container in place.
483                        *val.get_mut() = new_val;
484                        let prev = self.to_container.remove(&old_val).unwrap().1;
485                        self.to_container.insert(new_val, prev);
486                    }
487                }
488                changed
489            })
490            .max()
491            .unwrap_or(false);
492
493        shards
494            .iter_mut()
495            .enumerate()
496            .map(|(i, shard)| (i, shard, exec_state.clone()))
497            .par_bridge()
498            .for_each(|(shard_id, shard, mut exec_state)| {
499                // This bit is a real slog. Once Dashmap updates from RawTable to HashTable for
500                // the underlying shard, this will get a little better.
501                //
502                // NB: We are probably leaving some paralellism on the floor with these calls
503                // to `to_container` and `val_index`.
504                let shard = shard.get_mut();
505                let queue = &to_reinsert[shard_id];
506                while let Some((container, val)) = queue.pop() {
507                    let hc = hash_container(&container);
508                    let target_map = self.to_container.determine_shard(hc as usize);
509                    match shard.find_or_find_insert_slot(
510                        hc,
511                        |(c, _)| c == &container,
512                        |(c, _)| hash_container(c),
513                    ) {
514                        Ok(bucket) => {
515                            // SAFETY: the bucket is valid; we just got it from the shard and
516                            // we have not done any operations that can invalidate the bucket.
517                            let (container, val_slot) = unsafe { bucket.as_mut() };
518                            let old_val = *val_slot.get();
519                            let result = (self.merge_fn)(&mut exec_state, old_val, val);
520                            if result != old_val {
521                                self.to_container.remove(&old_val);
522                                self.to_container.insert(result, (hc as usize, target_map));
523                                *val_slot.get_mut() = result;
524                                for val in container.iter() {
525                                    let mut index = self.val_index.entry(val).or_default();
526                                    index.swap_remove(&old_val);
527                                    index.insert(result);
528                                }
529                            }
530                        }
531                        Err(slot) => {
532                            self.to_container.insert(val, (hc as usize, target_map));
533                            for v in container.iter() {
534                                self.val_index.entry(v).or_default().insert(val);
535                            }
536                            // SAFETY: We just got this slot from `find_or_find_insert_slot`
537                            // and we have not mutated the map at all since then.
538                            unsafe {
539                                shard.insert_in_slot(hc, slot, (container, SharedValue::new(val)));
540                            }
541                        }
542                    }
543                }
544            });
545        changed
546    }
547
548    fn get_container(&self, value: Value) -> Option<impl Deref<Target = C> + '_> {
549        let (hc, target_map) = *self.to_container.get(&value)?;
550        let shard = &self.to_id.shards()[target_map];
551        let read_guard = shard.read();
552        let val_ptr: *const (C, _) = shard
553            .read()
554            .find(hc as u64, |(_, v)| *v.get() == value)?
555            .as_ptr();
556        struct ValueDeref<'a, T, Guard> {
557            _guard: Guard,
558            data: &'a T,
559        }
560
561        impl<T, Guard> Deref for ValueDeref<'_, T, Guard> {
562            type Target = T;
563
564            fn deref(&self) -> &T {
565                self.data
566            }
567        }
568
569        Some(ValueDeref {
570            _guard: read_guard,
571            // SAFETY: the value will remain valid for as long as `read_guard` is in scope.
572            data: unsafe {
573                let unwrapped: &(C, _) = &*val_ptr;
574                &unwrapped.0
575            },
576        })
577    }
578}
579
580fn incremental_rebuild(_uf_size: usize, _table_size: usize, _parallel: bool) -> bool {
581    #[cfg(debug_assertions)]
582    {
583        use rand::Rng;
584        rand::rng().random_bool(0.5)
585    }
586    #[cfg(not(debug_assertions))]
587    {
588        if _parallel {
589            _table_size > 1000 && _uf_size * 512 <= _table_size
590        } else {
591            _table_size > 1000 && _uf_size * 8 <= _table_size
592        }
593    }
594}