Skip to content

Commit f92b7ae

Browse files
committed
Add union by rank to UnionFind
1 parent febab04 commit f92b7ae

File tree

1 file changed

+29
-14
lines changed

1 file changed

+29
-14
lines changed
Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,49 @@
11
package eu.sim642.adventofcodelib
22

3+
import eu.sim642.adventofcodelib.UnionFind.Node
4+
35
import scala.annotation.tailrec
46

5-
class UnionFind[A](val reprs: Map[A, A]) {
6-
// TODO: optimize
7+
class UnionFind[A](val nodes: Map[A, Node[A]]) {
78

8-
def this(items: Seq[A]) = {
9-
this(items.map(x => x -> x).toMap)
9+
def this(xs: Seq[A]) = {
10+
this(xs.map(x => x -> Node(x, 1)).toMap)
1011
}
1112

1213
@tailrec
13-
final def findRepr(x: A): A = {
14-
val repr = reprs(x)
15-
if (x == repr)
16-
repr
14+
final def findReprNode(x: A): Node[A] = {
15+
// TODO: path compression
16+
val node = nodes(x)
17+
if (x == node.parent)
18+
node
1719
else
18-
findRepr(repr)
20+
findReprNode(node.parent)
1921
}
2022

23+
def findRepr(x: A): A = findReprNode(x).parent
24+
2125
def sameRepr(x: A, y: A): Boolean =
2226
findRepr(x) == findRepr(y)
2327

2428
def unioned(x: A, y: A): UnionFind[A] = {
25-
val xRepr = findRepr(x)
26-
val yRepr = findRepr(y)
27-
new UnionFind(reprs + (yRepr -> xRepr))
29+
val xNode = findReprNode(x)
30+
val yNode = findReprNode(y)
31+
val xRepr = xNode.parent
32+
val yRepr = yNode.parent
33+
if (xRepr == yRepr) // sameRepr inlined
34+
this
35+
else if (xNode.size >= yNode.size)
36+
new UnionFind(nodes + (xRepr -> xNode.copy(size = xNode.size + yNode.size)) + (yRepr -> yNode.copy(parent = xRepr)))
37+
else
38+
new UnionFind(nodes + (yRepr -> yNode.copy(size = xNode.size + yNode.size)) + (xRepr -> xNode.copy(parent = yRepr)))
2839
}
2940

3041
def groups(): Seq[Seq[A]] =
31-
reprs.keys.groupBy(findRepr).values.map(_.toSeq).toSeq
42+
nodes.keys.groupBy(findRepr).values.map(_.toSeq).toSeq
43+
44+
override def toString: String = nodes.toString()
45+
}
3246

33-
override def toString: String = reprs.toString()
47+
object UnionFind {
48+
case class Node[A](parent: A, size: Int)
3449
}

0 commit comments

Comments
 (0)