Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit 4390229

Browse files
committed
Fixing bug with pearson average computation when skipping values
1 parent 199dc3a commit 4390229

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

algo/src/test/java/org/neo4j/graphalgo/similarity/WeightedInputTest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,19 @@ public void pearsonNaNRespectsSimilarityCutOff() {
8686
assertNull(input1.pearson(null, 0.1, input2));
8787
}
8888

89+
@Test
90+
public void pearsonWithNonOverlappingValues() {
91+
double[] weights1 = new double[]{1, 2,3, Double.NaN, 4}; // ave = 10/4 = 2.5
92+
double[] weights2 = new double[]{Double.NaN, 2,3, 1, 4}; // ave = 10/4 = 2.5
93+
94+
WeightedInput input1 = new WeightedInput(1, weights1);
95+
WeightedInput input2 = new WeightedInput(2, weights2);
96+
97+
SimilarityResult similarityResult = input1.pearsonSkip(null, -1.0, input2, Double.NaN);
98+
99+
assertEquals(1.0, similarityResult.similarity, 0.01);
100+
}
101+
89102
@Test
90103
public void cosineNoCompression() {
91104
double[] weights1 = new double[]{1, 2, 3, 4, 4, 4, 4, 5, 6};

core/src/main/java/org/neo4j/graphalgo/core/utils/Intersections.java

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,10 @@ public static double sumSquareDeltaSkip(double[] vector1, double[] vector2, int
109109
double result = 0;
110110
for (int i = 0; i < len; i++) {
111111
double weight1 = vector1[i];
112-
if (weight1 == skipValue || (skipNan && Double.isNaN(weight1))) continue;
112+
if (shouldSkip(weight1, skipValue, skipNan)) continue;
113113

114114
double weight2 = vector2[i];
115-
if (weight2 == skipValue || (skipNan && Double.isNaN(weight2))) continue;
115+
if (shouldSkip(weight2, skipValue, skipNan)) continue;
116116

117117
double delta = weight1 - weight2;
118118
result += delta * delta;
@@ -164,9 +164,9 @@ public static double cosineSquareSkip(double[] vector1, double[] vector2, int le
164164
double yLength = 0d;
165165
for (int i = 0; i < len; i++) {
166166
double weight1 = vector1[i];
167-
if (weight1 == skipValue || (skipNan && Double.isNaN(weight1))) continue;
167+
if (shouldSkip(weight1, skipValue, skipNan)) continue;
168168
double weight2 = vector2[i];
169-
if (weight2 == skipValue || (skipNan && Double.isNaN(weight2))) continue;
169+
if (shouldSkip(weight2, skipValue, skipNan)) continue;
170170

171171
dotProduct += weight1 * weight2;
172172
xLength += weight1 * weight1;
@@ -207,32 +207,36 @@ public static double pearsonSkip(double[] vector1, double[] vector2, int len, do
207207
boolean skipNan = Double.isNaN(skipValue);
208208

209209
double vector1Sum = 0.0;
210+
int vector1Count = 0;
210211
double vector2Sum = 0.0;
211-
int count =0;
212+
int vector2Count = 0;
212213
for (int i = 0; i < len; i++) {
213214
double weight1 = vector1[i];
214215
double weight2 = vector2[i];
215216

216-
if (weight1 == skipValue || (skipNan && Double.isNaN(weight1))) continue;
217-
if (weight2 == skipValue || (skipNan && Double.isNaN(weight2))) continue;
217+
if(!shouldSkip(weight1, skipValue, skipNan)) {
218+
vector1Sum += weight1;
219+
vector1Count++;
220+
}
218221

219-
vector1Sum += weight1;
220-
vector2Sum += weight2;
221-
count++;
222+
if (!shouldSkip(weight2, skipValue, skipNan)) {
223+
vector2Sum += weight2;
224+
vector2Count++;
225+
}
222226
}
223227

224-
double vector1Mean = vector1Sum / count;
225-
double vector2Mean = vector2Sum / count;
228+
double vector1Mean = vector1Sum / vector1Count;
229+
double vector2Mean = vector2Sum / vector2Count;
226230

227231
double dotProductMinusMean = 0d;
228232
double xLength = 0d;
229233
double yLength = 0d;
230234
for (int i = 0; i < len; i++) {
231235
double weight1 = vector1[i];
232-
if (weight1 == skipValue || (skipNan && Double.isNaN(weight1))) continue;
236+
if (shouldSkip(weight1, skipValue, skipNan)) continue;
233237

234238
double weight2 = vector2[i];
235-
if (weight2 == skipValue || (skipNan && Double.isNaN(weight2))) continue;
239+
if (shouldSkip(weight2, skipValue, skipNan)) continue;
236240

237241
double vector1Delta = weight1 - vector1Mean;
238242
double vector2Delta = weight2 - vector2Mean;
@@ -245,6 +249,10 @@ public static double pearsonSkip(double[] vector1, double[] vector2, int len, do
245249
return dotProductMinusMean / Math.sqrt(xLength * yLength);
246250
}
247251

252+
private static boolean shouldSkip(double weight, double skipValue, boolean skipNan) {
253+
return weight == skipValue || (skipNan && Double.isNaN(weight));
254+
}
255+
248256
public static double cosine(double[] vector1, double[] vector2, int len) {
249257
double dotProduct = 0d;
250258
double xLength = 0d;

tests/src/test/java/org/neo4j/graphalgo/algo/similarity/PearsonSimilarityTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,4 @@ public void oppositeVectors() {
4848
double similarity = similarities.pearsonSimilarity(user1Ratings, user2Ratings);
4949
assertEquals(-1.0, similarity, 0.01);
5050
}
51-
5251
}

0 commit comments

Comments
 (0)