1use std::{
12 any::{Any, TypeId},
13 hash::{Hash, Hasher},
14 ops::Deref,
15};
16
17use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
18use crossbeam_queue::SegQueue;
19use dashmap::SharedValue;
20use rayon::{
21 iter::{ParallelBridge, ParallelIterator},
22 prelude::*,
23};
24use rustc_hash::FxHasher;
25
26use crate::{
27 ColumnId, CounterId, ExecutionState, Offset, SubsetRef, TableId, TaggedRowBuffer, Value,
28 WrappedTable,
29 common::{DashMap, IndexSet, SubsetTracker},
30 parallel_heuristics::{parallelize_inter_container_op, parallelize_intra_container_op},
31 table_spec::Rebuilder,
32};
33
34#[cfg(test)]
35mod tests;
36
37define_id!(pub ContainerValueId, u32, "an identifier for containers");
38
39pub trait MergeFn:
40 Fn(&mut ExecutionState, Value, Value) -> Value + dyn_clone::DynClone + Send + Sync
41{
42}
43impl<T: Fn(&mut ExecutionState, Value, Value) -> Value + Clone + Send + Sync> MergeFn for T {}
44
45dyn_clone::clone_trait_object!(MergeFn);
47
48#[derive(Clone, Default)]
49struct ContainerIds {
50 ids: IndexSet<TypeId>,
51}
52
53impl ContainerIds {
54 fn insert(&mut self, ty: TypeId) -> ContainerValueId {
55 if let Some(idx) = self.ids.get_index_of(&ty) {
56 ContainerValueId::from_usize(idx)
57 } else {
58 let idx = self.ids.len();
59 self.ids.insert(ty);
60 ContainerValueId::from_usize(idx)
61 }
62 }
63
64 fn get(&self, ty: &TypeId) -> Option<ContainerValueId> {
65 self.ids.get_index_of(ty).map(ContainerValueId::from_usize)
66 }
67}
68
69#[derive(Clone, Default)]
70pub struct ContainerValues {
71 subset_tracker: SubsetTracker,
72 container_ids: ContainerIds,
73 data: DenseIdMap<ContainerValueId, Box<dyn DynamicContainerEnv + Send + Sync>>,
74}
75
76impl ContainerValues {
77 pub fn new() -> Self {
78 Default::default()
79 }
80
81 fn get<C: ContainerValue>(&self) -> Option<&ContainerEnv<C>> {
82 let id = self.container_ids.get(&TypeId::of::<C>())?;
83 let res = self.data.get(id)?.as_any();
84 Some(res.downcast_ref::<ContainerEnv<C>>().unwrap())
85 }
86
87 pub fn for_each<C: ContainerValue>(&self, mut f: impl FnMut(&C, Value)) {
89 let Some(env) = self.get::<C>() else {
90 return;
91 };
92 for ent in env.to_id.iter() {
93 f(ent.key(), *ent.value());
94 }
95 }
96
97 pub fn get_val<C: ContainerValue>(&self, val: Value) -> Option<impl Deref<Target = C> + '_> {
103 self.get::<C>()?.get_container(val)
104 }
105
106 pub fn register_val<C: ContainerValue>(
107 &self,
108 container: C,
109 exec_state: &mut ExecutionState,
110 ) -> Value {
111 let env = self
112 .get::<C>()
113 .expect("must register container type before registering a value");
114 env.get_or_insert(&container, exec_state)
115 }
116
117 pub fn rebuild_all(
119 &mut self,
120 table_id: TableId,
121 table: &WrappedTable,
122 exec_state: &mut ExecutionState,
123 ) -> bool {
124 let Some(rebuilder) = table.rebuilder(&[]) else {
125 return false;
126 };
127 let to_scan = rebuilder.hint_col().map(|_| {
128 self.subset_tracker.recent_updates(table_id, table)
130 });
131 if parallelize_inter_container_op(self.data.next_id().index()) {
132 self.data
133 .iter_mut()
134 .zip(std::iter::repeat_with(|| exec_state.clone()))
135 .par_bridge()
136 .map(|((_, env), mut exec_state)| {
137 env.apply_rebuild(
138 table,
139 &*rebuilder,
140 to_scan.as_ref().map(|x| x.as_ref()),
141 &mut exec_state,
142 )
143 })
144 .max()
145 .unwrap_or(false)
146 } else {
147 let mut changed = false;
148 for (_, env) in self.data.iter_mut() {
149 changed |= env.apply_rebuild(
150 table,
151 &*rebuilder,
152 to_scan.as_ref().map(|x| x.as_ref()),
153 exec_state,
154 );
155 }
156 changed
157 }
158 }
159
160 pub fn register_type<C: ContainerValue>(
165 &mut self,
166 id_counter: CounterId,
167 merge_fn: impl MergeFn + 'static,
168 ) -> ContainerValueId {
169 let id = self.container_ids.insert(TypeId::of::<C>());
170 self.data.get_or_insert(id, || {
171 Box::new(ContainerEnv::<C>::new(Box::new(merge_fn), id_counter))
172 });
173 id
174 }
175}
176
177pub trait ContainerValue: Hash + Eq + Clone + Send + Sync + 'static {
183 fn rebuild_contents(&mut self, rebuilder: &dyn Rebuilder) -> bool;
188
189 fn iter(&self) -> impl Iterator<Item = Value> + '_;
196}
197
198pub trait DynamicContainerEnv: Any + dyn_clone::DynClone + Send + Sync {
199 fn as_any(&self) -> &dyn Any;
200 fn apply_rebuild(
201 &mut self,
202 table: &WrappedTable,
203 rebuilder: &dyn Rebuilder,
204 subset: Option<SubsetRef>,
205 exec_state: &mut ExecutionState,
206 ) -> bool;
207}
208
209dyn_clone::clone_trait_object!(DynamicContainerEnv);
211
212fn hash_container(container: &impl ContainerValue) -> u64 {
213 let mut hasher = FxHasher::default();
214 container.hash(&mut hasher);
215 hasher.finish()
216}
217
218#[derive(Clone)]
219struct ContainerEnv<C: Eq + Hash> {
220 merge_fn: Box<dyn MergeFn>,
221 counter: CounterId,
222 to_id: DashMap<C, Value>,
223 to_container: DashMap<Value, (usize , usize )>,
224 val_index: DashMap<Value, IndexSet<Value>>,
226}
227
228impl<C: ContainerValue> DynamicContainerEnv for ContainerEnv<C> {
229 fn as_any(&self) -> &dyn Any {
230 self
231 }
232
233 fn apply_rebuild(
234 &mut self,
235 table: &WrappedTable,
236 rebuilder: &dyn Rebuilder,
237 subset: Option<SubsetRef>,
238 exec_state: &mut ExecutionState,
239 ) -> bool {
240 if let Some(subset) = subset {
241 if incremental_rebuild(
242 subset.size(),
243 self.to_id.len(),
244 parallelize_intra_container_op(self.to_id.len()),
245 ) {
246 return self.apply_rebuild_incremental(
247 table,
248 rebuilder,
249 exec_state,
250 subset,
251 rebuilder.hint_col().unwrap(),
252 );
253 }
254 }
255 self.apply_rebuild_nonincremental(rebuilder, exec_state)
256 }
257}
258
259impl<C: ContainerValue> ContainerEnv<C> {
260 pub fn new(merge_fn: Box<dyn MergeFn>, counter: CounterId) -> Self {
261 Self {
262 merge_fn,
263 counter,
264 to_id: DashMap::default(),
265 to_container: DashMap::default(),
266 val_index: DashMap::default(),
267 }
268 }
269
270 fn get_or_insert(&self, container: &C, exec_state: &mut ExecutionState) -> Value {
271 if let Some(value) = self.to_id.get(container) {
272 return *value;
273 }
274
275 let value = Value::from_usize(exec_state.inc_counter(self.counter));
280 let target_map = self.to_id.determine_map(container);
281 debug_assert_eq!(
285 target_map,
286 self.to_container
287 .determine_shard(hash_container(container) as usize)
288 );
289 self.to_container
290 .insert(value, (hash_container(container) as usize, target_map));
291
292 match self.to_id.entry(container.clone()) {
295 dashmap::Entry::Vacant(vac) => {
296 vac.insert(value);
298 for val in container.iter() {
299 self.val_index.entry(val).or_default().insert(value);
300 }
301 value
302 }
303 dashmap::Entry::Occupied(occ) => {
304 let res = *occ.get();
308 std::mem::drop(occ); self.to_container.remove(&value);
310 res
311 }
312 }
313 }
314
315 fn insert_owned(&self, container: C, value: Value, exec_state: &mut ExecutionState) {
316 let hc = hash_container(&container);
317 let target_map = self.to_id.determine_map(&container);
318 match self.to_id.entry(container) {
319 dashmap::Entry::Occupied(mut occ) => {
320 let result = (self.merge_fn)(exec_state, *occ.get(), value);
321 let old_val = *occ.get();
322 if result != old_val {
323 self.to_container.remove(&old_val);
324 self.to_container.insert(result, (hc as usize, target_map));
325 *occ.get_mut() = result;
326 for val in occ.key().iter() {
327 let mut index = self.val_index.entry(val).or_default();
328 index.swap_remove(&old_val);
329 index.insert(result);
330 }
331 }
332 }
333 dashmap::Entry::Vacant(vacant_entry) => {
334 self.to_container.insert(value, (hc as usize, target_map));
335 for val in vacant_entry.key().iter() {
336 self.val_index.entry(val).or_default().insert(value);
337 }
338 vacant_entry.insert(value);
339 }
340 }
341 }
342 fn apply_rebuild_incremental(
343 &mut self,
344 table: &WrappedTable,
345 rebuilder: &dyn Rebuilder,
346 exec_state: &mut ExecutionState,
347 to_scan: SubsetRef,
348 search_col: ColumnId,
349 ) -> bool {
350 let mut changed = false;
357 let mut buf = TaggedRowBuffer::new(1);
358 table.scan_project(
359 to_scan,
360 &[search_col],
361 Offset::new(0),
362 usize::MAX,
363 &[],
364 &mut buf,
365 );
366 let mut to_rebuild = IndexSet::<Value>::default();
368 for (_, row) in buf.iter() {
369 to_rebuild.insert(row[0]);
370 let Some(ids) = self.val_index.get(&row[0]) else {
371 continue;
372 };
373 to_rebuild.extend(&*ids);
374 }
375 for id in to_rebuild {
376 let Some((hc, target_map)) = self.to_container.get(&id).map(|x| *x) else {
377 continue;
378 };
379 let shard_mut = self.to_id.shards_mut()[target_map].get_mut();
380 let Some((mut container, _)) =
381 shard_mut.remove_entry(hc as u64, |(_, v)| *v.get() == id)
382 else {
383 continue;
384 };
385 changed |= container.rebuild_contents(rebuilder);
386 self.insert_owned(container, id, exec_state);
387 }
388 changed
389 }
390
391 fn apply_rebuild_nonincremental(
392 &mut self,
393 rebuilder: &dyn Rebuilder,
394 exec_state: &mut ExecutionState,
395 ) -> bool {
396 if parallelize_inter_container_op(self.to_id.len()) {
397 return self.apply_rebuild_nonincremental_parallel(rebuilder, exec_state);
398 }
399 let mut changed = false;
400 let mut to_reinsert = Vec::new();
401 let shards = self.to_id.shards_mut();
402 for shard in shards.iter_mut() {
403 let shard = shard.get_mut();
404 for bucket in unsafe { shard.iter() } {
406 let (container, val) = unsafe { bucket.as_mut() };
408 let old_val = *val.get();
409 let new_val = rebuilder.rebuild_val(old_val);
410 let container_changed = container.rebuild_contents(rebuilder);
411 if !container_changed && new_val == old_val {
412 continue;
414 }
415 changed = true;
416 if container_changed {
417 let ((container, _), _) = unsafe { shard.remove(bucket) };
421 self.to_container.remove(&old_val);
422 to_reinsert.push((container, new_val));
423 } else {
424 *val.get_mut() = new_val;
426 let prev = self.to_container.remove(&old_val).unwrap().1;
427 self.to_container.insert(new_val, prev);
428 }
429 }
430 }
431 for (container, val) in to_reinsert {
432 self.insert_owned(container, val, exec_state);
433 }
434 changed
435 }
436
437 fn apply_rebuild_nonincremental_parallel(
438 &mut self,
439 rebuilder: &dyn Rebuilder,
440 exec_state: &mut ExecutionState,
441 ) -> bool {
442 let mut to_reinsert = IdVec::<usize , SegQueue<(C, Value)>>::default();
447 to_reinsert.resize_with(self.to_id.shards().len(), Default::default);
448
449 let shards = self.to_id.shards_mut();
450 let changed = shards
451 .par_iter_mut()
452 .map(|shard| {
453 let mut changed = false;
454 let shard = shard.get_mut();
455 for bucket in unsafe { shard.iter() } {
457 let (container, val) = unsafe { bucket.as_mut() };
459 let old_val = *val.get();
460 let new_val = rebuilder.rebuild_val(old_val);
461 let container_changed = container.rebuild_contents(rebuilder);
462 if !container_changed && new_val == old_val {
463 continue;
465 }
466 changed = true;
467 if container_changed {
468 let ((container, _), _) = unsafe { shard.remove(bucket) };
472 self.to_container.remove(&old_val);
473 let shard = self
478 .to_container
479 .determine_shard(hash_container(&container) as usize);
480 to_reinsert[shard].push((container, new_val));
481 } else {
482 *val.get_mut() = new_val;
484 let prev = self.to_container.remove(&old_val).unwrap().1;
485 self.to_container.insert(new_val, prev);
486 }
487 }
488 changed
489 })
490 .max()
491 .unwrap_or(false);
492
493 shards
494 .iter_mut()
495 .enumerate()
496 .map(|(i, shard)| (i, shard, exec_state.clone()))
497 .par_bridge()
498 .for_each(|(shard_id, shard, mut exec_state)| {
499 let shard = shard.get_mut();
505 let queue = &to_reinsert[shard_id];
506 while let Some((container, val)) = queue.pop() {
507 let hc = hash_container(&container);
508 let target_map = self.to_container.determine_shard(hc as usize);
509 match shard.find_or_find_insert_slot(
510 hc,
511 |(c, _)| c == &container,
512 |(c, _)| hash_container(c),
513 ) {
514 Ok(bucket) => {
515 let (container, val_slot) = unsafe { bucket.as_mut() };
518 let old_val = *val_slot.get();
519 let result = (self.merge_fn)(&mut exec_state, old_val, val);
520 if result != old_val {
521 self.to_container.remove(&old_val);
522 self.to_container.insert(result, (hc as usize, target_map));
523 *val_slot.get_mut() = result;
524 for val in container.iter() {
525 let mut index = self.val_index.entry(val).or_default();
526 index.swap_remove(&old_val);
527 index.insert(result);
528 }
529 }
530 }
531 Err(slot) => {
532 self.to_container.insert(val, (hc as usize, target_map));
533 for v in container.iter() {
534 self.val_index.entry(v).or_default().insert(val);
535 }
536 unsafe {
539 shard.insert_in_slot(hc, slot, (container, SharedValue::new(val)));
540 }
541 }
542 }
543 }
544 });
545 changed
546 }
547
548 fn get_container(&self, value: Value) -> Option<impl Deref<Target = C> + '_> {
549 let (hc, target_map) = *self.to_container.get(&value)?;
550 let shard = &self.to_id.shards()[target_map];
551 let read_guard = shard.read();
552 let val_ptr: *const (C, _) = shard
553 .read()
554 .find(hc as u64, |(_, v)| *v.get() == value)?
555 .as_ptr();
556 struct ValueDeref<'a, T, Guard> {
557 _guard: Guard,
558 data: &'a T,
559 }
560
561 impl<T, Guard> Deref for ValueDeref<'_, T, Guard> {
562 type Target = T;
563
564 fn deref(&self) -> &T {
565 self.data
566 }
567 }
568
569 Some(ValueDeref {
570 _guard: read_guard,
571 data: unsafe {
573 let unwrapped: &(C, _) = &*val_ptr;
574 &unwrapped.0
575 },
576 })
577 }
578}
579
580fn incremental_rebuild(_uf_size: usize, _table_size: usize, _parallel: bool) -> bool {
581 #[cfg(debug_assertions)]
582 {
583 use rand::Rng;
584 rand::rng().random_bool(0.5)
585 }
586 #[cfg(not(debug_assertions))]
587 {
588 if _parallel {
589 _table_size > 1000 && _uf_size * 512 <= _table_size
590 } else {
591 _table_size > 1000 && _uf_size * 8 <= _table_size
592 }
593 }
594}