Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit ebcbf58

Browse files
committed
batching wip
1 parent b1961c0 commit ebcbf58

File tree

11 files changed

+322
-18
lines changed

11 files changed

+322
-18
lines changed

algo/src/main/java/org/neo4j/graphalgo/impl/DSSResult.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ public Stream<DisjointSetStruct.Result> resultStream(IdMapping idMapping) {
7373
: hugeStruct.resultStream(((HugeIdMapping) idMapping));
7474
}
7575

76+
public Stream<DisjointSetStruct.InternalResult> internalResultStream(IdMapping idMapping) {
77+
return struct.internalResultStream(idMapping);
78+
}
79+
7680
public void forEach(NodeIterator nodes, IntIntPredicate consumer) {
7781
if (struct != null) {
7882
nodes.forEachNode(nodeId -> consumer.apply(nodeId, struct.find(nodeId)));

algo/src/main/java/org/neo4j/graphalgo/similarity/CosineProc.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ public Stream<SimilaritySummaryResult> cosine(
7979
Stream<SimilarityResult> stream = generateWeightedStream(configuration, inputs, similarityCutoff, topN, topK, computer);
8080

8181
boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
82-
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty);
82+
boolean writeParallel = configuration.get("writeParallel", false);
83+
84+
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel);
8385
}
8486

8587
private SimilarityComputer<WeightedInput> similarityComputer(Double skipValue) {

algo/src/main/java/org/neo4j/graphalgo/similarity/EuclideanProc.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ public Stream<SimilaritySummaryResult> euclidean(
8080
Stream<SimilarityResult> stream = generateWeightedStream(configuration, inputs, similarityCutoff, topN, topK, computer);
8181

8282
boolean write = configuration.isWriteFlag(false); // && similarityCutoff != 0.0;
83-
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty );
83+
boolean writeParallel = configuration.get("writeParallel", false);
84+
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel );
8485
}
8586

8687
Stream<SimilarityResult> generateWeightedStream(ProcedureConfiguration configuration, WeightedInput[] inputs,

algo/src/main/java/org/neo4j/graphalgo/similarity/JaccardProc.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ public Stream<SimilaritySummaryResult> jaccard(
6767
similarityCutoff, getTopK(configuration)), getTopN(configuration));
6868

6969
boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
70-
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty );
70+
boolean writeParallel = configuration.get("writeParallel", false);
71+
72+
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel);
7173
}
7274

7375
private SimilarityComputer<CategoricalInput> similarityComputer() {

algo/src/main/java/org/neo4j/graphalgo/similarity/OverlapProc.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ public Stream<SimilaritySummaryResult> overlap(
7373
Stream<SimilarityResult> stream = topN(similarityStream(inputs, computer, configuration, () -> null, similarityCutoff, getTopK(configuration)), getTopN(configuration));
7474

7575
boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
76-
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty);
76+
boolean writeParallel = configuration.get("writeParallel", false);
77+
78+
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel);
7779
}
7880

7981

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
/**
2+
* Copyright (c) 2017 "Neo4j, Inc." <http://neo4j.com>
3+
* <p>
4+
* This file is part of Neo4j Graph Algorithms <http://github.com/neo4j-contrib/neo4j-graph-algorithms>.
5+
* <p>
6+
* Neo4j Graph Algorithms is free software: you can redistribute it and/or modify
7+
* it under the terms of the GNU General Public License as published by
8+
* the Free Software Foundation, either version 3 of the License, or
9+
* (at your option) any later version.
10+
* <p>
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14+
* GNU General Public License for more details.
15+
* <p>
16+
* You should have received a copy of the GNU General Public License
17+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
18+
*/
19+
package org.neo4j.graphalgo.similarity;
20+
21+
import com.carrotsearch.hppc.IntHashSet;
22+
import com.carrotsearch.hppc.IntSet;
23+
import org.neo4j.graphalgo.core.IdMap;
24+
import org.neo4j.graphalgo.core.WeightMap;
25+
import org.neo4j.graphalgo.core.heavyweight.AdjacencyMatrix;
26+
import org.neo4j.graphalgo.core.heavyweight.HeavyGraph;
27+
import org.neo4j.graphalgo.core.utils.*;
28+
import org.neo4j.graphalgo.core.utils.dss.DisjointSetStruct;
29+
import org.neo4j.graphalgo.core.utils.paged.AllocationTracker;
30+
import org.neo4j.graphalgo.impl.DSSResult;
31+
import org.neo4j.graphalgo.impl.GraphUnionFind;
32+
import org.neo4j.graphdb.Direction;
33+
import org.neo4j.internal.kernel.api.exceptions.EntityNotFoundException;
34+
import org.neo4j.internal.kernel.api.exceptions.InvalidTransactionTypeKernelException;
35+
import org.neo4j.internal.kernel.api.exceptions.KernelException;
36+
import org.neo4j.internal.kernel.api.exceptions.explicitindex.AutoIndexingKernelException;
37+
import org.neo4j.kernel.api.KernelTransaction;
38+
import org.neo4j.kernel.internal.GraphDatabaseAPI;
39+
import org.neo4j.logging.Log;
40+
import org.neo4j.values.storable.Values;
41+
42+
import java.util.*;
43+
import java.util.concurrent.*;
44+
import java.util.concurrent.atomic.AtomicInteger;
45+
import java.util.concurrent.atomic.LongAdder;
46+
import java.util.stream.Collectors;
47+
import java.util.stream.Stream;
48+
49+
public class ParallelSimilarityExporter extends StatementApi {
50+
51+
private final Log log;
52+
private final int propertyId;
53+
private final int relationshipTypeId;
54+
private final int nodeCount;
55+
56+
public ParallelSimilarityExporter(GraphDatabaseAPI api,
57+
Log log,
58+
String relationshipType,
59+
String propertyName, int nodeCount) {
60+
super(api);
61+
this.log = log;
62+
propertyId = getOrCreatePropertyId(propertyName);
63+
relationshipTypeId = getOrCreateRelationshipId(relationshipType);
64+
this.nodeCount = nodeCount;
65+
}
66+
67+
public void export(Stream<SimilarityResult> similarityPairs, long batchSize) {
68+
IdMap idMap = new IdMap(this.nodeCount);
69+
AdjacencyMatrix adjacencyMatrix = new AdjacencyMatrix(this.nodeCount, false, AllocationTracker.EMPTY);
70+
WeightMap weightMap = new WeightMap(nodeCount, 0, propertyId);
71+
72+
int[] numberOfRelationships = {0};
73+
74+
similarityPairs.forEach(pair -> {
75+
int id1 = idMap.mapOrGet(pair.item1);
76+
int id2 = idMap.mapOrGet(pair.item2);
77+
adjacencyMatrix.addOutgoing(id1, id2);
78+
weightMap.put(RawValues.combineIntInt(id1, id2), pair.similarity);
79+
numberOfRelationships[0]++;
80+
});
81+
82+
idMap.buildMappedIds();
83+
HeavyGraph graph = new HeavyGraph(idMap, adjacencyMatrix, weightMap, Collections.emptyMap());
84+
85+
DSSResult dssResult = computePartitions(graph);
86+
87+
Stream<List<DisjointSetStruct.InternalResult>> stream = dssResult.internalResultStream(graph)
88+
.collect(Collectors.groupingBy(item -> item.setId))
89+
.values()
90+
.stream();
91+
92+
int queueSize = dssResult.getSetCount();
93+
log.info("ParallelSimilarityExporter: Relationships to be created: %d, Partitions found: %d", numberOfRelationships[0], queueSize);
94+
95+
ArrayBlockingQueue<List<SimilarityResult>> outQueue = new ArrayBlockingQueue<>(queueSize);
96+
97+
ExecutorService executor = Executors.newFixedThreadPool(1);
98+
Future<Integer> inQueueBatchCountFuture = executor.submit(() -> {
99+
AtomicInteger inQueueBatchCount = new AtomicInteger(0);
100+
stream.parallel().forEach(partition -> {
101+
IntSet nodesInPartition = new IntHashSet();
102+
for (DisjointSetStruct.InternalResult internalResult : partition) {
103+
nodesInPartition.add(internalResult.internalNodeId);
104+
}
105+
106+
List<SimilarityResult> inPartition = new ArrayList<>();
107+
List<SimilarityResult> outPartition = new ArrayList<>();
108+
109+
for (DisjointSetStruct.InternalResult result : partition) {
110+
int nodeId = result.internalNodeId;
111+
graph.forEachRelationship(nodeId, Direction.OUTGOING, (sourceNodeId, targetNodeId, relationId, weight) -> {
112+
SimilarityResult similarityRelationship = new SimilarityResult(idMap.toOriginalNodeId(sourceNodeId), idMap.toOriginalNodeId(targetNodeId), -1, -1, -1, weight);
113+
114+
if (nodesInPartition.contains(targetNodeId)) {
115+
inPartition.add(similarityRelationship);
116+
} else {
117+
outPartition.add(similarityRelationship);
118+
}
119+
120+
return false;
121+
});
122+
}
123+
124+
if (!inPartition.isEmpty()) {
125+
int inQueueBatches = writeSequential(inPartition.stream(), batchSize);
126+
inQueueBatchCount.addAndGet(inQueueBatches);
127+
}
128+
129+
if (!outPartition.isEmpty()) {
130+
put(outQueue, outPartition);
131+
}
132+
});
133+
return inQueueBatchCount.get();
134+
});
135+
136+
Integer inQueueBatches = null;
137+
try {
138+
inQueueBatches = inQueueBatchCountFuture.get();
139+
} catch (InterruptedException | ExecutionException e) {
140+
e.printStackTrace();
141+
}
142+
143+
144+
int outQueueBatches = writeSequential(outQueue.stream().flatMap(Collection::stream), batchSize);
145+
log.info("ParallelSimilarityExporter: Batch Size: %d, Batches written - in parallel: %d, sequentially: %d", batchSize, inQueueBatches, outQueueBatches);
146+
}
147+
148+
private static <T> void put(BlockingQueue<T> queue, T items) {
149+
try {
150+
queue.put(items);
151+
} catch (InterruptedException e) {
152+
// ignore
153+
}
154+
}
155+
156+
private DSSResult computePartitions(HeavyGraph graph) {
157+
GraphUnionFind algo = new GraphUnionFind(graph);
158+
DisjointSetStruct struct = algo.compute();
159+
algo.release();
160+
return new DSSResult(struct);
161+
}
162+
163+
private void export(SimilarityResult similarityResult) {
164+
applyInTransaction(statement -> {
165+
try {
166+
createRelationship(similarityResult, statement);
167+
} catch (KernelException e) {
168+
ExceptionUtil.throwKernelException(e);
169+
}
170+
return null;
171+
});
172+
173+
}
174+
175+
private void export(List<SimilarityResult> similarityResults) {
176+
applyInTransaction(statement -> {
177+
for (SimilarityResult similarityResult : similarityResults) {
178+
try {
179+
createRelationship(similarityResult, statement);
180+
} catch (KernelException e) {
181+
ExceptionUtil.throwKernelException(e);
182+
}
183+
}
184+
return null;
185+
});
186+
187+
}
188+
189+
private void createRelationship(SimilarityResult similarityResult, KernelTransaction statement) throws EntityNotFoundException, InvalidTransactionTypeKernelException, AutoIndexingKernelException {
190+
long node1 = similarityResult.item1;
191+
long node2 = similarityResult.item2;
192+
long relationshipId = statement.dataWrite().relationshipCreate(node1, relationshipTypeId, node2);
193+
194+
statement.dataWrite().relationshipSetProperty(
195+
relationshipId, propertyId, Values.doubleValue(similarityResult.similarity));
196+
}
197+
198+
private int getOrCreateRelationshipId(String relationshipType) {
199+
return applyInTransaction(stmt -> stmt
200+
.tokenWrite()
201+
.relationshipTypeGetOrCreateForName(relationshipType));
202+
}
203+
204+
private int getOrCreatePropertyId(String propertyName) {
205+
return applyInTransaction(stmt -> stmt
206+
.tokenWrite()
207+
.propertyKeyGetOrCreateForName(propertyName));
208+
}
209+
210+
private int writeSequential(Stream<SimilarityResult> similarityPairs, long batchSize) {
211+
int[] counter = {0};
212+
if (batchSize == 1) {
213+
similarityPairs.forEach(similarityResult -> {
214+
export(similarityResult);
215+
counter[0]++;
216+
});
217+
} else {
218+
Iterator<SimilarityResult> iterator = similarityPairs.iterator();
219+
do {
220+
List<SimilarityResult> batch = take(iterator, Math.toIntExact(batchSize));
221+
export(batch);
222+
if(batch.size() > 0) {
223+
counter[0]++;
224+
}
225+
} while (iterator.hasNext());
226+
}
227+
228+
return counter[0];
229+
}
230+
231+
private static List<SimilarityResult> take(Iterator<SimilarityResult> iterator, int batchSize) {
232+
List<SimilarityResult> result = new ArrayList<>(batchSize);
233+
while (iterator.hasNext() && batchSize-- > 0) {
234+
result.add(iterator.next());
235+
}
236+
return result;
237+
}
238+
239+
240+
}

algo/src/main/java/org/neo4j/graphalgo/similarity/PearsonProc.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ public Stream<SimilaritySummaryResult> pearson(
8080
Stream<SimilarityResult> stream = generateWeightedStream(configuration, inputs, similarityCutoff, topN, topK, computer);
8181

8282
boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
83-
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty);
83+
boolean writeParallel = configuration.get("writeParallel", false);
84+
85+
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel);
8486
}
8587

8688
private SimilarityComputer<WeightedInput> similarityComputer(Double skipValue) {

algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityExporter.java

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.neo4j.internal.kernel.api.exceptions.explicitindex.AutoIndexingKernelException;
2727
import org.neo4j.kernel.api.KernelTransaction;
2828
import org.neo4j.kernel.internal.GraphDatabaseAPI;
29+
import org.neo4j.logging.Log;
2930
import org.neo4j.values.storable.Values;
3031

3132
import java.util.ArrayList;
@@ -35,19 +36,22 @@
3536

3637
public class SimilarityExporter extends StatementApi {
3738

39+
private final Log log;
3840
private final int propertyId;
3941
private final int relationshipTypeId;
4042

4143
public SimilarityExporter(GraphDatabaseAPI api,
42-
String relationshipType,
44+
Log log, String relationshipType,
4345
String propertyName) {
4446
super(api);
47+
this.log = log;
4548
propertyId = getOrCreatePropertyId(propertyName);
4649
relationshipTypeId = getOrCreateRelationshipId(relationshipType);
4750
}
4851

4952
public void export(Stream<SimilarityResult> similarityPairs, long batchSize) {
50-
writeSequential(similarityPairs, batchSize);
53+
int batches = writeSequential(similarityPairs, batchSize);
54+
log.info("ParallelSimilarityExporter: Batch Size: %d, Batches written - sequentially: %d", batchSize, batches);
5155
}
5256

5357
private void export(SimilarityResult similarityResult) {
@@ -97,17 +101,29 @@ private int getOrCreatePropertyId(String propertyName) {
97101
.propertyKeyGetOrCreateForName(propertyName));
98102
}
99103

100-
private void writeSequential(Stream<SimilarityResult> similarityPairs, long batchSize) {
104+
private int writeSequential(Stream<SimilarityResult> similarityPairs, long batchSize) {
105+
log.info("SimilarityExporter: Writing relationships...");
106+
int[] counter = {0};
101107
if (batchSize == 1) {
102-
similarityPairs.forEach(this::export);
108+
similarityPairs.forEach(similarityResult -> {
109+
export(similarityResult);
110+
counter[0]++;
111+
});
103112
} else {
104113
Iterator<SimilarityResult> iterator = similarityPairs.iterator();
105114
do {
106-
export(take(iterator, Math.toIntExact(batchSize)));
115+
List<SimilarityResult> batch = take(iterator, Math.toIntExact(batchSize));
116+
export(batch);
117+
if(batch.size() > 0) {
118+
counter[0]++;
119+
}
107120
} while (iterator.hasNext());
108121
}
122+
123+
return counter[0];
109124
}
110125

126+
111127
private static List<SimilarityResult> take(Iterator<SimilarityResult> iterator, int batchSize) {
112128
List<SimilarityResult> result = new ArrayList<>(batchSize);
113129
while (iterator.hasNext() && batchSize-- > 0) {

0 commit comments

Comments
 (0)