1313import org .elasticsearch .common .settings .Settings ;
1414import org .elasticsearch .plugins .Plugin ;
1515import org .elasticsearch .search .SearchHit ;
16+ import org .elasticsearch .search .vectors .ExactKnnQueryBuilder ;
1617import org .elasticsearch .search .vectors .KnnSearchBuilder ;
18+ import org .elasticsearch .search .vectors .VectorData ;
1719import org .elasticsearch .test .ESIntegTestCase ;
1820import org .elasticsearch .xpack .gpu .GPUPlugin ;
1921import org .elasticsearch .xpack .gpu .GPUSupport ;
2022import org .junit .Assert ;
2123import org .junit .BeforeClass ;
2224
2325import java .util .Collection ;
26+ import java .util .HashSet ;
2427import java .util .List ;
2528import java .util .Locale ;
29+ import java .util .Set ;
2630
2731import static org .elasticsearch .test .hamcrest .ElasticsearchAssertions .assertAcked ;
2832import static org .elasticsearch .test .hamcrest .ElasticsearchAssertions .assertNoFailures ;
@@ -56,20 +60,19 @@ public void testBasic() {
5660 assertSearch (indexName , randomFloatVector (dims ), totalDocs );
5761 }
5862
59- @ AwaitsFix (bugUrl = "Fix sorted index" )
6063 public void testSortedIndexReturnsSameResultsAsUnsorted () {
6164 String indexName1 = "index_unsorted" ;
6265 String indexName2 = "index_sorted" ;
6366 final int dims = randomIntBetween (4 , 128 );
6467 createIndex (indexName1 , dims , false );
6568 createIndex (indexName2 , dims , true );
6669
67- final int [] numDocs = new int [] { randomIntBetween (50 , 100 ), randomIntBetween (50 , 100 ) };
70+ final int [] numDocs = new int [] { randomIntBetween (300 , 999 ), randomIntBetween (300 , 999 ) };
6871 for (int i = 0 ; i < numDocs .length ; i ++) {
6972 BulkRequestBuilder bulkRequest1 = client ().prepareBulk ();
7073 BulkRequestBuilder bulkRequest2 = client ().prepareBulk ();
7174 for (int j = 0 ; j < numDocs [i ]; j ++) {
72- String id = String .valueOf (i * 100 + j );
75+ String id = String .valueOf (i * 1000 + j );
7376 String keywordValue = String .valueOf (numDocs [i ] - j );
7477 float [] vector = randomFloatVector (dims );
7578 bulkRequest1 .add (prepareIndex (indexName1 ).setId (id ).setSource ("my_vector" , vector , "my_keyword" , keywordValue ));
@@ -84,8 +87,9 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
8487
8588 float [] queryVector = randomFloatVector (dims );
8689 int k = 10 ;
87- int numCandidates = k * 10 ;
90+ int numCandidates = k * 5 ;
8891
92+ // Test 1: Approximate KNN search - expect at least k-3 out of k matches
8993 var searchResponse1 = prepareSearch (indexName1 ).setSize (k )
9094 .setFetchSource (false )
9195 .addFetchField ("my_keyword" )
@@ -101,22 +105,40 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
101105 try {
102106 SearchHit [] hits1 = searchResponse1 .getHits ().getHits ();
103107 SearchHit [] hits2 = searchResponse2 .getHits ().getHits ();
104- Assert .assertEquals (hits1 .length , hits2 .length );
105- for (int i = 0 ; i < hits1 .length ; i ++) {
106- Assert .assertEquals (hits1 [i ].getId (), hits2 [i ].getId ());
107- Assert .assertEquals (hits1 [i ].field ("my_keyword" ).getValue (), (String ) hits2 [i ].field ("my_keyword" ).getValue ());
108- Assert .assertEquals (hits1 [i ].getScore (), hits2 [i ].getScore (), 0.001f );
109- }
108+ assertAtLeastNOutOfKMatches (hits1 , hits2 , k - 3 , k );
110109 } finally {
111110 searchResponse1 .decRef ();
112111 searchResponse2 .decRef ();
113112 }
114113
114+ // Test 2: Exact KNN search (brute-force) - expect perfect k out of k matches
115+ var exactSearchResponse1 = prepareSearch (indexName1 ).setSize (k )
116+ .setFetchSource (false )
117+ .addFetchField ("my_keyword" )
118+ .setQuery (new ExactKnnQueryBuilder (VectorData .fromFloats (queryVector ), "my_vector" , null ))
119+ .get ();
120+
121+ var exactSearchResponse2 = prepareSearch (indexName2 ).setSize (k )
122+ .setFetchSource (false )
123+ .addFetchField ("my_keyword" )
124+ .setQuery (new ExactKnnQueryBuilder (VectorData .fromFloats (queryVector ), "my_vector" , null ))
125+ .get ();
126+
127+ try {
128+ SearchHit [] exactHits1 = exactSearchResponse1 .getHits ().getHits ();
129+ SearchHit [] exactHits2 = exactSearchResponse2 .getHits ().getHits ();
130+ assertExactMatches (exactHits1 , exactHits2 , k );
131+ } finally {
132+ exactSearchResponse1 .decRef ();
133+ exactSearchResponse2 .decRef ();
134+ }
135+
115136 // Force merge and search again
116137 assertNoFailures (indicesAdmin ().prepareForceMerge (indexName1 ).get ());
117138 assertNoFailures (indicesAdmin ().prepareForceMerge (indexName2 ).get ());
118139 ensureGreen ();
119140
141+ // Test 3: Approximate KNN search - expect at least k-3 out of k matches
120142 var searchResponse3 = prepareSearch (indexName1 ).setSize (k )
121143 .setFetchSource (false )
122144 .addFetchField ("my_keyword" )
@@ -132,16 +154,33 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
132154 try {
133155 SearchHit [] hits3 = searchResponse3 .getHits ().getHits ();
134156 SearchHit [] hits4 = searchResponse4 .getHits ().getHits ();
135- Assert .assertEquals (hits3 .length , hits4 .length );
136- for (int i = 0 ; i < hits3 .length ; i ++) {
137- Assert .assertEquals (hits3 [i ].getId (), hits4 [i ].getId ());
138- Assert .assertEquals (hits3 [i ].field ("my_keyword" ).getValue (), (String ) hits4 [i ].field ("my_keyword" ).getValue ());
139- Assert .assertEquals (hits3 [i ].getScore (), hits4 [i ].getScore (), 0.01f );
140- }
157+ assertAtLeastNOutOfKMatches (hits3 , hits4 , k - 3 , k );
141158 } finally {
142159 searchResponse3 .decRef ();
143160 searchResponse4 .decRef ();
144161 }
162+
163+ // Test 4: Exact KNN search after merge - expect perfect k out of k matches
164+ var exactSearchResponse3 = prepareSearch (indexName1 ).setSize (k )
165+ .setFetchSource (false )
166+ .addFetchField ("my_keyword" )
167+ .setQuery (new ExactKnnQueryBuilder (VectorData .fromFloats (queryVector ), "my_vector" , null ))
168+ .get ();
169+
170+ var exactSearchResponse4 = prepareSearch (indexName2 ).setSize (k )
171+ .setFetchSource (false )
172+ .addFetchField ("my_keyword" )
173+ .setQuery (new ExactKnnQueryBuilder (VectorData .fromFloats (queryVector ), "my_vector" , null ))
174+ .get ();
175+
176+ try {
177+ SearchHit [] exactHits3 = exactSearchResponse3 .getHits ().getHits ();
178+ SearchHit [] exactHits4 = exactSearchResponse4 .getHits ().getHits ();
179+ assertExactMatches (exactHits3 , exactHits4 , k );
180+ } finally {
181+ exactSearchResponse3 .decRef ();
182+ exactSearchResponse4 .decRef ();
183+ }
145184 }
146185
147186 public void testSearchWithoutGPU () {
@@ -261,4 +300,56 @@ private static float[] randomFloatVector(int dims) {
261300 }
262301 return vector ;
263302 }
303+
304+ /**
305+ * Asserts that at least N out of K hits have matching IDs between two result sets.
306+ */
307+ private static void assertAtLeastNOutOfKMatches (SearchHit [] hits1 , SearchHit [] hits2 , int minMatches , int k ) {
308+ Assert .assertEquals ("Both result sets should have k hits" , k , hits1 .length );
309+ Assert .assertEquals ("Both result sets should have k hits" , k , hits2 .length );
310+ Set <String > ids1 = new HashSet <>();
311+ Set <String > ids2 = new HashSet <>();
312+
313+ for (SearchHit hit : hits1 ) {
314+ ids1 .add (hit .getId ());
315+ }
316+ for (SearchHit hit : hits2 ) {
317+ ids2 .add (hit .getId ());
318+ }
319+
320+ Set <String > intersection = new HashSet <>(ids1 );
321+ intersection .retainAll (ids2 );
322+ Assert .assertTrue (
323+ String .format (
324+ Locale .ROOT ,
325+ "Expected at least %d matching IDs out of %d, but found %d. IDs1: %s, IDs2: %s" ,
326+ minMatches ,
327+ k ,
328+ intersection .size (),
329+ ids1 ,
330+ ids2
331+ ),
332+ intersection .size () >= minMatches
333+ );
334+ }
335+
336+ /**
337+ * Asserts that two result sets have exactly the same document IDs in the same order with the same scores.
338+ * Used for exact (brute-force) KNN search which should be deterministic.
339+ * Expects k out of k matches.
340+ */
341+ private static void assertExactMatches (SearchHit [] hits1 , SearchHit [] hits2 , int k ) {
342+ Assert .assertEquals ("Both result sets should have k hits" , k , hits1 .length );
343+ Assert .assertEquals ("Both result sets should have k hits" , k , hits2 .length );
344+
345+ for (int i = 0 ; i < k ; i ++) {
346+ Assert .assertEquals (String .format (Locale .ROOT , "Document ID mismatch at position %d" , i ), hits1 [i ].getId (), hits2 [i ].getId ());
347+ Assert .assertEquals (
348+ String .format (Locale .ROOT , "Score mismatch for document ID %s at position %d" , hits1 [i ].getId (), i ),
349+ hits1 [i ].getScore (),
350+ hits2 [i ].getScore (),
351+ 0.0001f
352+ );
353+ }
354+ }
264355}
0 commit comments