egglog_union_find/
lib.rs

1//! This crate contains two basic union-find implementations:
2//!
3//! * [`UnionFind`], a basic single-threaded union-find data-structure.
4//! * [`concurrent::UnionFind`], a concurrent union-find data-structure.
5//!
6//! Both structures are fairly rudimentary and are customized to be used in an
7//! egraph-related setting. In particular, they do "union by min id", which is a
8//! strategy that _does not_ guarantee the same asymptotic complexity as the
9//! main techniques in the literature (e.g. union by rank). Union by min is a
10//! heuristic introduced to reduce the number of ids perturbed during congruence
11//! closure. There's likely more to do in this area but for now it seems to work
12//! well enough. It doesn't hurt that it's also simpler to implement.
13use egglog_numeric_id as numeric_id;
14use numeric_id::NumericId;
15use std::cmp;
16
17pub mod concurrent;
18
19#[cfg(test)]
20mod tests;
21
22/// A basic implementation of a union-find datastructure.
23#[derive(Clone)]
24pub struct UnionFind<Value> {
25    parents: Vec<Value>,
26}
27
28impl<V> Default for UnionFind<V> {
29    fn default() -> Self {
30        Self {
31            parents: Vec::new(),
32        }
33    }
34}
35
36impl<Value: NumericId> UnionFind<Value> {
37    /// Reset the union-find data-structure to the point where all Ids are their
38    /// own parents.
39    pub fn reset(&mut self) {
40        for (i, v) in self.parents.iter_mut().enumerate() {
41            *v = Value::from_usize(i);
42        }
43    }
44
45    /// Reserve sufficient space for the given value `v`.
46    pub fn reserve(&mut self, v: Value) {
47        if v.index() >= self.parents.len() {
48            for i in self.parents.len()..=v.index() {
49                self.parents.push(Value::from_usize(i));
50            }
51        }
52    }
53
54    /// Merge two equivalence classes.
55    pub fn union(&mut self, a: Value, b: Value) -> (Value /* parent */, Value /* child */) {
56        self.reserve(a);
57        self.reserve(b);
58        let a = self.find(a);
59        let b = self.find(b);
60        if a != b {
61            let parent = cmp::min(a, b);
62            let child = cmp::max(a, b);
63            self.parents[child.index()] = parent;
64            (parent, child)
65        } else {
66            (a, a)
67        }
68    }
69
70    /// Find the representative of an equivalence class.
71    pub fn find(&mut self, id: Value) -> Value {
72        self.reserve(id);
73        let mut cur = id;
74        loop {
75            let parent = self.parents[cur.index()];
76            if cur == parent {
77                break;
78            }
79            let grand = self.parents[parent.index()];
80            self.parents[cur.index()] = grand;
81            cur = grand;
82        }
83        cur
84    }
85
86    /// Find the representative of an equivalence class without using path compression.
87    ///
88    /// The primary advantage of this method is that it allows the ability to answer `find` queries
89    /// without holding a mutable reference to the union-find.
90    pub fn find_naive(&self, id: Value) -> Value {
91        if self.parents.len() <= id.index() {
92            return id;
93        }
94        let mut cur = id;
95        loop {
96            let parent = self.parents[cur.index()];
97            if cur == parent {
98                break;
99            }
100            cur = parent;
101        }
102        cur
103    }
104}