Skip to content

Commit 6e34147

Browse files
Fixing sorted indices for GPU built indices (#138138)
- For flush, vectors are now reordered according to sortMap before building the GPU index, ensuring that HNSW graph node ordinals match the sorted document order. - Merge on the other hand doesn't require explicit sortMap handling since Lucene's MergedVecto utilities apply docMaps internally. - Enhanced tests with both approximate and exact KNN searches to validate sorting correctness.
1 parent ea040c3 commit 6e34147

File tree

3 files changed

+138
-43
lines changed

3 files changed

+138
-43
lines changed

docs/changelog/138138.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 138138
2+
summary: Fixing sorted indices for GPU built indices
3+
area: Vector Search
4+
type: bug
5+
issues: []

x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/GPUIndexIT.java

Lines changed: 107 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
import org.elasticsearch.common.settings.Settings;
1414
import org.elasticsearch.plugins.Plugin;
1515
import org.elasticsearch.search.SearchHit;
16+
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
1617
import org.elasticsearch.search.vectors.KnnSearchBuilder;
18+
import org.elasticsearch.search.vectors.VectorData;
1719
import org.elasticsearch.test.ESIntegTestCase;
1820
import org.elasticsearch.xpack.gpu.GPUPlugin;
1921
import org.elasticsearch.xpack.gpu.GPUSupport;
2022
import org.junit.Assert;
2123
import org.junit.BeforeClass;
2224

2325
import java.util.Collection;
26+
import java.util.HashSet;
2427
import java.util.List;
2528
import java.util.Locale;
29+
import java.util.Set;
2630

2731
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
2832
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
@@ -56,20 +60,19 @@ public void testBasic() {
5660
assertSearch(indexName, randomFloatVector(dims), totalDocs);
5761
}
5862

59-
@AwaitsFix(bugUrl = "Fix sorted index")
6063
public void testSortedIndexReturnsSameResultsAsUnsorted() {
6164
String indexName1 = "index_unsorted";
6265
String indexName2 = "index_sorted";
6366
final int dims = randomIntBetween(4, 128);
6467
createIndex(indexName1, dims, false);
6568
createIndex(indexName2, dims, true);
6669

67-
final int[] numDocs = new int[] { randomIntBetween(50, 100), randomIntBetween(50, 100) };
70+
final int[] numDocs = new int[] { randomIntBetween(300, 999), randomIntBetween(300, 999) };
6871
for (int i = 0; i < numDocs.length; i++) {
6972
BulkRequestBuilder bulkRequest1 = client().prepareBulk();
7073
BulkRequestBuilder bulkRequest2 = client().prepareBulk();
7174
for (int j = 0; j < numDocs[i]; j++) {
72-
String id = String.valueOf(i * 100 + j);
75+
String id = String.valueOf(i * 1000 + j);
7376
String keywordValue = String.valueOf(numDocs[i] - j);
7477
float[] vector = randomFloatVector(dims);
7578
bulkRequest1.add(prepareIndex(indexName1).setId(id).setSource("my_vector", vector, "my_keyword", keywordValue));
@@ -84,8 +87,9 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
8487

8588
float[] queryVector = randomFloatVector(dims);
8689
int k = 10;
87-
int numCandidates = k * 10;
90+
int numCandidates = k * 5;
8891

92+
// Test 1: Approximate KNN search - expect at least k-3 out of k matches
8993
var searchResponse1 = prepareSearch(indexName1).setSize(k)
9094
.setFetchSource(false)
9195
.addFetchField("my_keyword")
@@ -101,22 +105,40 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
101105
try {
102106
SearchHit[] hits1 = searchResponse1.getHits().getHits();
103107
SearchHit[] hits2 = searchResponse2.getHits().getHits();
104-
Assert.assertEquals(hits1.length, hits2.length);
105-
for (int i = 0; i < hits1.length; i++) {
106-
Assert.assertEquals(hits1[i].getId(), hits2[i].getId());
107-
Assert.assertEquals(hits1[i].field("my_keyword").getValue(), (String) hits2[i].field("my_keyword").getValue());
108-
Assert.assertEquals(hits1[i].getScore(), hits2[i].getScore(), 0.001f);
109-
}
108+
assertAtLeastNOutOfKMatches(hits1, hits2, k - 3, k);
110109
} finally {
111110
searchResponse1.decRef();
112111
searchResponse2.decRef();
113112
}
114113

114+
// Test 2: Exact KNN search (brute-force) - expect perfect k out of k matches
115+
var exactSearchResponse1 = prepareSearch(indexName1).setSize(k)
116+
.setFetchSource(false)
117+
.addFetchField("my_keyword")
118+
.setQuery(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), "my_vector", null))
119+
.get();
120+
121+
var exactSearchResponse2 = prepareSearch(indexName2).setSize(k)
122+
.setFetchSource(false)
123+
.addFetchField("my_keyword")
124+
.setQuery(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), "my_vector", null))
125+
.get();
126+
127+
try {
128+
SearchHit[] exactHits1 = exactSearchResponse1.getHits().getHits();
129+
SearchHit[] exactHits2 = exactSearchResponse2.getHits().getHits();
130+
assertExactMatches(exactHits1, exactHits2, k);
131+
} finally {
132+
exactSearchResponse1.decRef();
133+
exactSearchResponse2.decRef();
134+
}
135+
115136
// Force merge and search again
116137
assertNoFailures(indicesAdmin().prepareForceMerge(indexName1).get());
117138
assertNoFailures(indicesAdmin().prepareForceMerge(indexName2).get());
118139
ensureGreen();
119140

141+
// Test 3: Approximate KNN search - expect at least k-3 out of k matches
120142
var searchResponse3 = prepareSearch(indexName1).setSize(k)
121143
.setFetchSource(false)
122144
.addFetchField("my_keyword")
@@ -132,16 +154,33 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
132154
try {
133155
SearchHit[] hits3 = searchResponse3.getHits().getHits();
134156
SearchHit[] hits4 = searchResponse4.getHits().getHits();
135-
Assert.assertEquals(hits3.length, hits4.length);
136-
for (int i = 0; i < hits3.length; i++) {
137-
Assert.assertEquals(hits3[i].getId(), hits4[i].getId());
138-
Assert.assertEquals(hits3[i].field("my_keyword").getValue(), (String) hits4[i].field("my_keyword").getValue());
139-
Assert.assertEquals(hits3[i].getScore(), hits4[i].getScore(), 0.01f);
140-
}
157+
assertAtLeastNOutOfKMatches(hits3, hits4, k - 3, k);
141158
} finally {
142159
searchResponse3.decRef();
143160
searchResponse4.decRef();
144161
}
162+
163+
// Test 4: Exact KNN search after merge - expect perfect k out of k matches
164+
var exactSearchResponse3 = prepareSearch(indexName1).setSize(k)
165+
.setFetchSource(false)
166+
.addFetchField("my_keyword")
167+
.setQuery(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), "my_vector", null))
168+
.get();
169+
170+
var exactSearchResponse4 = prepareSearch(indexName2).setSize(k)
171+
.setFetchSource(false)
172+
.addFetchField("my_keyword")
173+
.setQuery(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), "my_vector", null))
174+
.get();
175+
176+
try {
177+
SearchHit[] exactHits3 = exactSearchResponse3.getHits().getHits();
178+
SearchHit[] exactHits4 = exactSearchResponse4.getHits().getHits();
179+
assertExactMatches(exactHits3, exactHits4, k);
180+
} finally {
181+
exactSearchResponse3.decRef();
182+
exactSearchResponse4.decRef();
183+
}
145184
}
146185

147186
public void testSearchWithoutGPU() {
@@ -261,4 +300,56 @@ private static float[] randomFloatVector(int dims) {
261300
}
262301
return vector;
263302
}
303+
304+
/**
305+
* Asserts that at least N out of K hits have matching IDs between two result sets.
306+
*/
307+
private static void assertAtLeastNOutOfKMatches(SearchHit[] hits1, SearchHit[] hits2, int minMatches, int k) {
308+
Assert.assertEquals("Both result sets should have k hits", k, hits1.length);
309+
Assert.assertEquals("Both result sets should have k hits", k, hits2.length);
310+
Set<String> ids1 = new HashSet<>();
311+
Set<String> ids2 = new HashSet<>();
312+
313+
for (SearchHit hit : hits1) {
314+
ids1.add(hit.getId());
315+
}
316+
for (SearchHit hit : hits2) {
317+
ids2.add(hit.getId());
318+
}
319+
320+
Set<String> intersection = new HashSet<>(ids1);
321+
intersection.retainAll(ids2);
322+
Assert.assertTrue(
323+
String.format(
324+
Locale.ROOT,
325+
"Expected at least %d matching IDs out of %d, but found %d. IDs1: %s, IDs2: %s",
326+
minMatches,
327+
k,
328+
intersection.size(),
329+
ids1,
330+
ids2
331+
),
332+
intersection.size() >= minMatches
333+
);
334+
}
335+
336+
/**
337+
* Asserts that two result sets have exactly the same document IDs in the same order with the same scores.
338+
* Used for exact (brute-force) KNN search which should be deterministic.
339+
* Expects k out of k matches.
340+
*/
341+
private static void assertExactMatches(SearchHit[] hits1, SearchHit[] hits2, int k) {
342+
Assert.assertEquals("Both result sets should have k hits", k, hits1.length);
343+
Assert.assertEquals("Both result sets should have k hits", k, hits2.length);
344+
345+
for (int i = 0; i < k; i++) {
346+
Assert.assertEquals(String.format(Locale.ROOT, "Document ID mismatch at position %d", i), hits1[i].getId(), hits2[i].getId());
347+
Assert.assertEquals(
348+
String.format(Locale.ROOT, "Score mismatch for document ID %s at position %d", hits1[i].getId(), i),
349+
hits1[i].getScore(),
350+
hits2[i].getScore(),
351+
0.0001f
352+
);
353+
}
354+
}
264355
}

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsWriter.java

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
1919
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
2020
import org.apache.lucene.index.ByteVectorValues;
21+
import org.apache.lucene.index.DocsWithFieldSet;
2122
import org.apache.lucene.index.FieldInfo;
2223
import org.apache.lucene.index.FloatVectorValues;
2324
import org.apache.lucene.index.IndexFileNames;
@@ -163,7 +164,6 @@ public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException
163164
* </p>
164165
*/
165166
@Override
166-
// TODO: fix sorted index case
167167
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
168168
var started = System.nanoTime();
169169
flatVectorWriter.flush(maxDoc, sortMap);
@@ -182,7 +182,11 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
182182
var started = System.nanoTime();
183183
var fieldInfo = field.fieldInfo;
184184

185-
var numVectors = field.flatFieldVectorsWriter.getVectors().size();
185+
var originalVectors = field.flatFieldVectorsWriter.getVectors();
186+
final List<float[]> vectorsInSortedOrder = sortMap == null
187+
? originalVectors
188+
: getVectorsInSortedOrder(field, sortMap, originalVectors);
189+
int numVectors = vectorsInSortedOrder.size();
186190
CagraIndexParams cagraIndexParams = createCagraIndexParams(
187191
fieldInfo.getVectorSimilarityFunction(),
188192
numVectors,
@@ -192,7 +196,7 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
192196
if (numVectors < MIN_NUM_VECTORS_FOR_GPU_BUILD) {
193197
logger.debug("Skip building carga index; vectors length {} < {} (min for GPU)", numVectors, MIN_NUM_VECTORS_FOR_GPU_BUILD);
194198
// Will not be indexed on the GPU
195-
flushFieldWithMockGraph(fieldInfo, numVectors, sortMap);
199+
generateMockGraphAndWriteMeta(fieldInfo, numVectors);
196200
} else {
197201
try (
198202
var resourcesHolder = new ResourcesHolder(
@@ -206,11 +210,11 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
206210
fieldInfo.getVectorDimension(),
207211
CuVSMatrix.DataType.FLOAT
208212
);
209-
for (var vector : field.flatFieldVectorsWriter.getVectors()) {
213+
for (var vector : vectorsInSortedOrder) {
210214
builder.addVector(vector);
211215
}
212216
try (var dataset = builder.build()) {
213-
flushFieldWithGpuGraph(resourcesHolder, fieldInfo, dataset, sortMap, cagraIndexParams);
217+
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
214218
}
215219
}
216220
}
@@ -219,28 +223,17 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
219223
}
220224
}
221225

222-
private void flushFieldWithMockGraph(FieldInfo fieldInfo, int numVectors, Sorter.DocMap sortMap) throws IOException {
223-
if (sortMap == null) {
224-
generateMockGraphAndWriteMeta(fieldInfo, numVectors);
225-
} else {
226-
// TODO: use sortMap
227-
generateMockGraphAndWriteMeta(fieldInfo, numVectors);
228-
}
229-
}
230-
231-
private void flushFieldWithGpuGraph(
232-
ResourcesHolder resourcesHolder,
233-
FieldInfo fieldInfo,
234-
CuVSMatrix dataset,
235-
Sorter.DocMap sortMap,
236-
CagraIndexParams cagraIndexParams
237-
) throws IOException {
238-
if (sortMap == null) {
239-
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
240-
} else {
241-
// TODO: use sortMap
242-
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
226+
private List<float[]> getVectorsInSortedOrder(FieldWriter field, Sorter.DocMap sortMap, List<float[]> originalVectors)
227+
throws IOException {
228+
DocsWithFieldSet docsWithField = field.getDocsWithFieldSet();
229+
int[] ordMap = new int[docsWithField.cardinality()];
230+
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
231+
KnnVectorsWriter.mapOldOrdToNewOrd(docsWithField, sortMap, null, ordMap, newDocsWithField);
232+
List<float[]> vectorsInSortedOrder = new ArrayList<>(ordMap.length);
233+
for (int oldOrd : ordMap) {
234+
vectorsInSortedOrder.add(originalVectors.get(oldOrd));
243235
}
236+
return vectorsInSortedOrder;
244237
}
245238

246239
@Override
@@ -512,8 +505,10 @@ public NodesIterator getNodesOnLevel(int level) {
512505

513506
// TODO check with deleted documents
514507
@Override
515-
// fix sorted index case
516508
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
509+
// Note: Merged raw vectors are already in sorted order. The flatVectorWriter and MergedVectorValues utilities
510+
// apply mergeState.docMaps internally, so vectors are returned in the final sorted document order.
511+
// Unlike flush(), we don't need to explicitly handle sorting here.
517512
try (var scorerSupplier = flatVectorWriter.mergeOneFieldToIndex(fieldInfo, mergeState)) {
518513
var started = System.nanoTime();
519514
int numVectors = scorerSupplier.totalVectorCount();
@@ -844,5 +839,9 @@ public float[] copyValue(float[] vectorValue) {
844839
public long ramBytesUsed() {
845840
return SHALLOW_SIZE + flatFieldVectorsWriter.ramBytesUsed();
846841
}
842+
843+
public DocsWithFieldSet getDocsWithFieldSet() {
844+
return flatFieldVectorsWriter.getDocsWithFieldSet();
845+
}
847846
}
848847
}

0 commit comments

Comments
 (0)