|
1 | 1 | package eu.sim642.adventofcodelib |
2 | 2 |
|
| 3 | +import eu.sim642.adventofcodelib.UnionFind.Node |
| 4 | + |
3 | 5 | import scala.annotation.tailrec |
4 | 6 |
|
5 | | -class UnionFind[A](val reprs: Map[A, A]) { |
6 | | - // TODO: optimize |
| 7 | +class UnionFind[A](val nodes: Map[A, Node[A]]) { |
7 | 8 |
|
8 | | - def this(items: Seq[A]) = { |
9 | | - this(items.map(x => x -> x).toMap) |
| 9 | + def this(xs: Seq[A]) = { |
| 10 | + this(xs.map(x => x -> Node(x, 1)).toMap) |
10 | 11 | } |
11 | 12 |
|
12 | 13 | @tailrec |
13 | | - final def findRepr(x: A): A = { |
14 | | - val repr = reprs(x) |
15 | | - if (x == repr) |
16 | | - repr |
| 14 | + final def findReprNode(x: A): Node[A] = { |
| 15 | + // TODO: path compression |
| 16 | + val node = nodes(x) |
| 17 | + if (x == node.parent) |
| 18 | + node |
17 | 19 | else |
18 | | - findRepr(repr) |
| 20 | + findReprNode(node.parent) |
19 | 21 | } |
20 | 22 |
|
| 23 | + def findRepr(x: A): A = findReprNode(x).parent |
| 24 | + |
21 | 25 | def sameRepr(x: A, y: A): Boolean = |
22 | 26 | findRepr(x) == findRepr(y) |
23 | 27 |
|
24 | 28 | def unioned(x: A, y: A): UnionFind[A] = { |
25 | | - val xRepr = findRepr(x) |
26 | | - val yRepr = findRepr(y) |
27 | | - new UnionFind(reprs + (yRepr -> xRepr)) |
| 29 | + val xNode = findReprNode(x) |
| 30 | + val yNode = findReprNode(y) |
| 31 | + val xRepr = xNode.parent |
| 32 | + val yRepr = yNode.parent |
| 33 | + if (xRepr == yRepr) // sameRepr inlined |
| 34 | + this |
| 35 | + else if (xNode.size >= yNode.size) |
| 36 | + new UnionFind(nodes + (xRepr -> xNode.copy(size = xNode.size + yNode.size)) + (yRepr -> yNode.copy(parent = xRepr))) |
| 37 | + else |
| 38 | + new UnionFind(nodes + (yRepr -> yNode.copy(size = xNode.size + yNode.size)) + (xRepr -> xNode.copy(parent = yRepr))) |
28 | 39 | } |
29 | 40 |
|
30 | 41 | def groups(): Seq[Seq[A]] = |
31 | | - reprs.keys.groupBy(findRepr).values.map(_.toSeq).toSeq |
| 42 | + nodes.keys.groupBy(findRepr).values.map(_.toSeq).toSeq |
| 43 | + |
| 44 | + override def toString: String = nodes.toString() |
| 45 | +} |
32 | 46 |
|
33 | | - override def toString: String = reprs.toString() |
| 47 | +object UnionFind { |
| 48 | + case class Node[A](parent: A, size: Int) |
34 | 49 | } |
0 commit comments