1use std::{
12 any::{Any, TypeId},
13 hash::{Hash, Hasher},
14 ops::Deref,
15};
16
17use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
18use crossbeam_queue::SegQueue;
19use dashmap::SharedValue;
20use rayon::{
21 iter::{ParallelBridge, ParallelIterator},
22 prelude::*,
23};
24use rustc_hash::FxHasher;
25
26use crate::{
27 ColumnId, CounterId, ExecutionState, Offset, SubsetRef, TableId, TaggedRowBuffer, Value,
28 WrappedTable,
29 common::{DashMap, IndexSet, SubsetTracker},
30 parallel_heuristics::{parallelize_inter_container_op, parallelize_intra_container_op},
31 table_spec::Rebuilder,
32};
33
34#[cfg(test)]
35mod tests;
36
37define_id!(pub ContainerValueId, u32, "an identifier for containers");
38
39pub trait MergeFn:
40 Fn(&mut ExecutionState, Value, Value) -> Value + dyn_clone::DynClone + Send + Sync
41{
42}
43impl<T: Fn(&mut ExecutionState, Value, Value) -> Value + Clone + Send + Sync> MergeFn for T {}
44
45dyn_clone::clone_trait_object!(MergeFn);
47
48#[derive(Clone, Default)]
49struct ContainerIds {
50 ids: IndexSet<TypeId>,
51}
52
53impl ContainerIds {
54 fn insert(&mut self, ty: TypeId) -> ContainerValueId {
55 if let Some(idx) = self.ids.get_index_of(&ty) {
56 ContainerValueId::from_usize(idx)
57 } else {
58 let idx = self.ids.len();
59 self.ids.insert(ty);
60 ContainerValueId::from_usize(idx)
61 }
62 }
63
64 fn get(&self, ty: &TypeId) -> Option<ContainerValueId> {
65 self.ids.get_index_of(ty).map(ContainerValueId::from_usize)
66 }
67}
68
69#[derive(Clone, Default)]
70pub struct ContainerValues {
71 subset_tracker: SubsetTracker,
72 container_ids: ContainerIds,
73 data: DenseIdMap<ContainerValueId, Box<dyn DynamicContainerEnv + Send + Sync>>,
74}
75
76#[derive(Clone, Default)]
89pub struct ContainerRebuildSummary {
90 changed: bool,
91 dirty_ids: IndexSet<Value>,
94}
95
96impl ContainerRebuildSummary {
97 pub fn changed(&self) -> bool {
99 self.changed
100 }
101
102 pub fn dirty_ids(&self) -> &IndexSet<Value> {
104 &self.dirty_ids
105 }
106
107 fn note_change(&mut self) {
108 self.changed = true;
109 }
110
111 fn note_dirty_id(&mut self, value: Value) {
112 self.changed = true;
113 self.dirty_ids.insert(value);
114 }
115
116 fn extend(&mut self, other: Self) {
117 self.changed |= other.changed;
118 self.dirty_ids.extend(other.dirty_ids);
119 }
120}
121
122impl ContainerValues {
123 pub fn new() -> Self {
124 Default::default()
125 }
126
127 fn get<C: ContainerValue>(&self) -> Option<&ContainerEnv<C>> {
128 let id = self.container_ids.get(&TypeId::of::<C>())?;
129 let res = self.data.get(id)?.as_any();
130 Some(res.downcast_ref::<ContainerEnv<C>>().unwrap())
131 }
132
133 pub fn for_each<C: ContainerValue>(&self, mut f: impl FnMut(&C, Value)) {
135 let Some(env) = self.get::<C>() else {
136 return;
137 };
138 for ent in env.to_id.iter() {
139 f(ent.key(), *ent.value());
140 }
141 }
142
143 pub fn get_val<C: ContainerValue>(&self, val: Value) -> Option<impl Deref<Target = C> + '_> {
149 self.get::<C>()?.get_container(val)
150 }
151
152 pub fn register_val<C: ContainerValue>(
153 &self,
154 container: C,
155 exec_state: &mut ExecutionState,
156 ) -> Value {
157 let env = self
158 .get::<C>()
159 .expect("must register container type before registering a value");
160 env.get_or_insert(&container, exec_state)
161 }
162
163 pub fn rebuild_all(
165 &mut self,
166 table_id: TableId,
167 table: &WrappedTable,
168 exec_state: &mut ExecutionState,
169 ) -> ContainerRebuildSummary {
170 let Some(rebuilder) = table.rebuilder(&[]) else {
171 return Default::default();
172 };
173 let to_scan = rebuilder.hint_col().map(|_| {
174 self.subset_tracker.recent_updates(table_id, table)
176 });
177 if parallelize_inter_container_op(self.data.next_id().index()) {
178 self.data
179 .iter_mut()
180 .zip(std::iter::repeat_with(|| exec_state.clone()))
181 .par_bridge()
182 .map(|((_, env), mut exec_state)| {
183 env.apply_rebuild(
184 table,
185 &*rebuilder,
186 to_scan.as_ref().map(|x| x.as_ref()),
187 &mut exec_state,
188 )
189 })
190 .reduce(ContainerRebuildSummary::default, |mut acc, summary| {
191 acc.extend(summary);
192 acc
193 })
194 } else {
195 let mut summary = ContainerRebuildSummary::default();
196 for (_, env) in self.data.iter_mut() {
197 summary.extend(env.apply_rebuild(
198 table,
199 &*rebuilder,
200 to_scan.as_ref().map(|x| x.as_ref()),
201 exec_state,
202 ));
203 }
204 summary
205 }
206 }
207
208 pub fn register_type<C: ContainerValue>(
213 &mut self,
214 id_counter: CounterId,
215 merge_fn: impl MergeFn + 'static,
216 ) -> ContainerValueId {
217 let id = self.container_ids.insert(TypeId::of::<C>());
218 self.data.get_or_insert(id, || {
219 Box::new(ContainerEnv::<C>::new(Box::new(merge_fn), id_counter))
220 });
221 id
222 }
223}
224
225pub trait ContainerValue: Hash + Eq + Clone + Send + Sync + 'static {
231 fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool;
236
237 fn iter(&self) -> impl Iterator<Item = Value> + '_;
244}
245
246pub trait DynamicContainerEnv: Any + dyn_clone::DynClone + Send + Sync {
247 fn as_any(&self) -> &dyn Any;
248 fn apply_rebuild(
249 &mut self,
250 table: &WrappedTable,
251 rebuilder: &dyn Rebuilder,
252 subset: Option<SubsetRef>,
253 exec_state: &mut ExecutionState,
254 ) -> ContainerRebuildSummary;
255}
256
257dyn_clone::clone_trait_object!(DynamicContainerEnv);
259
260fn hash_container(container: &impl ContainerValue) -> u64 {
261 let mut hasher = FxHasher::default();
262 container.hash(&mut hasher);
263 hasher.finish()
264}
265
266#[derive(Clone)]
267struct ContainerEnv<C: Eq + Hash> {
268 merge_fn: Box<dyn MergeFn>,
269 counter: CounterId,
270 to_id: DashMap<C, Value>,
271 to_container: DashMap<Value, (usize , usize )>,
272 val_index: DashMap<Value, IndexSet<Value>>,
274}
275
276impl<C: ContainerValue> DynamicContainerEnv for ContainerEnv<C> {
277 fn as_any(&self) -> &dyn Any {
278 self
279 }
280
281 fn apply_rebuild(
282 &mut self,
283 table: &WrappedTable,
284 rebuilder: &dyn Rebuilder,
285 subset: Option<SubsetRef>,
286 exec_state: &mut ExecutionState,
287 ) -> ContainerRebuildSummary {
288 if let Some(subset) = subset
289 && incremental_rebuild(
290 subset.size(),
291 self.to_id.len(),
292 parallelize_intra_container_op(self.to_id.len()),
293 )
294 {
295 return self.apply_rebuild_incremental(
296 table,
297 rebuilder,
298 exec_state,
299 subset,
300 rebuilder.hint_col().unwrap(),
301 );
302 }
303 self.apply_rebuild_nonincremental(rebuilder, exec_state)
304 }
305}
306
307impl<C: ContainerValue> ContainerEnv<C> {
308 pub fn new(merge_fn: Box<dyn MergeFn>, counter: CounterId) -> Self {
309 Self {
310 merge_fn,
311 counter,
312 to_id: DashMap::default(),
313 to_container: DashMap::default(),
314 val_index: DashMap::default(),
315 }
316 }
317
318 fn get_or_insert(&self, container: &C, exec_state: &mut ExecutionState) -> Value {
319 if let Some(value) = self.to_id.get(container) {
320 return *value;
321 }
322
323 let value = Value::from_usize(exec_state.inc_counter(self.counter));
328 let target_map = self.to_id.determine_map(container);
329 debug_assert_eq!(
333 target_map,
334 self.to_container
335 .determine_shard(hash_container(container) as usize)
336 );
337 self.to_container
338 .insert(value, (hash_container(container) as usize, target_map));
339
340 match self.to_id.entry(container.clone()) {
343 dashmap::Entry::Vacant(vac) => {
344 vac.insert(value);
346 for val in container.iter() {
347 self.val_index.entry(val).or_default().insert(value);
348 }
349 value
350 }
351 dashmap::Entry::Occupied(occ) => {
352 let res = *occ.get();
356 std::mem::drop(occ); self.to_container.remove(&value);
358 res
359 }
360 }
361 }
362
363 fn insert_owned(&self, container: C, value: Value, exec_state: &mut ExecutionState) -> Value {
364 let hc = hash_container(&container);
365 let target_map = self.to_id.determine_map(&container);
366 match self.to_id.entry(container) {
367 dashmap::Entry::Occupied(mut occ) => {
368 let result = (self.merge_fn)(exec_state, *occ.get(), value);
369 let old_val = *occ.get();
370 if result != old_val {
371 self.to_container.remove(&old_val);
372 self.to_container.insert(result, (hc as usize, target_map));
373 *occ.get_mut() = result;
374 for val in occ.key().iter() {
375 let mut index = self.val_index.entry(val).or_default();
376 index.swap_remove(&old_val);
377 index.insert(result);
378 }
379 }
380 result
381 }
382 dashmap::Entry::Vacant(vacant_entry) => {
383 self.to_container.insert(value, (hc as usize, target_map));
384 for val in vacant_entry.key().iter() {
385 self.val_index.entry(val).or_default().insert(value);
386 }
387 vacant_entry.insert(value);
388 value
389 }
390 }
391 }
392
393 fn reinsert_incremental(
394 &self,
395 container: C,
396 old_id: Value,
397 rebuilt_id: Value,
398 container_changed: bool,
399 exec_state: &mut ExecutionState,
400 summary: &mut ContainerRebuildSummary,
401 ) {
402 if container_changed || rebuilt_id != old_id {
403 summary.note_change();
404 }
405 if rebuilt_id != old_id {
406 self.to_container.remove(&old_id);
409 }
410 let actual = self.insert_owned(container, rebuilt_id, exec_state);
411 if container_changed && rebuilt_id == old_id && actual == old_id {
412 summary.note_dirty_id(old_id);
413 }
414 }
415
416 fn apply_rebuild_incremental(
417 &mut self,
418 table: &WrappedTable,
419 rebuilder: &dyn Rebuilder,
420 exec_state: &mut ExecutionState,
421 to_scan: SubsetRef,
422 search_col: ColumnId,
423 ) -> ContainerRebuildSummary {
424 let mut summary = ContainerRebuildSummary::default();
431 let mut buf = TaggedRowBuffer::new(1);
432 table.scan_project(
433 to_scan,
434 &[search_col],
435 Offset::new(0),
436 usize::MAX,
437 &[],
438 &mut buf,
439 );
440 let mut to_rebuild = IndexSet::<Value>::default();
442 for (_, row) in buf.iter() {
443 to_rebuild.insert(row[0]);
444 let Some(ids) = self.val_index.get(&row[0]) else {
445 continue;
446 };
447 to_rebuild.extend(&*ids);
448 }
449 for id in to_rebuild {
450 let Some((hc, target_map)) = self.to_container.get(&id).map(|x| *x) else {
451 continue;
452 };
453 let shard_mut = self.to_id.shards_mut()[target_map].get_mut();
454 let Some((mut container, _)) =
455 shard_mut.remove_entry(hc as u64, |(_, v)| *v.get() == id)
456 else {
457 continue;
458 };
459 let rebuilt_id = rebuilder.rebuild_val(id);
460 let container_changed = container.rebuild_contents(rebuilder);
461 self.reinsert_incremental(
462 container,
463 id,
464 rebuilt_id,
465 container_changed,
466 exec_state,
467 &mut summary,
468 );
469 }
470 summary
471 }
472
473 fn apply_rebuild_nonincremental(
474 &mut self,
475 rebuilder: &dyn Rebuilder,
476 exec_state: &mut ExecutionState,
477 ) -> ContainerRebuildSummary {
478 if parallelize_inter_container_op(self.to_id.len()) {
479 return self.apply_rebuild_nonincremental_parallel(rebuilder, exec_state);
480 }
481 let mut summary = ContainerRebuildSummary::default();
482 let mut to_reinsert = Vec::new();
483 let shards = self.to_id.shards_mut();
484 for shard in shards.iter_mut() {
485 let shard = shard.get_mut();
486 for bucket in unsafe { shard.iter() } {
488 let (container, val) = unsafe { bucket.as_mut() };
490 let old_val = *val.get();
491 let new_val = rebuilder.rebuild_val(old_val);
492 let container_changed = container.rebuild_contents(rebuilder);
493 if !container_changed && new_val == old_val {
494 continue;
496 }
497 summary.note_change();
498 if container_changed {
499 let ((container, _), _) = unsafe { shard.remove(bucket) };
503 self.to_container.remove(&old_val);
504 to_reinsert.push((container, new_val, new_val == old_val));
505 } else {
506 *val.get_mut() = new_val;
508 let prev = self.to_container.remove(&old_val).unwrap().1;
509 self.to_container.insert(new_val, prev);
510 }
511 }
512 }
513 for (container, val, stable_id) in to_reinsert {
514 let actual = self.insert_owned(container, val, exec_state);
515 if stable_id && actual == val {
519 summary.note_dirty_id(val);
520 }
521 }
522 summary
523 }
524
525 fn apply_rebuild_nonincremental_parallel(
526 &mut self,
527 rebuilder: &dyn Rebuilder,
528 exec_state: &mut ExecutionState,
529 ) -> ContainerRebuildSummary {
530 let mut to_reinsert =
535 IdVec::<usize , SegQueue<(C, Value, bool)>>::default();
536 to_reinsert.resize_with(self.to_id.shards().len(), Default::default);
537
538 let shards = self.to_id.shards_mut();
539 let changed = shards
540 .par_iter_mut()
541 .map(|shard| {
542 let mut changed = false;
543 let shard = shard.get_mut();
544 for bucket in unsafe { shard.iter() } {
546 let (container, val) = unsafe { bucket.as_mut() };
548 let old_val = *val.get();
549 let new_val = rebuilder.rebuild_val(old_val);
550 let container_changed = container.rebuild_contents(rebuilder);
551 if !container_changed && new_val == old_val {
552 continue;
554 }
555 changed = true;
556 if container_changed {
557 let ((container, _), _) = unsafe { shard.remove(bucket) };
561 self.to_container.remove(&old_val);
562 let shard = self
567 .to_container
568 .determine_shard(hash_container(&container) as usize);
569 to_reinsert[shard].push((container, new_val, new_val == old_val));
570 } else {
571 *val.get_mut() = new_val;
573 let prev = self.to_container.remove(&old_val).unwrap().1;
574 self.to_container.insert(new_val, prev);
575 }
576 }
577 changed
578 })
579 .max()
580 .unwrap_or(false);
581
582 let dirty_ids = SegQueue::new();
583 shards
584 .iter_mut()
585 .enumerate()
586 .map(|(i, shard)| (i, shard, exec_state.clone()))
587 .par_bridge()
588 .for_each(|(shard_id, shard, mut exec_state)| {
589 let shard = shard.get_mut();
595 let queue = &to_reinsert[shard_id];
596 while let Some((container, val, stable_id)) = queue.pop() {
597 let hc = hash_container(&container);
598 let target_map = self.to_container.determine_shard(hc as usize);
599 match shard.find_or_find_insert_slot(
600 hc,
601 |(c, _)| c == &container,
602 |(c, _)| hash_container(c),
603 ) {
604 Ok(bucket) => {
605 let (container, val_slot) = unsafe { bucket.as_mut() };
608 let old_val = *val_slot.get();
609 let result = (self.merge_fn)(&mut exec_state, old_val, val);
610 if result != old_val {
611 self.to_container.remove(&old_val);
612 self.to_container.insert(result, (hc as usize, target_map));
613 *val_slot.get_mut() = result;
614 for val in container.iter() {
615 let mut index = self.val_index.entry(val).or_default();
616 index.swap_remove(&old_val);
617 index.insert(result);
618 }
619 }
620 if stable_id && result == val {
623 dirty_ids.push(val);
624 }
625 }
626 Err(slot) => {
627 self.to_container.insert(val, (hc as usize, target_map));
628 for v in container.iter() {
629 self.val_index.entry(v).or_default().insert(val);
630 }
631 unsafe {
634 shard.insert_in_slot(hc, slot, (container, SharedValue::new(val)));
635 }
636 if stable_id {
637 dirty_ids.push(val);
638 }
639 }
640 }
641 }
642 });
643 let mut summary = ContainerRebuildSummary::default();
644 if changed {
645 summary.note_change();
646 }
647 while let Some(value) = dirty_ids.pop() {
648 summary.note_dirty_id(value);
649 }
650 summary
651 }
652
653 fn get_container(&self, value: Value) -> Option<impl Deref<Target = C> + '_> {
654 let (hc, target_map) = *self.to_container.get(&value)?;
655 let shard = &self.to_id.shards()[target_map];
656 let read_guard = shard.read();
657 let val_ptr: *const (C, _) = shard
658 .read()
659 .find(hc as u64, |(_, v)| *v.get() == value)?
660 .as_ptr();
661 struct ValueDeref<'a, T, Guard> {
662 _guard: Guard,
663 data: &'a T,
664 }
665
666 impl<T, Guard> Deref for ValueDeref<'_, T, Guard> {
667 type Target = T;
668
669 fn deref(&self) -> &T {
670 self.data
671 }
672 }
673
674 Some(ValueDeref {
675 _guard: read_guard,
676 data: unsafe {
678 let unwrapped: &(C, _) = &*val_ptr;
679 &unwrapped.0
680 },
681 })
682 }
683}
684
685fn incremental_rebuild(uf_size: usize, table_size: usize, parallel: bool) -> bool {
686 if parallel {
687 table_size > 1000 && uf_size * 512 <= table_size
688 } else {
689 table_size > 1000 && uf_size * 8 <= table_size
690 }
691}