1use std::{
12 any::{Any, TypeId},
13 hash::{Hash, Hasher},
14 ops::Deref,
15};
16
17use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
18use crossbeam_queue::SegQueue;
19use dashmap::SharedValue;
20use rayon::{
21 iter::{ParallelBridge, ParallelIterator},
22 prelude::*,
23};
24use rustc_hash::FxHasher;
25
26use crate::{
27 ColumnId, CounterId, ExecutionState, Offset, SubsetRef, TableId, TaggedRowBuffer, Value,
28 WrappedTable,
29 common::{DashMap, IndexSet, SubsetTracker},
30 parallel_heuristics::{parallelize_inter_container_op, parallelize_intra_container_op},
31 table_spec::Rebuilder,
32};
33
34#[cfg(test)]
35mod tests;
36
37define_id!(pub ContainerValueId, u32, "an identifier for containers");
38
39pub trait MergeFn:
40 Fn(&mut ExecutionState, Value, Value) -> Value + dyn_clone::DynClone + Send + Sync
41{
42}
43impl<T: Fn(&mut ExecutionState, Value, Value) -> Value + Clone + Send + Sync> MergeFn for T {}
44
45dyn_clone::clone_trait_object!(MergeFn);
47
48#[derive(Clone, Default)]
49struct ContainerIds {
50 ids: IndexSet<TypeId>,
51}
52
53impl ContainerIds {
54 fn insert(&mut self, ty: TypeId) -> ContainerValueId {
55 if let Some(idx) = self.ids.get_index_of(&ty) {
56 ContainerValueId::from_usize(idx)
57 } else {
58 let idx = self.ids.len();
59 self.ids.insert(ty);
60 ContainerValueId::from_usize(idx)
61 }
62 }
63
64 fn get(&self, ty: &TypeId) -> Option<ContainerValueId> {
65 self.ids.get_index_of(ty).map(ContainerValueId::from_usize)
66 }
67}
68
69#[derive(Clone, Default)]
70pub struct ContainerValues {
71 subset_tracker: SubsetTracker,
72 container_ids: ContainerIds,
73 data: DenseIdMap<ContainerValueId, Box<dyn DynamicContainerEnv + Send + Sync>>,
74}
75
76#[derive(Clone, Default)]
91pub struct ContainerRebuildSummary {
92 changed: bool,
93 dirty_ids: IndexSet<Value>,
96}
97
98impl ContainerRebuildSummary {
99 pub fn changed(&self) -> bool {
101 self.changed
102 }
103
104 pub fn dirty_ids(&self) -> &IndexSet<Value> {
106 &self.dirty_ids
107 }
108
109 fn note_change(&mut self) {
110 self.changed = true;
111 }
112
113 fn note_dirty_id(&mut self, value: Value) {
114 self.changed = true;
115 self.dirty_ids.insert(value);
116 }
117
118 fn extend(&mut self, other: Self) {
119 self.changed |= other.changed;
120 self.dirty_ids.extend(other.dirty_ids);
121 }
122}
123
124impl ContainerValues {
125 pub fn new() -> Self {
126 Default::default()
127 }
128
129 fn get<C: ContainerValue>(&self) -> Option<&ContainerEnv<C>> {
130 let id = self.container_ids.get(&TypeId::of::<C>())?;
131 let res = self.data.get(id)?.as_any();
132 Some(res.downcast_ref::<ContainerEnv<C>>().unwrap())
133 }
134
135 pub fn for_each<C: ContainerValue>(&self, mut f: impl FnMut(&C, Value)) {
137 let Some(env) = self.get::<C>() else {
138 return;
139 };
140 for ent in env.to_id.iter() {
141 f(ent.key(), *ent.value());
142 }
143 }
144
145 pub fn get_val<C: ContainerValue>(&self, val: Value) -> Option<impl Deref<Target = C> + '_> {
151 self.get::<C>()?.get_container(val)
152 }
153
154 pub fn register_val<C: ContainerValue>(
155 &self,
156 container: C,
157 exec_state: &mut ExecutionState,
158 ) -> Value {
159 let env = self
160 .get::<C>()
161 .expect("must register container type before registering a value");
162 env.get_or_insert(&container, exec_state)
163 }
164
165 pub fn rebuild_all(
167 &mut self,
168 table_id: TableId,
169 table: &WrappedTable,
170 exec_state: &mut ExecutionState,
171 ) -> ContainerRebuildSummary {
172 let Some(rebuilder) = table.rebuilder(&[]) else {
173 return Default::default();
174 };
175 let to_scan = rebuilder.hint_col().map(|_| {
176 self.subset_tracker.recent_updates(table_id, table)
178 });
179 let mut summary = if parallelize_inter_container_op(self.data.next_id().index()) {
180 self.data
181 .iter_mut()
182 .zip(std::iter::repeat_with(|| exec_state.clone()))
183 .par_bridge()
184 .map(|((_, env), mut exec_state)| {
185 env.apply_rebuild(
186 table,
187 &*rebuilder,
188 to_scan.as_ref().map(|x| x.as_ref()),
189 &mut exec_state,
190 )
191 })
192 .reduce(ContainerRebuildSummary::default, |mut acc, summary| {
193 acc.extend(summary);
194 acc
195 })
196 } else {
197 let mut summary = ContainerRebuildSummary::default();
198 for (_, env) in self.data.iter_mut() {
199 summary.extend(env.apply_rebuild(
200 table,
201 &*rebuilder,
202 to_scan.as_ref().map(|x| x.as_ref()),
203 exec_state,
204 ));
205 }
206 summary
207 };
208 self.expand_dirty_id_closure(&mut summary);
209 summary
210 }
211
212 fn expand_dirty_id_closure(&self, summary: &mut ContainerRebuildSummary) {
224 let mut frontier = summary.dirty_ids.clone();
225 let mut seen = frontier.iter().copied().collect::<IndexSet<_>>();
226
227 while !frontier.is_empty() {
228 let mut next = IndexSet::default();
229 for (_, env) in self.data.iter() {
230 env.extend_containers_containing(&frontier, &mut next);
231 }
232 frontier.clear();
233 for value in next {
234 if seen.insert(value) {
235 summary.note_dirty_id(value);
236 frontier.insert(value);
237 }
238 }
239 }
240 }
241
242 pub fn register_type<C: ContainerValue>(
247 &mut self,
248 id_counter: CounterId,
249 merge_fn: impl MergeFn + 'static,
250 ) -> ContainerValueId {
251 let id = self.container_ids.insert(TypeId::of::<C>());
252 self.data.get_or_insert(id, || {
253 Box::new(ContainerEnv::<C>::new(Box::new(merge_fn), id_counter))
254 });
255 id
256 }
257}
258
259pub trait ContainerValue: Hash + Eq + Clone + Send + Sync + 'static {
265 fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool;
270
271 fn iter(&self) -> impl Iterator<Item = Value> + '_;
278}
279
280pub trait DynamicContainerEnv: Any + dyn_clone::DynClone + Send + Sync {
281 fn as_any(&self) -> &dyn Any;
282 fn apply_rebuild(
283 &mut self,
284 table: &WrappedTable,
285 rebuilder: &dyn Rebuilder,
286 subset: Option<SubsetRef>,
287 exec_state: &mut ExecutionState,
288 ) -> ContainerRebuildSummary;
289 fn extend_containers_containing(&self, values: &IndexSet<Value>, out: &mut IndexSet<Value>);
295}
296
297dyn_clone::clone_trait_object!(DynamicContainerEnv);
299
300fn hash_container(container: &impl ContainerValue) -> u64 {
301 let mut hasher = FxHasher::default();
302 container.hash(&mut hasher);
303 hasher.finish()
304}
305
306#[derive(Clone)]
307struct ContainerEnv<C: Eq + Hash> {
308 merge_fn: Box<dyn MergeFn>,
309 counter: CounterId,
310 to_id: DashMap<C, Value>,
311 to_container: DashMap<Value, (usize , usize )>,
312 val_index: DashMap<Value, IndexSet<Value>>,
314}
315
316impl<C: ContainerValue> DynamicContainerEnv for ContainerEnv<C> {
317 fn as_any(&self) -> &dyn Any {
318 self
319 }
320
321 fn apply_rebuild(
322 &mut self,
323 table: &WrappedTable,
324 rebuilder: &dyn Rebuilder,
325 subset: Option<SubsetRef>,
326 exec_state: &mut ExecutionState,
327 ) -> ContainerRebuildSummary {
328 if let Some(subset) = subset
329 && incremental_rebuild(
330 subset.size(),
331 self.to_id.len(),
332 parallelize_intra_container_op(self.to_id.len()),
333 )
334 {
335 return self.apply_rebuild_incremental(
336 table,
337 rebuilder,
338 exec_state,
339 subset,
340 rebuilder.hint_col().unwrap(),
341 );
342 }
343 self.apply_rebuild_nonincremental(rebuilder, exec_state)
344 }
345
346 fn extend_containers_containing(&self, values: &IndexSet<Value>, out: &mut IndexSet<Value>) {
347 for value in values {
348 if let Some(containers) = self.val_index.get(value) {
349 out.extend(containers.iter().copied());
350 }
351 }
352 }
353}
354
355impl<C: ContainerValue> ContainerEnv<C> {
356 pub fn new(merge_fn: Box<dyn MergeFn>, counter: CounterId) -> Self {
357 Self {
358 merge_fn,
359 counter,
360 to_id: DashMap::default(),
361 to_container: DashMap::default(),
362 val_index: DashMap::default(),
363 }
364 }
365
366 fn get_or_insert(&self, container: &C, exec_state: &mut ExecutionState) -> Value {
367 if let Some(value) = self.to_id.get(container) {
368 return *value;
369 }
370
371 let value = Value::from_usize(exec_state.inc_counter(self.counter));
376 let target_map = self.to_id.determine_map(container);
377 debug_assert_eq!(
381 target_map,
382 self.to_container
383 .determine_shard(hash_container(container) as usize)
384 );
385 self.to_container
386 .insert(value, (hash_container(container) as usize, target_map));
387
388 match self.to_id.entry(container.clone()) {
391 dashmap::Entry::Vacant(vac) => {
392 vac.insert(value);
394 for val in container.iter() {
395 self.val_index.entry(val).or_default().insert(value);
396 }
397 value
398 }
399 dashmap::Entry::Occupied(occ) => {
400 let res = *occ.get();
404 std::mem::drop(occ); self.to_container.remove(&value);
406 res
407 }
408 }
409 }
410
411 fn insert_owned(&self, container: C, value: Value, exec_state: &mut ExecutionState) -> Value {
412 let hc = hash_container(&container);
413 let target_map = self.to_id.determine_map(&container);
414 match self.to_id.entry(container) {
415 dashmap::Entry::Occupied(mut occ) => {
416 let result = (self.merge_fn)(exec_state, *occ.get(), value);
417 let old_val = *occ.get();
418 if result != old_val {
419 self.to_container.remove(&old_val);
420 self.to_container.insert(result, (hc as usize, target_map));
421 *occ.get_mut() = result;
422 for val in occ.key().iter() {
423 let mut index = self.val_index.entry(val).or_default();
424 index.swap_remove(&old_val);
425 index.insert(result);
426 }
427 }
428 result
429 }
430 dashmap::Entry::Vacant(vacant_entry) => {
431 self.to_container.insert(value, (hc as usize, target_map));
432 for val in vacant_entry.key().iter() {
433 self.val_index.entry(val).or_default().insert(value);
434 }
435 vacant_entry.insert(value);
436 value
437 }
438 }
439 }
440
441 fn reinsert_incremental(
442 &self,
443 container: C,
444 old_id: Value,
445 rebuilt_id: Value,
446 container_changed: bool,
447 exec_state: &mut ExecutionState,
448 summary: &mut ContainerRebuildSummary,
449 ) {
450 if container_changed || rebuilt_id != old_id {
451 summary.note_change();
452 }
453 if rebuilt_id != old_id {
454 self.to_container.remove(&old_id);
457 }
458 let actual = self.insert_owned(container, rebuilt_id, exec_state);
459 if container_changed && rebuilt_id == old_id && actual == old_id {
460 summary.note_dirty_id(old_id);
461 }
462 }
463
464 fn apply_rebuild_incremental(
465 &mut self,
466 table: &WrappedTable,
467 rebuilder: &dyn Rebuilder,
468 exec_state: &mut ExecutionState,
469 to_scan: SubsetRef,
470 search_col: ColumnId,
471 ) -> ContainerRebuildSummary {
472 let mut summary = ContainerRebuildSummary::default();
479 let mut buf = TaggedRowBuffer::new(1);
480 table.scan_project(
481 to_scan,
482 &[search_col],
483 Offset::new(0),
484 usize::MAX,
485 &[],
486 &mut buf,
487 );
488 let mut to_rebuild = IndexSet::<Value>::default();
490 for (_, row) in buf.iter() {
491 to_rebuild.insert(row[0]);
492 let Some(ids) = self.val_index.get(&row[0]) else {
493 continue;
494 };
495 to_rebuild.extend(&*ids);
496 }
497 for id in to_rebuild {
498 let Some((hc, target_map)) = self.to_container.get(&id).map(|x| *x) else {
499 continue;
500 };
501 let shard_mut = self.to_id.shards_mut()[target_map].get_mut();
502 let Some((mut container, _)) =
503 shard_mut.remove_entry(hc as u64, |(_, v)| *v.get() == id)
504 else {
505 continue;
506 };
507 let rebuilt_id = rebuilder.rebuild_val(id);
508 let container_changed = container.rebuild_contents(rebuilder);
509 self.reinsert_incremental(
510 container,
511 id,
512 rebuilt_id,
513 container_changed,
514 exec_state,
515 &mut summary,
516 );
517 }
518 summary
519 }
520
521 fn apply_rebuild_nonincremental(
522 &mut self,
523 rebuilder: &dyn Rebuilder,
524 exec_state: &mut ExecutionState,
525 ) -> ContainerRebuildSummary {
526 if parallelize_inter_container_op(self.to_id.len()) {
527 return self.apply_rebuild_nonincremental_parallel(rebuilder, exec_state);
528 }
529 let mut summary = ContainerRebuildSummary::default();
530 let mut to_reinsert = Vec::new();
531 let shards = self.to_id.shards_mut();
532 for shard in shards.iter_mut() {
533 let shard = shard.get_mut();
534 for bucket in unsafe { shard.iter() } {
536 let (container, val) = unsafe { bucket.as_mut() };
538 let old_val = *val.get();
539 let new_val = rebuilder.rebuild_val(old_val);
540 let container_changed = container.rebuild_contents(rebuilder);
541 if !container_changed && new_val == old_val {
542 continue;
544 }
545 summary.note_change();
546 if container_changed {
547 let ((container, _), _) = unsafe { shard.remove(bucket) };
551 self.to_container.remove(&old_val);
552 to_reinsert.push((container, new_val, new_val == old_val));
553 } else {
554 *val.get_mut() = new_val;
556 let prev = self.to_container.remove(&old_val).unwrap().1;
557 self.to_container.insert(new_val, prev);
558 }
559 }
560 }
561 for (container, val, stable_id) in to_reinsert {
562 let actual = self.insert_owned(container, val, exec_state);
563 if stable_id && actual == val {
567 summary.note_dirty_id(val);
568 }
569 }
570 summary
571 }
572
573 fn apply_rebuild_nonincremental_parallel(
574 &mut self,
575 rebuilder: &dyn Rebuilder,
576 exec_state: &mut ExecutionState,
577 ) -> ContainerRebuildSummary {
578 let mut to_reinsert =
583 IdVec::<usize , SegQueue<(C, Value, bool)>>::default();
584 to_reinsert.resize_with(self.to_id.shards().len(), Default::default);
585
586 let shards = self.to_id.shards_mut();
587 let changed = shards
588 .par_iter_mut()
589 .map(|shard| {
590 let mut changed = false;
591 let shard = shard.get_mut();
592 for bucket in unsafe { shard.iter() } {
594 let (container, val) = unsafe { bucket.as_mut() };
596 let old_val = *val.get();
597 let new_val = rebuilder.rebuild_val(old_val);
598 let container_changed = container.rebuild_contents(rebuilder);
599 if !container_changed && new_val == old_val {
600 continue;
602 }
603 changed = true;
604 if container_changed {
605 let ((container, _), _) = unsafe { shard.remove(bucket) };
609 self.to_container.remove(&old_val);
610 let shard = self
615 .to_container
616 .determine_shard(hash_container(&container) as usize);
617 to_reinsert[shard].push((container, new_val, new_val == old_val));
618 } else {
619 *val.get_mut() = new_val;
621 let prev = self.to_container.remove(&old_val).unwrap().1;
622 self.to_container.insert(new_val, prev);
623 }
624 }
625 changed
626 })
627 .max()
628 .unwrap_or(false);
629
630 let dirty_ids = SegQueue::new();
631 shards
632 .iter_mut()
633 .enumerate()
634 .map(|(i, shard)| (i, shard, exec_state.clone()))
635 .par_bridge()
636 .for_each(|(shard_id, shard, mut exec_state)| {
637 let shard = shard.get_mut();
643 let queue = &to_reinsert[shard_id];
644 while let Some((container, val, stable_id)) = queue.pop() {
645 let hc = hash_container(&container);
646 let target_map = self.to_container.determine_shard(hc as usize);
647 match shard.find_or_find_insert_slot(
648 hc,
649 |(c, _)| c == &container,
650 |(c, _)| hash_container(c),
651 ) {
652 Ok(bucket) => {
653 let (container, val_slot) = unsafe { bucket.as_mut() };
656 let old_val = *val_slot.get();
657 let result = (self.merge_fn)(&mut exec_state, old_val, val);
658 if result != old_val {
659 self.to_container.remove(&old_val);
660 self.to_container.insert(result, (hc as usize, target_map));
661 *val_slot.get_mut() = result;
662 for val in container.iter() {
663 let mut index = self.val_index.entry(val).or_default();
664 index.swap_remove(&old_val);
665 index.insert(result);
666 }
667 }
668 if stable_id && result == val {
671 dirty_ids.push(val);
672 }
673 }
674 Err(slot) => {
675 self.to_container.insert(val, (hc as usize, target_map));
676 for v in container.iter() {
677 self.val_index.entry(v).or_default().insert(val);
678 }
679 unsafe {
682 shard.insert_in_slot(hc, slot, (container, SharedValue::new(val)));
683 }
684 if stable_id {
685 dirty_ids.push(val);
686 }
687 }
688 }
689 }
690 });
691 let mut summary = ContainerRebuildSummary::default();
692 if changed {
693 summary.note_change();
694 }
695 while let Some(value) = dirty_ids.pop() {
696 summary.note_dirty_id(value);
697 }
698 summary
699 }
700
701 fn get_container(&self, value: Value) -> Option<impl Deref<Target = C> + '_> {
702 let (hc, target_map) = *self.to_container.get(&value)?;
703 let shard = &self.to_id.shards()[target_map];
704 let read_guard = shard.read();
705 let val_ptr: *const (C, _) = shard
706 .read()
707 .find(hc as u64, |(_, v)| *v.get() == value)?
708 .as_ptr();
709 struct ValueDeref<'a, T, Guard> {
710 _guard: Guard,
711 data: &'a T,
712 }
713
714 impl<T, Guard> Deref for ValueDeref<'_, T, Guard> {
715 type Target = T;
716
717 fn deref(&self) -> &T {
718 self.data
719 }
720 }
721
722 Some(ValueDeref {
723 _guard: read_guard,
724 data: unsafe {
726 let unwrapped: &(C, _) = &*val_ptr;
727 &unwrapped.0
728 },
729 })
730 }
731}
732
733fn incremental_rebuild(uf_size: usize, table_size: usize, parallel: bool) -> bool {
734 if parallel {
735 table_size > 1000 && uf_size * 512 <= table_size
736 } else {
737 table_size > 1000 && uf_size * 8 <= table_size
738 }
739}