1818 */
1919package org .neo4j .graphalgo .impl ;
2020
21+ import com .carrotsearch .hppc .HashOrderMixing ;
22+ import com .carrotsearch .hppc .LongDoubleHashMap ;
23+ import com .carrotsearch .hppc .LongDoubleScatterMap ;
24+ import com .carrotsearch .hppc .cursors .LongDoubleCursor ;
25+ import org .neo4j .collection .primitive .PrimitiveIntIterator ;
26+ import org .neo4j .collection .primitive .PrimitiveLongIterator ;
2127import org .neo4j .graphalgo .api .Graph ;
2228import org .neo4j .graphalgo .api .WeightMapping ;
29+ import org .neo4j .graphalgo .core .utils .LazyBatchCollection ;
2330import org .neo4j .graphalgo .core .utils .ParallelUtil ;
31+ import org .neo4j .graphalgo .core .utils .ProgressLogger ;
32+ import org .neo4j .graphalgo .core .utils .RandomLongIterable ;
2433import org .neo4j .graphalgo .core .utils .paged .AllocationTracker ;
34+ import org .neo4j .graphalgo .core .utils .paged .BitUtil ;
2535import org .neo4j .graphdb .Direction ;
2636
37+ import java .util .ArrayList ;
38+ import java .util .Collection ;
2739import java .util .List ;
40+ import java .util .Random ;
2841import java .util .concurrent .ExecutorService ;
2942
43+ import static com .carrotsearch .hppc .Containers .DEFAULT_EXPECTED_ELEMENTS ;
44+ import static com .carrotsearch .hppc .HashContainers .DEFAULT_LOAD_FACTOR ;
45+
3046abstract class BaseLabelPropagation <
3147 G extends Graph ,
3248 W extends WeightMapping ,
33- L extends LabelPropagationAlgorithm .Labels ,
34- Self extends BaseLabelPropagation <G , W , L , Self >
49+ Self extends BaseLabelPropagation <G , W , Self >
3550 > extends LabelPropagationAlgorithm <Self > {
3651
37- static final int [] EMPTY_INTS = new int [0 ];
38- static final long [] EMPTY_LONGS = new long [0 ];
52+ private static final long [] EMPTY_LONGS = new long [0 ];
3953
4054 private G graph ;
4155 private final long nodeCount ;
@@ -46,14 +60,14 @@ abstract class BaseLabelPropagation<
4660 final int concurrency ;
4761 final ExecutorService executor ;
4862
49- private L labels ;
63+ private Labels labels ;
5064 private long ranIterations ;
5165 private boolean didConverge ;
5266
5367 BaseLabelPropagation (
5468 G graph ,
55- W nodeProperties ,
56- W nodeWeights ,
69+ W nodeProperties ,
70+ W nodeWeights ,
5771 int batchSize ,
5872 int concurrency ,
5973 ExecutorService executor ,
@@ -68,22 +82,24 @@ abstract class BaseLabelPropagation<
6882 this .nodeWeights = nodeWeights ;
6983 }
7084
71- abstract L initialLabels (long nodeCount , AllocationTracker tracker );
72-
85+ abstract Labels initialLabels (long nodeCount , AllocationTracker tracker );
7386
74- abstract List < BaseStep > baseSteps (
87+ abstract Initialization initStep (
7588 final G graph ,
76- final L labels ,
89+ final Labels labels ,
7790 final W nodeProperties ,
7891 final W nodeWeights ,
7992 final Direction direction ,
80- final boolean randomizeOrder );
93+ final ProgressLogger progressLogger ,
94+ final RandomProvider randomProvider ,
95+ final RandomLongIterable nodes
96+ );
8197
8298 @ Override
83- final Self compute (
99+ Self compute (
84100 Direction direction ,
85101 long maxIterations ,
86- boolean randomizeOrder ) {
102+ RandomProvider random ) {
87103 if (maxIterations <= 0L ) {
88104 throw new IllegalArgumentException ("Must iterate at least 1 time" );
89105 }
@@ -95,10 +111,12 @@ final Self compute(
95111 ranIterations = 0L ;
96112 didConverge = false ;
97113
98- List <BaseStep > baseSteps = baseSteps (graph , labels , nodeProperties , nodeWeights , direction , randomizeOrder );
114+ List <BaseStep > baseSteps = baseSteps (direction , random );
99115
100- for (long i = 0L ; i < maxIterations ; i ++) {
101- ParallelUtil .runWithConcurrency (concurrency , baseSteps , executor );
116+ long currentIteration = 0L ;
117+ while (running () && currentIteration < maxIterations ) {
118+ ParallelUtil .runWithConcurrency (concurrency , baseSteps , terminationFlag , executor );
119+ ++currentIteration ;
102120 }
103121
104122 long maxIteration = 0L ;
@@ -121,10 +139,6 @@ final Self compute(
121139 return me ();
122140 }
123141
124- final BaseStep asStep (Initialization initialization ) {
125- return new BaseStep (initialization );
126- }
127-
128142 @ Override
129143 public final long ranIterations () {
130144 return ranIterations ;
@@ -146,6 +160,47 @@ public Self release() {
146160 return me ();
147161 }
148162
163+ private List <BaseStep > baseSteps (Direction direction , RandomProvider random ) {
164+
165+ long nodeCount = graph .nodeCount ();
166+ long batchSize = adjustBatchSize (nodeCount , (long ) this .batchSize );
167+
168+ Collection <RandomLongIterable > nodeBatches = LazyBatchCollection .of (
169+ nodeCount ,
170+ batchSize ,
171+ (start , length ) -> new RandomLongIterable (start , start + length , random .randomForNewIteration ()));
172+
173+ int threads = nodeBatches .size ();
174+ List <BaseStep > tasks = new ArrayList <>(threads );
175+ for (RandomLongIterable iter : nodeBatches ) {
176+ Initialization initStep = initStep (
177+ graph ,
178+ labels ,
179+ nodeProperties ,
180+ nodeWeights ,
181+ direction ,
182+ getProgressLogger (),
183+ random ,
184+ iter
185+ );
186+ BaseStep task = new BaseStep (initStep );
187+ tasks .add (task );
188+ }
189+ ParallelUtil .runWithConcurrency (concurrency , tasks , terminationFlag , executor );
190+ return tasks ;
191+ }
192+
193+ private long adjustBatchSize (long nodeCount , long batchSize ) {
194+ if (batchSize <= 0L ) {
195+ batchSize = 1L ;
196+ }
197+ batchSize = BitUtil .nextHighestPowerOfTwo (batchSize );
198+ while (((nodeCount + batchSize + 1L ) / batchSize ) > (long ) Integer .MAX_VALUE ) {
199+ batchSize = batchSize << 1 ;
200+ }
201+ return batchSize ;
202+ }
203+
149204 static abstract class Initialization implements Step {
150205 abstract void setExistingLabels ();
151206
@@ -164,12 +219,40 @@ public final Step next() {
164219
165220 static abstract class Computation implements Step {
166221
222+ final RandomProvider randomProvider ;
223+ private final Labels existingLabels ;
224+ private final ProgressLogger progressLogger ;
225+ private final double maxNode ;
226+ private final LongDoubleHashMap votes ;
227+
167228 private boolean didChange = true ;
168- private long iteration = 0L ;
229+ long iteration = 0L ;
230+
231+ Computation (
232+ final Labels existingLabels ,
233+ final ProgressLogger progressLogger ,
234+ final long maxNode ,
235+ final RandomProvider randomProvider ) {
236+ this .randomProvider = randomProvider ;
237+ this .existingLabels = existingLabels ;
238+ this .progressLogger = progressLogger ;
239+ this .maxNode = (double ) maxNode ;
240+ if (randomProvider .isRandom ()) {
241+ Random random = randomProvider .randomForNewIteration ();
242+ this .votes = new LongDoubleHashMap (
243+ DEFAULT_EXPECTED_ELEMENTS ,
244+ (double ) DEFAULT_LOAD_FACTOR ,
245+ HashOrderMixing .constant (random .nextLong ()));
246+ } else {
247+ this .votes = new LongDoubleScatterMap ();
248+ }
249+ }
169250
170251 abstract boolean computeAll ();
171252
172- abstract void release ();
253+ abstract void forEach (long nodeId );
254+
255+ abstract double weightOf (long nodeId , long candidate );
173256
174257 @ Override
175258 public final void run () {
@@ -182,10 +265,68 @@ public final void run() {
182265 }
183266 }
184267
268+ final boolean iterateAll (PrimitiveIntIterator nodeIds ) {
269+ boolean didChange = false ;
270+ while (nodeIds .hasNext ()) {
271+ long nodeId = (long ) nodeIds .next ();
272+ didChange = compute (nodeId , didChange );
273+ progressLogger .logProgress ((double ) nodeId , maxNode );
274+ }
275+ return didChange ;
276+ }
277+
278+ final boolean iterateAll (PrimitiveLongIterator nodeIds ) {
279+ boolean didChange = false ;
280+ while (nodeIds .hasNext ()) {
281+ long nodeId = nodeIds .next ();
282+ didChange = compute (nodeId , didChange );
283+ progressLogger .logProgress ((double ) nodeId , maxNode );
284+ }
285+ return didChange ;
286+ }
287+
288+ final boolean compute (long nodeId , boolean didChange ) {
289+ votes .clear ();
290+ long partition = existingLabels .labelFor (nodeId );
291+ long previous = partition ;
292+ forEach (nodeId );
293+ double weight = Double .NEGATIVE_INFINITY ;
294+ for (LongDoubleCursor vote : votes ) {
295+ if (weight < vote .value ) {
296+ weight = vote .value ;
297+ partition = vote .key ;
298+ }
299+ }
300+ if (partition != previous ) {
301+ existingLabels .setLabelFor (nodeId , partition );
302+ return true ;
303+ }
304+ return didChange ;
305+ }
306+
307+ final void castVote (long nodeId , long candidate ) {
308+ double weight = weightOf (nodeId , candidate );
309+ long partition = existingLabels .labelFor (candidate );
310+ votes .addTo (partition , weight );
311+ }
312+
185313 @ Override
186314 public final Step next () {
187315 return this ;
188316 }
317+
318+ final void release () {
319+ // the HPPC release() method allocates new arrays
320+ // the clear() method overwrite the existing keys with the default value
321+ // we want to throw away all data to allow for GC collection instead.
322+
323+ if (votes .keys != null ) {
324+ votes .keys = EMPTY_LONGS ;
325+ votes .clear ();
326+ votes .keys = null ;
327+ votes .values = null ;
328+ }
329+ }
189330 }
190331
191332 interface Step extends Runnable {
0 commit comments