Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit 3b45068

Browse files
knutwalkerjexp
authored andcommitted
Prevent UnionFind deadlock (#866)
* Add failing test * Add test timeouts * Test all other implementations as well * Catch wider test exception * Better exception handling in UF/FJMerge * Guard UF/Queue implementations against runtime failures * Add same guard logic to merging as well
1 parent 81c8bfb commit 3b45068

File tree

6 files changed

+359
-98
lines changed

6 files changed

+359
-98
lines changed

algo/src/main/java/org/neo4j/graphalgo/impl/HugeParallelUnionFindFJMerge.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import org.neo4j.graphalgo.api.HugeGraph;
2222
import org.neo4j.graphalgo.api.HugeRelationshipIterator;
23+
import org.neo4j.graphalgo.core.utils.ExceptionUtil;
2324
import org.neo4j.graphalgo.core.utils.ParallelUtil;
2425
import org.neo4j.graphalgo.core.utils.paged.AllocationTracker;
2526
import org.neo4j.graphalgo.core.utils.paged.PagedDisjointSetStruct;
@@ -148,9 +149,7 @@ public void run() {
148149
return true;
149150
});
150151
} catch (Exception e) {
151-
System.out.println("exception for nodeid:" + node);
152-
e.printStackTrace();
153-
return;
152+
throw ExceptionUtil.asUnchecked(e);
154153
}
155154
}
156155
getProgressLogger().logProgress((end - 1) / (nodeCount - 1));

algo/src/main/java/org/neo4j/graphalgo/impl/HugeParallelUnionFindQueue.java

Lines changed: 52 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
import java.util.concurrent.BlockingQueue;
3232
import java.util.concurrent.ExecutorService;
3333
import java.util.concurrent.Future;
34+
import java.util.concurrent.atomic.AtomicInteger;
35+
36+
import static org.neo4j.graphalgo.core.utils.ParallelUtil.awaitTermination;
37+
import static org.neo4j.graphalgo.impl.ParallelUnionFindQueue.mergeTask;
3438

3539
/**
3640
* parallel UnionFind using ExecutorService only.
@@ -48,10 +52,10 @@
4852
public class HugeParallelUnionFindQueue extends GraphUnionFindAlgo<HugeGraph, PagedDisjointSetStruct, HugeParallelUnionFindQueue> {
4953

5054
private final ExecutorService executor;
55+
private final AllocationTracker tracker;
5156
private final long nodeCount;
5257
private final long batchSize;
5358
private final int stepSize;
54-
private final AllocationTracker tracker;
5559

5660
/**
5761
* initialize parallel UF
@@ -64,14 +68,13 @@ public class HugeParallelUnionFindQueue extends GraphUnionFindAlgo<HugeGraph, Pa
6468
AllocationTracker tracker) {
6569
super(graph);
6670
this.executor = executor;
67-
nodeCount = graph.nodeCount();
6871
this.tracker = tracker;
72+
this.nodeCount = graph.nodeCount();
6973
this.batchSize = ParallelUtil.adjustBatchSize(
7074
nodeCount,
7175
concurrency,
7276
minBatchSize,
7377
Integer.MAX_VALUE);
74-
7578
long targetSteps = ParallelUtil.threadSize(batchSize, nodeCount);
7679
if (targetSteps > Integer.MAX_VALUE) {
7780
throw new IllegalArgumentException(String.format(
@@ -80,88 +83,90 @@ public class HugeParallelUnionFindQueue extends GraphUnionFindAlgo<HugeGraph, Pa
8083
concurrency,
8184
batchSize));
8285
}
83-
stepSize = (int) targetSteps;
86+
this.stepSize = (int) targetSteps;
8487
}
8588

8689
@Override
8790
public PagedDisjointSetStruct compute() {
88-
final List<Future<?>> futures = new ArrayList<>(stepSize);
91+
final List<Future<?>> futures = new ArrayList<>(2 * stepSize);
8992
final BlockingQueue<PagedDisjointSetStruct> queue = new ArrayBlockingQueue<>(stepSize);
93+
AtomicInteger expectedStructs = new AtomicInteger();
9094

91-
int steps = 0;
9295
for (long i = 0L; i < nodeCount; i += batchSize) {
93-
futures.add(executor.submit(new HugeUnionFindTask(queue, i)));
94-
++steps;
96+
futures.add(executor.submit(new HugeUnionFindTask(queue, i, expectedStructs)));
9597
}
98+
int steps = futures.size();
9699

97100
for (int i = 1; i < steps; ++i) {
98-
futures.add(executor.submit(() -> {
99-
try {
100-
final PagedDisjointSetStruct a = queue.take();
101-
final PagedDisjointSetStruct b = queue.take();
102-
queue.add(a.merge(b));
103-
} catch (InterruptedException e) {
104-
Thread.currentThread().interrupt();
105-
throw new RuntimeException(e);
106-
}
107-
}));
101+
futures.add(executor.submit(() -> mergeTask(queue, expectedStructs, PagedDisjointSetStruct::merge)));
108102
}
109103

110-
await(futures);
104+
awaitTermination(futures);
111105
return getStruct(queue);
112106
}
113107

108+
@Override
114109
public PagedDisjointSetStruct compute(double threshold) {
115-
throw new IllegalArgumentException("Not yet implemented");
116-
}
117-
118-
private void await(final List<Future<?>> futures) {
119-
ParallelUtil.awaitTermination(futures);
110+
throw new IllegalArgumentException(
111+
"Parallel UnionFind with threshold not implemented, please use either `concurrency:1` or one of the exp* variants of UnionFind");
120112
}
121113

122114
private PagedDisjointSetStruct getStruct(final BlockingQueue<PagedDisjointSetStruct> queue) {
123-
try {
124-
return queue.take();
125-
} catch (InterruptedException e) {
126-
Thread.currentThread().interrupt();
127-
throw new RuntimeException(e);
115+
PagedDisjointSetStruct set = queue.poll();
116+
if (set == null) {
117+
set = new PagedDisjointSetStruct(nodeCount, tracker);
128118
}
119+
return set;
129120
}
130121

131122
private class HugeUnionFindTask implements Runnable {
132123

133124
private final HugeRelationshipIterator rels;
134125
private final BlockingQueue<PagedDisjointSetStruct> queue;
126+
private final AtomicInteger expectedStructs;
135127
private final long offset;
136128
private final long end;
137129

138-
HugeUnionFindTask(BlockingQueue<PagedDisjointSetStruct> queue, long offset) {
130+
HugeUnionFindTask(
131+
BlockingQueue<PagedDisjointSetStruct> queue,
132+
long offset,
133+
AtomicInteger expectedStructs) {
139134
this.rels = graph.concurrentCopy();
140135
this.queue = queue;
136+
this.expectedStructs = expectedStructs;
141137
this.offset = offset;
142138
this.end = Math.min(offset + batchSize, nodeCount);
139+
expectedStructs.incrementAndGet();
143140
}
144141

145142
@Override
146143
public void run() {
147-
final PagedDisjointSetStruct struct = new PagedDisjointSetStruct(
148-
nodeCount,
149-
tracker).reset();
150-
for (long node = offset; node < end; node++) {
151-
rels.forEachRelationship(
152-
node,
153-
Direction.OUTGOING,
154-
(sourceNodeId, targetNodeId) -> {
155-
struct.union(sourceNodeId, targetNodeId);
156-
return true;
157-
});
158-
}
159-
getProgressLogger().logProgress((end - 1.0) / (nodeCount - 1.0));
144+
boolean pushed = false;
160145
try {
161-
queue.put(struct);
162-
} catch (InterruptedException e) {
163-
Thread.currentThread().interrupt();
164-
throw new RuntimeException(e);
146+
final PagedDisjointSetStruct struct = new PagedDisjointSetStruct(
147+
nodeCount,
148+
tracker).reset();
149+
for (long node = offset; node < end; node++) {
150+
rels.forEachRelationship(
151+
node,
152+
Direction.OUTGOING,
153+
(sourceNodeId, targetNodeId) -> {
154+
struct.union(sourceNodeId, targetNodeId);
155+
return true;
156+
});
157+
}
158+
getProgressLogger().logProgress((end - 1.0) / (nodeCount - 1.0));
159+
try {
160+
queue.put(struct);
161+
pushed = true;
162+
} catch (InterruptedException e) {
163+
Thread.currentThread().interrupt();
164+
throw new RuntimeException(e);
165+
}
166+
} finally {
167+
if (!pushed) {
168+
expectedStructs.decrementAndGet();
169+
}
165170
}
166171
}
167172
}

algo/src/main/java/org/neo4j/graphalgo/impl/ParallelUnionFindFJMerge.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.neo4j.graphalgo.impl;
2020

2121
import org.neo4j.graphalgo.api.Graph;
22+
import org.neo4j.graphalgo.core.utils.ExceptionUtil;
2223
import org.neo4j.graphalgo.core.utils.ParallelUtil;
2324
import org.neo4j.graphalgo.core.utils.dss.DisjointSetStruct;
2425
import org.neo4j.graphdb.Direction;
@@ -136,9 +137,7 @@ public void run() {
136137
return true;
137138
});
138139
} catch (Exception e) {
139-
System.out.println("exception for nodeid:" + node);
140-
e.printStackTrace();
141-
return;
140+
throw ExceptionUtil.asUnchecked(e);
142141
}
143142
}
144143
getProgressLogger().logProgress((end - 1) / (nodeCount - 1));

algo/src/main/java/org/neo4j/graphalgo/impl/ParallelUnionFindQueue.java

Lines changed: 73 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@
2525

2626
import java.util.*;
2727
import java.util.concurrent.*;
28+
import java.util.concurrent.atomic.AtomicInteger;
29+
import java.util.function.BinaryOperator;
2830
import java.util.function.Function;
2931

32+
import static org.neo4j.graphalgo.core.utils.ParallelUtil.awaitTermination;
33+
3034
/**
3135
* parallel UnionFind using ExecutorService only.
3236
* <p>
@@ -45,7 +49,6 @@ public class ParallelUnionFindQueue extends GraphUnionFindAlgo<Graph, DisjointSe
4549
private final ExecutorService executor;
4650
private final int nodeCount;
4751
private final int batchSize;
48-
private final int stepSize;
4952

5053
public static Function<Graph, ParallelUnionFindQueue> of(ExecutorService executor, int minBatchSize, int concurrency) {
5154
return graph -> new ParallelUnionFindQueue(
@@ -61,43 +64,62 @@ public static Function<Graph, ParallelUnionFindQueue> of(ExecutorService executo
6164
public ParallelUnionFindQueue(Graph graph, ExecutorService executor, int minBatchSize, int concurrency) {
6265
super(graph);
6366
this.executor = executor;
64-
nodeCount = Math.toIntExact(graph.nodeCount());
67+
this.nodeCount = Math.toIntExact(graph.nodeCount());
6568
this.batchSize = ParallelUtil.adjustBatchSize(nodeCount, concurrency, minBatchSize);
66-
stepSize = ParallelUtil.threadSize(batchSize, nodeCount);
6769
}
6870

6971
@Override
7072
public DisjointSetStruct compute() {
71-
final List<Future<?>> futures = new ArrayList<>(stepSize);
73+
int stepSize = ParallelUtil.threadSize(batchSize, nodeCount);
74+
final List<Future<?>> futures = new ArrayList<>(2 * stepSize);
7275
final BlockingQueue<DisjointSetStruct> queue = new ArrayBlockingQueue<>(stepSize);
76+
AtomicInteger expectedStructs = new AtomicInteger();
7377

74-
Phaser phaser = new Phaser();
75-
int steps = 0;
7678
for (int i = 0; i < nodeCount; i += batchSize) {
77-
futures.add(executor.submit(new UnionFindTask(queue, i, phaser)));
78-
++steps;
79+
futures.add(executor.submit(new UnionFindTask(queue, i, expectedStructs)));
7980
}
80-
phaser.awaitAdvance(phaser.getPhase());
81+
int steps = futures.size();
8182

8283
for (int i = 1; i < steps; ++i) {
83-
futures.add(executor.submit(() -> {
84-
try {
85-
final DisjointSetStruct a = queue.take();
86-
final DisjointSetStruct b = queue.take();
87-
queue.add(a.merge(b));
88-
} catch (InterruptedException e) {
89-
Thread.currentThread().interrupt();
90-
throw new RuntimeException(e);
91-
}
92-
}));
84+
futures.add(executor.submit(() -> mergeTask(queue, expectedStructs, DisjointSetStruct::merge)));
9385
}
9486

95-
await(futures);
87+
awaitTermination(futures);
9688
return getStruct(queue);
9789
}
9890

99-
private void await(final List<Future<?>> futures) {
100-
ParallelUtil.awaitTermination(futures);
91+
static <T> void mergeTask(
92+
final BlockingQueue<T> queue,
93+
final AtomicInteger expected,
94+
BinaryOperator<T> merge) {
95+
// basically a decrement operation, but we don't decrement in case there's not
96+
// enough sets for us to operate on
97+
int available, afterMerge;
98+
do {
99+
available = expected.get();
100+
// see if there are at least two sets to take, so we don't wait for a set that will never come
101+
if (available < 2) {
102+
return;
103+
}
104+
// decrease by one, as we're pushing a new set onto the queue
105+
afterMerge = available - 1;
106+
} while (!expected.compareAndSet(available, afterMerge));
107+
108+
boolean pushed = false;
109+
try {
110+
final T a = queue.take();
111+
final T b = queue.take();
112+
final T next = merge.apply(a, b);
113+
queue.add(next);
114+
pushed = true;
115+
} catch (InterruptedException e) {
116+
Thread.currentThread().interrupt();
117+
throw new RuntimeException(e);
118+
} finally {
119+
if (!pushed) {
120+
expected.decrementAndGet();
121+
}
122+
}
101123
}
102124

103125
@Override
@@ -106,48 +128,55 @@ public DisjointSetStruct compute(double threshold) {
106128
}
107129

108130
private DisjointSetStruct getStruct(final BlockingQueue<DisjointSetStruct> queue) {
109-
try {
110-
return queue.take();
111-
} catch (InterruptedException e) {
112-
Thread.currentThread().interrupt();
113-
throw new RuntimeException(e);
131+
DisjointSetStruct set = queue.poll();
132+
if (set == null) {
133+
set = new DisjointSetStruct(nodeCount);
114134
}
135+
return set;
115136
}
116137

117138
private class UnionFindTask implements Runnable {
118139

119140
private final BlockingQueue<DisjointSetStruct> queue;
120-
private final Phaser phaser;
141+
private final AtomicInteger expectedStructs;
121142
private final int offset;
122143
private final int end;
123144

145+
124146
UnionFindTask(
125147
BlockingQueue<DisjointSetStruct> queue,
126148
int offset,
127-
Phaser phaser) {
149+
AtomicInteger expectedStructs) {
128150
this.queue = queue;
151+
this.expectedStructs = expectedStructs;
129152
this.offset = offset;
130153
this.end = Math.min(offset + batchSize, nodeCount);
131-
this.phaser = phaser;
132-
phaser.register();
154+
expectedStructs.incrementAndGet();
133155
}
134156

135157
@Override
136158
public void run() {
137-
phaser.arriveAndDeregister();
138-
final DisjointSetStruct struct = new DisjointSetStruct(nodeCount).reset();
139-
for (int node = offset; node < end; node++) {
140-
graph.forEachRelationship(node, Direction.OUTGOING, (sourceNodeId, targetNodeId, relationId) -> {
141-
struct.union(sourceNodeId, targetNodeId);
142-
return true;
143-
});
144-
}
145-
getProgressLogger().logProgress((end - 1.0) / (nodeCount - 1.0));
159+
boolean pushed = false;
146160
try {
147-
queue.put(struct);
148-
} catch (InterruptedException e) {
149-
Thread.currentThread().interrupt();
150-
throw new RuntimeException(e);
161+
final DisjointSetStruct struct = new DisjointSetStruct(nodeCount).reset();
162+
for (int node = offset; node < end; node++) {
163+
graph.forEachRelationship(node, Direction.OUTGOING, (sourceNodeId, targetNodeId, relationId) -> {
164+
struct.union(sourceNodeId, targetNodeId);
165+
return true;
166+
});
167+
}
168+
getProgressLogger().logProgress((end - 1.0) / (nodeCount - 1.0));
169+
try {
170+
queue.put(struct);
171+
pushed = true;
172+
} catch (InterruptedException e) {
173+
Thread.currentThread().interrupt();
174+
throw new RuntimeException(e);
175+
}
176+
} finally {
177+
if (!pushed) {
178+
expectedStructs.decrementAndGet();
179+
}
151180
}
152181
}
153182
}

0 commit comments

Comments
 (0)