|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from collections.abc import Sequence |
| 4 | +from dataclasses import dataclass, field |
| 5 | +from typing import Any, overload |
| 6 | + |
| 7 | +from pydantic_ai.embeddings.base import EmbeddingModel |
| 8 | +from pydantic_ai.embeddings.settings import EmbeddingSettings |
| 9 | + |
| 10 | +# Optional dependency: sentence-transformers |
| 11 | +try: # pragma: no cover - depends on optional install |
| 12 | + from sentence_transformers import SentenceTransformer |
| 13 | +except ImportError as _import_error: # pragma: no cover - depends on optional install |
| 14 | + SentenceTransformer = None |
| 15 | + raise ImportError( |
| 16 | + 'Please install `sentence-transformers` to use the Sentence-Transformers embeddings model, ' |
| 17 | + 'you can use the `sentence-transformers` optional group — ' |
| 18 | + 'pip install "pydantic-ai-slim[sentence-transformers]"' |
| 19 | + ) from _import_error |
| 20 | + |
| 21 | + |
| 22 | +class SentenceTransformersEmbeddingSettings(EmbeddingSettings, total=False): |
| 23 | + """Settings used for a Sentence-Transformers embedding model request. |
| 24 | +
|
| 25 | + All fields are `sentence_transformers_`-prefixed so settings can be merged across providers safely. |
| 26 | + """ |
| 27 | + |
| 28 | + # Device to run inference on, e.g. "cpu", "cuda", "cuda:0", "mps". |
| 29 | + sentence_transformers_device: str |
| 30 | + |
| 31 | + # Whether to L2-normalize embeddings. Mirrors `normalize_embeddings` in SentenceTransformer.encode. |
| 32 | + sentence_transformers_normalize_embeddings: bool |
| 33 | + |
| 34 | + # Batch size to use during encoding. |
| 35 | + sentence_transformers_batch_size: int |
| 36 | + |
| 37 | + |
| 38 | +@dataclass(init=False) |
| 39 | +class SentenceTransformerEmbeddingModel(EmbeddingModel): |
| 40 | + """Local embeddings using `sentence-transformers` models. |
| 41 | +
|
| 42 | + Example models include "all-MiniLM-L6-v2" and many others hosted on Hugging Face. |
| 43 | + """ |
| 44 | + |
| 45 | + _model_name: str = field(repr=False) |
| 46 | + _model: Any = field(repr=False) |
| 47 | + |
| 48 | + def __init__(self, model_name: str, *, settings: EmbeddingSettings | None = None) -> None: |
| 49 | + """Initialize a Sentence-Transformers embedding model. |
| 50 | +
|
| 51 | + Args: |
| 52 | + model_name: The model name or local path to load with `SentenceTransformer`. |
| 53 | + settings: Model-specific settings that will be used as defaults for this model. |
| 54 | + """ |
| 55 | + self._model_name = model_name |
| 56 | + if SentenceTransformer is None: # pragma: no cover - depends on optional install |
| 57 | + raise ImportError( |
| 58 | + 'Please install `sentence-transformers` to use this embeddings model, ' |
| 59 | + 'you can use the `sentence-transformers` optional group — ' |
| 60 | + 'pip install "pydantic-ai-slim[sentence-transformers]"' |
| 61 | + ) |
| 62 | + # Defer device selection to encode() where we can override via settings |
| 63 | + self._model = SentenceTransformer(model_name) |
| 64 | + |
| 65 | + super().__init__(settings=settings) |
| 66 | + |
| 67 | + @property |
| 68 | + def base_url(self) -> str | None: |
| 69 | + """No base URL — runs locally.""" |
| 70 | + return None |
| 71 | + |
| 72 | + @property |
| 73 | + def model_name(self) -> str: |
| 74 | + """The embedding model name.""" |
| 75 | + return self._model_name |
| 76 | + |
| 77 | + @property |
| 78 | + def system(self) -> str: |
| 79 | + """The embedding model provider/system identifier.""" |
| 80 | + return 'sentence-transformers' |
| 81 | + |
| 82 | + @overload |
| 83 | + async def embed(self, documents: str, *, settings: EmbeddingSettings | None = None) -> list[float]: ... |
| 84 | + |
| 85 | + @overload |
| 86 | + async def embed( |
| 87 | + self, documents: Sequence[str], *, settings: EmbeddingSettings | None = None |
| 88 | + ) -> list[list[float]]: ... |
| 89 | + |
| 90 | + async def embed( |
| 91 | + self, documents: str | Sequence[str], *, settings: EmbeddingSettings | None = None |
| 92 | + ) -> list[float] | list[list[float]]: |
| 93 | + docs, is_single_document, settings = self.prepare_embed(documents, settings) |
| 94 | + embeddings = await self._embed(docs, settings) |
| 95 | + return embeddings[0] if is_single_document else embeddings |
| 96 | + |
| 97 | + async def _embed( |
| 98 | + self, documents: Sequence[str], settings: SentenceTransformersEmbeddingSettings |
| 99 | + ) -> list[list[float]]: |
| 100 | + device = settings.get('sentence_transformers_device', None) |
| 101 | + normalize = settings.get('sentence_transformers_normalize_embeddings', False) |
| 102 | + batch_size = settings.get('sentence_transformers_batch_size') |
| 103 | + |
| 104 | + encode_kwargs: dict[str, Any] = { |
| 105 | + 'show_progress_bar': False, |
| 106 | + 'convert_to_numpy': True, |
| 107 | + 'convert_to_tensor': False, |
| 108 | + 'device': device, |
| 109 | + 'normalize_embeddings': normalize, |
| 110 | + } |
| 111 | + if batch_size is not None: |
| 112 | + encode_kwargs['batch_size'] = batch_size |
| 113 | + |
| 114 | + np_embeddings = self._model.encode(list(documents), **encode_kwargs) |
| 115 | + return np_embeddings.tolist() |
0 commit comments