From 4bd00fb3530a4696f4844d62a2f0e1bd0ece7286 Mon Sep 17 00:00:00 2001 From: Daria Korenieva Date: Fri, 24 Apr 2026 12:59:16 -0700 Subject: [PATCH] Python: Add Mistral AI embedding client package Signed-off-by: Daria Korenieva --- .../workflows/python-integration-tests.yml | 5 +- .github/workflows/python-merge-tests.yml | 6 +- python/.env.example | 3 + python/PACKAGE_STATUS.md | 1 + python/packages/mistral/AGENTS.md | 25 ++ python/packages/mistral/LICENSE | 21 ++ python/packages/mistral/README.md | 42 +++ .../agent_framework_mistral/__init__.py | 17 ++ .../_embedding_client.py | 249 ++++++++++++++++ .../mistral/agent_framework_mistral/py.typed | 1 + python/packages/mistral/pyproject.toml | 105 +++++++ .../mistral/test_mistral_embedding_client.py | 267 ++++++++++++++++++ python/pyproject.toml | 5 + python/uv.lock | 56 +++- 14 files changed, 800 insertions(+), 3 deletions(-) create mode 100644 python/packages/mistral/AGENTS.md create mode 100644 python/packages/mistral/LICENSE create mode 100644 python/packages/mistral/README.md create mode 100644 python/packages/mistral/agent_framework_mistral/__init__.py create mode 100644 python/packages/mistral/agent_framework_mistral/_embedding_client.py create mode 100644 python/packages/mistral/agent_framework_mistral/py.typed create mode 100644 python/packages/mistral/pyproject.toml create mode 100644 python/packages/mistral/tests/mistral/test_mistral_embedding_client.py diff --git a/.github/workflows/python-integration-tests.yml b/.github/workflows/python-integration-tests.yml index f2fb5c6448..eadfecfb20 100644 --- a/.github/workflows/python-integration-tests.yml +++ b/.github/workflows/python-integration-tests.yml @@ -140,6 +140,8 @@ jobs: env: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} ANTHROPIC_CHAT_MODEL: ${{ vars.ANTHROPIC_CHAT_MODEL_ID }} + MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} + MISTRAL_EMBEDDING_MODEL: ${{ vars.MISTRAL_EMBEDDING_MODEL_ID }} LOCAL_MCP_URL: ${{ vars.LOCAL_MCP__URL }} defaults: run: @@ -162,11 +164,12 @@ jobs: fallback_url: ${{ env.LOCAL_MCP_URL }} - name: Prefer local MCP URL when available run: echo "LOCAL_MCP_URL=${{ steps.local-mcp.outputs.effective_url }}" >> "$GITHUB_ENV" - - name: Test with pytest (Anthropic, Hyperlight, Ollama, MCP integration) + - name: Test with pytest (Anthropic, Hyperlight, Mistral, Ollama, MCP integration) run: > uv run pytest --import-mode=importlib packages/anthropic/tests packages/hyperlight/tests + packages/mistral/tests packages/ollama/tests packages/core/tests/core/test_mcp.py -m integration diff --git a/.github/workflows/python-merge-tests.yml b/.github/workflows/python-merge-tests.yml index dd48b268df..6a82ec88fb 100644 --- a/.github/workflows/python-merge-tests.yml +++ b/.github/workflows/python-merge-tests.yml @@ -66,6 +66,7 @@ jobs: misc: - 'python/packages/anthropic/**' - 'python/packages/hyperlight/**' + - 'python/packages/mistral/**' - 'python/packages/ollama/**' - 'python/packages/core/agent_framework/_mcp.py' - 'python/packages/core/tests/core/test_mcp.py' @@ -260,6 +261,8 @@ jobs: env: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} ANTHROPIC_CHAT_MODEL: ${{ vars.ANTHROPIC_CHAT_MODEL_ID }} + MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} + MISTRAL_EMBEDDING_MODEL: ${{ vars.MISTRAL_EMBEDDING_MODEL_ID }} LOCAL_MCP_URL: ${{ vars.LOCAL_MCP__URL }} defaults: run: @@ -279,11 +282,12 @@ jobs: fallback_url: ${{ env.LOCAL_MCP_URL }} - name: Prefer local MCP URL when available run: echo "LOCAL_MCP_URL=${{ steps.local-mcp.outputs.effective_url }}" >> "$GITHUB_ENV" - - name: Test with pytest (Anthropic, Hyperlight, Ollama, MCP integration) + - name: Test with pytest (Anthropic, Hyperlight, Mistral, Ollama, MCP integration) run: > uv run pytest --import-mode=importlib packages/anthropic/tests packages/hyperlight/tests + packages/mistral/tests packages/ollama/tests packages/core/tests/core/test_mcp.py -m integration diff --git a/python/.env.example b/python/.env.example index bff78961aa..51918f6694 100644 --- a/python/.env.example +++ b/python/.env.example @@ -44,6 +44,9 @@ GEMINI_MODEL="" # Ollama OLLAMA_ENDPOINT="" OLLAMA_MODEL="" +# Mistral AI +MISTRAL_API_KEY="" +MISTRAL_EMBEDDING_MODEL="" # Observability ENABLE_INSTRUMENTATION=true ENABLE_SENSITIVE_DATA=true diff --git a/python/PACKAGE_STATUS.md b/python/PACKAGE_STATUS.md index 661cebe53a..5093059695 100644 --- a/python/PACKAGE_STATUS.md +++ b/python/PACKAGE_STATUS.md @@ -36,6 +36,7 @@ Status is grouped into these buckets: | `agent-framework-hyperlight` | `python/packages/hyperlight` | `alpha` | | `agent-framework-lab` | `python/packages/lab` | `beta` | | `agent-framework-mem0` | `python/packages/mem0` | `beta` | +| `agent-framework-mistral` | `python/packages/mistral` | `alpha` | | `agent-framework-ollama` | `python/packages/ollama` | `beta` | | `agent-framework-openai` | `python/packages/openai` | `released` | | `agent-framework-orchestrations` | `python/packages/orchestrations` | `beta` | diff --git a/python/packages/mistral/AGENTS.md b/python/packages/mistral/AGENTS.md new file mode 100644 index 0000000000..868c981a80 --- /dev/null +++ b/python/packages/mistral/AGENTS.md @@ -0,0 +1,25 @@ +# Mistral Package (agent-framework-mistral) + +Integration with Mistral AI for embedding generation. + +## Main Classes + +- **`MistralEmbeddingClient`** - Embedding client for Mistral AI models +- **`MistralEmbeddingOptions`** - Options TypedDict for Mistral-specific embedding parameters +- **`MistralEmbeddingSettings`** - TypedDict settings for Mistral configuration + +## Usage + +```python +from agent_framework_mistral import MistralEmbeddingClient + +client = MistralEmbeddingClient(model="mistral-embed") +result = await client.get_embeddings(["Hello, world!"]) +print(result[0].vector) +``` + +## Import Path + +```python +from agent_framework_mistral import MistralEmbeddingClient +``` diff --git a/python/packages/mistral/LICENSE b/python/packages/mistral/LICENSE new file mode 100644 index 0000000000..9e841e7a26 --- /dev/null +++ b/python/packages/mistral/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/python/packages/mistral/README.md b/python/packages/mistral/README.md new file mode 100644 index 0000000000..a3b3d567fa --- /dev/null +++ b/python/packages/mistral/README.md @@ -0,0 +1,42 @@ +# Get Started with Microsoft Agent Framework Mistral AI + +Please install this package as the extra for `agent-framework`: + +```bash +pip install agent-framework-mistral --pre +``` + +and see the [README](https://github.com/microsoft/agent-framework/tree/main/python/README.md) for more information. + +## Embedding Client + +The `MistralEmbeddingClient` provides embedding generation using Mistral AI models. + +### Quick Start + +```python +from agent_framework_mistral import MistralEmbeddingClient + +# Using environment variables (MISTRAL_API_KEY, MISTRAL_EMBEDDING_MODEL) +client = MistralEmbeddingClient() + +# Or passing parameters directly +client = MistralEmbeddingClient( + model="mistral-embed", + api_key="your-api-key", +) + +# Generate embeddings +result = await client.get_embeddings(["Hello, world!", "How are you?"]) +for embedding in result: + print(f"Dimensions: {embedding.dimensions}") + print(f"Vector: {embedding.vector[:5]}...") +``` + +### Configuration + +| Environment Variable | Description | +|---|---| +| `MISTRAL_API_KEY` | Your Mistral AI API key | +| `MISTRAL_EMBEDDING_MODEL` | Embedding model name (e.g., `mistral-embed`) | +| `MISTRAL_SERVER_URL` | Optional server URL override | diff --git a/python/packages/mistral/agent_framework_mistral/__init__.py b/python/packages/mistral/agent_framework_mistral/__init__.py new file mode 100644 index 0000000000..58d4677a82 --- /dev/null +++ b/python/packages/mistral/agent_framework_mistral/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft. All rights reserved. + +import importlib.metadata + +from ._embedding_client import MistralEmbeddingClient, MistralEmbeddingOptions, MistralEmbeddingSettings + +try: + __version__ = importlib.metadata.version(__name__) +except importlib.metadata.PackageNotFoundError: + __version__ = "0.0.0" # Fallback for development mode + +__all__ = [ + "MistralEmbeddingClient", + "MistralEmbeddingOptions", + "MistralEmbeddingSettings", + "__version__", +] diff --git a/python/packages/mistral/agent_framework_mistral/_embedding_client.py b/python/packages/mistral/agent_framework_mistral/_embedding_client.py new file mode 100644 index 0000000000..74e175ee2d --- /dev/null +++ b/python/packages/mistral/agent_framework_mistral/_embedding_client.py @@ -0,0 +1,249 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import logging +import sys +from collections.abc import Sequence +from typing import Any, ClassVar, Generic, TypedDict + +from agent_framework import ( + BaseEmbeddingClient, + Embedding, + EmbeddingGenerationOptions, + GeneratedEmbeddings, + UsageDetails, + load_settings, +) +from agent_framework._settings import SecretString +from agent_framework.observability import EmbeddingTelemetryLayer +from mistralai.client import Mistral + +if sys.version_info >= (3, 13): + from typing import TypeVar # type: ignore # pragma: no cover +else: + from typing_extensions import TypeVar # type: ignore # pragma: no cover + + +logger = logging.getLogger("agent_framework.mistral") + + +class MistralEmbeddingOptions(EmbeddingGenerationOptions, total=False): + """Mistral AI-specific embedding options. + + Extends EmbeddingGenerationOptions with Mistral-specific fields. + + Examples: + .. code-block:: python + + from agent_framework_mistral import MistralEmbeddingOptions + + options: MistralEmbeddingOptions = { + "model": "mistral-embed", + "dimensions": 1024, + } + """ + + +MistralEmbeddingOptionsT = TypeVar( + "MistralEmbeddingOptionsT", + bound=TypedDict, # type: ignore[valid-type] + default="MistralEmbeddingOptions", + covariant=True, +) + + +class MistralEmbeddingSettings(TypedDict, total=False): + """Mistral AI embedding settings. + + Fields: + api_key: Mistral API key. Resolved from ``MISTRAL_API_KEY``. + embedding_model: Embedding model name. Resolved from ``MISTRAL_EMBEDDING_MODEL``. + server_url: Optional server URL override. Resolved from ``MISTRAL_SERVER_URL``. + """ + + api_key: str | None + embedding_model: str | None + server_url: str | None + + +class RawMistralEmbeddingClient( + BaseEmbeddingClient[str, list[float], MistralEmbeddingOptionsT], + Generic[MistralEmbeddingOptionsT], +): + """Raw Mistral AI embedding client without telemetry. + + Keyword Args: + model: The Mistral embedding model (e.g. "mistral-embed"). + Can also be set via environment variable ``MISTRAL_EMBEDDING_MODEL``. + api_key: Mistral API key. Defaults to ``MISTRAL_API_KEY`` environment variable. + server_url: Optional server URL override. Defaults to ``MISTRAL_SERVER_URL`` + environment variable, or the Mistral default. + client: Optional pre-configured ``Mistral`` client instance. + additional_properties: Additional properties stored on the client instance. + env_file_path: Path to ``.env`` file for settings. + env_file_encoding: Encoding for ``.env`` file. + """ + + INJECTABLE: ClassVar[set[str]] = {"client"} + + def __init__( + self, + *, + model: str | None = None, + api_key: str | SecretString | None = None, + server_url: str | None = None, + client: Mistral | None = None, + additional_properties: dict[str, Any] | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initialize a raw Mistral AI embedding client.""" + mistral_settings = load_settings( + MistralEmbeddingSettings, + env_prefix="MISTRAL_", + required_fields=["embedding_model", "api_key"], + api_key=str(api_key) if isinstance(api_key, SecretString) else api_key, + embedding_model=model, + server_url=server_url, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + + self.model: str = mistral_settings["embedding_model"] # type: ignore[assignment] + resolved_api_key: str = mistral_settings["api_key"] # type: ignore[assignment] + resolved_server_url = mistral_settings.get("server_url") + + if client is not None: + self.client = client + else: + client_kwargs: dict[str, Any] = {"api_key": resolved_api_key} + if resolved_server_url: + client_kwargs["server_url"] = resolved_server_url + self.client = Mistral(**client_kwargs) + + self.server_url = resolved_server_url + super().__init__(additional_properties=additional_properties) + + def service_url(self) -> str: + """Get the URL of the service.""" + return self.server_url or "https://api.mistral.ai" + + async def get_embeddings( + self, + values: Sequence[str], + *, + options: MistralEmbeddingOptionsT | None = None, + ) -> GeneratedEmbeddings[list[float], MistralEmbeddingOptionsT]: + """Call the Mistral AI embeddings API. + + Args: + values: The text values to generate embeddings for. + options: Optional embedding generation options. + + Returns: + Generated embeddings with usage metadata. + + Raises: + ValueError: If model is not provided or values is empty. + """ + if not values: + return GeneratedEmbeddings([], options=options) + + opts: dict[str, Any] = options or {} # type: ignore + model = opts.get("model") or self.model + if not model: + raise ValueError("model is required") + + kwargs: dict[str, Any] = {"model": model, "inputs": list(values)} + if dimensions := opts.get("dimensions"): + kwargs["output_dimension"] = dimensions + + response = await self.client.embeddings.create_async(**kwargs) + + embeddings: list[Embedding[list[float]]] = [] + if response and response.data: + for item in response.data: + vector = list(item.embedding) if item.embedding else [] + embeddings.append( + Embedding( + vector=vector, + dimensions=len(vector), + model=response.model or model, + ) + ) + + usage_dict: UsageDetails | None = None + if response and response.usage: + usage_dict = { + "input_token_count": response.usage.prompt_tokens, + "total_token_count": response.usage.total_tokens, + } + + return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict) + + +class MistralEmbeddingClient( + EmbeddingTelemetryLayer[str, list[float], MistralEmbeddingOptionsT], + RawMistralEmbeddingClient[MistralEmbeddingOptionsT], + Generic[MistralEmbeddingOptionsT], +): + """Mistral AI embedding client with telemetry support. + + Keyword Args: + model: The Mistral embedding model (e.g. "mistral-embed"). + Can also be set via environment variable ``MISTRAL_EMBEDDING_MODEL``. + api_key: Mistral API key. Defaults to ``MISTRAL_API_KEY`` environment variable. + server_url: Optional server URL override. Defaults to ``MISTRAL_SERVER_URL`` + environment variable, or the Mistral default. + client: Optional pre-configured ``Mistral`` client instance. + otel_provider_name: Optional telemetry provider name override. + env_file_path: Path to ``.env`` file for settings. + env_file_encoding: Encoding for ``.env`` file. + + Examples: + .. code-block:: python + + from agent_framework_mistral import MistralEmbeddingClient + + # Using environment variables + # Set MISTRAL_API_KEY=your-key + # Set MISTRAL_EMBEDDING_MODEL=mistral-embed + client = MistralEmbeddingClient() + + # Or passing parameters directly + client = MistralEmbeddingClient( + model="mistral-embed", + api_key="your-api-key", + ) + + # Generate embeddings + result = await client.get_embeddings(["Hello, world!"]) + print(result[0].vector) + """ + + OTEL_PROVIDER_NAME: ClassVar[str] = "mistralai" + + def __init__( + self, + *, + model: str | None = None, + api_key: str | SecretString | None = None, + server_url: str | None = None, + client: Mistral | None = None, + otel_provider_name: str | None = None, + additional_properties: dict[str, Any] | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initialize a Mistral AI embedding client.""" + super().__init__( + model=model, + api_key=api_key, + server_url=server_url, + client=client, + additional_properties=additional_properties, + otel_provider_name=otel_provider_name, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) diff --git a/python/packages/mistral/agent_framework_mistral/py.typed b/python/packages/mistral/agent_framework_mistral/py.typed new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/python/packages/mistral/agent_framework_mistral/py.typed @@ -0,0 +1 @@ + diff --git a/python/packages/mistral/pyproject.toml b/python/packages/mistral/pyproject.toml new file mode 100644 index 0000000000..59c3b7fc16 --- /dev/null +++ b/python/packages/mistral/pyproject.toml @@ -0,0 +1,105 @@ +[project] +name = "agent-framework-mistral" +description = "Mistral AI integration for Microsoft Agent Framework." +authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] +readme = "README.md" +requires-python = ">=3.10" +version = "1.0.0b260421" +license-files = ["LICENSE"] +urls.homepage = "https://learn.microsoft.com/en-us/agent-framework/" +urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" +urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true" +urls.issues = "https://github.com/microsoft/agent-framework/issues" +classifiers = [ + "License :: OSI Approved :: MIT License", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Framework :: Pydantic :: 2", + "Typing :: Typed", +] +dependencies = [ + "agent-framework-core>=1.1.0,<2", + "mistralai>=2.0.0,<3", +] + +[tool.uv] +prerelease = "if-necessary-or-explicit" +environments = [ + "sys_platform == 'darwin'", + "sys_platform == 'linux'", + "sys_platform == 'win32'" +] + +[tool.uv-dynamic-versioning] +fallback-version = "0.0.0" + +[tool.pytest.ini_options] +testpaths = 'tests' +addopts = "-ra -q -r fEX" +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +filterwarnings = [] +markers = [ + "integration: marks tests as integration tests that require external services", +] +timeout = 120 + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "W"] + +[tool.coverage.run] +omit = [ + "**/__init__.py" +] + +[tool.pyright] +extends = "../../pyproject.toml" +include = ["agent_framework_mistral"] +exclude = ['tests'] + +[tool.mypy] +plugins = ['pydantic.mypy'] +strict = true +python_version = "3.10" +ignore_missing_imports = true +disallow_untyped_defs = true +no_implicit_optional = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false +disallow_incomplete_defs = true +disallow_untyped_decorators = true +disallow_any_unimported = true + +[tool.bandit] +targets = ["agent_framework_mistral"] +exclude_dirs = ["tests"] + +[tool.poe] +executor.type = "uv" +include = "../../shared_tasks.toml" + +[tool.poe.tasks.mypy] +help = "Run MyPy for this package." +cmd = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_mistral" + +[tool.poe.tasks.test] +help = "Run the default unit test suite for this package." +cmd = 'pytest -m "not integration" --cov=agent_framework_mistral --cov-report=term-missing:skip-covered tests' + +[tool.uv.build-backend] +module-name = "agent_framework_mistral" +module-root = "" + +[build-system] +requires = ["uv_build>=0.8.2,<0.9.0"] +build-backend = "uv_build" diff --git a/python/packages/mistral/tests/mistral/test_mistral_embedding_client.py b/python/packages/mistral/tests/mistral/test_mistral_embedding_client.py new file mode 100644 index 0000000000..cecd03b3e2 --- /dev/null +++ b/python/packages/mistral/tests/mistral/test_mistral_embedding_client.py @@ -0,0 +1,267 @@ +# Copyright (c) Microsoft. All rights reserved. + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from agent_framework import Embedding, GeneratedEmbeddings + +from agent_framework_mistral import MistralEmbeddingClient, MistralEmbeddingOptions + +# region: Unit Tests + + +def test_mistral_embedding_construction(monkeypatch: pytest.MonkeyPatch) -> None: + """Test construction with environment variables.""" + monkeypatch.setenv("MISTRAL_EMBEDDING_MODEL", "mistral-embed") + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls: + mock_cls.return_value = MagicMock() + client = MistralEmbeddingClient() + assert client.model == "mistral-embed" + + +def test_mistral_embedding_construction_with_params() -> None: + """Test construction with explicit parameters.""" + with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls: + mock_cls.return_value = MagicMock() + client = MistralEmbeddingClient( + model="mistral-embed", + api_key="test-key", + ) + assert client.model == "mistral-embed" + mock_cls.assert_called_once_with(api_key="test-key") + + +def test_mistral_embedding_construction_with_server_url() -> None: + """Test construction with custom server URL.""" + with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls: + mock_cls.return_value = MagicMock() + client = MistralEmbeddingClient( + model="mistral-embed", + api_key="test-key", + server_url="https://custom.mistral.ai", + ) + assert client.model == "mistral-embed" + assert client.server_url == "https://custom.mistral.ai" + mock_cls.assert_called_once_with( + api_key="test-key", + server_url="https://custom.mistral.ai", + ) + + +def test_mistral_embedding_construction_with_client() -> None: + """Test construction with a pre-configured client.""" + mock_client = MagicMock() + with patch("agent_framework_mistral._embedding_client.Mistral"): + client = MistralEmbeddingClient( + model="mistral-embed", + api_key="test-key", + client=mock_client, + ) + assert client.client is mock_client + + +def test_mistral_embedding_construction_missing_model_raises(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that missing model raises an error.""" + monkeypatch.delenv("MISTRAL_EMBEDDING_MODEL", raising=False) + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + from agent_framework.exceptions import SettingNotFoundError + + with pytest.raises(SettingNotFoundError): + MistralEmbeddingClient() + + +def test_mistral_embedding_construction_missing_api_key_raises(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that missing API key raises an error.""" + monkeypatch.delenv("MISTRAL_API_KEY", raising=False) + monkeypatch.setenv("MISTRAL_EMBEDDING_MODEL", "mistral-embed") + from agent_framework.exceptions import SettingNotFoundError + + with pytest.raises(SettingNotFoundError): + MistralEmbeddingClient() + + +def test_mistral_embedding_service_url() -> None: + """Test service_url returns the correct URL.""" + with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls: + mock_cls.return_value = MagicMock() + client = MistralEmbeddingClient( + model="mistral-embed", + api_key="test-key", + ) + assert client.service_url() == "https://api.mistral.ai" + + +def test_mistral_embedding_service_url_custom() -> None: + """Test service_url returns custom URL when set.""" + with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls: + mock_cls.return_value = MagicMock() + client = MistralEmbeddingClient( + model="mistral-embed", + api_key="test-key", + server_url="https://custom.mistral.ai", + ) + assert client.service_url() == "https://custom.mistral.ai" + + +async def test_mistral_embedding_get_embeddings() -> None: + """Test generating embeddings via the Mistral API.""" + mock_response = MagicMock() + mock_response.data = [ + MagicMock(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"), + MagicMock(embedding=[0.4, 0.5, 0.6], index=1, object="embedding"), + ] + mock_response.model = "mistral-embed" + mock_response.usage = MagicMock(prompt_tokens=10, total_tokens=10) + + with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls: + mock_client = MagicMock() + mock_client.embeddings = MagicMock() + mock_client.embeddings.create_async = AsyncMock(return_value=mock_response) + mock_cls.return_value = mock_client + + client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key") + result = await client.get_embeddings(["hello", "world"]) + + assert isinstance(result, GeneratedEmbeddings) + assert len(result) == 2 + assert result[0].vector == [0.1, 0.2, 0.3] + assert result[1].vector == [0.4, 0.5, 0.6] + assert result[0].model == "mistral-embed" + assert result.usage == {"input_token_count": 10, "total_token_count": 10} + + mock_client.embeddings.create_async.assert_called_once_with( + model="mistral-embed", + inputs=["hello", "world"], + ) + + +async def test_mistral_embedding_get_embeddings_empty_input() -> None: + """Test generating embeddings with empty input.""" + with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls: + mock_client = MagicMock() + mock_cls.return_value = mock_client + + client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key") + result = await client.get_embeddings([]) + + assert isinstance(result, GeneratedEmbeddings) + assert len(result) == 0 + + +async def test_mistral_embedding_get_embeddings_with_dimensions() -> None: + """Test generating embeddings with custom dimensions option.""" + mock_response = MagicMock() + mock_response.data = [ + MagicMock(embedding=[0.1, 0.2], index=0, object="embedding"), + ] + mock_response.model = "mistral-embed" + mock_response.usage = MagicMock(prompt_tokens=5, total_tokens=5) + + with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls: + mock_client = MagicMock() + mock_client.embeddings = MagicMock() + mock_client.embeddings.create_async = AsyncMock(return_value=mock_response) + mock_cls.return_value = mock_client + + client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key") + options: MistralEmbeddingOptions = {"dimensions": 512} + result = await client.get_embeddings(["hello"], options=options) + + assert len(result) == 1 + mock_client.embeddings.create_async.assert_called_once_with( + model="mistral-embed", + inputs=["hello"], + output_dimension=512, + ) + + +async def test_mistral_embedding_get_embeddings_no_model_raises() -> None: + """Test that missing model at call time raises ValueError.""" + with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls: + mock_client = MagicMock() + mock_cls.return_value = mock_client + + client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key") + client.model = None # type: ignore[assignment] + + with pytest.raises(ValueError, match="model is required"): + await client.get_embeddings(["hello"]) + + +async def test_mistral_embedding_get_embeddings_model_override() -> None: + """Test that model can be overridden via options.""" + mock_response = MagicMock() + mock_response.data = [ + MagicMock(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"), + ] + mock_response.model = "custom-embed" + mock_response.usage = MagicMock(prompt_tokens=5, total_tokens=5) + + with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls: + mock_client = MagicMock() + mock_client.embeddings = MagicMock() + mock_client.embeddings.create_async = AsyncMock(return_value=mock_response) + mock_cls.return_value = mock_client + + client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key") + options: MistralEmbeddingOptions = {"model": "custom-embed"} + result = await client.get_embeddings(["hello"], options=options) + + assert len(result) == 1 + assert result[0].model == "custom-embed" + mock_client.embeddings.create_async.assert_called_once_with( + model="custom-embed", + inputs=["hello"], + ) + + +async def test_mistral_embedding_get_embeddings_no_usage() -> None: + """Test handling response without usage information.""" + mock_response = MagicMock() + mock_response.data = [ + MagicMock(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"), + ] + mock_response.model = "mistral-embed" + mock_response.usage = None + + with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls: + mock_client = MagicMock() + mock_client.embeddings = MagicMock() + mock_client.embeddings.create_async = AsyncMock(return_value=mock_response) + mock_cls.return_value = mock_client + + client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key") + result = await client.get_embeddings(["hello"]) + + assert len(result) == 1 + assert result.usage is None + + +# region: Integration Tests + +skip_if_mistral_embedding_integration_tests_disabled = pytest.mark.skipif( + os.getenv("MISTRAL_EMBEDDING_MODEL", "") in ("", "test-model") or os.getenv("MISTRAL_API_KEY", "") == "", + reason="No real Mistral embedding model or API key provided; skipping integration tests.", +) + + +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_mistral_embedding_integration_tests_disabled +async def test_mistral_embedding_integration() -> None: + """Integration test for Mistral AI embedding client.""" + client = MistralEmbeddingClient() + result = await client.get_embeddings(["Hello, world!", "How are you?"]) + + assert isinstance(result, GeneratedEmbeddings) + assert len(result) == 2 + for embedding in result: + assert isinstance(embedding, Embedding) + assert isinstance(embedding.vector, list) + assert len(embedding.vector) > 0 + assert all(isinstance(v, float) for v in embedding.vector) + assert result.usage is not None + assert result.usage["input_token_count"] is not None + assert result.usage["input_token_count"] > 0 diff --git a/python/pyproject.toml b/python/pyproject.toml index 0d5dac2a0b..8b98cf695f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -54,6 +54,9 @@ package = false prerelease = "if-necessary-or-explicit" # Keep transitive litellm below the compromised 1.82.7/1.82.8 releases. constraint-dependencies = ["litellm<1.82.7"] +# Allow opentelemetry-semantic-conventions 0.61b0 for mistralai compatibility +# (mistralai pins <0.61 but 0.61b0 is compatible at runtime). +override-dependencies = ["opentelemetry-semantic-conventions>=0.60b1"] environments = [ "sys_platform == 'darwin'", "sys_platform == 'linux'", @@ -87,6 +90,7 @@ agent-framework-github-copilot = { workspace = true } agent-framework-hyperlight = { workspace = true } agent-framework-lab = { workspace = true } agent-framework-mem0 = { workspace = true } +agent-framework-mistral = { workspace = true } agent-framework-ollama = { workspace = true } agent-framework-openai = { workspace = true } agent-framework-orchestrations = { workspace = true } @@ -207,6 +211,7 @@ executionEnvironments = [ { root = "packages/lab/lightning/tests", reportPrivateUsage = "none" }, { root = "packages/lab/tau2/tests", reportPrivateUsage = "none" }, { root = "packages/mem0/tests", reportPrivateUsage = "none" }, + { root = "packages/mistral/tests", reportPrivateUsage = "none" }, { root = "packages/ollama/tests", reportPrivateUsage = "none" }, { root = "packages/orchestrations/tests", reportPrivateUsage = "none" }, { root = "packages/purview/tests", reportPrivateUsage = "none" }, diff --git a/python/uv.lock b/python/uv.lock index 73c18a6375..7701b10892 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -49,6 +49,7 @@ members = [ "agent-framework-hyperlight", "agent-framework-lab", "agent-framework-mem0", + "agent-framework-mistral", "agent-framework-ollama", "agent-framework-openai", "agent-framework-orchestrations", @@ -56,6 +57,7 @@ members = [ "agent-framework-redis", ] constraints = [{ name = "litellm", url = "https://files.pythonhosted.org/packages/57/77/0c6eca2cb049793ddf8ce9cdcd5123a35666c4962514788c4fc90edf1d3b/litellm-1.82.1-py3-none-any.whl" }] +overrides = [{ name = "opentelemetry-semantic-conventions", specifier = ">=0.60b1" }] [[package]] name = "a2a-sdk" @@ -565,7 +567,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "agent-framework-core", editable = "packages/core" }, - { name = "github-copilot-sdk", marker = "python_full_version >= '3.11'", specifier = ">=0.2.1,<=0.2.1" }, + { name = "github-copilot-sdk", marker = "python_full_version >= '3.11'", specifier = "<=0.2.1,>=0.2.1" }, ] [[package]] @@ -683,6 +685,21 @@ requires-dist = [ { name = "mem0ai", specifier = ">=1.0.0,<2" }, ] +[[package]] +name = "agent-framework-mistral" +version = "1.0.0b260421" +source = { editable = "packages/mistral" } +dependencies = [ + { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "mistralai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] + +[package.metadata] +requires-dist = [ + { name = "agent-framework-core", editable = "packages/core" }, + { name = "mistralai", specifier = ">=2.0.0,<3" }, +] + [[package]] name = "agent-framework-ollama" version = "1.0.0b260421" @@ -2063,6 +2080,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" }, ] +[[package]] +name = "eval-type-backport" +version = "0.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fb/a3/cafafb4558fd638aadfe4121dc6cefb8d743368c085acb2f521df0f3d9d7/eval_type_backport-0.3.1.tar.gz", hash = "sha256:57e993f7b5b69d271e37482e62f74e76a0276c82490cf8e4f0dffeb6b332d5ed", size = 9445, upload-time = "2025-12-02T11:51:42.987Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/22/fdc2e30d43ff853720042fa15baa3e6122722be1a7950a98233ebb55cd71/eval_type_backport-0.3.1-py3-none-any.whl", hash = "sha256:279ab641905e9f11129f56a8a78f493518515b83402b860f6f06dd7c011fdfa8", size = 6063, upload-time = "2025-12-02T11:51:41.665Z" }, +] + [[package]] name = "exceptiongroup" version = "1.3.1" @@ -3107,6 +3133,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/03/99/33c7d78a3fb70d545fd5411ac67a651c81602cc09c9cf0df383733f068c5/jsonpath_ng-1.8.0-py3-none-any.whl", hash = "sha256:b8dde192f8af58d646fc031fac9c99fe4d00326afc4148f1f043c601a8cfe138", size = 67844, upload-time = "2026-02-28T00:53:19.637Z" }, ] +[[package]] +name = "jsonpath-python" +version = "1.1.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/db/2f4ecc24da35c6142b39c353d5b7c16eef955cc94b35a48d3fa47996d7c3/jsonpath_python-1.1.5.tar.gz", hash = "sha256:ceea2efd9e56add09330a2c9631ea3d55297b9619348c1055e5bfb9cb0b8c538", size = 87352, upload-time = "2026-03-17T06:16:40.597Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/28/50/1a313fb700526b134c71eb8a225d8b83be0385dbb0204337b4379c698cef/jsonpath_python-1.1.5-py3-none-any.whl", hash = "sha256:a60315404d70a65e76c9a782c84e50600480221d94a58af47b7b4d437351cb4b", size = 14090, upload-time = "2026-03-17T06:16:39.152Z" }, +] + [[package]] name = "jsonschema" version = "4.26.0" @@ -3769,6 +3804,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f0/1b/543ddaa2daf8593911a02a07a6a78366d4a6a0053ec86a557c19fa97b60e/microsoft_agents_hosting_core-0.3.1-py3-none-any.whl", hash = "sha256:a4b41556b15321b74f539c5a0a89f70955459b7ec57e9e4b24e61bba27f1cbbc", size = 94573, upload-time = "2025-09-09T23:19:53.855Z" }, ] +[[package]] +name = "mistralai" +version = "2.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "eval-type-backport", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "jsonpath-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "opentelemetry-api", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "opentelemetry-semantic-conventions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "typing-inspection", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5f/f9/0e67c37f305b127d4df005f4a216d3570c4cf3330aa775415bde90d2f662/mistralai-2.4.2.tar.gz", hash = "sha256:7896ffa763e0be1ec05e5b436d2c21ae089e4b5438cda9033dcd1b25bc3021a2", size = 416708, upload-time = "2026-04-23T15:11:00.809Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/97/3e/d6df4c222e1845d16c75c5d12981b15fc4f6413a968cd7647691bf8d03d3/mistralai-2.4.2-py3-none-any.whl", hash = "sha256:caf57734078b5a6f2777157a8cd5ffe6a7d530078755f18b9884092d918299f4", size = 980037, upload-time = "2026-04-23T15:10:59.223Z" }, +] + [[package]] name = "ml-dtypes" version = "0.5.4"