Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit 277874c

Browse files
knutwalkerjexp
authored andcommitted
Improve randomness for LPA (#858)
* Enable LPA to control its randomness * Add Random Iterators with proper random iteration * Add random iteration benchmark * Use proper randomness in LPA
1 parent 90b9b9b commit 277874c

File tree

16 files changed

+884
-344
lines changed

16 files changed

+884
-344
lines changed

algo/src/main/java/org/neo4j/graphalgo/impl/BaseLabelPropagation.java

Lines changed: 164 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,38 @@
1818
*/
1919
package org.neo4j.graphalgo.impl;
2020

21+
import com.carrotsearch.hppc.HashOrderMixing;
22+
import com.carrotsearch.hppc.LongDoubleHashMap;
23+
import com.carrotsearch.hppc.LongDoubleScatterMap;
24+
import com.carrotsearch.hppc.cursors.LongDoubleCursor;
25+
import org.neo4j.collection.primitive.PrimitiveIntIterator;
26+
import org.neo4j.collection.primitive.PrimitiveLongIterator;
2127
import org.neo4j.graphalgo.api.Graph;
2228
import org.neo4j.graphalgo.api.WeightMapping;
29+
import org.neo4j.graphalgo.core.utils.LazyBatchCollection;
2330
import org.neo4j.graphalgo.core.utils.ParallelUtil;
31+
import org.neo4j.graphalgo.core.utils.ProgressLogger;
32+
import org.neo4j.graphalgo.core.utils.RandomLongIterable;
2433
import org.neo4j.graphalgo.core.utils.paged.AllocationTracker;
34+
import org.neo4j.graphalgo.core.utils.paged.BitUtil;
2535
import org.neo4j.graphdb.Direction;
2636

37+
import java.util.ArrayList;
38+
import java.util.Collection;
2739
import java.util.List;
40+
import java.util.Random;
2841
import java.util.concurrent.ExecutorService;
2942

43+
import static com.carrotsearch.hppc.Containers.DEFAULT_EXPECTED_ELEMENTS;
44+
import static com.carrotsearch.hppc.HashContainers.DEFAULT_LOAD_FACTOR;
45+
3046
abstract class BaseLabelPropagation<
3147
G extends Graph,
3248
W extends WeightMapping,
33-
L extends LabelPropagationAlgorithm.Labels,
34-
Self extends BaseLabelPropagation<G, W, L, Self>
49+
Self extends BaseLabelPropagation<G, W, Self>
3550
> extends LabelPropagationAlgorithm<Self> {
3651

37-
static final int[] EMPTY_INTS = new int[0];
38-
static final long[] EMPTY_LONGS = new long[0];
52+
private static final long[] EMPTY_LONGS = new long[0];
3953

4054
private G graph;
4155
private final long nodeCount;
@@ -46,14 +60,14 @@ abstract class BaseLabelPropagation<
4660
final int concurrency;
4761
final ExecutorService executor;
4862

49-
private L labels;
63+
private Labels labels;
5064
private long ranIterations;
5165
private boolean didConverge;
5266

5367
BaseLabelPropagation(
5468
G graph,
55-
W nodeProperties,
56-
W nodeWeights,
69+
W nodeProperties,
70+
W nodeWeights,
5771
int batchSize,
5872
int concurrency,
5973
ExecutorService executor,
@@ -68,22 +82,24 @@ abstract class BaseLabelPropagation<
6882
this.nodeWeights = nodeWeights;
6983
}
7084

71-
abstract L initialLabels(long nodeCount, AllocationTracker tracker);
72-
85+
abstract Labels initialLabels(long nodeCount, AllocationTracker tracker);
7386

74-
abstract List<BaseStep> baseSteps(
87+
abstract Initialization initStep(
7588
final G graph,
76-
final L labels,
89+
final Labels labels,
7790
final W nodeProperties,
7891
final W nodeWeights,
7992
final Direction direction,
80-
final boolean randomizeOrder);
93+
final ProgressLogger progressLogger,
94+
final RandomProvider randomProvider,
95+
final RandomLongIterable nodes
96+
);
8197

8298
@Override
83-
final Self compute(
99+
Self compute(
84100
Direction direction,
85101
long maxIterations,
86-
boolean randomizeOrder) {
102+
RandomProvider random) {
87103
if (maxIterations <= 0L) {
88104
throw new IllegalArgumentException("Must iterate at least 1 time");
89105
}
@@ -95,10 +111,12 @@ final Self compute(
95111
ranIterations = 0L;
96112
didConverge = false;
97113

98-
List<BaseStep> baseSteps = baseSteps(graph, labels, nodeProperties, nodeWeights, direction, randomizeOrder);
114+
List<BaseStep> baseSteps = baseSteps(direction, random);
99115

100-
for (long i = 0L; i < maxIterations; i++) {
101-
ParallelUtil.runWithConcurrency(concurrency, baseSteps, executor);
116+
long currentIteration = 0L;
117+
while (running() && currentIteration < maxIterations) {
118+
ParallelUtil.runWithConcurrency(concurrency, baseSteps, terminationFlag, executor);
119+
++currentIteration;
102120
}
103121

104122
long maxIteration = 0L;
@@ -121,10 +139,6 @@ final Self compute(
121139
return me();
122140
}
123141

124-
final BaseStep asStep(Initialization initialization) {
125-
return new BaseStep(initialization);
126-
}
127-
128142
@Override
129143
public final long ranIterations() {
130144
return ranIterations;
@@ -146,6 +160,47 @@ public Self release() {
146160
return me();
147161
}
148162

163+
private List<BaseStep> baseSteps(Direction direction, RandomProvider random) {
164+
165+
long nodeCount = graph.nodeCount();
166+
long batchSize = adjustBatchSize(nodeCount, (long) this.batchSize);
167+
168+
Collection<RandomLongIterable> nodeBatches = LazyBatchCollection.of(
169+
nodeCount,
170+
batchSize,
171+
(start, length) -> new RandomLongIterable(start, start + length, random.randomForNewIteration()));
172+
173+
int threads = nodeBatches.size();
174+
List<BaseStep> tasks = new ArrayList<>(threads);
175+
for (RandomLongIterable iter : nodeBatches) {
176+
Initialization initStep = initStep(
177+
graph,
178+
labels,
179+
nodeProperties,
180+
nodeWeights,
181+
direction,
182+
getProgressLogger(),
183+
random,
184+
iter
185+
);
186+
BaseStep task = new BaseStep(initStep);
187+
tasks.add(task);
188+
}
189+
ParallelUtil.runWithConcurrency(concurrency, tasks, terminationFlag, executor);
190+
return tasks;
191+
}
192+
193+
private long adjustBatchSize(long nodeCount, long batchSize) {
194+
if (batchSize <= 0L) {
195+
batchSize = 1L;
196+
}
197+
batchSize = BitUtil.nextHighestPowerOfTwo(batchSize);
198+
while (((nodeCount + batchSize + 1L) / batchSize) > (long) Integer.MAX_VALUE) {
199+
batchSize = batchSize << 1;
200+
}
201+
return batchSize;
202+
}
203+
149204
static abstract class Initialization implements Step {
150205
abstract void setExistingLabels();
151206

@@ -164,12 +219,40 @@ public final Step next() {
164219

165220
static abstract class Computation implements Step {
166221

222+
final RandomProvider randomProvider;
223+
private final Labels existingLabels;
224+
private final ProgressLogger progressLogger;
225+
private final double maxNode;
226+
private final LongDoubleHashMap votes;
227+
167228
private boolean didChange = true;
168-
private long iteration = 0L;
229+
long iteration = 0L;
230+
231+
Computation(
232+
final Labels existingLabels,
233+
final ProgressLogger progressLogger,
234+
final long maxNode,
235+
final RandomProvider randomProvider) {
236+
this.randomProvider = randomProvider;
237+
this.existingLabels = existingLabels;
238+
this.progressLogger = progressLogger;
239+
this.maxNode = (double) maxNode;
240+
if (randomProvider.isRandom()) {
241+
Random random = randomProvider.randomForNewIteration();
242+
this.votes = new LongDoubleHashMap(
243+
DEFAULT_EXPECTED_ELEMENTS,
244+
(double) DEFAULT_LOAD_FACTOR,
245+
HashOrderMixing.constant(random.nextLong()));
246+
} else {
247+
this.votes = new LongDoubleScatterMap();
248+
}
249+
}
169250

170251
abstract boolean computeAll();
171252

172-
abstract void release();
253+
abstract void forEach(long nodeId);
254+
255+
abstract double weightOf(long nodeId, long candidate);
173256

174257
@Override
175258
public final void run() {
@@ -182,10 +265,68 @@ public final void run() {
182265
}
183266
}
184267

268+
final boolean iterateAll(PrimitiveIntIterator nodeIds) {
269+
boolean didChange = false;
270+
while (nodeIds.hasNext()) {
271+
long nodeId = (long) nodeIds.next();
272+
didChange = compute(nodeId, didChange);
273+
progressLogger.logProgress((double) nodeId, maxNode);
274+
}
275+
return didChange;
276+
}
277+
278+
final boolean iterateAll(PrimitiveLongIterator nodeIds) {
279+
boolean didChange = false;
280+
while (nodeIds.hasNext()) {
281+
long nodeId = nodeIds.next();
282+
didChange = compute(nodeId, didChange);
283+
progressLogger.logProgress((double) nodeId, maxNode);
284+
}
285+
return didChange;
286+
}
287+
288+
final boolean compute(long nodeId, boolean didChange) {
289+
votes.clear();
290+
long partition = existingLabels.labelFor(nodeId);
291+
long previous = partition;
292+
forEach(nodeId);
293+
double weight = Double.NEGATIVE_INFINITY;
294+
for (LongDoubleCursor vote : votes) {
295+
if (weight < vote.value) {
296+
weight = vote.value;
297+
partition = vote.key;
298+
}
299+
}
300+
if (partition != previous) {
301+
existingLabels.setLabelFor(nodeId, partition);
302+
return true;
303+
}
304+
return didChange;
305+
}
306+
307+
final void castVote(long nodeId, long candidate) {
308+
double weight = weightOf(nodeId, candidate);
309+
long partition = existingLabels.labelFor(candidate);
310+
votes.addTo(partition, weight);
311+
}
312+
185313
@Override
186314
public final Step next() {
187315
return this;
188316
}
317+
318+
final void release() {
319+
// the HPPC release() method allocates new arrays
320+
// the clear() method overwrite the existing keys with the default value
321+
// we want to throw away all data to allow for GC collection instead.
322+
323+
if (votes.keys != null) {
324+
votes.keys = EMPTY_LONGS;
325+
votes.clear();
326+
votes.keys = null;
327+
votes.values = null;
328+
}
329+
}
189330
}
190331

191332
interface Step extends Runnable {

0 commit comments

Comments
 (0)