1717import logging
1818from typing import Any , Dict , List , Optional , Union
1919
20- from annoy import AnnoyIndex
20+ from annoy import AnnoyIndex # type: ignore
2121
2222from nemoguardrails .embeddings .cache import cache_embeddings
2323from nemoguardrails .embeddings .index import EmbeddingsIndex , IndexItem
@@ -45,26 +45,16 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
4545 max_batch_hold: The maximum time a batch is held before being processed
4646 """
4747
48- embedding_model : str
49- embedding_engine : str
50- embedding_params : Dict [str , Any ]
51- index : AnnoyIndex
52- embedding_size : int
53- cache_config : EmbeddingsCacheConfig
54- embeddings : List [List [float ]]
55- search_threshold : float
56- use_batching : bool
57- max_batch_size : int
58- max_batch_hold : float
48+ # Instance attributes are defined in __init__ and accessed via properties
5949
6050 def __init__ (
6151 self ,
62- embedding_model = None ,
63- embedding_engine = None ,
64- embedding_params = None ,
65- index = None ,
66- cache_config : Union [EmbeddingsCacheConfig , Dict [str , Any ]] = None ,
67- search_threshold : float = None ,
52+ embedding_model : Optional [ str ] = None ,
53+ embedding_engine : Optional [ str ] = None ,
54+ embedding_params : Optional [ Dict [ str , Any ]] = None ,
55+ index : Optional [ AnnoyIndex ] = None ,
56+ cache_config : Optional [ Union [EmbeddingsCacheConfig , Dict [str , Any ] ]] = None ,
57+ search_threshold : Optional [ float ] = None ,
6858 use_batching : bool = False ,
6959 max_batch_size : int = 10 ,
7060 max_batch_hold : float = 0.01 ,
@@ -81,10 +71,10 @@ def __init__(
8171 max_batch_hold: The maximum time a batch is held before being processed
8272 """
8373 self ._model : Optional [EmbeddingModel ] = None
84- self ._items = []
85- self ._embeddings = []
86- self .embedding_model = embedding_model
87- self .embedding_engine = embedding_engine
74+ self ._items : List [ IndexItem ] = []
75+ self ._embeddings : List [ List [ float ]] = []
76+ self .embedding_model : Optional [ str ] = embedding_model
77+ self .embedding_engine : Optional [ str ] = embedding_engine
8878 self .embedding_params = embedding_params or {}
8979 self ._embedding_size = 0
9080 self .search_threshold = search_threshold or float ("inf" )
@@ -95,12 +85,12 @@ def __init__(
9585 self ._index = index
9686
9787 # Data structures for batching embedding requests
98- self ._req_queue = {}
99- self ._req_results = {}
100- self ._req_idx = 0
101- self ._current_batch_finished_event = None
102- self ._current_batch_full_event = None
103- self ._current_batch_submitted = asyncio .Event ()
88+ self ._req_queue : Dict [ int , str ] = {}
89+ self ._req_results : Dict [ int , List [ float ]] = {}
90+ self ._req_idx : int = 0
91+ self ._current_batch_finished_event : Optional [ asyncio . Event ] = None
92+ self ._current_batch_full_event : Optional [ asyncio . Event ] = None
93+ self ._current_batch_submitted : asyncio . Event = asyncio .Event ()
10494
10595 # Initialize the batching configuration
10696 self .use_batching = use_batching
@@ -112,6 +102,11 @@ def embeddings_index(self):
112102 """Get the current embedding index"""
113103 return self ._index
114104
105+ @embeddings_index .setter
106+ def embeddings_index (self , index ):
107+ """Setter to allow replacing the index dynamically."""
108+ self ._index = index
109+
115110 @property
116111 def cache_config (self ):
117112 """Get the cache configuration."""
@@ -127,19 +122,23 @@ def embeddings(self):
127122 """Get the computed embeddings."""
128123 return self ._embeddings
129124
130- @embeddings_index .setter
131- def embeddings_index (self , index ):
132- """Setter to allow replacing the index dynamically."""
133- self ._index = index
134-
135125 def _init_model (self ):
136126 """Initialize the model used for computing the embeddings."""
127+ # Provide defaults if not specified
128+ model = self .embedding_model or "sentence-transformers/all-MiniLM-L6-v2"
129+ engine = self .embedding_engine or "SentenceTransformers"
130+
137131 self ._model = init_embedding_model (
138- embedding_model = self . embedding_model ,
139- embedding_engine = self . embedding_engine ,
132+ embedding_model = model ,
133+ embedding_engine = engine ,
140134 embedding_params = self .embedding_params ,
141135 )
142136
137+ if not self ._model :
138+ raise ValueError (
139+ f"Couldn't create embedding model with model { model } and engine { engine } "
140+ )
141+
143142 @cache_embeddings
144143 async def _get_embeddings (self , texts : List [str ]) -> List [List [float ]]:
145144 """Compute embeddings for a list of texts.
@@ -153,6 +152,8 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
153152 if self ._model is None :
154153 self ._init_model ()
155154
155+ if not self ._model :
156+ raise Exception ("Couldn't initialize embedding model" )
156157 embeddings = await self ._model .encode_async (texts )
157158 return embeddings
158159
@@ -199,6 +200,10 @@ async def _run_batch(self):
199200 """Runs the current batch of embeddings."""
200201
201202 # Wait up to `max_batch_hold` time or until `max_batch_size` is reached.
203+ if not self ._current_batch_full_event :
204+ raise Exception ("self._current_batch_full_event not initialized" )
205+
206+ assert self ._current_batch_full_event is not None
202207 done , pending = await asyncio .wait (
203208 [
204209 asyncio .create_task (asyncio .sleep (self .max_batch_hold )),
@@ -210,6 +215,10 @@ async def _run_batch(self):
210215 task .cancel ()
211216
212217 # Reset the batch event
218+ if not self ._current_batch_finished_event :
219+ raise Exception ("self._current_batch_finished_event not initialized" )
220+
221+ assert self ._current_batch_finished_event is not None
213222 batch_event : asyncio .Event = self ._current_batch_finished_event
214223 self ._current_batch_finished_event = None
215224
@@ -252,9 +261,13 @@ async def _batch_get_embeddings(self, text: str) -> List[float]:
252261
253262 # We check if we reached the max batch size
254263 if len (self ._req_queue ) >= self .max_batch_size :
264+ if not self ._current_batch_full_event :
265+ raise Exception ("self._current_batch_full_event not initialized" )
255266 self ._current_batch_full_event .set ()
256267
257- # Wait for the batch to finish
268+ # Wait for the batch to finish
269+ if not self ._current_batch_finished_event :
270+ raise Exception ("self._current_batch_finished_event not initialized" )
258271 await self ._current_batch_finished_event .wait ()
259272
260273 # Remove the result and return it
0 commit comments