Skip to content

Commit 1c050ce

Browse files
authored
Create int generic vector formats (#138172)
Add bfloat16 and on_disk_rescore support to `int4_*`, `int8_*`, and `flat` dense_vector index types. Currently all hidden behind a feature flag
1 parent 6b65726 commit 1c050ce

File tree

22 files changed

+4695
-98
lines changed

22 files changed

+4695
-98
lines changed

rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized_bfloat16.yml

Lines changed: 896 additions & 0 deletions
Large diffs are not rendered by default.

rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized_bfloat16.yml

Lines changed: 933 additions & 0 deletions
Large diffs are not rendered by default.

rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat_bfloat16.yml

Lines changed: 432 additions & 0 deletions
Large diffs are not rendered by default.

rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat_bfloat16.yml

Lines changed: 619 additions & 0 deletions
Large diffs are not rendered by default.

rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat_bfloat16.yml

Lines changed: 555 additions & 0 deletions
Large diffs are not rendered by default.

server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import java.util.Map;
3939
import java.util.OptionalLong;
4040
import java.util.stream.IntStream;
41+
import java.util.stream.Stream;
4142

4243
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
4344
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
@@ -76,7 +77,7 @@ protected boolean useDirectIO(String name, IOContext context, OptionalLong fileL
7677

7778
@ParametersFactory
7879
public static Iterable<Object[]> parameters() {
79-
return List.of(new Object[] { "bbq_hnsw" }, new Object[] { "bbq_disk" });
80+
return Stream.of("int4_hnsw", "int8_hnsw", "bbq_hnsw", "bbq_disk").map(s -> new Object[] { s }).toList();
8081
}
8182

8283
public DirectIOIT(String type) {

server/src/main/java/module-info.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,11 @@
465465
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat,
466466
org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat,
467467
org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat,
468-
org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat,
468+
org.elasticsearch.index.codec.vectors.es93.ES93FlatVectorFormat,
469469
org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat,
470+
org.elasticsearch.index.codec.vectors.es93.ES93ScalarQuantizedVectorsFormat,
471+
org.elasticsearch.index.codec.vectors.es93.ES93HnswScalarQuantizedVectorsFormat,
472+
org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat,
470473
org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat;
471474

472475
provides org.apache.lucene.codecs.Codec
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors.es93;
11+
12+
import org.apache.lucene.codecs.KnnVectorsFormat;
13+
import org.apache.lucene.codecs.KnnVectorsReader;
14+
import org.apache.lucene.codecs.KnnVectorsWriter;
15+
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
16+
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
17+
import org.apache.lucene.index.ByteVectorValues;
18+
import org.apache.lucene.index.FieldInfo;
19+
import org.apache.lucene.index.FloatVectorValues;
20+
import org.apache.lucene.index.SegmentReadState;
21+
import org.apache.lucene.index.SegmentWriteState;
22+
import org.apache.lucene.search.AcceptDocs;
23+
import org.apache.lucene.search.KnnCollector;
24+
import org.apache.lucene.util.Bits;
25+
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
26+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
27+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
28+
29+
import java.io.IOException;
30+
import java.util.Map;
31+
32+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
33+
34+
public class ES93FlatVectorFormat extends KnnVectorsFormat {
35+
36+
static final String NAME = "ES93FlatVectorFormat";
37+
38+
private final FlatVectorsFormat format;
39+
40+
/**
41+
* Sole constructor
42+
*/
43+
public ES93FlatVectorFormat() {
44+
super(NAME);
45+
format = new ES93GenericFlatVectorsFormat();
46+
}
47+
48+
public ES93FlatVectorFormat(DenseVectorFieldMapper.ElementType elementType) {
49+
super(NAME);
50+
assert elementType != DenseVectorFieldMapper.ElementType.BIT : "ES815BitFlatVectorFormat should be used for bits";
51+
format = new ES93GenericFlatVectorsFormat(elementType, false);
52+
}
53+
54+
@Override
55+
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
56+
return format.fieldsWriter(state);
57+
}
58+
59+
@Override
60+
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
61+
return new ES93FlatVectorReader(format.fieldsReader(state));
62+
}
63+
64+
@Override
65+
public int getMaxDimensions(String fieldName) {
66+
return MAX_DIMS_COUNT;
67+
}
68+
69+
static class ES93FlatVectorReader extends KnnVectorsReader {
70+
71+
private final FlatVectorsReader reader;
72+
73+
ES93FlatVectorReader(FlatVectorsReader reader) {
74+
this.reader = reader;
75+
}
76+
77+
@Override
78+
public void checkIntegrity() throws IOException {
79+
reader.checkIntegrity();
80+
}
81+
82+
@Override
83+
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
84+
return reader.getFloatVectorValues(field);
85+
}
86+
87+
@Override
88+
public ByteVectorValues getByteVectorValues(String field) throws IOException {
89+
return reader.getByteVectorValues(field);
90+
}
91+
92+
@Override
93+
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
94+
collectAllMatchingDocs(knnCollector, acceptDocs, reader.getRandomVectorScorer(field, target));
95+
}
96+
97+
private void collectAllMatchingDocs(KnnCollector knnCollector, AcceptDocs acceptDocs, RandomVectorScorer scorer)
98+
throws IOException {
99+
OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
100+
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits());
101+
for (int i = 0; i < scorer.maxOrd(); i++) {
102+
if (acceptedOrds == null || acceptedOrds.get(i)) {
103+
collector.collect(i, scorer.score(i));
104+
collector.incVisitedCount(1);
105+
}
106+
}
107+
assert collector.earlyTerminated() == false;
108+
}
109+
110+
@Override
111+
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
112+
collectAllMatchingDocs(knnCollector, acceptDocs, reader.getRandomVectorScorer(field, target));
113+
}
114+
115+
@Override
116+
public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
117+
return reader.getOffHeapByteSize(fieldInfo);
118+
}
119+
120+
@Override
121+
public void close() throws IOException {
122+
reader.close();
123+
}
124+
}
125+
}
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors.es93;
11+
12+
import org.apache.lucene.codecs.KnnVectorsReader;
13+
import org.apache.lucene.codecs.KnnVectorsWriter;
14+
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
15+
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
16+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
17+
import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer;
18+
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
19+
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
20+
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsReader;
21+
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter;
22+
import org.apache.lucene.index.SegmentReadState;
23+
import org.apache.lucene.index.SegmentWriteState;
24+
import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat;
25+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
26+
27+
import java.io.IOException;
28+
import java.util.concurrent.ExecutorService;
29+
30+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
31+
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL;
32+
33+
public class ES93HnswScalarQuantizedVectorsFormat extends AbstractHnswVectorsFormat {
34+
35+
static final String NAME = "ES93HnswScalarQuantizedVectorsFormat";
36+
private static final int ALLOWED_BITS = (1 << 7) | (1 << 4);
37+
38+
/** The minimum confidence interval */
39+
private static final float MINIMUM_CONFIDENCE_INTERVAL = 0.9f;
40+
41+
/** The maximum confidence interval */
42+
private static final float MAXIMUM_CONFIDENCE_INTERVAL = 1f;
43+
44+
static final FlatVectorsScorer flatVectorScorer = new ES93ScalarQuantizedVectorsFormat.ESQuantizedFlatVectorsScorer(
45+
new ScalarQuantizedVectorScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())
46+
);
47+
48+
private final FlatVectorsFormat rawVectorFormat;
49+
50+
/**
51+
* Controls the confidence interval used to scalar quantize the vectors the default value is
52+
* calculated as `1-1/(vector_dimensions + 1)`
53+
*/
54+
public final Float confidenceInterval;
55+
56+
private final byte bits;
57+
private final boolean compress;
58+
59+
public ES93HnswScalarQuantizedVectorsFormat() {
60+
super(NAME);
61+
this.rawVectorFormat = new ES93GenericFlatVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, false);
62+
this.confidenceInterval = null;
63+
this.bits = 7;
64+
this.compress = false;
65+
}
66+
67+
public ES93HnswScalarQuantizedVectorsFormat(
68+
int maxConn,
69+
int beamWidth,
70+
DenseVectorFieldMapper.ElementType elementType,
71+
Float confidenceInterval,
72+
int bits,
73+
boolean compress,
74+
boolean useDirectIO
75+
) {
76+
this(maxConn, beamWidth, elementType, confidenceInterval, bits, compress, useDirectIO, DEFAULT_NUM_MERGE_WORKER, null);
77+
}
78+
79+
public ES93HnswScalarQuantizedVectorsFormat(
80+
int maxConn,
81+
int beamWidth,
82+
DenseVectorFieldMapper.ElementType elementType,
83+
Float confidenceInterval,
84+
int bits,
85+
boolean compress,
86+
boolean useDirectIO,
87+
int numMergeWorkers,
88+
ExecutorService mergeExec
89+
) {
90+
super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec);
91+
92+
if (confidenceInterval != null
93+
&& confidenceInterval != DYNAMIC_CONFIDENCE_INTERVAL
94+
&& (confidenceInterval < MINIMUM_CONFIDENCE_INTERVAL || confidenceInterval > MAXIMUM_CONFIDENCE_INTERVAL)) {
95+
throw new IllegalArgumentException(
96+
"confidenceInterval must be between "
97+
+ MINIMUM_CONFIDENCE_INTERVAL
98+
+ " and "
99+
+ MAXIMUM_CONFIDENCE_INTERVAL
100+
+ "; confidenceInterval="
101+
+ confidenceInterval
102+
);
103+
}
104+
if (bits < 1 || bits > 8 || (ALLOWED_BITS & (1 << bits)) == 0) {
105+
throw new IllegalArgumentException("bits must be one of: 4, 7; bits=" + bits);
106+
}
107+
assert elementType != DenseVectorFieldMapper.ElementType.BIT : "BIT should not be used with scalar quantization";
108+
109+
this.rawVectorFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO);
110+
this.confidenceInterval = confidenceInterval;
111+
this.bits = (byte) bits;
112+
this.compress = compress;
113+
}
114+
115+
@Override
116+
protected FlatVectorsFormat flatVectorsFormat() {
117+
return rawVectorFormat;
118+
}
119+
120+
@Override
121+
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
122+
return new Lucene99HnswVectorsWriter(
123+
state,
124+
maxConn,
125+
beamWidth,
126+
new Lucene99ScalarQuantizedVectorsWriter(
127+
state,
128+
confidenceInterval,
129+
bits,
130+
compress,
131+
rawVectorFormat.fieldsWriter(state),
132+
flatVectorScorer
133+
),
134+
numMergeWorkers,
135+
mergeExec
136+
);
137+
}
138+
139+
@Override
140+
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
141+
return new Lucene99HnswVectorsReader(
142+
state,
143+
new Lucene99ScalarQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), flatVectorScorer)
144+
);
145+
}
146+
147+
@Override
148+
public String toString() {
149+
return NAME
150+
+ "(name="
151+
+ NAME
152+
+ ", maxConn="
153+
+ maxConn
154+
+ ", beamWidth="
155+
+ beamWidth
156+
+ ", confidenceInterval="
157+
+ confidenceInterval
158+
+ ", bits="
159+
+ bits
160+
+ ", compressed="
161+
+ compress
162+
+ ", flatVectorScorer="
163+
+ flatVectorScorer
164+
+ ", flatVectorFormat="
165+
+ rawVectorFormat
166+
+ ")";
167+
}
168+
}

0 commit comments

Comments
 (0)