Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit 642ec6b

Browse files
committed
Similarity refactoring (#803)
* Delegate pearson function to the array based computation * Delegate cosine function to the array based computation * Delegate euclidean function to the array based computation * similarity vector aggregation function to save users doing the boring collect function * adding more tests + pushing NaN logic into Intersections * nicer switch config name * better name for the agg function * category instead of id
1 parent 238a19a commit 642ec6b

File tree

7 files changed

+286
-41
lines changed

7 files changed

+286
-41
lines changed

algo/pom.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@
101101
<version>${neo4j.version}</version>
102102
<scope>provided</scope>
103103
</dependency>
104+
105+
<dependency>
106+
<groupId>org.mockito</groupId>
107+
<artifactId>mockito-core</artifactId>
108+
<version>2.23.4</version>
109+
<scope>test</scope>
110+
</dependency>
111+
104112
</dependencies>
105113

106114
<build>

algo/src/main/java/org/neo4j/graphalgo/similarity/Similarities.java

Lines changed: 84 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,23 @@
1818
*/
1919
package org.neo4j.graphalgo.similarity;
2020

21+
import com.carrotsearch.hppc.LongDoubleHashMap;
22+
import com.carrotsearch.hppc.LongDoubleMap;
23+
import com.carrotsearch.hppc.LongHashSet;
24+
import com.carrotsearch.hppc.LongSet;
25+
import org.neo4j.graphalgo.core.ProcedureConfiguration;
2126
import org.neo4j.graphalgo.core.utils.Intersections;
2227
import org.neo4j.procedure.Description;
2328
import org.neo4j.procedure.Name;
29+
import org.neo4j.procedure.UserAggregationFunction;
2430
import org.neo4j.procedure.UserFunction;
2531

2632
import java.util.HashSet;
2733
import java.util.List;
34+
import java.util.Map;
35+
36+
import static org.neo4j.graphalgo.similarity.SimilarityVectorAggregator.CATEGORY_KEY;
37+
import static org.neo4j.graphalgo.similarity.SimilarityVectorAggregator.WEIGHT_KEY;
2838

2939
public class Similarities {
3040

@@ -50,43 +60,82 @@ public double cosineSimilarity(@Name("vector1") List<Number> vector1, @Name("vec
5060
throw new RuntimeException("Vectors must be non-empty and of the same size");
5161
}
5262

53-
double dotProduct = 0d;
54-
double xLength = 0d;
55-
double yLength = 0d;
56-
for (int i = 0; i < vector1.size(); i++) {
57-
double weight1 = vector1.get(i).doubleValue();
58-
double weight2 = vector2.get(i).doubleValue();
63+
int len = Math.min(vector1.size(), vector2.size());
64+
double[] weights1 = new double[len];
65+
double[] weights2 = new double[len];
5966

60-
dotProduct += weight1 * weight2;
61-
xLength += weight1 * weight1;
62-
yLength += weight2 * weight2;
67+
for (int i = 0; i < len; i++) {
68+
weights1[i] = vector1.get(i).doubleValue();
69+
weights2[i] = vector2.get(i).doubleValue();
6370
}
6471

65-
xLength = Math.sqrt(xLength);
66-
yLength = Math.sqrt(yLength);
72+
return Math.sqrt(Intersections.cosineSquare(weights1, weights2, len));
73+
}
6774

68-
return dotProduct / (xLength * yLength);
75+
@UserAggregationFunction("algo.similarity.asVector")
76+
@Description("algo.similarity.asVector - builds a vector of maps containing items and weights")
77+
public SimilarityVectorAggregator asVector() {
78+
return new SimilarityVectorAggregator();
6979
}
7080

7181
@UserFunction("algo.similarity.pearson")
7282
@Description("algo.similarity.pearson([vector1], [vector2]) " +
7383
"given two collection vectors, calculate pearson similarity")
74-
public double pearsonSimilarity(@Name("vector1") List<Number> vector1, @Name("vector2") List<Number> vector2) {
75-
if (vector1.size() != vector2.size() || vector1.size() == 0) {
76-
throw new RuntimeException("Vectors must be non-empty and of the same size");
77-
}
78-
79-
int len = Math.min(vector1.size(), vector2.size());
80-
81-
double[] weights1 = new double[len];
82-
double[] weights2 = new double[len];
83-
84-
for (int i = 0; i < len; i++) {
85-
weights1[i] = vector1.get(i).doubleValue();
86-
weights2[i] = vector2.get(i).doubleValue();
84+
public double pearsonSimilarity(@Name("vector1") Object rawVector1, @Name("vector2") Object rawVector2, @Name(value = "config", defaultValue = "{}") Map<String, Object> config) {
85+
ProcedureConfiguration configuration = ProcedureConfiguration.create(config);
86+
87+
String listType = configuration.get("vectorType", "numbers");
88+
89+
if (listType.equalsIgnoreCase("maps")) {
90+
List<Map<String, Object>> vector1 = (List<Map<String, Object>>) rawVector1;
91+
List<Map<String, Object>> vector2 = (List<Map<String, Object>>) rawVector2;
92+
93+
LongSet ids = new LongHashSet();
94+
95+
LongDoubleMap v1Mappings = new LongDoubleHashMap();
96+
for (Map<String, Object> entry : vector1) {
97+
Long id = (Long) entry.get(CATEGORY_KEY);
98+
ids.add(id);
99+
v1Mappings.put(id, (Double) entry.get(WEIGHT_KEY));
100+
}
101+
102+
LongDoubleMap v2Mappings = new LongDoubleHashMap();
103+
for (Map<String, Object> entry : vector2) {
104+
Long id = (Long) entry.get(CATEGORY_KEY);
105+
ids.add(id);
106+
v2Mappings.put(id, (Double) entry.get(WEIGHT_KEY));
107+
}
108+
109+
double[] weights1 = new double[ids.size()];
110+
double[] weights2 = new double[ids.size()];
111+
112+
double skipValue = Double.NaN;
113+
int index = 0;
114+
for (long id : ids.toArray()) {
115+
weights1[index] = v1Mappings.getOrDefault(id, skipValue);
116+
weights2[index] = v2Mappings.getOrDefault(id, skipValue);
117+
index++;
118+
}
119+
120+
return Intersections.pearsonSkip(weights1, weights2, ids.size(), skipValue);
121+
} else {
122+
List<Number> vector1 = (List<Number>) rawVector1;
123+
List<Number> vector2 = (List<Number>) rawVector2;
124+
125+
if (vector1.size() != vector2.size() || vector1.size() == 0) {
126+
throw new RuntimeException("Vectors must be non-empty and of the same size");
127+
}
128+
129+
int len = vector1.size();
130+
double[] weights1 = new double[len];
131+
double[] weights2 = new double[len];
132+
133+
for (int i = 0; i < len; i++) {
134+
weights1[i] = vector1.get(i).doubleValue();
135+
weights2[i] = vector2.get(i).doubleValue();
136+
}
137+
return Intersections.pearson(weights1, weights2, len);
87138
}
88-
89-
return Intersections.pearson(weights1, weights2, len);
90139
}
91140

92141
@UserFunction("algo.similarity.euclideanDistance")
@@ -97,15 +146,16 @@ public double euclideanDistance(@Name("vector1") List<Number> vector1, @Name("ve
97146
throw new RuntimeException("Vectors must be non-empty and of the same size");
98147
}
99148

100-
double distance = 0.0;
101-
for (int i = 0; i < vector1.size(); i++) {
102-
double sqOfDiff = vector1.get(i).doubleValue() - vector2.get(i).doubleValue();
103-
sqOfDiff *= sqOfDiff;
104-
distance += sqOfDiff;
149+
int len = Math.min(vector1.size(), vector2.size());
150+
double[] weights1 = new double[len];
151+
double[] weights2 = new double[len];
152+
153+
for (int i = 0; i < len; i++) {
154+
weights1[i] = vector1.get(i).doubleValue();
155+
weights2[i] = vector2.get(i).doubleValue();
105156
}
106-
distance = Math.sqrt(distance);
107157

108-
return distance;
158+
return Math.sqrt(Intersections.sumSquareDelta(weights1, weights2, len));
109159
}
110160

111161
@UserFunction("algo.similarity.euclidean")
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package org.neo4j.graphalgo.similarity;
2+
3+
import org.neo4j.graphdb.Node;
4+
import org.neo4j.helpers.collection.MapUtil;
5+
import org.neo4j.procedure.Name;
6+
import org.neo4j.procedure.UserAggregationResult;
7+
import org.neo4j.procedure.UserAggregationUpdate;
8+
9+
import java.util.ArrayList;
10+
import java.util.List;
11+
import java.util.Map;
12+
13+
public class SimilarityVectorAggregator {
14+
private List<Map<String, Object>> vector = new ArrayList<>();
15+
public static String CATEGORY_KEY = "category";
16+
public static String WEIGHT_KEY = "weight";
17+
18+
@UserAggregationUpdate
19+
public void next(
20+
@Name("node") Node node, @Name("weight") double weight) {
21+
vector.add(MapUtil.map(CATEGORY_KEY, node.getId(), WEIGHT_KEY, weight));
22+
}
23+
24+
@UserAggregationResult
25+
public List<Map<String, Object>> result() {
26+
return vector;
27+
}
28+
}

algo/src/main/java/org/neo4j/graphalgo/similarity/WeightedInput.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ public SimilarityResult pearson(RleDecoder decoder, double similarityCutoff, Wei
127127

128128
int len = Math.min(thisWeights.length, otherWeights.length);
129129
double pearson = Intersections.pearson(thisWeights, otherWeights, len);
130-
pearson = Double.isNaN(pearson) ? 0 : pearson;
131130

132131
if (similarityCutoff >= 0d && (pearson == 0 || pearson < similarityCutoff)) return null;
133132

@@ -145,7 +144,6 @@ public SimilarityResult pearsonSkip(RleDecoder decoder, double similarityCutoff,
145144

146145
int len = Math.min(thisWeights.length, otherWeights.length);
147146
double pearson = Intersections.pearsonSkip(thisWeights, otherWeights, len, skipValue);
148-
pearson = Double.isNaN(pearson) ? 0 : pearson;
149147

150148
if (similarityCutoff >= 0d && (pearson == 0 || pearson < similarityCutoff)) return null;
151149

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package org.neo4j.graphalgo.similarity;
2+
3+
import org.junit.Test;
4+
import org.neo4j.graphdb.Node;
5+
import org.neo4j.helpers.collection.MapUtil;
6+
7+
import java.util.Arrays;
8+
import java.util.Collections;
9+
import java.util.List;
10+
import java.util.Map;
11+
12+
import static org.hamcrest.Matchers.is;
13+
import static org.junit.Assert.*;
14+
import static org.mockito.Mockito.mock;
15+
import static org.mockito.Mockito.when;
16+
import static org.neo4j.graphalgo.similarity.SimilarityVectorAggregator.CATEGORY_KEY;
17+
import static org.neo4j.graphalgo.similarity.SimilarityVectorAggregator.WEIGHT_KEY;
18+
19+
public class SimilarityVectorAggregatorTest {
20+
21+
@Test
22+
public void singleItem() {
23+
SimilarityVectorAggregator aggregator = new SimilarityVectorAggregator();
24+
25+
Node node = mock(Node.class);
26+
when(node.getId()).thenReturn(1L);
27+
28+
aggregator.next(node, 3.0);
29+
30+
List<Map<String, Object>> expected = Collections.singletonList(
31+
MapUtil.map(CATEGORY_KEY, 1L, WEIGHT_KEY, 3.0)
32+
);
33+
34+
assertThat(aggregator.result(), is(expected));
35+
}
36+
37+
@Test
38+
public void multipleItems() {
39+
SimilarityVectorAggregator aggregator = new SimilarityVectorAggregator();
40+
41+
Node node = mock(Node.class);
42+
when(node.getId()).thenReturn(1L, 2L, 3L);
43+
44+
aggregator.next(node, 3.0);
45+
aggregator.next(node, 2.0);
46+
aggregator.next(node, 1.0);
47+
48+
List<Map<String, Object>> expected = Arrays.asList(
49+
MapUtil.map(CATEGORY_KEY, 1L, WEIGHT_KEY, 3.0),
50+
MapUtil.map(CATEGORY_KEY, 2L, WEIGHT_KEY, 2.0),
51+
MapUtil.map(CATEGORY_KEY, 3L, WEIGHT_KEY, 1.0)
52+
);
53+
54+
assertThat(aggregator.result(), is(expected));
55+
}
56+
57+
}

core/src/main/java/org/neo4j/graphalgo/core/utils/Intersections.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ public static double pearson(double[] vector1, double[] vector2, int len) {
200200
yLength += vector2Delta * vector2Delta;
201201
}
202202

203-
return dotProductMinusMean / Math.sqrt(xLength * yLength);
203+
double result = dotProductMinusMean / Math.sqrt(xLength * yLength);
204+
return Double.isNaN(result) ? 0 : result;
204205
}
205206

206207
public static double pearsonSkip(double[] vector1, double[] vector2, int len, double skipValue) {
@@ -246,7 +247,8 @@ public static double pearsonSkip(double[] vector1, double[] vector2, int len, do
246247
yLength += vector2Delta * vector2Delta;
247248
}
248249

249-
return dotProductMinusMean / Math.sqrt(xLength * yLength);
250+
double result = dotProductMinusMean / Math.sqrt(xLength * yLength);
251+
return Double.isNaN(result) ? 0 : result;
250252
}
251253

252254
private static boolean shouldSkip(double weight, double skipValue, boolean skipNan) {

0 commit comments

Comments
 (0)