Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit bb2462d

Browse files
committed
Pearson NaN -> 0 was happening too late
1 parent c7b881a commit bb2462d

File tree

2 files changed

+34
-15
lines changed

2 files changed

+34
-15
lines changed

algo/src/main/java/org/neo4j/graphalgo/similarity/WeightedInput.java

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public WeightedInput(long id, double[] weights) {
2525

2626
private static int calculateCount(double[] weights, double skipValue) {
2727
boolean skipNan = Double.isNaN(skipValue);
28-
int count =0;
28+
int count = 0;
2929
for (double weight : weights) {
3030
if (!(weight == skipValue || (skipNan && Double.isNaN(weight)))) count++;
3131
}
@@ -51,7 +51,7 @@ public int compareTo(WeightedInput o) {
5151
public SimilarityResult sumSquareDeltaSkip(RleDecoder decoder, double similarityCutoff, WeightedInput other, double skipValue) {
5252
double[] thisWeights = weights;
5353
double[] otherWeights = other.weights;
54-
if(decoder != null) {
54+
if (decoder != null) {
5555
decoder.reset(weights, other.weights);
5656
thisWeights = decoder.item1();
5757
otherWeights = decoder.item2();
@@ -68,7 +68,7 @@ public SimilarityResult sumSquareDeltaSkip(RleDecoder decoder, double similarity
6868
public SimilarityResult sumSquareDelta(RleDecoder decoder, double similarityCutoff, WeightedInput other) {
6969
double[] thisWeights = weights;
7070
double[] otherWeights = other.weights;
71-
if(decoder != null) {
71+
if (decoder != null) {
7272
decoder.reset(weights, other.weights);
7373
thisWeights = decoder.item1();
7474
otherWeights = decoder.item2();
@@ -85,7 +85,7 @@ public SimilarityResult sumSquareDelta(RleDecoder decoder, double similarityCuto
8585
public SimilarityResult cosineSquaresSkip(RleDecoder decoder, double similarityCutoff, WeightedInput other, double skipValue) {
8686
double[] thisWeights = weights;
8787
double[] otherWeights = other.weights;
88-
if(decoder != null) {
88+
if (decoder != null) {
8989
decoder.reset(weights, other.weights);
9090
thisWeights = decoder.item1();
9191
otherWeights = decoder.item2();
@@ -102,7 +102,7 @@ public SimilarityResult cosineSquaresSkip(RleDecoder decoder, double similarityC
102102
public SimilarityResult cosineSquares(RleDecoder decoder, double similarityCutoff, WeightedInput other) {
103103
double[] thisWeights = weights;
104104
double[] otherWeights = other.weights;
105-
if(decoder != null) {
105+
if (decoder != null) {
106106
decoder.reset(weights, other.weights);
107107
thisWeights = decoder.item1();
108108
otherWeights = decoder.item2();
@@ -127,14 +127,11 @@ public SimilarityResult pearson(RleDecoder decoder, double similarityCutoff, Wei
127127

128128
int len = Math.min(thisWeights.length, otherWeights.length);
129129
double pearson = Intersections.pearson(thisWeights, otherWeights, len);
130+
pearson = Double.isNaN(pearson) ? 0 : pearson;
130131

131132
if (similarityCutoff >= 0d && (pearson == 0 || pearson < similarityCutoff)) return null;
132133

133-
if (Double.isNaN(pearson)) {
134-
return new SimilarityResult(id, other.id, itemCount, other.itemCount, 0, 0);
135-
} else {
136-
return new SimilarityResult(id, other.id, itemCount, other.itemCount, 0, pearson);
137-
}
134+
return new SimilarityResult(id, other.id, itemCount, other.itemCount, 0, pearson);
138135
}
139136

140137
public SimilarityResult pearsonSkip(RleDecoder decoder, double similarityCutoff, WeightedInput other, Double skipValue) {
@@ -148,13 +145,10 @@ public SimilarityResult pearsonSkip(RleDecoder decoder, double similarityCutoff,
148145

149146
int len = Math.min(thisWeights.length, otherWeights.length);
150147
double pearson = Intersections.pearsonSkip(thisWeights, otherWeights, len, skipValue);
148+
pearson = Double.isNaN(pearson) ? 0 : pearson;
151149

152150
if (similarityCutoff >= 0d && (pearson == 0 || pearson < similarityCutoff)) return null;
153151

154-
if (Double.isNaN(pearson)) {
155-
return new SimilarityResult(id, other.id, itemCount, other.itemCount, 0, 0);
156-
} else {
157-
return new SimilarityResult(id, other.id, itemCount, other.itemCount, 0, pearson);
158-
}
152+
return new SimilarityResult(id, other.id, itemCount, other.itemCount, 0, pearson);
159153
}
160154
}

algo/src/test/java/org/neo4j/graphalgo/similarity/WeightedInputTest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import org.junit.Test;
44

55
import static junit.framework.TestCase.assertEquals;
6+
import static junit.framework.TestCase.assertNull;
67

78
public class WeightedInputTest {
89

@@ -62,6 +63,30 @@ public void pearsonSkipCompression() {
6263
assertEquals(1.0, similarityResult.similarity, 0.01);
6364
}
6465

66+
@Test
67+
public void pearsonNaNReturns0() {
68+
double[] weights1 = new double[]{};
69+
double[] weights2 = new double[]{};
70+
71+
WeightedInput input1 = new WeightedInput(1, weights1);
72+
WeightedInput input2 = new WeightedInput(2, weights2);
73+
74+
assertEquals(0.0, input1.pearsonSkip(null, -1.0, input2, 0.0).similarity, 0.01);
75+
assertEquals(0.0, input1.pearson(null, -1.0, input2).similarity, 0.01);
76+
}
77+
78+
@Test
79+
public void pearsonNaNRespectsSimilarityCutOff() {
80+
double[] weights1 = new double[]{};
81+
double[] weights2 = new double[]{};
82+
83+
WeightedInput input1 = new WeightedInput(1, weights1);
84+
WeightedInput input2 = new WeightedInput(2, weights2);
85+
86+
assertNull(input1.pearsonSkip(null, 0.1, input2, 0.0));
87+
assertNull(input1.pearson(null, 0.1, input2));
88+
}
89+
6590
@Test
6691
public void cosineNoCompression() {
6792
double[] weights1 = new double[]{1, 2, 3, 4, 4, 4, 4, 5, 6};

0 commit comments

Comments
 (0)