Skip to content

Commit d777138

Browse files
authored
Extension of embeddings draft implementation to support local models (#3463)
1 parent 9ffddf8 commit d777138

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed

pydantic_ai_slim/pydantic_ai/embeddings/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ def infer_model(
7474
from .cohere import CohereEmbeddingModel
7575

7676
return CohereEmbeddingModel(model_name, provider=provider)
77+
elif model_kind == 'sentence-transformers':
78+
from .sentence_transformers import SentenceTransformerEmbeddingModel
79+
80+
return SentenceTransformerEmbeddingModel(model_name)
7781
else:
7882
raise UserError(f'Unknown embeddings model: {model}') # pragma: no cover
7983

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Sequence
4+
from dataclasses import dataclass, field
5+
from typing import Any, overload
6+
7+
from pydantic_ai.embeddings.base import EmbeddingModel
8+
from pydantic_ai.embeddings.settings import EmbeddingSettings
9+
10+
# Optional dependency: sentence-transformers
11+
try: # pragma: no cover - depends on optional install
12+
from sentence_transformers import SentenceTransformer
13+
except ImportError as _import_error: # pragma: no cover - depends on optional install
14+
SentenceTransformer = None
15+
raise ImportError(
16+
'Please install `sentence-transformers` to use the Sentence-Transformers embeddings model, '
17+
'you can use the `sentence-transformers` optional group — '
18+
'pip install "pydantic-ai-slim[sentence-transformers]"'
19+
) from _import_error
20+
21+
22+
class SentenceTransformersEmbeddingSettings(EmbeddingSettings, total=False):
23+
"""Settings used for a Sentence-Transformers embedding model request.
24+
25+
All fields are `sentence_transformers_`-prefixed so settings can be merged across providers safely.
26+
"""
27+
28+
# Device to run inference on, e.g. "cpu", "cuda", "cuda:0", "mps".
29+
sentence_transformers_device: str
30+
31+
# Whether to L2-normalize embeddings. Mirrors `normalize_embeddings` in SentenceTransformer.encode.
32+
sentence_transformers_normalize_embeddings: bool
33+
34+
# Batch size to use during encoding.
35+
sentence_transformers_batch_size: int
36+
37+
38+
@dataclass(init=False)
39+
class SentenceTransformerEmbeddingModel(EmbeddingModel):
40+
"""Local embeddings using `sentence-transformers` models.
41+
42+
Example models include "all-MiniLM-L6-v2" and many others hosted on Hugging Face.
43+
"""
44+
45+
_model_name: str = field(repr=False)
46+
_model: Any = field(repr=False)
47+
48+
def __init__(self, model_name: str, *, settings: EmbeddingSettings | None = None) -> None:
49+
"""Initialize a Sentence-Transformers embedding model.
50+
51+
Args:
52+
model_name: The model name or local path to load with `SentenceTransformer`.
53+
settings: Model-specific settings that will be used as defaults for this model.
54+
"""
55+
self._model_name = model_name
56+
if SentenceTransformer is None: # pragma: no cover - depends on optional install
57+
raise ImportError(
58+
'Please install `sentence-transformers` to use this embeddings model, '
59+
'you can use the `sentence-transformers` optional group — '
60+
'pip install "pydantic-ai-slim[sentence-transformers]"'
61+
)
62+
# Defer device selection to encode() where we can override via settings
63+
self._model = SentenceTransformer(model_name)
64+
65+
super().__init__(settings=settings)
66+
67+
@property
68+
def base_url(self) -> str | None:
69+
"""No base URL — runs locally."""
70+
return None
71+
72+
@property
73+
def model_name(self) -> str:
74+
"""The embedding model name."""
75+
return self._model_name
76+
77+
@property
78+
def system(self) -> str:
79+
"""The embedding model provider/system identifier."""
80+
return 'sentence-transformers'
81+
82+
@overload
83+
async def embed(self, documents: str, *, settings: EmbeddingSettings | None = None) -> list[float]: ...
84+
85+
@overload
86+
async def embed(
87+
self, documents: Sequence[str], *, settings: EmbeddingSettings | None = None
88+
) -> list[list[float]]: ...
89+
90+
async def embed(
91+
self, documents: str | Sequence[str], *, settings: EmbeddingSettings | None = None
92+
) -> list[float] | list[list[float]]:
93+
docs, is_single_document, settings = self.prepare_embed(documents, settings)
94+
embeddings = await self._embed(docs, settings)
95+
return embeddings[0] if is_single_document else embeddings
96+
97+
async def _embed(
98+
self, documents: Sequence[str], settings: SentenceTransformersEmbeddingSettings
99+
) -> list[list[float]]:
100+
device = settings.get('sentence_transformers_device', None)
101+
normalize = settings.get('sentence_transformers_normalize_embeddings', False)
102+
batch_size = settings.get('sentence_transformers_batch_size')
103+
104+
encode_kwargs: dict[str, Any] = {
105+
'show_progress_bar': False,
106+
'convert_to_numpy': True,
107+
'convert_to_tensor': False,
108+
'device': device,
109+
'normalize_embeddings': normalize,
110+
}
111+
if batch_size is not None:
112+
encode_kwargs['batch_size'] = batch_size
113+
114+
np_embeddings = self._model.encode(list(documents), **encode_kwargs)
115+
return np_embeddings.tolist()

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ groq = ["groq>=0.25.0"]
7676
mistral = ["mistralai>=1.9.10"]
7777
bedrock = ["boto3>=1.40.14"]
7878
huggingface = ["huggingface-hub[inference]>=0.33.5"]
79+
sentence-transformers = ["sentence-transformers"]
7980
outlines-transformers = ["outlines[transformers]>=1.0.0, <1.3.0; (sys_platform != 'darwin' or platform_machine != 'x86_64')", "transformers>=4.0.0", "pillow", "torch; (sys_platform != 'darwin' or platform_machine != 'x86_64')"]
8081
outlines-llamacpp = ["outlines[llamacpp]>=1.0.0, <1.3.0"]
8182
outlines-mlxlm = ["outlines[mlxlm]>=1.0.0, <1.3.0; (sys_platform != 'darwin' or platform_machine != 'x86_64')"]

0 commit comments

Comments
 (0)