diff --git a/python/packages/azure-cosmos-memory/AGENTS.md b/python/packages/azure-cosmos-memory/AGENTS.md new file mode 100644 index 0000000000..36db1f7f9a --- /dev/null +++ b/python/packages/azure-cosmos-memory/AGENTS.md @@ -0,0 +1,40 @@ +# Azure Cosmos DB Memory Package (agent-framework-azure-cosmos-memory) + +Long-term semantic memory for agents, backed by Azure Cosmos DB via the +[Azure Cosmos DB Agent Memory Toolkit](https://github.com/AzureCosmosDB/AgentMemoryToolkit). + +## Main Classes + +- **`CosmosMemoryContextProvider`** - Context provider that integrates Cosmos DB-backed + semantic memory (facts, procedural/episodic memories, and user/thread summaries) into agents. + +## Usage + +```python +from azure.identity.aio import DefaultAzureCredential +from agent_framework_azure_cosmos_memory import CosmosMemoryContextProvider + +provider = CosmosMemoryContextProvider( + cosmos_endpoint="https://.documents.azure.com:443/", + cosmos_database="ai_memory", + ai_foundry_endpoint="https://.services.ai.azure.com", + credential=DefaultAzureCredential(), +) +``` + +## Import Path + +```python +from agent_framework_azure_cosmos_memory import CosmosMemoryContextProvider +``` + +## Notes + +- Requires the `azure-cosmos-agent-memory` toolkit and an AI Foundry endpoint (used for both + embeddings and fact extraction). +- Set a stable `user_id` in `state["user_id"]` or `session.state["user_id"]` for long-term, + cross-session memory. Without it, memory scopes to the ephemeral session id and the provider + logs a one-time warning. +- Background fact extraction runs out-of-band after each turn. Call `provider.flush()` before + shutdown so in-flight extraction completes before the client closes. +- See `README.md` for full configuration, authentication, and processor-tuning options. diff --git a/python/packages/azure-cosmos-memory/LICENSE b/python/packages/azure-cosmos-memory/LICENSE new file mode 100644 index 0000000000..9e841e7a26 --- /dev/null +++ b/python/packages/azure-cosmos-memory/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/python/packages/azure-cosmos-memory/README.md b/python/packages/azure-cosmos-memory/README.md new file mode 100644 index 0000000000..16eb6d80c8 --- /dev/null +++ b/python/packages/azure-cosmos-memory/README.md @@ -0,0 +1,468 @@ +# Get Started with Microsoft Agent Framework Azure Cosmos DB Memory + +Please install this package via pip: + +```bash +pip install agent-framework-azure-cosmos-memory --pre +``` + +## Azure Cosmos DB Memory Context Provider + +The Azure Cosmos DB Memory integration provides `CosmosMemoryContextProvider` for long-term semantic memory storage using the [Azure Cosmos DB Agent Memory Toolkit](https://github.com/AzureCosmosDB/AgentMemoryToolkit). + +This context provider enables: +- **Semantic memory retrieval** - Facts, procedural knowledge, and episodic memories +- **Automatic memory extraction** - Conversation turns are processed to extract structured knowledge +- **User profile consolidation** - Cross-thread user profiles with preferences and facts +- **Memory reconciliation** - Deduplication and contradiction resolution + +### Basic Usage Example + +```python +from azure.identity.aio import DefaultAzureCredential +from agent_framework.foundry import FoundryChatClient +from agent_framework_azure_cosmos_memory import CosmosMemoryContextProvider + +# A single AI Foundry endpoint powers both memory and the chat agent +ai_foundry_endpoint = "https://.services.ai.azure.com" + +# Create the memory provider +memory_provider = CosmosMemoryContextProvider( + cosmos_endpoint="https://.documents.azure.com:443/", + cosmos_database="ai_memory", + ai_foundry_endpoint=ai_foundry_endpoint, + credential=DefaultAzureCredential(), +) + +# Create an agent with memory - reuses the same AI Foundry endpoint +agent = FoundryChatClient( + project_endpoint=ai_foundry_endpoint, + model="gpt-4o-mini", + credential=DefaultAzureCredential(), +).as_agent( + instructions="You are a helpful assistant with long-term memory.", + context_providers=[memory_provider] +) + +# Use the agent - memories are automatically stored and retrieved +session = agent.create_session() +await agent.run("I love hiking and prefer vegetarian food.", session=session) +await agent.run("What do you know about my preferences?", session=session) +``` + +### Authentication Options + +The provider supports the same authentication modes as other Azure integrations: + +- **Managed identity / RBAC** (recommended): Pass `DefaultAzureCredential()` +- **Connection string**: Set environment variables +- **Environment variables**: `COSMOS_DB_ENDPOINT`, `COSMOS_DB_DATABASE`, `AI_FOUNDRY_ENDPOINT` + +### Development Setup + +To avoid dependency conflicts with your system Python, it's recommended to use a virtual environment: + +#### Option 1: Using venv (Built-in, Cross-Platform) + +**Bash/Linux/macOS:** +```bash +# Navigate to the package directory +cd python/packages/azure-cosmos-memory + +# Create virtual environment +python3 -m venv .venv + +# Activate virtual environment +source .venv/bin/activate + +# Install package in development mode with all dependencies +pip install -e ".[dev]" + +# OPTIONAL: Install sample dependencies (needed for interactive_chat.py) +pip install -e ".[samples]" + +# Verify installation +python -c "from agent_framework_azure_cosmos_memory import CosmosMemoryContextProvider; print('✓ Package installed')" +``` + +**PowerShell:** +```powershell +# Navigate to the package directory +cd python\packages\azure-cosmos-memory + +# Create virtual environment +python -m venv .venv + +# Activate virtual environment +.\.venv\Scripts\Activate.ps1 + +# If you get execution policy errors, run first: +# Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser + +# Install package in development mode with all dependencies +pip install -e ".[dev]" + +# OPTIONAL: Install sample dependencies (needed for interactive_chat.py) +pip install -e ".[samples]" + +# Verify installation +python -c "from agent_framework_azure_cosmos_memory import CosmosMemoryContextProvider; print('✓ Package installed')" +``` + +**To deactivate the virtual environment:** +```bash +deactivate # Works on all platforms +``` + +#### Option 2: Using uv (Fast Alternative) + +If you have [uv](https://github.com/astral-sh/uv) installed: + +```bash +# Sync all dependencies including dev dependencies +uv sync --prerelease=allow + +# Run samples with uv (it manages the environment for you) +uv run python samples/interactive_chat.py +``` + +### How to Run the Samples + +**Important:** Before running samples, complete the [Development Setup](#development-setup) above to create a virtual environment and install the package. + +This package includes two samples demonstrating different usage patterns: + +#### 1. **Basic Usage (`samples/basic_usage.py`)** - API Demonstration +This sample shows the **raw ContextProvider API** by manually calling `before_run()` and `after_run()`. It demonstrates: +- How the provider searches for memories +- How memories are injected into context +- How conversations are stored +- **Not a real agent** - just shows the API mechanics + +**Run it:** + +Ensure your virtual environment is activated, then: + +```bash +# Bash/Linux/macOS +export COSMOS_DB_ENDPOINT="https://.documents.azure.com:443/" +export AI_FOUNDRY_ENDPOINT="https://.services.ai.azure.com" +python samples/basic_usage.py +``` + +```powershell +# PowerShell +$env:COSMOS_DB_ENDPOINT="https://.documents.azure.com:443/" +$env:AI_FOUNDRY_ENDPOINT="https://.services.ai.azure.com" +python samples/basic_usage.py +``` + +#### 2. **Interactive Chat (`samples/interactive_chat.py`)** - Real Agent Integration +This sample shows **real-world usage** with Agent Framework. It demonstrates: +- ✅ **Full Agent Framework integration** - actual chatbot you can interact with +- ✅ **Custom memory extraction rubric** - inject your own extraction logic +- ✅ **Multi-turn conversations** - see memories persist across sessions +- ✅ **User/thread scoping** - test memory isolation +- ✅ **Interactive CLI** - chat with the agent, switch users, start new threads + +**Prerequisites:** + +1. **Complete [Development Setup](#development-setup)** - Create venv and install package **with sample dependencies**: + ```bash + pip install -e ".[dev,samples]" + ``` + Or install separately: + ```bash + pip install -e ".[dev]" + pip install -e ".[samples]" + ``` with sample dependencies: + ```bash + pip install -e ".[dev,samples]" + ``` + +2. **Azure Resources** - You'll need: + - An Azure Cosmos DB account with a database (e.g., `ai_memory`) + - An Azure AI Foundry project with embedding and chat deployments + - The following deployments configured in AI Foundry: + - `text-embedding-3-large` (or your preferred embedding model) + - `gpt-4o-mini` (or your preferred chat model) + +3. **Configure environment variables** - Set these in your activated virtual environment. + + > **Note:** A **single** `AI_FOUNDRY_ENDPOINT` powers everything: + > - The **memory provider** uses it internally for embeddings + memory extraction. + > - The **chat agent** you talk to uses it via `FoundryChatClient`. + > + > Authentication is via `DefaultAzureCredential` (i.e. `az login`), so **no API key is required**. + + **Bash/Linux/macOS:** + ```bash + # Cosmos DB + export COSMOS_DB_ENDPOINT="https://.documents.azure.com:443/" + export COSMOS_DB_DATABASE="ai_memory" + + # AI Foundry - used by BOTH the memory provider and the chat agent + export AI_FOUNDRY_ENDPOINT="https://.services.ai.azure.com" + export AI_FOUNDRY_EMBEDDING_DEPLOYMENT_NAME="text-embedding-3-large" + export AI_FOUNDRY_CHAT_DEPLOYMENT_NAME="gpt-4o-mini" + ``` + + **PowerShell:** + ```powershell + # Cosmos DB + $env:COSMOS_DB_ENDPOINT="https://.documents.azure.com:443/" + $env:COSMOS_DB_DATABASE="ai_memory" + + # AI Foundry - used by BOTH the memory provider and the chat agent + $env:AI_FOUNDRY_ENDPOINT="https://.services.ai.azure.com" + $env:AI_FOUNDRY_EMBEDDING_DEPLOYMENT_NAME="text-embedding-3-large" + $env:AI_FOUNDRY_CHAT_DEPLOYMENT_NAME="gpt-4o-mini" + ``` + +4. **Ensure Azure authentication** - The samples use `DefaultAzureCredential`, which tries: + - Environment variables (service principal) + - Managed identity (if running in Azure) + - Azure CLI (`az login`) + - Interactive browser login (fallback) + + For local development, the easiest option is: `az login` + +5. **Run the sample** (ensure your virtual environment is activated): + + **Bash/Linux/macOS:** + ```bash + # Make sure venv is activated (you should see (.venv) in your prompt) + python samples/interactive_chat.py + ``` + + **PowerShell:** + ```powershell + # Make sure venv is activated (you should see (.venv) in your prompt) + python samples/interactive_chat.py + ``` + +**Interactive sample features:** +- Chat naturally and tell the assistant your preferences +- Use `/new` to start a new thread (memories persist across threads) +- Use `/user ` to switch users (test memory isolation) +- Use `/quit` to exit + +The interactive sample demonstrates: +- **Example 1**: Real agent with memory integration +- **Example 2**: Custom memory extraction rubric injection +- **Example 3**: Multi-user and multi-thread memory scoping + +### Custom Memory Extraction Rubric + +The Agent Memory Toolkit's `AsyncCosmosMemoryClient` accepts a custom `processor` parameter to control **what** gets extracted and **how**. There are two approaches: + +#### Approach 1: Configure via Environment Variables (Simplest) +Use `processor_config` to control extraction frequency: + +```python +memory_provider = CosmosMemoryContextProvider( + cosmos_endpoint=..., + ai_foundry_endpoint=..., + processor_config={ + "FACT_EXTRACTION_EVERY_N": "1", # Extract after every turn + "DEDUP_EVERY_N": "3", # Deduplicate every 3 extractions + "USER_SUMMARY_EVERY_N": "5", # Update user profile every 5 turns + "THREAD_SUMMARY_EVERY_N": "10", # Summarize thread every 10 turns + } +) +``` + +#### Approach 2: Custom Processor (Advanced) +Inject your own extraction logic with a custom rubric: + +```python +class CustomMemoryProcessor: + def __init__(self, extraction_rubric: str): + self.extraction_rubric = extraction_rubric # Your custom prompt + + async def extract_memories(self, user_id, thread_id, messages): + # Your extraction logic here using self.extraction_rubric + # Return list of memory records + pass + +# Create client with custom processor +memory_client = AsyncCosmosMemoryClient( + cosmos_endpoint=cosmos_endpoint, + ai_foundry_endpoint=ai_foundry_endpoint, + use_default_credential=True, + processor=CustomMemoryProcessor(YOUR_RUBRIC), # <-- Inject here +) + +# Pass to provider +memory_provider = CosmosMemoryContextProvider(memory_client=memory_client) +``` + +See [`samples/interactive_chat.py`](samples/interactive_chat.py) for a complete example with a custom extraction rubric that defines: +- What to extract (preferences, facts, decisions, patterns) +- What to ignore (transient requests, small talk, tool chatter) +- How to classify memories (fact, procedural, episodic) +- Confidence scoring rules + +### Configuration + +```python +memory_provider = CosmosMemoryContextProvider( + source_id="cosmos_memory", # Provider identifier + cosmos_endpoint="https://...", # Cosmos DB endpoint + cosmos_database="ai_memory", # Database name + ai_foundry_endpoint="https://...", # AI Foundry endpoint + credential=DefaultAzureCredential(), # Azure credential + + # Memory retrieval options + top_k=5, # Number of memories to retrieve + min_confidence=0.7, # Minimum confidence score (0.0-1.0) + memory_types=["fact", "procedural"], # Types to retrieve + + # Processing options + auto_extract=True, # Auto-extract memories after runs + processor_config={ # Optional processor settings + "FACT_EXTRACTION_EVERY_N": 1, # Extract facts every N turns + "DEDUP_EVERY_N": 5, # Deduplicate every N extractions + } +) +``` + +### Memory Types + +The provider retrieves four types of memories: + +| Type | Description | Default TTL | +|------|-------------|-------------| +| **fact** | Declarative knowledge ("user prefers dark mode") | None | +| **procedural** | Behavioral rules ("always confirm before deleting") | None | +| **episodic** | Past experiences with context and outcomes | 90 days | +| **unclassified** | Memories that couldn't be confidently classified | None | + +Each memory has a confidence score (0.0-1.0). Use `min_confidence` to filter low-quality extractions. + +### Processing Pipeline + +The memory toolkit automatically: + +1. **Stores conversation turns** - Raw messages saved to Cosmos DB +2. **Extracts memories** - LLM extracts facts, rules, and experiences +3. **Generates summaries** - Thread and user-level summaries +4. **Reconciles duplicates** - Merges similar memories and resolves contradictions + +Processing can run: +- **In-process** (default) - Zero infrastructure, suitable for prototypes and low TPS +- **Azure Functions** - Scalable processing via Cosmos DB change feed + +### Working with Multiple Providers + +Combine with other context providers for comprehensive memory: + +```python +from agent_framework import InMemoryHistoryProvider +from agent_framework_azure_cosmos import CosmosHistoryProvider +from agent_framework_azure_cosmos_memory import CosmosMemoryContextProvider + +agent = client.as_agent( + context_providers=[ + # Short-term: recent conversation + InMemoryHistoryProvider("recent"), + + # Mid-term: persistent conversation history + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + credential=credential, + database_name="agent-framework", + container_name="chat-history", + ), + + # Long-term: semantic memory with facts and profiles + CosmosMemoryContextProvider( + cosmos_endpoint=cosmos_endpoint, + ai_foundry_endpoint=ai_foundry_endpoint, + credential=credential, + ), + ] +) +``` + +### User and Thread Scoping + +Memories are scoped by `user_id` and `thread_id`: + +```python +session = agent.create_session() + +# Set user_id and thread_id in session state +session.state["user_id"] = "user-123" +session.state["thread_id"] = "thread-456" + +await agent.run("Remember that I'm allergic to peanuts.", session=session) +``` + +If not provided, the provider uses `session.session_id` as both user and thread identifiers. + +### Advanced: Custom Processing + +For fine-grained control over memory processing: + +```python +from azure.cosmos.agent_memory.aio import AsyncCosmosMemoryClient + +# Create a custom memory client +memory_client = AsyncCosmosMemoryClient( + cosmos_endpoint=cosmos_endpoint, + cosmos_database="ai_memory", + ai_foundry_endpoint=ai_foundry_endpoint, + use_default_credential=True, +) + +# Pass to the provider +memory_provider = CosmosMemoryContextProvider( + memory_client=memory_client, + auto_extract=False, # Disable automatic extraction +) + +# Manually trigger processing when needed +await memory_client.process_now(user_id="user-123", thread_id="thread-456") +``` + +### Environment Variables + +All configuration can be provided via environment variables: + +**Using a `.env` file** (cross-platform, recommended): +```bash +COSMOS_DB_ENDPOINT=https://.documents.azure.com:443/ +COSMOS_DB_DATABASE=ai_memory +AI_FOUNDRY_ENDPOINT=https://.services.ai.azure.com +AI_FOUNDRY_EMBEDDING_DEPLOYMENT_NAME=text-embedding-3-large +AI_FOUNDRY_CHAT_DEPLOYMENT_NAME=gpt-4o-mini + +# Optional: Processing configuration +FACT_EXTRACTION_EVERY_N=1 +DEDUP_EVERY_N=5 +THREAD_SUMMARY_EVERY_N=10 +USER_SUMMARY_EVERY_N=20 +``` + +**Or set in your shell session:** + +Bash/Linux/macOS: +```bash +export COSMOS_DB_ENDPOINT=https://.documents.azure.com:443/ +export COSMOS_DB_DATABASE=ai_memory +export AI_FOUNDRY_ENDPOINT=https://.services.ai.azure.com +``` + +PowerShell: +```powershell +$env:COSMOS_DB_ENDPOINT="https://.documents.azure.com:443/" +$env:COSMOS_DB_DATABASE="ai_memory" +$env:AI_FOUNDRY_ENDPOINT="https://.services.ai.azure.com" +``` + +## See Also + +- [Azure Cosmos DB Agent Memory Toolkit](https://github.com/AzureCosmosDB/AgentMemoryToolkit) +- [Agent Framework Context Providers](https://learn.microsoft.com/en-us/agent-framework/agents/conversations/context-providers?pivots=programming-language-python) +- [agent-framework-azure-cosmos](https://pypi.org/project/agent-framework-azure-cosmos/) - For basic history and checkpoint storage diff --git a/python/packages/azure-cosmos-memory/agent_framework_azure_cosmos_memory/__init__.py b/python/packages/azure-cosmos-memory/agent_framework_azure_cosmos_memory/__init__.py new file mode 100644 index 0000000000..89cd2cc740 --- /dev/null +++ b/python/packages/azure-cosmos-memory/agent_framework_azure_cosmos_memory/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft. All rights reserved. + +import importlib.metadata + +from ._context_provider import CosmosMemoryContextProvider + +try: + __version__ = importlib.metadata.version(__name__) +except importlib.metadata.PackageNotFoundError: + __version__ = "0.0.0" # Fallback for development mode + +__all__ = [ + "CosmosMemoryContextProvider", + "__version__", +] diff --git a/python/packages/azure-cosmos-memory/agent_framework_azure_cosmos_memory/_context_provider.py b/python/packages/azure-cosmos-memory/agent_framework_azure_cosmos_memory/_context_provider.py new file mode 100644 index 0000000000..d3a31e4249 --- /dev/null +++ b/python/packages/azure-cosmos-memory/agent_framework_azure_cosmos_memory/_context_provider.py @@ -0,0 +1,394 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Azure Cosmos DB Memory Context Provider using Agent Memory Toolkit. + +This module provides ``CosmosMemoryContextProvider``, built on the +:class:`ContextProvider` pattern for long-term semantic memory. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import sys +from collections.abc import Sequence +from contextlib import AbstractAsyncContextManager +from typing import TYPE_CHECKING, Any, ClassVar, TypedDict + +from agent_framework import AgentSession, ContextProvider, Message, SessionContext + +if sys.version_info >= (3, 11): + from typing import Self # pragma: no cover +else: + from typing_extensions import Self # pragma: no cover + +if TYPE_CHECKING: + from agent_framework._agents import SupportsAgentRun + from azure.core.credentials import TokenCredential + from azure.core.credentials_async import AsyncTokenCredential + from azure.cosmos.agent_memory.aio import AsyncCosmosMemoryClient + +try: + from azure.cosmos.agent_memory.aio import AsyncCosmosMemoryClient + from azure.identity.aio import DefaultAzureCredential + + _memory_toolkit_available = True +except ImportError: + _memory_toolkit_available = False + AsyncCosmosMemoryClient = None # type: ignore + DefaultAzureCredential = None # type: ignore + +logger = logging.getLogger(__name__) + +AzureCredentialTypes = "TokenCredential | AsyncTokenCredential" + + +class CosmosMemorySettings(TypedDict, total=False): + """Settings for Cosmos Memory Context Provider with auto-loading from environment.""" + + cosmos_endpoint: str | None + cosmos_database: str | None + ai_foundry_endpoint: str | None + embedding_deployment_name: str | None + chat_deployment_name: str | None + + +class CosmosMemoryContextProvider(ContextProvider): + """Azure Cosmos DB Memory context provider using Agent Memory Toolkit. + + Provides long-term semantic memory with fact extraction, user profiles, + and cross-thread memory consolidation. + """ + + DEFAULT_SOURCE_ID: ClassVar[str] = "cosmos_memory" + DEFAULT_CONTEXT_PROMPT: ClassVar[str] = "## Relevant Memories\nConsider these memories when responding:" + DEFAULT_DATABASE: ClassVar[str] = "ai_memory" + + # Agent Framework uses the "assistant" role, but the Agent Memory Toolkit's TurnRecord + # only accepts {user, agent, tool, system}. Map AF roles to toolkit roles when storing. + _ROLE_MAP: ClassVar[dict[str, str]] = {"assistant": "agent"} + + def __init__( + self, + source_id: str = DEFAULT_SOURCE_ID, + *, + cosmos_endpoint: str | None = None, + cosmos_database: str | None = None, + ai_foundry_endpoint: str | None = None, + embedding_deployment_name: str | None = None, + chat_deployment_name: str | None = None, + credential: Any = None, + memory_client: AsyncCosmosMemoryClient | None = None, + top_k: int = 5, + min_confidence: float = 0.7, + memory_types: Sequence[str] | None = None, + context_prompt: str | None = None, + auto_extract: bool = True, + processor_config: dict[str, Any] | None = None, + ) -> None: + """Initialize the Cosmos Memory context provider. + + Args: + source_id: Unique identifier for this provider instance. + cosmos_endpoint: Cosmos DB account endpoint. + Can be set via ``COSMOS_DB_ENDPOINT``. + cosmos_database: Cosmos DB database name. + Can be set via ``COSMOS_DB_DATABASE``. + ai_foundry_endpoint: AI Foundry project endpoint for LLM and embeddings. + Can be set via ``AI_FOUNDRY_ENDPOINT``. + embedding_deployment_name: Embedding model deployment name. + Can be set via ``AI_FOUNDRY_EMBEDDING_DEPLOYMENT_NAME``. + chat_deployment_name: Chat model deployment name. + Can be set via ``AI_FOUNDRY_CHAT_DEPLOYMENT_NAME``. + credential: Azure credential for authentication. If None, uses DefaultAzureCredential. + memory_client: Pre-created AsyncCosmosMemoryClient. + top_k: Number of memories to retrieve in search. + min_confidence: Minimum confidence score (0.0-1.0) for retrieved memories. + memory_types: Types of memories to retrieve. Default: ["fact", "procedural"]. + context_prompt: Prompt to prepend to retrieved memories. + auto_extract: Enable automatic memory extraction after runs. + processor_config: Optional processor configuration dict (e.g., extraction frequency). + + Raises: + ImportError: If azure-cosmos-agent-memory is not installed. + """ + if not _memory_toolkit_available: + raise ImportError( + "azure-cosmos-agent-memory is required. " + "Install with: pip install agent-framework-azure-cosmos-memory" + ) + + super().__init__(source_id) + + # Track whether we created the client (and thus should close it in __aexit__) + # vs. received a pre-created client (which the caller owns and should close) + self._should_close_client = False + self.top_k = top_k + self.min_confidence = min_confidence + self.memory_types = list(memory_types) if memory_types else ["fact", "procedural"] + self.context_prompt = context_prompt or self.DEFAULT_CONTEXT_PROMPT + self.auto_extract = auto_extract + + # Apply processor config to environment BEFORE creating the memory client. + # The AsyncCosmosMemoryClient reads these environment variables during initialization + # to configure the InProcessProcessor (extraction frequency, deduplication, etc.) + if processor_config: + for key, value in processor_config.items(): + os.environ[key] = str(value) + + # Initialize memory client if not provided + if memory_client is None: + # Load settings from environment if not provided + cosmos_endpoint = cosmos_endpoint or os.getenv("COSMOS_DB_ENDPOINT") + cosmos_database = cosmos_database or os.getenv("COSMOS_DB_DATABASE", self.DEFAULT_DATABASE) + ai_foundry_endpoint = ai_foundry_endpoint or os.getenv("AI_FOUNDRY_ENDPOINT") + embedding_deployment_name = embedding_deployment_name or os.getenv( + "AI_FOUNDRY_EMBEDDING_DEPLOYMENT_NAME", "text-embedding-3-large" + ) + chat_deployment_name = chat_deployment_name or os.getenv("AI_FOUNDRY_CHAT_DEPLOYMENT_NAME", "gpt-4o-mini") + + if not cosmos_endpoint: + raise ValueError("cosmos_endpoint must be provided or set via COSMOS_DB_ENDPOINT") + if not ai_foundry_endpoint: + raise ValueError("ai_foundry_endpoint must be provided or set via AI_FOUNDRY_ENDPOINT") + + # Create Azure credential using the standard chain: EnvironmentCredential → + # ManagedIdentityCredential → AzureCliCredential → InteractiveBrowserCredential. + # This works seamlessly in production (via ManagedIdentity) and local dev (via az login). + if credential is None: + credential = DefaultAzureCredential() # type: ignore + + memory_client = AsyncCosmosMemoryClient( + cosmos_endpoint=cosmos_endpoint, + cosmos_database=cosmos_database, + ai_foundry_endpoint=ai_foundry_endpoint, + embedding_deployment_name=embedding_deployment_name, + chat_deployment_name=chat_deployment_name, + use_default_credential=True, + ) + self._should_close_client = True + + self.memory_client = memory_client + self._cosmos_endpoint = cosmos_endpoint + self._ai_foundry_endpoint = ai_foundry_endpoint + # Emit the "no stable user_id" warning at most once per provider instance to avoid + # log spam on every run when a caller forgets to set user_id. + self._warned_user_fallback = False + + def _resolve_user_id(self, state: dict[str, Any], session: AgentSession) -> str: + """Resolve the user id for memory scoping, warning once if none was provided. + + Long-term, cross-session memory requires a *stable* user id. If the caller does + not set ``state["user_id"]`` or ``session.state["user_id"]``, memory silently + scopes to the ephemeral ``session_id`` (or ``"default"``), so cross-session recall + will not work as intended. Log a one-time warning so this misconfiguration is + visible instead of failing silently. + + Args: + state: Provider-scoped mutable state. + session: The current session. + + Returns: + The resolved user id. + """ + explicit = state.get("user_id") or session.state.get("user_id") + if explicit: + return explicit + if not self._warned_user_fallback: + self._warned_user_fallback = True + logger.warning( + "No 'user_id' found in state or session; falling back to session id '%s'. " + "Long-term cross-session memory requires a stable user_id set via " + "state['user_id'] or session.state['user_id'].", + session.session_id, + ) + return session.session_id or "default" + + async def flush(self, timeout: float = 30.0) -> None: + """Wait for any pending background memory-extraction tasks to complete. + + After each stored turn, the Agent Memory Toolkit schedules fact/summary + extraction as background ``asyncio`` tasks that run out-of-band. The client's + ``close()`` cancels any still-pending tasks, so call ``flush()`` before shutdown + to let in-flight extraction finish and persist instead of being discarded. + + Args: + timeout: Maximum seconds to wait for pending tasks to complete. + """ + tasks = getattr(self.memory_client, "_background_tasks", None) + if not tasks: + return + pending = [task for task in list(tasks) if not task.done()] + if pending: + await asyncio.wait(pending, timeout=timeout) + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + if self.memory_client and isinstance(self.memory_client, AbstractAsyncContextManager): + await self.memory_client.__aenter__() # type: ignore + # The async client cannot create or connect Cosmos containers in __init__ (no running + # event loop), so ensure the database and memory containers exist and the client is + # connected here. create_memory_store() is idempotent (create-if-not-exists), so it is + # safe to call for both provider-created and caller-provided clients. + if self.memory_client is not None: + await self.memory_client.create_memory_store() + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any + ) -> None: + """Async context manager exit. + + Only close the memory client if this provider created it (_should_close_client=True). + If a pre-created client was provided, the caller is responsible for closing it. + """ + if self.memory_client and isinstance(self.memory_client, AbstractAsyncContextManager): + if self._should_close_client: + await self.memory_client.__aexit__(exc_type, exc_val, exc_tb) # type: ignore + + async def before_run( + self, + *, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Search for relevant memories and inject into context. + + Args: + agent: The agent running this invocation. + session: The current session. + context: The invocation context to add memories to. + state: Provider-scoped mutable state. + """ + # Extract query from input messages + query_text = "\n".join(msg.text for msg in context.input_messages if msg.text and msg.text.strip()) + + if not query_text: + return + + # Get user_id from state or session (warns once if no stable user_id was provided) + user_id = self._resolve_user_id(state, session) + + # Memory search and user-summary retrieval are independent: the user summary + # provides baseline context even when no memories match the query, so a failure + # in one must not suppress the other. They get separate error handling. + try: + results = await self.memory_client.search_cosmos( + search_terms=query_text, + user_id=user_id, + top_k=self.top_k, + memory_types=self.memory_types, + min_confidence=self.min_confidence, + ) + + if results: + # Format and inject memories + memory_content = self._format_memories(results) + context.extend_messages( + self.source_id, [Message(role="user", contents=[f"{self.context_prompt}\n{memory_content}"])] + ) + except Exception as e: + logger.warning("Failed to retrieve memories: %s", e, exc_info=True) + + # Retrieve and inject user summary as agent instructions. + # This is INDEPENDENT of search results - even if no memories match the query, + # the user summary provides baseline context about the user's preferences and traits. + try: + user_summary = await self.memory_client.get_user_summary(user_id=user_id) + if user_summary: + # get_user_summary returns the Cosmos summary document (a dict) whose + # roll-up text lives in the "content" field; fall back to str() defensively. + summary_text = user_summary.get("content") if isinstance(user_summary, dict) else str(user_summary) + if summary_text and summary_text.strip(): + context.extend_instructions(self.source_id, [f"User Profile: {summary_text}"]) + except Exception as e: + logger.warning("Failed to retrieve user summary: %s", e, exc_info=True) + + async def after_run( + self, + *, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Store conversation turns and optionally trigger memory extraction. + + Args: + agent: The agent that ran this invocation. + session: The current session. + context: The invocation context with response populated. + state: Provider-scoped mutable state. + """ + # Get user_id and thread_id from state or session (warns once if no stable user_id) + user_id = self._resolve_user_id(state, session) + thread_id = state.get("thread_id") or session.state.get("thread_id") or session.session_id or "default" + + try: + # Store input messages + for msg in context.input_messages: + if hasattr(msg, "role") and hasattr(msg, "text") and msg.text: + role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + if role_value in {"user", "assistant", "system"}: + await self.memory_client.add_cosmos( + user_id=user_id, + thread_id=thread_id, + role=self._ROLE_MAP.get(role_value, role_value), + content=msg.text, + ) + + # Store response messages + if context.response and context.response.messages: + for msg in context.response.messages: + if hasattr(msg, "role") and hasattr(msg, "text") and msg.text: + role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + if role_value in {"user", "assistant", "system"}: + await self.memory_client.add_cosmos( + user_id=user_id, + thread_id=thread_id, + role=self._ROLE_MAP.get(role_value, role_value), + content=msg.text, + ) + + # Auto-extraction and processing: + # The AsyncCosmosMemoryClient uses an InProcessProcessor that runs in the background + # and automatically extracts facts, generates summaries, and reconciles memories based on + # configured thresholds (FACT_EXTRACTION_EVERY_N, DEDUP_EVERY_N, etc.). + # This happens asynchronously after add_cosmos() completes, so no explicit process_now() call is needed. + # To disable auto-extraction, set auto_extract=False and call memory_client.process_now() manually. + + except Exception as e: + logger.warning("Failed to store conversation turns: %s", e, exc_info=True) + + def _format_memories(self, memories: Sequence[dict[str, Any]]) -> str: + """Format memories for context injection. + + Each memory is formatted as: "[type] content (confidence: X.XX)" + This provides the agent with both the memory content and metadata about + its type (fact, procedural, episodic) and confidence score for better reasoning. + + Args: + memories: List of memory records from search. + + Returns: + Formatted string of memories. + """ + formatted = [] + for memory in memories: + content = memory.get("content", "") + memory_type = memory.get("memory_type", "") + confidence = memory.get("confidence", 0.0) + + # Format: [Type] Content (confidence: X.XX) + if memory_type and confidence: + formatted.append(f"[{memory_type}] {content} (confidence: {confidence:.2f})") + else: + formatted.append(content) + + return "\n".join(formatted) + + +__all__ = ["CosmosMemoryContextProvider"] diff --git a/python/packages/azure-cosmos-memory/pyproject.toml b/python/packages/azure-cosmos-memory/pyproject.toml new file mode 100644 index 0000000000..dcb2423a64 --- /dev/null +++ b/python/packages/azure-cosmos-memory/pyproject.toml @@ -0,0 +1,118 @@ +[project] +name = "agent-framework-azure-cosmos-memory" +description = "Azure Cosmos DB Agent Memory Toolkit integration for Microsoft Agent Framework - semantic memory with fact extraction and user profiles." +authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] +readme = "README.md" +requires-python = ">=3.11" +version = "1.0.0b260618" +license-files = ["LICENSE"] +urls.homepage = "https://aka.ms/agent-framework" +urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" +urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true" +urls.issues = "https://github.com/microsoft/agent-framework/issues" +classifiers = [ + "License :: OSI Approved :: MIT License", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Typing :: Typed", +] +dependencies = [ + "agent-framework-core>=1.6.0,<2", + "azure-cosmos-agent-memory>=0.1.0b2", +] + +[dependency-groups] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "pytest-cov>=4.0.0", +] +samples = [ + "agent-framework-foundry>=1.6.0,<2", + "python-dotenv>=1.0.0", +] + +[tool.uv] +prerelease = "if-necessary-or-explicit" +environments = [ + "sys_platform == 'darwin'", + "sys_platform == 'linux'", + "sys_platform == 'win32'" +] + +[tool.uv-dynamic-versioning] +fallback-version = "0.0.0" + +[tool.pytest.ini_options] +testpaths = 'tests' +addopts = "-ra -q -r fEX" +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +filterwarnings = [ + "ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*", + "ignore:.*telemetry.*:UserWarning", +] +timeout = 120 +markers = [ + "integration: marks tests as integration tests that require external services", + "azure: marks integration tests that require a live Azure account (Cosmos DB + AI Foundry)", +] + +[tool.ruff] +extend = "../../pyproject.toml" + +[tool.coverage.run] +omit = [ + "**/__init__.py" +] + +[tool.pyright] +extends = "../../pyproject.toml" +include = ["agent_framework_azure_cosmos_memory"] + +[tool.mypy] +plugins = ['pydantic.mypy'] +strict = true +python_version = "3.11" +ignore_missing_imports = true +disallow_untyped_defs = true +no_implicit_optional = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false +disallow_incomplete_defs = true +disallow_untyped_decorators = true + +[tool.bandit] +targets = ["agent_framework_azure_cosmos_memory"] +exclude_dirs = ["tests"] + +[tool.poe] +executor.type = "uv" +include = "../../shared_tasks.toml" + +[tool.poe.tasks.mypy] +help = "Run MyPy for this package." +cmd = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure_cosmos_memory" + +[tool.poe.tasks.test] +help = "Run the default unit test suite for this package." +cmd = 'pytest -m "not integration" --cov=agent_framework_azure_cosmos_memory --cov-report=term-missing:skip-covered tests' + +[tool.poe.tasks.integration-tests] +help = "Run the package integration test suite (emulator-backed, no live Azure)." +cmd = 'pytest -m "integration and not azure" tests' + +[tool.poe.tasks.integration-tests-azure] +help = "Run the live-Azure integration test suite (requires Cosmos DB + AI Foundry)." +cmd = 'pytest -m "integration and azure" tests' + +[build-system] +requires = ["flit-core >= 3.11,<4.0"] +build-backend = "flit_core.buildapi" diff --git a/python/packages/azure-cosmos-memory/samples/basic_usage.py b/python/packages/azure-cosmos-memory/samples/basic_usage.py new file mode 100644 index 0000000000..aafa804100 --- /dev/null +++ b/python/packages/azure-cosmos-memory/samples/basic_usage.py @@ -0,0 +1,143 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Sample usage of CosmosMemoryContextProvider. + +This example demonstrates: +1. Creating a provider with Azure credentials +2. Using it with an OpenAI agent +3. Multi-turn conversation with memory +4. Combining with history providers + +Prerequisites: + Install the package in development mode first: + pip install -e . + + Then run this sample: + python samples/basic_usage.py +""" + +import asyncio +import os + +from agent_framework import Message +from agent_framework._sessions import AgentSession, SessionContext +from agent_framework_azure_cosmos_memory import CosmosMemoryContextProvider +from azure.identity.aio import DefaultAzureCredential + + +async def basic_example() -> None: + """Basic example with environment variables.""" + # Create provider - reads from environment + async with CosmosMemoryContextProvider( + cosmos_endpoint=os.environ["COSMOS_DB_ENDPOINT"], + ai_foundry_endpoint=os.environ["AI_FOUNDRY_ENDPOINT"], + credential=DefaultAzureCredential(), + ) as provider: + # Use with agent session + session = AgentSession(session_id="user-session-123") + session.state["user_id"] = "alice" + session.state["thread_id"] = "conversation-1" + + # Simulate agent run - before_run searches memories + ctx = SessionContext( + input_messages=[Message(role="user", contents=["What do you know about my preferences?"])], + session_id=session.session_id, + ) + + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + print(f"Retrieved {len(ctx.context_messages.get(provider.source_id, []))} memory messages") + + # After agent responds, store the conversation + await provider.after_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + print("Conversation stored for future memory extraction") + + +async def custom_config_example() -> None: + """Example with custom configuration.""" + provider = CosmosMemoryContextProvider( + source_id="custom_memory", + cosmos_endpoint=os.environ["COSMOS_DB_ENDPOINT"], + cosmos_database="my_agent_memory", + ai_foundry_endpoint=os.environ["AI_FOUNDRY_ENDPOINT"], + embedding_deployment_name="text-embedding-3-large", + chat_deployment_name="gpt-4o-mini", + credential=DefaultAzureCredential(), + top_k=10, # Retrieve more memories + min_confidence=0.8, # Higher confidence threshold + memory_types=["fact", "procedural", "episodic"], # Include episodic memories + context_prompt="## What I Remember About You", + processor_config={ + "FACT_EXTRACTION_EVERY_N": "1", # Extract facts every message + "USER_SUMMARY_EVERY_N": "5", # Update user profile every 5 messages + }, + ) + + async with provider: + session = AgentSession(session_id="demo-session") + session.state["user_id"] = "bob" + + ctx = SessionContext( + input_messages=[Message(role="user", contents=["I'm learning Rust programming"])], + session_id=session.session_id, + ) + + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + await provider.after_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + print("Custom configured provider executed successfully") + + +async def multi_provider_example() -> None: + """Example combining memory with other providers.""" + from agent_framework_azure_cosmos_memory import CosmosMemoryContextProvider + + # Combine semantic memory with conversation history + memory_provider = CosmosMemoryContextProvider( + source_id="semantic_memory", + cosmos_endpoint=os.environ["COSMOS_DB_ENDPOINT"], + ai_foundry_endpoint=os.environ["AI_FOUNDRY_ENDPOINT"], + credential=DefaultAzureCredential(), + memory_types=["fact", "procedural"], # Long-term facts + ) + + # Note: In real usage, you'd also add a history provider like: + # from agent_framework_azure_cosmos import CosmosHistoryProvider + # history_provider = CosmosHistoryProvider(...) + + async with memory_provider: + session = AgentSession(session_id="multi-provider-session") + session.state["user_id"] = "charlie" + session.state["thread_id"] = "support-thread-456" + + ctx = SessionContext( + input_messages=[Message(role="user", contents=["How do I configure authentication?"])], + session_id=session.session_id, + ) + + # Both providers would be called in agent run + await memory_provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(memory_provider.source_id, {}) + ) # type: ignore + + print("Multi-provider setup ready") + + +if __name__ == "__main__": + print("=== Basic Example ===") + asyncio.run(basic_example()) + + print("\n=== Custom Config Example ===") + asyncio.run(custom_config_example()) + + print("\n=== Multi-Provider Example ===") + asyncio.run(multi_provider_example()) diff --git a/python/packages/azure-cosmos-memory/samples/interactive_chat.py b/python/packages/azure-cosmos-memory/samples/interactive_chat.py new file mode 100644 index 0000000000..ed2964ebc0 --- /dev/null +++ b/python/packages/azure-cosmos-memory/samples/interactive_chat.py @@ -0,0 +1,328 @@ +# Copyright (c) Microsoft. All rights reserved. +"""Interactive chat demonstrating CosmosMemoryContextProvider with Agent Framework. + +This sample shows: +- Real agent integration with memory persistence +- Custom memory extraction rubric/prompt injection +- Multi-turn conversations with semantic memory +- Memory retrieval across different sessions + +Prerequisites: + Install the package in development mode first: + pip install -e . + + Then run this sample: + python samples/interactive_chat.py +""" + +import asyncio +import os +import sys +from typing import Any + +from agent_framework import Agent +from agent_framework.foundry import FoundryChatClient +from agent_framework_azure_cosmos_memory import CosmosMemoryContextProvider +from azure.cosmos.agent_memory.aio import AsyncCosmosMemoryClient +from azure.identity.aio import DefaultAzureCredential +from dotenv import load_dotenv + + +# Custom memory extraction rubric - defines WHAT gets remembered and HOW +CUSTOM_EXTRACTION_RUBRIC = """You are a memory extraction specialist analyzing conversation transcripts. + +Your task is to identify and extract important information worth remembering long-term. + +WHAT TO EXTRACT: +- User preferences and dislikes (food, hobbies, work style, communication preferences) +- Personal facts (job title, location, family, allergies, accessibility needs) +- Decisions made during conversations (chosen solutions, rejected alternatives, rationale) +- Behavioral patterns (how user likes to approach problems, learning style) +- Project context (current projects, goals, deadlines, stakeholders) +- Technical environment (tools used, tech stack, common issues) + +WHAT TO IGNORE: +- Transient requests ("book a meeting for tomorrow") +- Small talk and greetings +- Tool output and system messages +- Temporary context that won't be useful later + +OUTPUT FORMAT: +Return ONLY valid JSON with this exact structure: +{ + "memories": [ + { + "type": "fact|procedural|episodic", + "content": "A single, clear sentence capturing the memory", + "confidence": 0.0-1.0 + } + ] +} + +MEMORY TYPES: +- fact: Declarative knowledge ("User prefers dark mode", "User is allergic to peanuts") +- procedural: Behavioral rules ("User wants confirmation before deletions", "User prefers concise answers") +- episodic: Past experiences with context ("User struggled with OAuth setup on 2024-03-15") + +CONFIDENCE SCORING: +- 0.9-1.0: Explicit statements ("I prefer...", "I always...") +- 0.7-0.9: Strong implications from behavior +- 0.5-0.7: Weak signals, might need confirmation +- Below 0.5: Don't extract + +EXAMPLES: + +Conversation: "I really dislike verbose explanations. Just give me the code." +Output: {"memories": [{"type": "procedural", "content": "User prefers concise, code-first responses without lengthy explanations", "confidence": 0.95}]} + +Conversation: "I'm working on a Python project using FastAPI and PostgreSQL." +Output: {"memories": [{"type": "fact", "content": "User is working on a Python project with FastAPI and PostgreSQL stack", "confidence": 0.9}]} + +Conversation: "What's the weather today?" +Output: {"memories": []} + +Return {"memories": []} if nothing worth remembering long-term. +""" + + +class CustomMemoryProcessor: + """Custom processor that injects our extraction rubric into the memory pipeline. + + The Azure Cosmos DB Agent Memory Toolkit accepts a custom processor that can + override the default extraction logic. This shows how to inject domain-specific + extraction rules. + """ + + def __init__(self, extraction_rubric: str): + """Initialize with custom extraction rubric. + + Args: + extraction_rubric: System prompt for memory extraction LLM calls + """ + self.extraction_rubric = extraction_rubric + + async def extract_memories( + self, user_id: str, thread_id: str, messages: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + """Extract memories from conversation using custom rubric. + + This is called by the AsyncCosmosMemoryClient after conversation turns. + + Args: + user_id: User identifier + thread_id: Conversation thread identifier + messages: Recent conversation messages + + Returns: + List of extracted memory records + """ + # In a real implementation, this would: + # 1. Format messages into transcript + # 2. Call LLM with self.extraction_rubric as system prompt + # 3. Parse and validate the JSON response + # 4. Return structured memory records + # + # For this sample, we rely on the toolkit's default processor + # but configure it via environment variables. See processor_config below. + return [] + + +async def create_agent_with_memory() -> tuple[Agent, CosmosMemoryContextProvider]: + """Create an agent with Cosmos DB memory integration. + + Returns: + Tuple of (agent, memory_provider) + """ + # Load environment variables + load_dotenv() + + cosmos_endpoint = os.environ.get("COSMOS_DB_ENDPOINT") + ai_foundry_endpoint = os.environ.get("AI_FOUNDRY_ENDPOINT") + cosmos_database = os.environ.get("COSMOS_DB_DATABASE", "ai_memory") + + # The SAME AI Foundry endpoint is used for both: + # 1. The memory provider (embeddings + memory extraction), and + # 2. The chat agent you talk to (via FoundryChatClient below). + # The only extra setting is which chat deployment the agent should use. + chat_deployment = os.environ.get("AI_FOUNDRY_CHAT_DEPLOYMENT_NAME", "gpt-4o-mini") + + if not cosmos_endpoint or not ai_foundry_endpoint: + print("ERROR: Missing required environment variables:") + print(" COSMOS_DB_ENDPOINT - Azure Cosmos DB account endpoint") + print(" AI_FOUNDRY_ENDPOINT - Azure AI Foundry project endpoint (used by BOTH memory + chat)") + print("\nOptional:") + print(" COSMOS_DB_DATABASE - Database name (default: ai_memory)") + print(" AI_FOUNDRY_CHAT_DEPLOYMENT_NAME - Chat model deployment (default: gpt-4o-mini)") + print(" AI_FOUNDRY_EMBEDDING_DEPLOYMENT_NAME - Embedding model (default: text-embedding-3-large)") + sys.exit(1) + + # Create credential (works with az login or managed identity) + credential = DefaultAzureCredential() + + # Option 1: Use the toolkit's default processor with custom configuration + # This is the simplest approach - configure extraction via environment variables + memory_provider = CosmosMemoryContextProvider( + cosmos_endpoint=cosmos_endpoint, + cosmos_database=cosmos_database, + ai_foundry_endpoint=ai_foundry_endpoint, + credential=credential, + top_k=5, # Retrieve top 5 relevant memories + min_confidence=0.7, # Only show high-confidence memories + memory_types=["fact", "procedural", "episodic"], + context_prompt="## What I Remember About You\nI'll use these memories to personalize my responses:", + # Configure the extraction processor behavior + processor_config={ + "FACT_EXTRACTION_EVERY_N": "1", # Extract after every conversation turn + "DEDUP_EVERY_N": "3", # Deduplicate every 3 extractions + "USER_SUMMARY_EVERY_N": "5", # Update user profile every 5 turns + "THREAD_SUMMARY_EVERY_N": "10", # Summarize thread every 10 turns + }, + ) + + # Option 2: Create a custom memory client with your own processor + # Uncomment this to use a fully custom extraction rubric: + # + # custom_processor = CustomMemoryProcessor(CUSTOM_EXTRACTION_RUBRIC) + # memory_client = AsyncCosmosMemoryClient( + # cosmos_endpoint=cosmos_endpoint, + # cosmos_database=cosmos_database, + # ai_foundry_endpoint=ai_foundry_endpoint, + # use_default_credential=True, + # processor=custom_processor, # Inject custom extraction logic + # ) + # memory_provider = CosmosMemoryContextProvider( + # memory_client=memory_client, + # top_k=5, + # min_confidence=0.7, + # ) + + # Create the agent with memory. + # + # FoundryChatClient talks to your Azure AI Foundry project using the SAME + # endpoint the memory provider uses (ai_foundry_endpoint). This gives a + # single-endpoint experience: one AI_FOUNDRY_ENDPOINT powers both the chat + # agent and the memory pipeline. Auth is via DefaultAzureCredential + # (az login / managed identity) - no API key required. + agent = Agent( + client=FoundryChatClient( + project_endpoint=ai_foundry_endpoint, + model=chat_deployment, + credential=DefaultAzureCredential(), + ), + name="Memory Assistant", + instructions=( + "You are a helpful assistant with long-term memory. " + "When you remember facts about the user, mention them naturally in conversation. " + "If you don't remember something, just say so - don't make up information." + ), + context_providers=[memory_provider], + ) + + return agent, memory_provider + + +async def chat_loop(agent: Agent, user_id: str) -> None: + """Run interactive chat loop. + + Args: + agent: Agent to chat with + user_id: User identifier for memory scoping + """ + print("\n" + "=" * 70) + print(" Interactive Chat with Cosmos DB Memory") + print("=" * 70) + print(f"\nUser ID: {user_id}") + print("\nCommands:") + print(" /new - Start a new conversation thread") + print(" /user - Change user ID (to test cross-user isolation)") + print(" /quit - Exit") + print("\nTips:") + print(" - Tell the assistant your preferences (food, work style, etc.)") + print(" - Start a new thread and see if it remembers you") + print(" - Change user ID to see memory isolation") + print("\n" + "=" * 70 + "\n") + + session = agent.create_session() + session.state["user_id"] = user_id + session.state["thread_id"] = f"thread-{session.session_id}" + + print(f"Started conversation thread: {session.state['thread_id']}\n") + + while True: + try: + # Read input in a worker thread so the asyncio event loop keeps running while we + # wait. This lets the toolkit's background memory-extraction tasks (scheduled after + # each stored turn) make progress between messages instead of being starved by a + # blocking input() call. + user_input = (await asyncio.to_thread(input, "You: ")).strip() + + if not user_input: + continue + + if user_input == "/quit": + print("\nGoodbye! 👋") + break + + if user_input == "/new": + # Start new thread but keep same user (memories carry over) + session = agent.create_session() + session.state["user_id"] = user_id + session.state["thread_id"] = f"thread-{session.session_id}" + print(f"\n[New conversation thread: {session.state['thread_id']}]") + print("[Memories from previous conversations will still be available]\n") + continue + + if user_input == "/user": + new_user_id = (await asyncio.to_thread(input, "Enter new user ID: ")).strip() + if new_user_id: + user_id = new_user_id + session = agent.create_session() + session.state["user_id"] = user_id + session.state["thread_id"] = f"thread-{session.session_id}" + print(f"\n[Switched to user: {user_id}]") + print(f"[New conversation thread: {session.state['thread_id']}]\n") + continue + + # Send message to agent + response = await agent.run(user_input, session=session) + + print(f"\nAssistant: {response.text}\n") + + except KeyboardInterrupt: + print("\n\nGoodbye! 👋") + break + except Exception as e: + print(f"\n❌ Error: {e}\n") + import traceback + + traceback.print_exc() + + +async def main() -> None: + """Main entry point.""" + try: + agent, memory_provider = await create_agent_with_memory() + + # Use the async context manager to ensure proper cleanup + async with memory_provider: + # Default user ID (can be changed with /user command) + default_user_id = "demo-user-123" + + try: + await chat_loop(agent, default_user_id) + finally: + # Let any in-flight background memory extraction finish and persist before the + # client closes (close() cancels still-pending background tasks). + print("Finalizing memory extraction...") + await memory_provider.flush() + + except Exception as e: + print(f"❌ Failed to initialize: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/packages/azure-cosmos-memory/tests/conftest.py b/python/packages/azure-cosmos-memory/tests/conftest.py new file mode 100644 index 0000000000..105f55d22f --- /dev/null +++ b/python/packages/azure-cosmos-memory/tests/conftest.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Pytest configuration for azure-cosmos-memory tests.""" + +import pytest + + +def pytest_configure(config: pytest.Config) -> None: + """Register custom markers.""" + config.addinivalue_line("markers", "integration: mark test as integration test requiring live Azure accounts") diff --git a/python/packages/azure-cosmos-memory/tests/test_context_provider.py b/python/packages/azure-cosmos-memory/tests/test_context_provider.py new file mode 100644 index 0000000000..0207044e92 --- /dev/null +++ b/python/packages/azure-cosmos-memory/tests/test_context_provider.py @@ -0,0 +1,585 @@ +# Copyright (c) Microsoft. All rights reserved. +# pyright: reportPrivateUsage=false + +"""Unit tests for CosmosMemoryContextProvider with mocked dependencies.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from agent_framework import AgentResponse, Message +from agent_framework._sessions import AgentSession, SessionContext + +from agent_framework_azure_cosmos_memory._context_provider import CosmosMemoryContextProvider + + +@pytest.fixture +def mock_memory_client() -> AsyncMock: + """Create a mock AsyncCosmosMemoryClient.""" + mock_client = AsyncMock() + mock_client.search_cosmos = AsyncMock(return_value=[]) + mock_client.get_user_summary = AsyncMock(return_value=None) + mock_client.add_cosmos = AsyncMock() + mock_client.create_memory_store = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock() + return mock_client + + +# -- Initialization tests ------------------------------------------------------ + + +class TestInit: + """Test CosmosMemoryContextProvider initialization.""" + + def test_init_with_all_params(self, mock_memory_client: AsyncMock) -> None: + """Initialize with all parameters provided.""" + provider = CosmosMemoryContextProvider( + source_id="test_memory", + memory_client=mock_memory_client, + top_k=10, + min_confidence=0.8, + memory_types=["fact", "episodic"], + context_prompt="Custom prompt:", + auto_extract=False, + ) + + assert provider.source_id == "test_memory" + assert provider.top_k == 10 + assert provider.min_confidence == 0.8 + assert provider.memory_types == ["fact", "episodic"] + assert provider.context_prompt == "Custom prompt:" + assert provider.auto_extract is False + assert provider.memory_client is mock_memory_client + assert provider._should_close_client is False + + def test_init_default_values(self, mock_memory_client: AsyncMock) -> None: + """Initialize with default values.""" + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + + assert provider.source_id == "cosmos_memory" + assert provider.top_k == 5 + assert provider.min_confidence == 0.7 + assert provider.memory_types == ["fact", "procedural"] + assert provider.context_prompt == CosmosMemoryContextProvider.DEFAULT_CONTEXT_PROMPT + assert provider.auto_extract is True + + def test_init_creates_client_when_none(self) -> None: + """When no client provided, creates AsyncCosmosMemoryClient with credentials.""" + with ( + patch( + "agent_framework_azure_cosmos_memory._context_provider.AsyncCosmosMemoryClient" + ) as mock_client_class, + patch("agent_framework_azure_cosmos_memory._context_provider.DefaultAzureCredential") as mock_cred_class, + ): + mock_client_class.return_value = AsyncMock() + mock_cred_class.return_value = MagicMock() + + provider = CosmosMemoryContextProvider( + cosmos_endpoint="https://test.documents.azure.com:443/", + cosmos_database="test_db", + ai_foundry_endpoint="https://test.ai.azure.com", + ) + + mock_client_class.assert_called_once() + assert provider._should_close_client is True + + def test_init_raises_without_endpoints(self) -> None: + """Raises ValueError when endpoints not provided.""" + with pytest.raises(ValueError, match="cosmos_endpoint must be provided"): + CosmosMemoryContextProvider() + + def test_init_raises_without_ai_foundry(self) -> None: + """Raises ValueError when AI Foundry endpoint not provided.""" + with pytest.raises(ValueError, match="ai_foundry_endpoint must be provided"): + CosmosMemoryContextProvider(cosmos_endpoint="https://test.documents.azure.com:443/") + + def test_init_processor_config_applied(self, mock_memory_client: AsyncMock) -> None: + """Processor config is applied to environment variables.""" + import os + + original_value = os.environ.get("FACT_EXTRACTION_EVERY_N") + try: + provider = CosmosMemoryContextProvider( + memory_client=mock_memory_client, processor_config={"FACT_EXTRACTION_EVERY_N": "10"} + ) + assert os.environ.get("FACT_EXTRACTION_EVERY_N") == "10" + finally: + if original_value is not None: + os.environ["FACT_EXTRACTION_EVERY_N"] = original_value + else: + os.environ.pop("FACT_EXTRACTION_EVERY_N", None) + + def test_init_raises_when_memory_toolkit_not_available(self) -> None: + """Raises ImportError when azure-cosmos-agent-memory not installed.""" + with patch("agent_framework_azure_cosmos_memory._context_provider._memory_toolkit_available", False): + with pytest.raises(ImportError, match="azure-cosmos-agent-memory is required"): + CosmosMemoryContextProvider(memory_client=MagicMock()) # type: ignore + + +# -- before_run tests ---------------------------------------------------------- + + +class TestBeforeRun: + """Test before_run hook - memory retrieval and context injection.""" + + async def test_retrieves_and_injects_memories(self, mock_memory_client: AsyncMock) -> None: + """Searches for memories and injects them into context.""" + mock_memory_client.search_cosmos.return_value = [ + {"content": "User prefers Python", "memory_type": "fact", "confidence": 0.95}, + {"content": "User completed ML course", "memory_type": "episodic", "confidence": 0.85}, + ] + + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", contents=["What do you know about me?"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + # Verify search was called + mock_memory_client.search_cosmos.assert_awaited_once() + call_kwargs = mock_memory_client.search_cosmos.call_args.kwargs + assert call_kwargs["user_id"] == "test-session" + assert call_kwargs["search_terms"] == "What do you know about me?" + assert call_kwargs["top_k"] == 5 + assert call_kwargs["memory_types"] == ["fact", "procedural"] + assert call_kwargs["min_confidence"] == 0.7 + + # Verify memories added to context + assert "cosmos_memory" in ctx.context_messages + added = ctx.context_messages["cosmos_memory"] + assert len(added) == 1 + assert "User prefers Python" in added[0].text # type: ignore + assert "User completed ML course" in added[0].text # type: ignore + assert "0.95" in added[0].text # type: ignore + assert "0.85" in added[0].text # type: ignore + + async def test_user_summary_injected_as_instruction(self, mock_memory_client: AsyncMock) -> None: + """User summary is retrieved and injected as instruction.""" + mock_memory_client.search_cosmos.return_value = [] + # get_user_summary returns the Cosmos summary document (a dict) whose roll-up text + # lives in the "content" field. + mock_memory_client.get_user_summary.return_value = { + "content": "Tech enthusiast, prefers concise answers", + "type": "user_summary", + } + + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + assert len(ctx.instructions) == 1 + assert "User Profile:" in ctx.instructions[0] + assert "Tech enthusiast" in ctx.instructions[0] + + async def test_empty_user_summary_dict_not_injected(self, mock_memory_client: AsyncMock) -> None: + """A user summary document with empty content is not injected.""" + mock_memory_client.search_cosmos.return_value = [] + mock_memory_client.get_user_summary.return_value = {"content": " ", "type": "user_summary"} + + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + assert len(ctx.instructions) == 0 + + async def test_no_user_summary_not_injected(self, mock_memory_client: AsyncMock) -> None: + """No user summary (None) does not inject an instruction.""" + mock_memory_client.search_cosmos.return_value = [] + mock_memory_client.get_user_summary.return_value = None + + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + assert len(ctx.instructions) == 0 + + async def test_empty_input_skips_search(self, mock_memory_client: AsyncMock) -> None: + """Empty input messages skip memory search.""" + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", contents=[""])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + mock_memory_client.search_cosmos.assert_not_awaited() + assert "cosmos_memory" not in ctx.context_messages + + async def test_empty_search_results_no_injection(self, mock_memory_client: AsyncMock) -> None: + """Empty search results don't inject messages.""" + mock_memory_client.search_cosmos.return_value = [] + + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + assert "cosmos_memory" not in ctx.context_messages + + async def test_uses_user_id_from_state(self, mock_memory_client: AsyncMock) -> None: + """Uses user_id from session state if available.""" + mock_memory_client.search_cosmos.return_value = [] + + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + session.state["user_id"] = "custom-user-123" + ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + call_kwargs = mock_memory_client.search_cosmos.call_args.kwargs + assert call_kwargs["user_id"] == "custom-user-123" + + async def test_search_failure_logs_warning(self, mock_memory_client: AsyncMock, caplog: pytest.LogCaptureFixture) -> None: + """Search failures are logged but don't raise.""" + mock_memory_client.search_cosmos.side_effect = Exception("Cosmos DB connection failed") + + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1") + + # Should not raise + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + assert "Failed to retrieve memories" in caplog.text + + async def test_search_failure_does_not_block_user_summary(self, mock_memory_client: AsyncMock) -> None: + """A search failure must not suppress user-summary injection (split error handling).""" + mock_memory_client.search_cosmos.side_effect = Exception("search boom") + mock_memory_client.get_user_summary.return_value = {"content": "Prefers concise answers"} + + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + session.state["user_id"] = "u1" + ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + # Memories failed, but the user summary was still injected as an instruction. + assert any("Prefers concise answers" in instr for instr in ctx.instructions) + + async def test_user_summary_failure_does_not_block_search(self, mock_memory_client: AsyncMock) -> None: + """A user-summary failure must not suppress memory injection (split error handling).""" + mock_memory_client.search_cosmos.return_value = [ + {"content": "User likes hiking", "memory_type": "fact", "confidence": 0.9} + ] + mock_memory_client.get_user_summary.side_effect = Exception("summary boom") + + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + session.state["user_id"] = "u1" + ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + injected = ctx.context_messages[provider.source_id] + assert any("User likes hiking" in m.text for m in injected) # type: ignore[arg-type] + + async def test_warns_once_when_no_user_id( + self, mock_memory_client: AsyncMock, caplog: pytest.LogCaptureFixture + ) -> None: + """Falling back to the session id (no stable user_id) logs a one-time warning.""" + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="ephemeral-session") + ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1") + + with caplog.at_level("WARNING"): + for _ in range(2): + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + # Search used the session id as the fallback user id... + assert mock_memory_client.search_cosmos.call_args.kwargs["user_id"] == "ephemeral-session" + # ...and the fallback warning was emitted exactly once across both runs. + assert caplog.text.count("No 'user_id' found") == 1 + + +# -- after_run tests ----------------------------------------------------------- + + +class TestAfterRun: + """Test after_run hook - conversation storage.""" + + async def test_stores_input_and_response_messages(self, mock_memory_client: AsyncMock) -> None: + """Stores both input and response messages.""" + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[Message(role="user", contents=["Hello assistant"])], + session_id="s1", + ) + ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["Hello! How can I help?"])]) + + await provider.after_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + assert mock_memory_client.add_cosmos.await_count == 2 + calls = mock_memory_client.add_cosmos.await_args_list + + # Check input message stored + assert calls[0].kwargs["role"] == "user" + assert calls[0].kwargs["content"] == "Hello assistant" + assert calls[0].kwargs["user_id"] == "test-session" + assert calls[0].kwargs["thread_id"] == "test-session" + + # Check response message stored + assert calls[1].kwargs["role"] == "agent" + assert calls[1].kwargs["content"] == "Hello! How can I help?" + + async def test_assistant_role_mapped_to_agent(self, mock_memory_client: AsyncMock) -> None: + """Agent Framework 'assistant' role is mapped to the toolkit's 'agent' role. + + The Agent Memory Toolkit's TurnRecord only accepts {user, agent, tool, system}; + storing 'assistant' raises a pydantic validation error. + """ + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[Message(role="user", contents=["Hi"])], + session_id="s1", + ) + ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["Hello there"])]) + + await provider.after_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + stored_roles = [c.kwargs["role"] for c in mock_memory_client.add_cosmos.await_args_list] + assert stored_roles == ["user", "agent"] + # No raw "assistant" role should ever be sent to the toolkit. + assert "assistant" not in stored_roles + + async def test_uses_custom_user_and_thread_ids(self, mock_memory_client: AsyncMock) -> None: + """Uses custom user_id and thread_id from state.""" + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + session.state["user_id"] = "user-456" + session.state["thread_id"] = "thread-789" + ctx = SessionContext( + input_messages=[Message(role="user", contents=["test"])], + session_id="s1", + ) + + await provider.after_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + call_kwargs = mock_memory_client.add_cosmos.await_args_list[0].kwargs + assert call_kwargs["user_id"] == "user-456" + assert call_kwargs["thread_id"] == "thread-789" + + async def test_skips_empty_messages(self, mock_memory_client: AsyncMock) -> None: + """Skips messages with no text content.""" + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[ + Message(role="user", contents=[""]), + Message(role="user", contents=["Valid message"]), + ], + session_id="s1", + ) + + await provider.after_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + # Only one message should be stored + assert mock_memory_client.add_cosmos.await_count == 1 + call_kwargs = mock_memory_client.add_cosmos.await_args_list[0].kwargs + assert call_kwargs["content"] == "Valid message" + + async def test_storage_failure_logs_warning(self, mock_memory_client: AsyncMock, caplog: pytest.LogCaptureFixture) -> None: + """Storage failures are logged but don't raise.""" + mock_memory_client.add_cosmos.side_effect = Exception("Storage failed") + + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1") + + # Should not raise + await provider.after_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + assert "Failed to store conversation turns" in caplog.text + + +# -- Helper method tests ------------------------------------------------------- + + +class TestFormatMemories: + """Test _format_memories helper method.""" + + def test_formats_with_type_and_confidence(self, mock_memory_client: AsyncMock) -> None: + """Formats memories with type and confidence.""" + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + memories = [ + {"content": "User likes Python", "memory_type": "fact", "confidence": 0.95}, + {"content": "User prefers vim", "memory_type": "procedural", "confidence": 0.82}, + ] + + result = provider._format_memories(memories) + + assert "[fact] User likes Python (confidence: 0.95)" in result + assert "[procedural] User prefers vim (confidence: 0.82)" in result + + def test_formats_without_metadata(self, mock_memory_client: AsyncMock) -> None: + """Formats memories without type/confidence metadata.""" + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + memories = [{"content": "Some memory"}] + + result = provider._format_memories(memories) + + assert result == "Some memory" + + +# -- Context manager tests ----------------------------------------------------- + + +class TestContextManager: + """Test async context manager protocol.""" + + async def test_enters_and_exits_client(self, mock_memory_client: AsyncMock) -> None: + """Enters and exits the memory client when provider owns it.""" + # When provider creates the client, it should manage its lifecycle + with ( + patch( + "agent_framework_azure_cosmos_memory._context_provider.AsyncCosmosMemoryClient" + ) as mock_client_class, + patch("agent_framework_azure_cosmos_memory._context_provider.DefaultAzureCredential"), + ): + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + provider = CosmosMemoryContextProvider( + cosmos_endpoint="https://test.documents.azure.com:443/", + ai_foundry_endpoint="https://test.ai.azure.com", + ) + + async with provider: + pass + + mock_client.__aenter__.assert_awaited_once() + mock_client.__aexit__.assert_awaited_once() + + async def test_provided_client_not_closed(self, mock_memory_client: AsyncMock) -> None: + """When client is provided externally, provider should not close it.""" + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + + async with provider: + pass + + # Should still enter the client + mock_memory_client.__aenter__.assert_awaited_once() + # But should NOT exit it (caller owns it) + mock_memory_client.__aexit__.assert_not_awaited() + + async def test_aenter_creates_memory_store(self, mock_memory_client: AsyncMock) -> None: + """Entering the provider creates/connects the Cosmos memory store. + + The async client cannot create or connect Cosmos containers in __init__ + (no running event loop), so the provider must call create_memory_store() + on entry. Without this, add_cosmos/search_cosmos raise CosmosNotConnectedError + and no containers are ever created. + """ + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + + async with provider: + pass + + mock_memory_client.create_memory_store.assert_awaited_once() + + +class TestFlush: + """Test flush() draining of pending background extraction tasks.""" + + async def test_flush_waits_for_pending_tasks(self, mock_memory_client: AsyncMock) -> None: + """flush() awaits in-flight background tasks so extraction can complete.""" + import asyncio + + completed = False + + async def _work() -> None: + nonlocal completed + await asyncio.sleep(0.01) + completed = True + + task = asyncio.ensure_future(_work()) + mock_memory_client._background_tasks = {task} + + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + await provider.flush() + + assert task.done() + assert completed is True + + async def test_flush_no_tasks_is_noop(self, mock_memory_client: AsyncMock) -> None: + """flush() returns cleanly when there are no background tasks.""" + mock_memory_client._background_tasks = set() + + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + # Should not raise. + await provider.flush() + + async def test_flush_handles_missing_attribute(self, mock_memory_client: AsyncMock) -> None: + """flush() is a no-op if the client exposes no background-task registry.""" + # Simulate a client without a usable background-task registry. + mock_memory_client._background_tasks = None + + provider = CosmosMemoryContextProvider(memory_client=mock_memory_client) + # Should not raise. + await provider.flush() + async def test_only_closes_owned_client(self) -> None: + """Only closes client if provider created it.""" + with ( + patch( + "agent_framework_azure_cosmos_memory._context_provider.AsyncCosmosMemoryClient" + ) as mock_client_class, + patch("agent_framework_azure_cosmos_memory._context_provider.DefaultAzureCredential"), + ): + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + + provider = CosmosMemoryContextProvider( + cosmos_endpoint="https://test.documents.azure.com:443/", + ai_foundry_endpoint="https://test.ai.azure.com", + ) + + assert provider._should_close_client is True + + async with provider: + pass + + mock_client.__aenter__.assert_awaited_once() + mock_client.__aexit__.assert_awaited_once() diff --git a/python/packages/azure-cosmos-memory/tests/test_integration.py b/python/packages/azure-cosmos-memory/tests/test_integration.py new file mode 100644 index 0000000000..04a1bc0570 --- /dev/null +++ b/python/packages/azure-cosmos-memory/tests/test_integration.py @@ -0,0 +1,262 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Integration tests for CosmosMemoryContextProvider with live Azure accounts. + +These tests require valid Azure credentials and environment variables: +- COSMOS_DB_ENDPOINT: Cosmos DB account endpoint +- COSMOS_DB_DATABASE: Database name (will be created if not exists) +- AI_FOUNDRY_ENDPOINT: AI Foundry project endpoint +- AI_FOUNDRY_EMBEDDING_DEPLOYMENT_NAME: Embedding model deployment +- AI_FOUNDRY_CHAT_DEPLOYMENT_NAME: Chat model deployment + +Run with: pytest -m integration tests/ +""" + +from __future__ import annotations + +import os +import uuid + +import pytest +from agent_framework import Message +from agent_framework._sessions import AgentSession, SessionContext +from azure.identity.aio import DefaultAzureCredential + +from agent_framework_azure_cosmos_memory import CosmosMemoryContextProvider + +# Skip all tests in this module if required env vars not set. +# These tests hit a LIVE Azure account (Cosmos DB + AI Foundry), so they carry both +# the ``integration`` and ``azure`` markers. The emulator-backed suite in +# ``test_emulator.py`` is marked ``integration`` only and runs without any Azure account. +pytestmark = [pytest.mark.integration, pytest.mark.azure] + +REQUIRED_ENV_VARS = [ + "COSMOS_DB_ENDPOINT", + "AI_FOUNDRY_ENDPOINT", +] + + +def _check_env_vars() -> tuple[bool, list[str]]: + """Check if required environment variables are set.""" + missing = [var for var in REQUIRED_ENV_VARS if not os.getenv(var)] + return len(missing) == 0, missing + + +@pytest.fixture(scope="module") +def skip_if_no_env() -> None: + """Skip integration tests if environment variables not configured.""" + has_env, missing = _check_env_vars() + if not has_env: + pytest.skip(f"Integration tests require environment variables: {', '.join(missing)}") + + +@pytest.fixture +async def live_provider(skip_if_no_env: None) -> CosmosMemoryContextProvider: + """Create a live CosmosMemoryContextProvider with real Azure credentials.""" + provider = CosmosMemoryContextProvider( + cosmos_endpoint=os.environ["COSMOS_DB_ENDPOINT"], + cosmos_database=os.getenv("COSMOS_DB_DATABASE", "test_agent_memory"), + ai_foundry_endpoint=os.environ["AI_FOUNDRY_ENDPOINT"], + embedding_deployment_name=os.getenv("AI_FOUNDRY_EMBEDDING_DEPLOYMENT_NAME", "text-embedding-3-large"), + chat_deployment_name=os.getenv("AI_FOUNDRY_CHAT_DEPLOYMENT_NAME", "gpt-4o-mini"), + credential=DefaultAzureCredential(), + top_k=3, + min_confidence=0.5, + ) + + async with provider: + yield provider + + +@pytest.fixture +def test_user_id() -> str: + """Generate a unique user ID for test isolation.""" + return f"test-user-{uuid.uuid4().hex[:8]}" + + +@pytest.fixture +def test_thread_id() -> str: + """Generate a unique thread ID for test isolation.""" + return f"test-thread-{uuid.uuid4().hex[:8]}" + + +# -- Basic functionality tests ------------------------------------------------- + + +class TestBasicFunctionality: + """Test basic memory storage and retrieval with live accounts.""" + + async def test_store_and_retrieve_conversation( + self, live_provider: CosmosMemoryContextProvider, test_user_id: str, test_thread_id: str + ) -> None: + """Store a conversation and verify it's persisted.""" + session = AgentSession(session_id="integration-test") + session.state["user_id"] = test_user_id + session.state["thread_id"] = test_thread_id + + # Store messages + ctx = SessionContext( + input_messages=[Message(role="user", contents=["I love Python programming"])], + session_id=session.session_id, + ) + + await live_provider.after_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(live_provider.source_id, {}) + ) # type: ignore + + # Verify messages were stored (this tests the memory client integration) + # In a real scenario, the memory extraction pipeline would process these + # For this test, we're verifying the storage mechanism works + + async def test_search_returns_results( + self, live_provider: CosmosMemoryContextProvider, test_user_id: str, test_thread_id: str + ) -> None: + """Search for memories (may return empty if no facts extracted yet).""" + session = AgentSession(session_id="integration-test") + session.state["user_id"] = test_user_id + + ctx = SessionContext( + input_messages=[Message(role="user", contents=["What are my programming preferences?"])], + session_id=session.session_id, + ) + + # Should not raise even if no memories exist yet + await live_provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(live_provider.source_id, {}) + ) # type: ignore + + +# -- Multi-turn conversation tests --------------------------------------------- + + +class TestMultiTurnConversation: + """Test memory across multiple conversation turns.""" + + async def test_multi_turn_storage( + self, live_provider: CosmosMemoryContextProvider, test_user_id: str, test_thread_id: str + ) -> None: + """Store multiple conversation turns.""" + session = AgentSession(session_id="integration-test") + session.state["user_id"] = test_user_id + session.state["thread_id"] = test_thread_id + + conversations = [ + ("user", "My name is Alice"), + ("assistant", "Nice to meet you, Alice!"), + ("user", "I work as a data scientist"), + ("assistant", "That's a great field!"), + ] + + for role, content in conversations: + ctx = SessionContext( + input_messages=[Message(role=role, contents=[content])], # type: ignore + session_id=session.session_id, + ) + + await live_provider.after_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(live_provider.source_id, {}) + ) # type: ignore + + +# -- Error handling tests ------------------------------------------------------ + + +class TestErrorHandling: + """Test error handling in integration scenarios.""" + + async def test_handles_missing_user_id_gracefully(self, live_provider: CosmosMemoryContextProvider) -> None: + """Falls back to session_id when user_id not in state.""" + session = AgentSession(session_id="fallback-test") + ctx = SessionContext( + input_messages=[Message(role="user", contents=["test"])], + session_id=session.session_id, + ) + + # Should use session_id as fallback and not raise + await live_provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(live_provider.source_id, {}) + ) # type: ignore + + async def test_handles_empty_messages( + self, live_provider: CosmosMemoryContextProvider, test_user_id: str, test_thread_id: str + ) -> None: + """Handles empty message content gracefully.""" + session = AgentSession(session_id="integration-test") + session.state["user_id"] = test_user_id + session.state["thread_id"] = test_thread_id + + ctx = SessionContext( + input_messages=[Message(role="user", contents=[""])], + session_id=session.session_id, + ) + + # Should not raise + await live_provider.after_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(live_provider.source_id, {}) + ) # type: ignore + + +# -- Configuration tests ------------------------------------------------------- + + +class TestConfiguration: + """Test different configuration options.""" + + async def test_custom_memory_types(self, skip_if_no_env: None, test_user_id: str) -> None: + """Provider with custom memory types configuration.""" + provider = CosmosMemoryContextProvider( + cosmos_endpoint=os.environ["COSMOS_DB_ENDPOINT"], + cosmos_database=os.getenv("COSMOS_DB_DATABASE", "test_agent_memory"), + ai_foundry_endpoint=os.environ["AI_FOUNDRY_ENDPOINT"], + credential=DefaultAzureCredential(), + memory_types=["fact", "episodic", "procedural"], + min_confidence=0.8, + top_k=10, + ) + + async with provider: + session = AgentSession(session_id="config-test") + session.state["user_id"] = test_user_id + + ctx = SessionContext( + input_messages=[Message(role="user", contents=["test query"])], + session_id=session.session_id, + ) + + # Should not raise + await provider.before_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + async def test_processor_config(self, skip_if_no_env: None, test_user_id: str, test_thread_id: str) -> None: + """Provider with custom processor configuration.""" + provider = CosmosMemoryContextProvider( + cosmos_endpoint=os.environ["COSMOS_DB_ENDPOINT"], + cosmos_database=os.getenv("COSMOS_DB_DATABASE", "test_agent_memory"), + ai_foundry_endpoint=os.environ["AI_FOUNDRY_ENDPOINT"], + credential=DefaultAzureCredential(), + processor_config={ + "FACT_EXTRACTION_EVERY_N": "1", + "DEDUP_EVERY_N": "3", + }, + ) + + async with provider: + session = AgentSession(session_id="config-test") + session.state["user_id"] = test_user_id + session.state["thread_id"] = test_thread_id + + ctx = SessionContext( + input_messages=[Message(role="user", contents=["I prefer TypeScript over JavaScript"])], + session_id=session.session_id, + ) + + # Should not raise + await provider.after_run( + agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore + + +# -- Cleanup note -------------------------------------------------------------- +# Note: These integration tests create data in the live Cosmos DB account. +# Consider adding cleanup logic or using time-based partitions if running frequently.