Skip to content

Commit 1b6eaab

Browse files
committed
Cleaned embeddings/
1 parent d660d66 commit 1b6eaab

File tree

5 files changed

+93
-61
lines changed

5 files changed

+93
-61
lines changed

nemoguardrails/embeddings/basic.py

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import logging
1818
from typing import Any, Dict, List, Optional, Union
1919

20-
from annoy import AnnoyIndex
20+
from annoy import AnnoyIndex # type: ignore
2121

2222
from nemoguardrails.embeddings.cache import cache_embeddings
2323
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
@@ -45,26 +45,16 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
4545
max_batch_hold: The maximum time a batch is held before being processed
4646
"""
4747

48-
embedding_model: str
49-
embedding_engine: str
50-
embedding_params: Dict[str, Any]
51-
index: AnnoyIndex
52-
embedding_size: int
53-
cache_config: EmbeddingsCacheConfig
54-
embeddings: List[List[float]]
55-
search_threshold: float
56-
use_batching: bool
57-
max_batch_size: int
58-
max_batch_hold: float
48+
# Instance attributes are defined in __init__ and accessed via properties
5949

6050
def __init__(
6151
self,
62-
embedding_model=None,
63-
embedding_engine=None,
64-
embedding_params=None,
65-
index=None,
66-
cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None,
67-
search_threshold: float = None,
52+
embedding_model: Optional[str] = None,
53+
embedding_engine: Optional[str] = None,
54+
embedding_params: Optional[Dict[str, Any]] = None,
55+
index: Optional[AnnoyIndex] = None,
56+
cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None,
57+
search_threshold: Optional[float] = None,
6858
use_batching: bool = False,
6959
max_batch_size: int = 10,
7060
max_batch_hold: float = 0.01,
@@ -81,10 +71,10 @@ def __init__(
8171
max_batch_hold: The maximum time a batch is held before being processed
8272
"""
8373
self._model: Optional[EmbeddingModel] = None
84-
self._items = []
85-
self._embeddings = []
86-
self.embedding_model = embedding_model
87-
self.embedding_engine = embedding_engine
74+
self._items: List[IndexItem] = []
75+
self._embeddings: List[List[float]] = []
76+
self.embedding_model: Optional[str] = embedding_model
77+
self.embedding_engine: Optional[str] = embedding_engine
8878
self.embedding_params = embedding_params or {}
8979
self._embedding_size = 0
9080
self.search_threshold = search_threshold or float("inf")
@@ -95,12 +85,12 @@ def __init__(
9585
self._index = index
9686

9787
# Data structures for batching embedding requests
98-
self._req_queue = {}
99-
self._req_results = {}
100-
self._req_idx = 0
101-
self._current_batch_finished_event = None
102-
self._current_batch_full_event = None
103-
self._current_batch_submitted = asyncio.Event()
88+
self._req_queue: Dict[int, str] = {}
89+
self._req_results: Dict[int, List[float]] = {}
90+
self._req_idx: int = 0
91+
self._current_batch_finished_event: Optional[asyncio.Event] = None
92+
self._current_batch_full_event: Optional[asyncio.Event] = None
93+
self._current_batch_submitted: asyncio.Event = asyncio.Event()
10494

10595
# Initialize the batching configuration
10696
self.use_batching = use_batching
@@ -112,6 +102,11 @@ def embeddings_index(self):
112102
"""Get the current embedding index"""
113103
return self._index
114104

105+
@embeddings_index.setter
106+
def embeddings_index(self, index):
107+
"""Setter to allow replacing the index dynamically."""
108+
self._index = index
109+
115110
@property
116111
def cache_config(self):
117112
"""Get the cache configuration."""
@@ -127,19 +122,23 @@ def embeddings(self):
127122
"""Get the computed embeddings."""
128123
return self._embeddings
129124

130-
@embeddings_index.setter
131-
def embeddings_index(self, index):
132-
"""Setter to allow replacing the index dynamically."""
133-
self._index = index
134-
135125
def _init_model(self):
136126
"""Initialize the model used for computing the embeddings."""
127+
# Provide defaults if not specified
128+
model = self.embedding_model or "sentence-transformers/all-MiniLM-L6-v2"
129+
engine = self.embedding_engine or "SentenceTransformers"
130+
137131
self._model = init_embedding_model(
138-
embedding_model=self.embedding_model,
139-
embedding_engine=self.embedding_engine,
132+
embedding_model=model,
133+
embedding_engine=engine,
140134
embedding_params=self.embedding_params,
141135
)
142136

137+
if not self._model:
138+
raise ValueError(
139+
f"Couldn't create embedding model with model {model} and engine {engine}"
140+
)
141+
143142
@cache_embeddings
144143
async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
145144
"""Compute embeddings for a list of texts.
@@ -153,6 +152,8 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
153152
if self._model is None:
154153
self._init_model()
155154

155+
if not self._model:
156+
raise Exception("Couldn't initialize embedding model")
156157
embeddings = await self._model.encode_async(texts)
157158
return embeddings
158159

@@ -199,6 +200,10 @@ async def _run_batch(self):
199200
"""Runs the current batch of embeddings."""
200201

201202
# Wait up to `max_batch_hold` time or until `max_batch_size` is reached.
203+
if not self._current_batch_full_event:
204+
raise Exception("self._current_batch_full_event not initialized")
205+
206+
assert self._current_batch_full_event is not None
202207
done, pending = await asyncio.wait(
203208
[
204209
asyncio.create_task(asyncio.sleep(self.max_batch_hold)),
@@ -210,6 +215,10 @@ async def _run_batch(self):
210215
task.cancel()
211216

212217
# Reset the batch event
218+
if not self._current_batch_finished_event:
219+
raise Exception("self._current_batch_finished_event not initialized")
220+
221+
assert self._current_batch_finished_event is not None
213222
batch_event: asyncio.Event = self._current_batch_finished_event
214223
self._current_batch_finished_event = None
215224

@@ -252,9 +261,13 @@ async def _batch_get_embeddings(self, text: str) -> List[float]:
252261

253262
# We check if we reached the max batch size
254263
if len(self._req_queue) >= self.max_batch_size:
264+
if not self._current_batch_full_event:
265+
raise Exception("self._current_batch_full_event not initialized")
255266
self._current_batch_full_event.set()
256267

257-
# Wait for the batch to finish
268+
# Wait for the batch to finish
269+
if not self._current_batch_finished_event:
270+
raise Exception("self._current_batch_finished_event not initialized")
258271
await self._current_batch_finished_event.wait()
259272

260273
# Remove the result and return it

nemoguardrails/embeddings/cache.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from abc import ABC, abstractmethod
2121
from functools import singledispatchmethod
2222
from pathlib import Path
23-
from typing import Dict, List
23+
from typing import Dict, List, Optional
24+
25+
try:
26+
import redis # type: ignore
27+
except ImportError:
28+
redis = None # type: ignore
2429

2530
from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig
2631

@@ -30,18 +35,20 @@
3035
class KeyGenerator(ABC):
3136
"""Abstract class for key generators."""
3237

38+
name: str # Class attribute that should be defined in subclasses
39+
3340
@abstractmethod
3441
def generate_key(self, text: str) -> str:
3542
pass
3643

3744
@classmethod
3845
def from_name(cls, name):
3946
for subclass in cls.__subclasses__():
40-
if subclass.name == name:
47+
if hasattr(subclass, "name") and subclass.name == name:
4148
return subclass
4249
raise ValueError(
4350
f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: "
44-
f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}"
51+
f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}"
4552
". Make sure to import the derived class before using it."
4653
)
4754

@@ -76,6 +83,8 @@ def generate_key(self, text: str) -> str:
7683
class CacheStore(ABC):
7784
"""Abstract class for cache stores."""
7885

86+
name: str # Class attribute that should be defined in subclasses
87+
7988
@abstractmethod
8089
def get(self, key):
8190
"""Get a value from the cache."""
@@ -94,11 +103,11 @@ def clear(self):
94103
@classmethod
95104
def from_name(cls, name):
96105
for subclass in cls.__subclasses__():
97-
if subclass.name == name:
106+
if hasattr(subclass, "name") and subclass.name == name:
98107
return subclass
99108
raise ValueError(
100109
f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: "
101-
f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}"
110+
f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}"
102111
". Make sure to import the derived class before using it."
103112
)
104113

@@ -147,7 +156,7 @@ class FilesystemCacheStore(CacheStore):
147156

148157
name = "filesystem"
149158

150-
def __init__(self, cache_dir: str = None):
159+
def __init__(self, cache_dir: Optional[str] = None):
151160
self._cache_dir = Path(cache_dir or ".cache/embeddings")
152161
self._cache_dir.mkdir(parents=True, exist_ok=True)
153162

@@ -190,8 +199,10 @@ class RedisCacheStore(CacheStore):
190199
name = "redis"
191200

192201
def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0):
193-
import redis
194-
202+
if redis is None:
203+
raise ImportError(
204+
"Could not import redis, please install it with `pip install redis`."
205+
)
195206
self._redis = redis.Redis(host=host, port=port, db=db)
196207

197208
def get(self, key):
@@ -207,9 +218,9 @@ def clear(self):
207218
class EmbeddingsCache:
208219
def __init__(
209220
self,
210-
key_generator: KeyGenerator = None,
211-
cache_store: CacheStore = None,
212-
store_config: dict = None,
221+
key_generator: Optional[KeyGenerator] = None,
222+
cache_store: Optional[CacheStore] = None,
223+
store_config: Optional[dict] = None,
213224
):
214225
self._key_generator = key_generator
215226
self._cache_store = cache_store
@@ -218,7 +229,10 @@ def __init__(
218229
@classmethod
219230
def from_dict(cls, d: Dict[str, str]):
220231
key_generator = KeyGenerator.from_name(d.get("key_generator"))()
221-
store_config = d.get("store_config")
232+
store_config_raw = d.get("store_config")
233+
store_config: dict = (
234+
store_config_raw if isinstance(store_config_raw, dict) else {}
235+
)
222236
cache_store = CacheStore.from_name(d.get("store"))(**store_config)
223237

224238
return cls(key_generator=key_generator, cache_store=cache_store)
@@ -230,25 +244,27 @@ def from_config(cls, config: EmbeddingsCacheConfig):
230244

231245
def get_config(self):
232246
return EmbeddingsCacheConfig(
233-
key_generator=self._key_generator.name,
234-
store=self._cache_store.name,
247+
key_generator=self._key_generator.name if self._key_generator else "sha256",
248+
store=self._cache_store.name if self._cache_store else "filesystem",
235249
store_config=self._store_config,
236250
)
237251

238252
@singledispatchmethod
239253
def get(self, texts):
240254
raise NotImplementedError
241255

242-
@get.register
256+
@get.register(str)
243257
def _(self, text: str):
258+
if self._key_generator is None or self._cache_store is None:
259+
return None
244260
key = self._key_generator.generate_key(text)
245261
log.info(f"Fetching key {key} for text '{text[:20]}...' from cache")
246262

247263
result = self._cache_store.get(key)
248264

249265
return result
250266

251-
@get.register
267+
@get.register(list)
252268
def _(self, texts: list):
253269
cached = {}
254270

@@ -266,19 +282,22 @@ def _(self, texts: list):
266282
def set(self, texts):
267283
raise NotImplementedError
268284

269-
@set.register
285+
@set.register(str)
270286
def _(self, text: str, value: List[float]):
287+
if self._key_generator is None or self._cache_store is None:
288+
return
271289
key = self._key_generator.generate_key(text)
272290
log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.")
273291
self._cache_store.set(key, value)
274292

275-
@set.register
293+
@set.register(list)
276294
def _(self, texts: list, values: List[List[float]]):
277295
for text, value in zip(texts, values):
278296
self.set(text, value)
279297

280298
def clear(self):
281-
self._cache_store.clear()
299+
if self._cache_store is not None:
300+
self._cache_store.clear()
282301

283302

284303
def cache_embeddings(func):

nemoguardrails/embeddings/providers/fastembed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class FastEmbedEmbeddingModel(EmbeddingModel):
4242
engine_name = "FastEmbed"
4343

4444
def __init__(self, embedding_model: str, **kwargs):
45-
from fastembed import TextEmbedding as Embedding
45+
from fastembed import TextEmbedding as Embedding # type: ignore
4646

4747
# Enabling a short form model name for all-MiniLM-L6-v2.
4848
if embedding_model == "all-MiniLM-L6-v2":

nemoguardrails/embeddings/providers/openai.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ def __init__(
4646
**kwargs,
4747
):
4848
try:
49-
import openai
50-
from openai import AsyncOpenAI, OpenAI
49+
import openai # type: ignore
50+
from openai import AsyncOpenAI, OpenAI # type: ignore
5151
except ImportError:
5252
raise ImportError(
5353
"Could not import openai, please install it with "
5454
"`pip install openai`."
5555
)
56-
if openai.__version__ < "1.0.0":
56+
if openai.__version__ < "1.0.0": # type: ignore
5757
raise RuntimeError(
5858
"`openai<1.0.0` is no longer supported. "
5959
"Please upgrade using `pip install openai>=1.0.0`."

nemoguardrails/embeddings/providers/sentence_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
4343

4444
def __init__(self, embedding_model: str, **kwargs):
4545
try:
46-
from sentence_transformers import SentenceTransformer
46+
from sentence_transformers import SentenceTransformer # type: ignore
4747
except ImportError:
4848
raise ImportError(
4949
"Could not import sentence-transformers, please install it with "
5050
"`pip install sentence-transformers`."
5151
)
5252

5353
try:
54-
from torch import cuda
54+
from torch import cuda # type: ignore
5555
except ImportError:
5656
raise ImportError(
5757
"Could not import torch, please install it with `pip install torch`."

0 commit comments

Comments
 (0)