Skip to content

Commit c3e036b

Browse files
Upgrade to use Pydantic 2.x.x (#101)
With the shift in the ecosystem to Pydantic 2.x.x, we need to get `redisvl` to the same level. Referenced here: https://github.com/RedisVentures/redisvl/issues/86 Fortunately, there is not a ton of complex usage of Pydantic; which makes this easy! Changes include: - Shifting the storage classes to use Pydantic to avoid manual (and ugly) field validation. Now, the subclasses inherit from the base and params are all typed as expected. - Shifting the vectorizer classes to use Pydantic for the same reasons. - Bump and pin pydantic version in the packge requirements to a reasonable range `pydantic>=2.0.0,<3` - use `v1` shim for now for safety
1 parent 24955a2 commit c3e036b

File tree

15 files changed

+132
-90
lines changed

15 files changed

+132
-90
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ scratch
66
.DS_Store
77
*.csv
88
wiki_schema.yaml
9-
docs/_build/
9+
docs/_build/
10+
.venv

CONTRIBUTING.md

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,38 @@ Here's how to get started with your code contribution:
3030
### Dev Environment
3131
There is a provided `requirements.txt` and `requirements-dev.txt` file you can use to install required libraries with `pip` into your virtual environment.
3232

33+
Or use the local package editable install method:
34+
```bash
35+
python -m venv .venv
36+
source .venv/bin/activate
37+
pip install -e .[all,dev]
38+
```
39+
40+
Then to deactivate the env:
41+
```
42+
source deactivate
43+
```
44+
45+
### Linting and Tests
46+
47+
Check formatting, linting, and typing:
48+
```bash
49+
make check
50+
```
51+
52+
Tests (with vectorizers):
53+
```bash
54+
make test-cov
55+
```
56+
57+
Tests w/out vectorizers:
58+
```bash
59+
SKIP_VECTORIZERS=true make test-cov
60+
```
61+
62+
> Dev requirements are needed here to be able to run tests and linting.
63+
> See other commands in the [Makefile](Makefile)
64+
3365
### Docker Tips
3466

3567
Make sure to have [Redis](https://redis.io) accessible with Search & Query features enabled on [Redis Cloud](https://redis.com/try-free) or locally in docker with [Redis Stack](https://redis.io/docs/getting-started/install-stack/docker/):
@@ -38,7 +70,7 @@ Make sure to have [Redis](https://redis.io) accessible with Search & Query featu
3870
docker run -d --name redis-stack -p 6379:6379 -p 8001:8001 redis/redis-stack:latest
3971
```
4072

41-
This will also spin up the [Redis Insight GUI](https://redis.com/redis-enterprise/redis-insight/) at `http://localhost:8001`.
73+
This will also spin up the [FREE RedisInsight GUI](https://redis.com/redis-enterprise/redis-insight/) at `http://localhost:8001`.
4274

4375
## How to Report a Bug
4476

redisvl/index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def __init__(
196196
self.schema = schema
197197

198198
self._storage = self._STORAGE_MAP[self.schema.storage_type](
199-
self.schema.prefix, self.schema.key_separator
199+
prefix=self.schema.prefix, key_separator=self.schema.key_separator
200200
)
201201

202202
@property

redisvl/llmcache/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
distance_threshold: float = 0.1,
6868
ttl: Optional[int] = None,
6969
vectorizer: BaseVectorizer = HFTextVectorizer(
70-
"sentence-transformers/all-mpnet-base-v2"
70+
model="sentence-transformers/all-mpnet-base-v2"
7171
),
7272
redis_url: str = "redis://localhost:6379",
7373
connection_args: Dict[str, Any] = {},

redisvl/schema/fields.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Dict, Optional, Union
22

3-
from pydantic import BaseModel, Field, validator
3+
from pydantic.v1 import BaseModel, Field, validator
44
from redis.commands.search.field import Field as RedisField
55
from redis.commands.search.field import GeoField as RedisGeoField
66
from redis.commands.search.field import NumericField as RedisNumericField
@@ -69,6 +69,7 @@ class BaseVectorField(BaseModel):
6969
as_name: Optional[str] = None
7070

7171
@validator("algorithm", "datatype", "distance_metric", pre=True)
72+
@classmethod
7273
def uppercase_strings(cls, v):
7374
return v.upper()
7475

redisvl/schema/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any, Dict, List, Union
55

66
import yaml
7-
from pydantic import BaseModel, validator
7+
from pydantic.v1 import BaseModel, validator
88
from redis.commands.search.field import Field as RedisField
99

1010
from redisvl.schema.fields import BaseField, BaseVectorField, FieldFactory
@@ -66,6 +66,7 @@ class IndexSchema(BaseModel):
6666
fields: Dict[str, List[Union[BaseField, BaseVectorField]]] = {}
6767

6868
@validator("fields", pre=True)
69+
@classmethod
6970
def check_unique_field_names(cls, fields):
7071
"""Validate that field names are all unique."""
7172
all_names = cls._get_field_names(fields)

redisvl/storage.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,27 @@
22
import uuid
33
from typing import Any, Callable, Dict, Iterable, List, Optional
44

5+
from pydantic.v1 import BaseModel
56
from redis import Redis
67
from redis.asyncio import Redis as AsyncRedis
78
from redis.commands.search.indexDefinition import IndexType
89

910
from redisvl.utils.utils import convert_bytes
1011

1112

12-
class BaseStorage:
13-
type: IndexType
14-
DEFAULT_BATCH_SIZE: int = 200
15-
DEFAULT_WRITE_CONCURRENCY: int = 20
13+
class BaseStorage(BaseModel):
14+
"""
15+
Base class for internal storage handling in Redis.
1616
17-
def __init__(self, prefix: str, key_separator: str):
18-
"""Initialize the BaseStorage with a specific prefix and key separator
19-
for Redis keys.
17+
Provides foundational methods for key management, data preprocessing,
18+
validation, and basic read/write operations (both sync and async).
19+
"""
2020

21-
Args:
22-
prefix (str): The prefix to prepend to each Redis key.
23-
key_separator (str): The separator to use between the prefix and
24-
the key value.
25-
"""
26-
self._prefix = prefix
27-
self._key_separator = key_separator
21+
type: IndexType # Type of index used in storage
22+
prefix: str # Prefix for Redis keys
23+
key_separator: str # Separator between prefix and key value
24+
default_batch_size: int = 200 # Default size for batch operations
25+
default_write_concurrency: int = 20 # Default concurrency for async ops
2826

2927
@staticmethod
3028
def _key(key_value: str, prefix: str, key_separator: str) -> str:
@@ -69,7 +67,7 @@ def _create_key(self, obj: Dict[str, Any], key_field: Optional[str] = None) -> s
6967
raise ValueError(f"Key field {key_field} not found in record {obj}")
7068

7169
return self._key(
72-
key_value, prefix=self._prefix, key_separator=self._key_separator
70+
key_value, prefix=self.prefix, key_separator=self.key_separator
7371
)
7472

7573
@staticmethod
@@ -202,7 +200,7 @@ def write(
202200

203201
if batch_size is None:
204202
# Use default or calculate based on the input data
205-
batch_size = self.DEFAULT_BATCH_SIZE
203+
batch_size = self.default_batch_size
206204

207205
keys_iterator = iter(keys) if keys else None
208206
added_keys: List[str] = []
@@ -272,7 +270,7 @@ async def awrite(
272270
raise ValueError("Length of keys does not match the length of objects")
273271

274272
if not concurrency:
275-
concurrency = self.DEFAULT_WRITE_CONCURRENCY
273+
concurrency = self.default_write_concurrency
276274

277275
semaphore = asyncio.Semaphore(concurrency)
278276
keys_iterator = iter(keys) if keys else None
@@ -322,7 +320,7 @@ def get(
322320

323321
if batch_size is None:
324322
batch_size = (
325-
self.DEFAULT_BATCH_SIZE
323+
self.default_batch_size
326324
) # Use default or calculate based on the input data
327325

328326
# Use a pipeline to batch the retrieval
@@ -363,7 +361,7 @@ async def aget(
363361
return []
364362

365363
if not concurrency:
366-
concurrency = self.DEFAULT_WRITE_CONCURRENCY
364+
concurrency = self.default_write_concurrency
367365

368366
semaphore = asyncio.Semaphore(concurrency)
369367

@@ -378,6 +376,13 @@ async def _get(key: str) -> Dict[str, Any]:
378376

379377

380378
class HashStorage(BaseStorage):
379+
"""
380+
Internal subclass of BaseStorage for the Redis hash data type.
381+
382+
Implements hash-specific logic for validation and read/write operations
383+
(both sync and async) in Redis.
384+
"""
385+
381386
type: IndexType = IndexType.HASH
382387

383388
def _validate(self, obj: Dict[str, Any]):
@@ -443,6 +448,13 @@ async def _aget(client: AsyncRedis, key: str) -> Dict[str, Any]:
443448

444449

445450
class JsonStorage(BaseStorage):
451+
"""
452+
Internal subclass of BaseStorage for the Redis JSON data type.
453+
454+
Implements json-specific logic for validation and read/write operations
455+
(both sync and async) in Redis.
456+
"""
457+
446458
type: IndexType = IndexType.JSON
447459

448460
def _validate(self, obj: Dict[str, Any]):

redisvl/vectorize/base.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,21 @@
1-
from typing import Callable, List, Optional
2-
3-
from redisvl.utils.utils import array_to_buffer
1+
from typing import Any, Callable, List, Optional
42

3+
from pydantic.v1 import BaseModel, validator
54

6-
class BaseVectorizer:
7-
_dims = None
8-
9-
def __init__(self, model: str):
10-
self._model = model
5+
from redisvl.utils.utils import array_to_buffer
116

12-
@property
13-
def model(self) -> str:
14-
return self._model
157

16-
@property
17-
def dims(self) -> Optional[int]:
18-
return self._dims
8+
class BaseVectorizer(BaseModel):
9+
model: str
10+
dims: int
11+
client: Any
1912

20-
def set_model(self, model: str, dims: Optional[int] = None) -> None:
21-
self._model = model
22-
if dims is not None:
23-
self._dims = dims
13+
@validator("dims", pre=True)
14+
@classmethod
15+
def check_dims(cls, v):
16+
if v <= 0:
17+
raise ValueError("Dimension must be a positive integer")
18+
return v
2419

2520
def embed_many(
2621
self,

redisvl/vectorize/text/cohere.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def __init__(
6060
ValueError: If the API key is not provided.
6161
6262
"""
63-
super().__init__(model)
6463
# Dynamic import of the cohere module
6564
try:
6665
import cohere
@@ -80,15 +79,16 @@ def __init__(
8079
"Provide it in api_config or set the COHERE_API_KEY environment variable."
8180
)
8281

83-
self._model = model
84-
self._model_client = cohere.Client(api_key)
85-
self._dims = self._set_model_dims()
82+
client = cohere.Client(api_key)
83+
dims = self._set_model_dims(client, model)
84+
super().__init__(model=model, dims=dims, client=client)
8685

87-
def _set_model_dims(self) -> int:
86+
@staticmethod
87+
def _set_model_dims(client, model) -> int:
8888
try:
89-
embedding = self._model_client.embed(
89+
embedding = client.embed(
9090
texts=["dimension test"],
91-
model=self._model,
91+
model=model,
9292
input_type="search_document",
9393
).embeddings[0]
9494
except (KeyError, IndexError) as ke:
@@ -150,8 +150,8 @@ def embed(
150150
)
151151
if preprocess:
152152
text = preprocess(text)
153-
embedding = self._model_client.embed(
154-
texts=[text], model=self._model, input_type=input_type
153+
embedding = self.client.embed(
154+
texts=[text], model=self.model, input_type=input_type
155155
).embeddings[0]
156156
return self._process_embedding(embedding, as_buffer)
157157

@@ -219,8 +219,8 @@ def embed_many(
219219

220220
embeddings: List = []
221221
for batch in self.batchify(texts, batch_size, preprocess):
222-
response = self._model_client.embed(
223-
texts=batch, model=self._model, input_type=input_type
222+
response = self.client.embed(
223+
texts=batch, model=self.model, input_type=input_type
224224
)
225225
embeddings += [
226226
self._process_embedding(embedding, as_buffer)

redisvl/vectorize/text/huggingface.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ def __init__(
4242
ImportError: If the sentence-transformers library is not installed.
4343
ValueError: If there is an error setting the embedding model dimensions.
4444
"""
45-
super().__init__(model)
46-
4745
# Load the SentenceTransformer model
4846
try:
4947
from sentence_transformers import SentenceTransformer
@@ -53,16 +51,19 @@ def __init__(
5351
"Please install with `pip install sentence-transformers`"
5452
)
5553

56-
self._model_client = SentenceTransformer(model)
54+
client = SentenceTransformer(model)
55+
dims = self._set_model_dims(client)
56+
super().__init__(model=model, dims=dims, client=client)
5757

58-
# Initialize model dimensions
58+
@staticmethod
59+
def _set_model_dims(client):
5960
try:
60-
self._dims = self._set_model_dims()
61-
except Exception as e:
62-
raise ValueError(f"Error setting embedding model dimensions: {e}")
63-
64-
def _set_model_dims(self):
65-
embedding = self._model_client.encode(["dimension check"])[0]
61+
embedding = client.encode(["dimension check"])[0]
62+
except (KeyError, IndexError) as ke:
63+
raise ValueError(f"Empty response from the embedding model: {str(ke)}")
64+
except Exception as e: # pylint: disable=broad-except
65+
# fall back (TODO get more specific)
66+
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
6667
return len(embedding)
6768

6869
def embed(
@@ -92,7 +93,7 @@ def embed(
9293

9394
if preprocess:
9495
text = preprocess(text)
95-
embedding = self._model_client.encode([text])[0]
96+
embedding = self.client.encode([text])[0]
9697
return self._process_embedding(embedding.tolist(), as_buffer)
9798

9899
def embed_many(
@@ -128,7 +129,7 @@ def embed_many(
128129

129130
embeddings: List = []
130131
for batch in self.batchify(texts, batch_size, preprocess):
131-
batch_embeddings = self._model_client.encode(batch)
132+
batch_embeddings = self.client.encode(batch)
132133
embeddings.extend(
133134
[
134135
self._process_embedding(embedding.tolist(), as_buffer)

0 commit comments

Comments
 (0)