Skip to content

Commit 2c08720

Browse files
authored
Adding more bit support to jmh osq benchmark (#138049)
This adds more bits to our OSQ scorer jmh benchmark.
1 parent 71579e0 commit 2c08720

File tree

1 file changed

+44
-10
lines changed

1 file changed

+44
-10
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
1919
import org.elasticsearch.common.logging.LogConfigurator;
2020
import org.elasticsearch.core.IOUtils;
21+
import org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat;
2122
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
23+
import org.elasticsearch.simdvec.ESNextOSQVectorsScorer;
2224
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
2325
import org.openjdk.jmh.annotations.Benchmark;
2426
import org.openjdk.jmh.annotations.BenchmarkMode;
@@ -57,6 +59,9 @@ public class OSQScorerBenchmark {
5759
@Param({ "384", "782", "1024" })
5860
int dims;
5961

62+
@Param({ "1", "2", "4" })
63+
int bits;
64+
6065
int length;
6166

6267
int numVectors = ES91OSQVectorsScorer.BULK_SIZE * 10;
@@ -68,8 +73,8 @@ public class OSQScorerBenchmark {
6873
float centroidDp;
6974

7075
byte[] scratch;
71-
ES91OSQVectorsScorer scorerMmap;
72-
ES91OSQVectorsScorer scorerNfios;
76+
ESNextOSQVectorsScorer scorerMmap;
77+
ESNextOSQVectorsScorer scorerNfios;
7378

7479
Directory dirMmap;
7580
IndexInput inMmap;
@@ -84,7 +89,12 @@ public class OSQScorerBenchmark {
8489
public void setup() throws IOException {
8590
Random random = new Random(123);
8691

87-
this.length = OptimizedScalarQuantizer.discretize(dims, 64) / 8;
92+
this.length = switch (bits) {
93+
case 1 -> ESNextDiskBBQVectorsFormat.QuantEncoding.ONE_BIT_4BIT_QUERY.getDocPackedLength(dims);
94+
case 2 -> ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY.getDocPackedLength(dims);
95+
case 4 -> ESNextDiskBBQVectorsFormat.QuantEncoding.FOUR_BIT_SYMMETRIC.getDocPackedLength(dims);
96+
default -> throw new IllegalArgumentException("Unsupported bits: " + bits);
97+
};
8898

8999
binaryVectors = new byte[numVectors][length];
90100
for (byte[] binaryVector : binaryVectors) {
@@ -109,8 +119,14 @@ public void setup() throws IOException {
109119
outNfios.close();
110120
inMmap = dirMmap.openInput("vectors", IOContext.DEFAULT);
111121
inNiofs = dirNiofs.openInput("vectors", IOContext.DEFAULT);
112-
113-
binaryQueries = new byte[numVectors][4 * length];
122+
int binaryQueryLength = switch (bits) {
123+
case 1 -> ESNextDiskBBQVectorsFormat.QuantEncoding.ONE_BIT_4BIT_QUERY.getQueryPackedLength(dims);
124+
case 2 -> ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY.getQueryPackedLength(dims);
125+
case 4 -> ESNextDiskBBQVectorsFormat.QuantEncoding.FOUR_BIT_SYMMETRIC.getQueryPackedLength(dims);
126+
default -> throw new IllegalArgumentException("Unsupported bits: " + bits);
127+
};
128+
129+
binaryQueries = new byte[numVectors][binaryQueryLength];
114130
for (byte[] binaryVector : binaryVectors) {
115131
random.nextBytes(binaryVector);
116132
}
@@ -123,8 +139,26 @@ public void setup() throws IOException {
123139
centroidDp = random.nextFloat();
124140

125141
scratch = new byte[length];
126-
scorerMmap = ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(inMmap, dims);
127-
scorerNfios = ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(inNiofs, dims);
142+
final int docBits;
143+
final int queryBits = switch (bits) {
144+
case 1 -> {
145+
docBits = 1;
146+
yield 4;
147+
}
148+
case 2 -> {
149+
docBits = 2;
150+
yield 4;
151+
}
152+
case 4 -> {
153+
docBits = 4;
154+
yield 4;
155+
}
156+
default -> throw new IllegalArgumentException("Unsupported bits: " + bits);
157+
};
158+
scorerMmap = ESVectorizationProvider.getInstance()
159+
.newESNextOSQVectorsScorer(inMmap, (byte) queryBits, (byte) docBits, dims, length);
160+
scorerNfios = ESVectorizationProvider.getInstance()
161+
.newESNextOSQVectorsScorer(inNiofs, (byte) queryBits, (byte) docBits, dims, length);
128162
scratchScores = new float[16];
129163
corrections = new float[3];
130164
}
@@ -156,7 +190,7 @@ public void scoreFromMemorySegmentOnlyVectorNiofsVect(Blackhole bh) throws IOExc
156190
scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios);
157191
}
158192

159-
private void scoreFromMemorySegmentOnlyVector(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException {
193+
private void scoreFromMemorySegmentOnlyVector(Blackhole bh, IndexInput in, ESNextOSQVectorsScorer scorer) throws IOException {
160194
for (int j = 0; j < numQueries; j++) {
161195
in.seek(0);
162196
for (int i = 0; i < numVectors; i++) {
@@ -203,7 +237,7 @@ public void scoreFromMemorySegmentOnlyVectorBulkNiofsVect(Blackhole bh) throws I
203237
scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios);
204238
}
205239

206-
private void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException {
240+
private void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh, IndexInput in, ESNextOSQVectorsScorer scorer) throws IOException {
207241
for (int j = 0; j < numQueries; j++) {
208242
in.seek(0);
209243
for (int i = 0; i < numVectors; i += 16) {
@@ -252,7 +286,7 @@ public void scoreFromMemorySegmentAllBulkNiofsVect(Blackhole bh) throws IOExcept
252286
scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios);
253287
}
254288

255-
private void scoreFromMemorySegmentAllBulk(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException {
289+
private void scoreFromMemorySegmentAllBulk(Blackhole bh, IndexInput in, ESNextOSQVectorsScorer scorer) throws IOException {
256290
for (int j = 0; j < numQueries; j++) {
257291
in.seek(0);
258292
for (int i = 0; i < numVectors; i += 16) {

0 commit comments

Comments
 (0)