1414import os
1515import warnings
1616
17+ import numpy as np
18+
1719from tempfile import TemporaryDirectory
1820
1921import pysvs
@@ -79,12 +81,50 @@ def _setup(self, loader: pysvs.VectorDataLoader):
7981 }),
8082 ]
8183
84+ # Ensure that passing 1-dimensional queries works and produces the same results as
85+ # query batches.
86+ def _test_single_query (
87+ self ,
88+ vamana : pysvs .Vamana ,
89+ queries
90+ ):
91+
92+ I_full , D_full = vamana .search (queries , 10 );
93+
94+ I_single = []
95+ D_single = []
96+ for i in range (queries .shape [0 ]):
97+ query = queries [i , :]
98+ self .assertTrue (query .ndim == 1 )
99+ I , D = vamana .search (query , 10 )
100+
101+ self .assertTrue (I .ndim == 2 )
102+ self .assertTrue (D .ndim == 2 )
103+ self .assertTrue (I .shape == (1 , 10 ))
104+ self .assertTrue (D .shape == (1 , 10 ))
105+
106+ I_single .append (I )
107+ D_single .append (D )
108+
109+ I_single_concat = np .concatenate (I_single , axis = 0 )
110+ D_single_concat = np .concatenate (D_single , axis = 0 )
111+ self .assertTrue (np .array_equal (I_full , I_single_concat ))
112+ self .assertTrue (np .array_equal (D_full , D_single_concat ))
113+
114+ # Throw an error on 3-dimensional inputs.
115+ queries_3d = queries [:, :, np .newaxis ]
116+ with self .assertRaises (Exception ) as context :
117+ vamana .search (queries_3d , 10 )
118+
119+ self .assertTrue ("only accept numpy vectors or matrices" in str (context .exception ))
120+
82121 def _test_basic_inner (
83122 self ,
84123 vamana : pysvs .Vamana ,
85124 recall_dict ,
86125 num_threads : int ,
87126 skip_thread_test : bool = False ,
127+ test_single_query : bool = False ,
88128 ):
89129 # Make sure that the number of threads is propagated correctly.
90130 self .assertEqual (vamana .num_threads , num_threads )
@@ -129,6 +169,9 @@ def _test_basic_inner(
129169 if not DEBUG :
130170 self .assertTrue (isapprox (recall , expected_recall , epsilon = 0.0005 ))
131171
172+ if test_single_query :
173+ self ._test_single_query (vamana , queries )
174+
132175 # Disable visited set.
133176 self .visited_set_enabled = False
134177
@@ -158,6 +201,7 @@ def _test_basic(self, loader, recall_dict):
158201 self ._test_basic_inner (vamana , recall_dict , num_threads )
159202
160203 # Test saving and reloading.
204+ is_first = True
161205 with TemporaryDirectory () as tempdir :
162206 configdir = os .path .join (tempdir , "config" )
163207 graphdir = os .path .join (tempdir , "graph" )
@@ -179,8 +223,13 @@ def _test_basic(self, loader, recall_dict):
179223
180224 reloaded .num_threads = num_threads
181225 self ._test_basic_inner (
182- reloaded , recall_dict , num_threads , skip_thread_test = True
226+ reloaded ,
227+ recall_dict ,
228+ num_threads ,
229+ skip_thread_test = True ,
230+ test_single_query = is_first ,
183231 )
232+ is_first = False
184233
185234 def test_basic (self ):
186235 # Load the index from files.
0 commit comments