Skip to content

Commit 2c4a62f

Browse files
committed
first changes to base_client
1 parent 9e8e2ff commit 2c4a62f

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

engine/base_client/search.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import random
23
import time
34
from multiprocessing import Process, Queue
45
from typing import Iterable, List, Optional, Tuple
@@ -41,6 +42,12 @@ def search_one(
4142
) -> List[Tuple[int, float]]:
4243
raise NotImplementedError()
4344

45+
@classmethod
46+
def insert_one(
47+
cls, vector: List[float], meta_conditions, top: Optional[int]
48+
) -> List[Tuple[int, float]]:
49+
raise NotImplementedError()
50+
4451
@classmethod
4552
def _search_one(cls, query, top: Optional[int] = None):
4653
if top is None:
@@ -228,7 +235,7 @@ def chunked_iterable(iterable, size):
228235
yield chunk
229236

230237
# Function to be executed by each worker process
231-
def worker_function(self, distance, search_one, chunk, result_queue):
238+
def worker_function(self, distance, search_one, insert_one, chunk, result_queue, insert_fraction=0.0):
232239
self.init_client(
233240
self.host,
234241
distance,
@@ -238,10 +245,17 @@ def worker_function(self, distance, search_one, chunk, result_queue):
238245
self.setup_search()
239246

240247
start_time = time.perf_counter()
241-
results = process_chunk(chunk, search_one)
248+
results = process_chunk(chunk, search_one, insert_one, insert_fraction)
242249
result_queue.put((start_time, results))
243250

244-
def process_chunk(chunk, search_one):
245-
"""Process a chunk of queries using the search_one function."""
246-
# No progress bar in worker processes to avoid cluttering the output
247-
return [search_one(query) for query in chunk]
251+
252+
def process_chunk(chunk, search_one, insert_one, insert_fraction):
253+
results = []
254+
for i, query in enumerate(chunk):
255+
if random.random() < insert_fraction:
256+
result = insert_one(query)
257+
else:
258+
# Search
259+
result = search_one(query)
260+
results.append(result)
261+
return results

0 commit comments

Comments
 (0)