diff --git a/python/packages/a2a/AGENTS.md b/python/packages/a2a/AGENTS.md index d27d6d1c11..46f07333bd 100644 --- a/python/packages/a2a/AGENTS.md +++ b/python/packages/a2a/AGENTS.md @@ -6,6 +6,8 @@ Agent-to-Agent (A2A) protocol support for inter-agent communication. - **`A2AAgent`** - Client to connect to remote A2A-compliant agents. - **`A2AExecutor`** - Bridge to expose Agent Framework agents via the A2A protocol. +- **`A2AServiceSessionId`** - Typed durable A2A continuation state shape stored in `AgentSession.service_session_id`. +- **`A2AAgentSession`** - Deprecated compatibility session wrapper; prefer `AgentSession` + `A2AServiceSessionId`. ## Usage @@ -50,7 +52,7 @@ app = Starlette( ## Import Path ```python -from agent_framework.a2a import A2AAgent, A2AExecutor +from agent_framework.a2a import A2AAgent, A2AExecutor, A2AServiceSessionId # or directly: -from agent_framework_a2a import A2AAgent, A2AExecutor +from agent_framework_a2a import A2AAgent, A2AExecutor, A2AServiceSessionId ``` diff --git a/python/packages/a2a/agent_framework_a2a/__init__.py b/python/packages/a2a/agent_framework_a2a/__init__.py index 4f1cf41ba0..0ad9589790 100644 --- a/python/packages/a2a/agent_framework_a2a/__init__.py +++ b/python/packages/a2a/agent_framework_a2a/__init__.py @@ -3,7 +3,7 @@ import importlib.metadata from ._a2a_executor import A2AExecutor -from ._agent import A2AAgent, A2AAgentSession, A2AContinuationToken +from ._agent import A2AAgent, A2AAgentSession, A2AContinuationToken, A2AServiceSessionId try: __version__ = importlib.metadata.version(__name__) @@ -15,5 +15,6 @@ "A2AAgentSession", "A2AContinuationToken", "A2AExecutor", + "A2AServiceSessionId", "__version__", ] diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index c2bf1d6ffb..01709fcf3c 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -3,9 +3,11 @@ from __future__ import annotations import base64 +import sys import uuid +import warnings from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence -from typing import Any, Final, Literal, TypeAlias, overload +from typing import Any, Final, Literal, TypeAlias, cast, overload import httpx from a2a.client import Client, ClientConfig, ClientFactory, minimal_agent_card @@ -35,6 +37,7 @@ HistoryProvider, Message, ResponseStream, + ServiceSessionId, SessionContext, normalize_messages, prepend_agent_framework_to_user_agent, @@ -43,11 +46,24 @@ from agent_framework.observability import AgentTelemetryLayer from google.protobuf.json_format import MessageToDict -__all__ = ["A2AAgent", "A2AAgentSession", "A2AContinuationToken"] +if sys.version_info >= (3, 11): + from typing import TypedDict # pragma: no cover +else: + from typing_extensions import TypedDict # pragma: no cover + +__all__ = ["A2AAgent", "A2AAgentSession", "A2AContinuationToken", "A2AServiceSessionId"] from agent_framework_a2a._utils import get_uri_data +class A2AServiceSessionId(TypedDict): + """Durable A2A continuation state stored in ``AgentSession.service_session_id``.""" + + context_id: str + task_id: str | None + task_state: TaskState | None + + class A2AAgentSession(AgentSession): """Session for A2A-based agents. @@ -79,13 +95,27 @@ def __init__( task_id: Optional task ID from a previous interaction. task_state: Optional state of the most recent task. """ - super().__init__(service_session_id=context_id) + warnings.warn( + "A2AAgentSession is deprecated and will be removed in a future version. " + "Use AgentSession with service_session_id=A2AServiceSessionId(...) instead.", + DeprecationWarning, + stacklevel=2, + ) self.context_id: str | None = context_id self.task_id: str | None = task_id self.task_state: TaskState | None = task_state + service_session_id: str | ServiceSessionId | None = None + if context_id is not None: + service_session_id = A2AServiceSessionId( + context_id=context_id, + task_id=task_id, + task_state=task_state, + ) + super().__init__(service_session_id=service_session_id) def to_dict(self) -> dict[str, Any]: """Serialize session to a plain dict for storage/transfer.""" + self._sync_service_session_id() data = super().to_dict() if self.context_id is not None: data[self._CONTEXT_ID_KEY] = self.context_id @@ -115,9 +145,22 @@ def from_dict(cls, data: dict[str, Any]) -> A2AAgentSession: # Delegate state deserialization to the base class base_session = AgentSession.from_dict(data) + base_service_session = base_session.service_session_id + if isinstance(base_service_session, Mapping): + mapped_context_id = base_service_session.get("context_id") + if isinstance(mapped_context_id, str): + context_id = context_id or mapped_context_id + mapped_task_id = base_service_session.get("task_id") + if isinstance(mapped_task_id, str): + task_id = task_id or mapped_task_id + mapped_task_state = base_service_session.get("task_state") + if isinstance(mapped_task_state, int): + task_state = task_state if task_state is not None else cast(TaskState, mapped_task_state) + elif isinstance(base_service_session, str): + context_id = context_id or base_service_session session = cls( - context_id=context_id or base_session.service_session_id, + context_id=context_id, task_id=task_id, task_state=task_state, ) @@ -125,6 +168,17 @@ def from_dict(cls, data: dict[str, Any]) -> A2AAgentSession: session.state.update(base_session.state) return session + def _sync_service_session_id(self) -> None: + """Keep compatibility fields and ``service_session_id`` aligned.""" + if self.context_id is None: + self.service_session_id = None + return + self.service_session_id = A2AServiceSessionId( + context_id=self.context_id, + task_id=self.task_id, + task_state=self.task_state, + ) + class A2AContinuationToken(ContinuationToken): """Continuation token for A2A protocol long-running tasks.""" @@ -316,6 +370,83 @@ async def __aexit__( if self._http_client is not None and self._close_http_client: await self._http_client.aclose() + @staticmethod + def _build_service_session_id( + *, + context_id: str, + task_id: str | None, + task_state: TaskState | None, + ) -> A2AServiceSessionId: + return A2AServiceSessionId( + context_id=context_id, + task_id=task_id, + task_state=task_state, + ) + + def _extract_a2a_session_state( + self, + session: AgentSession | None, + ) -> tuple[str | None, str | None, TaskState | None]: + """Extract A2A continuation state from supported session shapes.""" + if session is None: + return None, None, None + if isinstance(session, A2AAgentSession): + return session.context_id, session.task_id, session.task_state + service_session_id = session.service_session_id + if service_session_id is None: + return None, None, None + if isinstance(service_session_id, str): + return service_session_id, None, None + if not isinstance(service_session_id, Mapping): + raise ValueError( + "A2AAgent requires service_session_id to be a string or mapping with context_id/task_id/task_state." + ) + + context_id = service_session_id.get("context_id") + if not isinstance(context_id, str) or not context_id: + raise ValueError("A2A service_session_id mapping must include a non-empty string 'context_id'.") + + task_id_value = service_session_id.get("task_id") + if task_id_value is None: + task_id = None + elif isinstance(task_id_value, str): + task_id = task_id_value + else: + raise ValueError("A2A service_session_id mapping field 'task_id' must be a string or None.") + + task_state_value = service_session_id.get("task_state") + if task_state_value is None: + task_state = None + elif isinstance(task_state_value, int): + task_state = cast(TaskState, task_state_value) + else: + raise ValueError("A2A service_session_id mapping field 'task_state' must be an integer enum value or None.") + return context_id, task_id, task_state + + def _apply_a2a_session_state( + self, + *, + session: AgentSession, + context_id: str, + task_id: str | None, + task_state: TaskState | None, + ) -> None: + """Persist durable A2A continuation state on the session.""" + session.service_session_id = self._build_service_session_id( + context_id=context_id, + task_id=task_id, + task_state=task_state, + ) + if isinstance(session, A2AAgentSession): + session.context_id = context_id + session.task_id = task_id + session.task_state = task_state + + def _get_otel_conversation_id(self, session: AgentSession | None) -> str | None: + """Return A2A context_id as OpenTelemetry conversation id.""" + context_id, _, _ = self._extract_a2a_session_state(session) + return context_id + @overload def run( self, @@ -577,20 +708,24 @@ async def _map_a2a_stream( session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment] # Persist A2A protocol state on the session for follow-up message linking. - if isinstance(session, A2AAgentSession) and (last_task_id or last_context_id): + if session is not None and (last_task_id or last_context_id): + existing_context_id, existing_task_id, existing_task_state = self._extract_a2a_session_state(session) + # Validate context_id consistency - if session.context_id is not None and last_context_id and session.context_id != last_context_id: + if existing_context_id is not None and last_context_id and existing_context_id != last_context_id: raise RuntimeError( f"The context_id returned from the A2A agent ('{last_context_id}') " - f"differs from the session's context_id ('{session.context_id}')." + f"differs from the session's context_id ('{existing_context_id}')." + ) + + persisted_context_id = existing_context_id or last_context_id + if persisted_context_id is not None: + self._apply_a2a_session_state( + session=session, + context_id=persisted_context_id, + task_id=last_task_id or existing_task_id, + task_state=last_task_state if last_task_state is not None else existing_task_state, ) - # Assign server-generated context_id if not already set - if session.context_id is None and last_context_id: - session.context_id = last_context_id - session.service_session_id = last_context_id - if last_task_id: - session.task_id = last_task_id - session.task_state = last_task_state await self._run_after_providers(session=session, context=session_context) @@ -789,19 +924,11 @@ def _prepare_message_for_a2a(self, message: Message, *, session: AgentSession | Keyword Args: session: Optional session to read A2A state from. If an ``A2AAgentSession``, context_id/task_id/task_state are used for - linking. A plain ``AgentSession`` provides service_session_id as - a fallback context_id. + linking. A plain ``AgentSession`` may provide either a string + or structured A2A ``service_session_id`` mapping. """ # Extract A2A state from the session - context_id: str | None = None - previous_task_id: str | None = None - task_state: TaskState | None = None - if isinstance(session, A2AAgentSession): - context_id = session.context_id - previous_task_id = session.task_id - task_state = session.task_state - elif session is not None: - context_id = session.service_session_id + context_id, previous_task_id, task_state = self._extract_a2a_session_state(session) parts: list[A2APart] = [] if not message.contents: diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index ea271f7379..1e180fcd12 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -29,9 +29,9 @@ SessionContext, ) from agent_framework.a2a import A2AAgent -from pytest import fixture, mark, raises +from pytest import fixture, mark, raises, warns -from agent_framework_a2a import A2AAgentSession, A2AContinuationToken +from agent_framework_a2a import A2AAgentSession, A2AContinuationToken, A2AServiceSessionId from agent_framework_a2a._utils import get_uri_data @@ -146,6 +146,12 @@ def test_a2a_agent_initialization_with_client(mock_a2a_client: MockA2AClient) -> assert agent.client == mock_a2a_client +def test_a2a_agent_session_emits_deprecation_warning() -> None: + """A2AAgentSession emits a deprecation warning on construction.""" + with warns(DeprecationWarning, match="A2AAgentSession is deprecated"): + A2AAgentSession() + + def test_a2a_agent_defaults_name_description_from_agent_card(mock_a2a_client: MockA2AClient) -> None: """Test A2AAgent defaults name and description from agent_card when not explicitly provided.""" mock_card = MagicMock(spec=AgentCard) @@ -2034,7 +2040,11 @@ async def test_context_id_assigned_from_response(mock_a2a_client: MockA2AClient) # context_id from the task response should be assigned assert session.context_id == "test-context" - assert session.service_session_id == "test-context" + assert session.service_session_id == A2AServiceSessionId( + context_id="test-context", + task_id="task-ctx", + task_state=TaskState.TASK_STATE_COMPLETED, + ) @mark.asyncio @@ -2056,7 +2066,11 @@ async def test_context_id_tracked_from_message_payload(mock_a2a_client: MockA2AC # context_id should be captured even without a task_id assert session.context_id == "server-ctx-123" - assert session.service_session_id == "server-ctx-123" + assert session.service_session_id == A2AServiceSessionId( + context_id="server-ctx-123", + task_id=None, + task_state=None, + ) assert session.task_id is None @@ -2096,21 +2110,24 @@ async def test_task_state_tracked_on_session(mock_a2a_client: MockA2AClient) -> @mark.asyncio -async def test_plain_agent_session_no_reference_tracking(mock_a2a_client: MockA2AClient) -> None: - """Test that a plain AgentSession works but does not get reference_task_ids tracking.""" +async def test_plain_agent_session_tracks_structured_service_session_id(mock_a2a_client: MockA2AClient) -> None: + """Plain AgentSession should persist A2A continuation state in structured service_session_id.""" agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) mock_a2a_client.add_task_response("task-plain", [{"content": "Reply"}]) session = AgentSession() await agent.run("Hello", session=session) - # Plain session does not get task_id tracking - assert "a2a_task_id" not in session.state + assert session.service_session_id == A2AServiceSessionId( + context_id="test-context", + task_id="task-plain", + task_state=TaskState.TASK_STATE_COMPLETED, + ) - # Follow-up has no reference_task_ids (no tracking on plain session) + # Follow-up should use the tracked task_id in reference_task_ids mock_a2a_client.add_task_response("task-plain-2", [{"content": "Reply 2"}]) await agent.run("Follow up", session=session) - assert list(mock_a2a_client.last_message.reference_task_ids) == [] + assert list(mock_a2a_client.last_message.reference_task_ids) == ["task-plain"] @mark.asyncio @@ -2129,6 +2146,52 @@ async def test_a2a_agent_session_serialization() -> None: assert restored.context_id == "ctx-456" assert restored.task_id == "task-789" assert restored.task_state == TaskState.TASK_STATE_COMPLETED + assert restored.service_session_id == A2AServiceSessionId( + context_id="ctx-456", + task_id="task-789", + task_state=TaskState.TASK_STATE_COMPLETED, + ) + + +@mark.asyncio +async def test_plain_agent_session_structured_service_session_id_for_input_required( + mock_a2a_client: MockA2AClient, +) -> None: + """Structured service_session_id should drive INPUT_REQUIRED follow-up task_id behavior.""" + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) + session = AgentSession( + service_session_id=A2AServiceSessionId( + context_id="ctx-ir", + task_id="task-ir-123", + task_state=TaskState.TASK_STATE_INPUT_REQUIRED, + ) + ) + + mock_a2a_client.add_in_progress_task_response( + "task-ir-456", + context_id="ctx-ir", + state=TaskState.TASK_STATE_COMPLETED, + text="Thanks!", + ) + await agent.run("My name is Alice", session=session) + + last_msg = mock_a2a_client.last_message + assert last_msg.task_id == "task-ir-123" + assert list(last_msg.reference_task_ids) == [] + + +def test_a2a_agent_otel_conversation_id_uses_context_id() -> None: + """Telemetry conversation id should map to context_id for structured A2A sessions.""" + agent = A2AAgent(client=MagicMock(), http_client=None) + session = AgentSession( + service_session_id=A2AServiceSessionId( + context_id="ctx-otel", + task_id="task-otel", + task_state=TaskState.TASK_STATE_WORKING, + ) + ) + + assert agent._get_otel_conversation_id(session) == "ctx-otel" @mark.asyncio diff --git a/python/packages/ag-ui/tests/ag_ui/conftest.py b/python/packages/ag-ui/tests/ag_ui/conftest.py index e68b395bf2..c4d4d972cd 100644 --- a/python/packages/ag-ui/tests/ag_ui/conftest.py +++ b/python/packages/ag-ui/tests/ag_ui/conftest.py @@ -19,6 +19,7 @@ ChatResponseUpdate, Content, Message, + ServiceSessionId, SupportsAgentRun, SupportsChatGetResponse, ) @@ -58,7 +59,7 @@ def __init__(self, stream_fn: StreamFn, response_fn: ResponseFn | None = None) - self._stream_fn = stream_fn self._response_fn = response_fn self.last_session: AgentSession | None = None - self.last_service_session_id: str | None = None + self.last_service_session_id: str | ServiceSessionId | None = None @overload def get_response( @@ -244,7 +245,7 @@ async def _get_response() -> AgentResponse[Any]: def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession(session_id=kwargs.get("session_id")) - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + def get_session(self, service_session_id: str | ServiceSessionId, *, session_id: str | None = None) -> AgentSession: return AgentSession(session_id=session_id, service_session_id=service_session_id) diff --git a/python/packages/claude/agent_framework_claude/_agent.py b/python/packages/claude/agent_framework_claude/_agent.py index 5b88d0cf5e..fa5562f4e3 100644 --- a/python/packages/claude/agent_framework_claude/_agent.py +++ b/python/packages/claude/agent_framework_claude/_agent.py @@ -757,7 +757,7 @@ async def _get_stream( session = session or self.create_session() # Ensure we're connected to the right session - await self._ensure_session(session.service_session_id) + await self._ensure_session(self._get_chat_conversation_id(session)) if not self._client: raise RuntimeError("Claude SDK client not initialized.") diff --git a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py index 333e75470f..0bc9c48432 100644 --- a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py +++ b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py @@ -244,14 +244,18 @@ async def _run_impl( """Non-streaming implementation of run.""" if not session: session = self.create_session() - if not session.service_session_id: + service_session_id = session.service_session_id + if service_session_id is None: session.service_session_id = await self._start_new_conversation() + service_session_id = session.service_session_id + if not isinstance(service_session_id, str): + raise AgentException("CopilotStudioAgent requires service_session_id to be a string") input_messages = normalize_messages(messages) question = "\n".join([message.text for message in input_messages]) - activities = self.client.ask_question(question, session.service_session_id) + activities = self.client.ask_question(question, service_session_id) response_messages: list[Message] = [] response_id: str | None = None @@ -272,14 +276,18 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: nonlocal session if not session: session = self.create_session() - if not session.service_session_id: + service_session_id = session.service_session_id + if service_session_id is None: session.service_session_id = await self._start_new_conversation() + service_session_id = session.service_session_id + if not isinstance(service_session_id, str): + raise AgentException("CopilotStudioAgent requires service_session_id to be a string") input_messages = normalize_messages(messages) question = "\n".join([message.text for message in input_messages]) - activities = self.client.ask_question(question, session.service_session_id) + activities = self.client.ask_question(question, service_session_id) async for message in self._process_activities(activities, streaming=True): yield AgentResponseUpdate( diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index 48ef906ab3..3c8a067a10 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -61,6 +61,7 @@ agent_framework/ ### Sessions (`_sessions.py`) - **`AgentSession`** - Manages conversation state and session metadata +- **`ServiceSessionId`** - Mapping alias for structured service-owned continuation handles used in `AgentSession.service_session_id` - **`SessionContext`** - Context object for session-scoped data during agent runs - **`ContextProvider`** - Base class for context providers (RAG, memory systems) - **`HistoryProvider`** - Base class for conversation history storage diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 1d60836598..6ebd7e5801 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -173,6 +173,7 @@ FileHistoryProvider, HistoryProvider, InMemoryHistoryProvider, + ServiceSessionId, SessionContext, register_state_type, ) @@ -503,6 +504,7 @@ "SamplingApprovalCallback", "SecretString", "SelectiveToolCallCompactionStrategy", + "ServiceSessionId", "SessionContext", "SingleEdgeGroup", "Skill", diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index f9a7a0c220..51798da63c 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -37,6 +37,7 @@ HistoryProvider, InMemoryHistoryProvider, PerServiceCallHistoryPersistingMiddleware, + ServiceSessionId, SessionContext, is_local_history_conversation_id, ) @@ -231,7 +232,12 @@ def create_session(self, *, session_id: str | None = None): return AgentSession(session_id=session_id) - def get_session(self, service_session_id: str, *, session_id: str | None = None): + def get_session( + self, + service_session_id: str | ServiceSessionId, + *, + session_id: str | None = None, + ): from agent_framework import AgentSession return AgentSession(service_session_id=service_session_id, session_id=session_id) @@ -307,7 +313,12 @@ def create_session(self, *, session_id: str | None = None) -> AgentSession: """Creates a new conversation session.""" ... - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + def get_session( + self, + service_session_id: str | ServiceSessionId, + *, + session_id: str | None = None, + ) -> AgentSession: """Gets or creates a session for a service-managed session ID.""" ... @@ -431,7 +442,12 @@ def create_session(self, *, session_id: str | None = None) -> AgentSession: """ return AgentSession(session_id=session_id) - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + def get_session( + self, + service_session_id: str | ServiceSessionId, + *, + session_id: str | None = None, + ) -> AgentSession: """Get a session for a service-managed session ID. Only use this to create a session continuing that session id from a service. @@ -448,6 +464,41 @@ def get_session(self, service_session_id: str, *, session_id: str | None = None) """ return AgentSession(session_id=session_id, service_session_id=service_session_id) + def _get_chat_conversation_id(self, session: AgentSession | None) -> str | None: + """Get the chat conversation id to forward to generic chat clients. + + Args: + session: The active session for this run. + + Returns: + The conversation id when it is a string, otherwise None. + + Raises: + AgentInvalidRequestException: If the session contains a structured + service_session_id that this generic chat path cannot forward. + """ + service_session_id = session.service_session_id if session is not None else None + if service_session_id is None: + return None + if isinstance(service_session_id, str): + return service_session_id + raise AgentInvalidRequestException( + "This agent expects a string service_session_id for provider conversation continuation. " + "Received a structured service_session_id; use a compatible agent/session shape for this provider." + ) + + def _get_otel_conversation_id(self, session: AgentSession | None) -> str | None: + """Get the OTel conversation id for ``gen_ai.conversation.id``. + + Args: + session: The active session for this run. + + Returns: + A string conversation id, or None when no string id is available. + """ + service_session_id = session.service_session_id if session else None + return service_session_id if isinstance(service_session_id, str) else None + async def _run_after_providers( self, *, @@ -811,8 +862,10 @@ def _resolve_per_service_call_history_providers( return [] # A live service-managed session id takes precedence over the resolved conversation id. - if session and session.service_session_id: - conversation_id = session.service_session_id + # Structured values are validated by _get_chat_conversation_id before generic forwarding. + session_conversation_id = self._get_chat_conversation_id(session) + if session_conversation_id: + conversation_id = session_conversation_id # Without service-side storage the middleware persists locally and drives the function # loop with a local sentinel, which cannot be reconciled with an existing service-managed # conversation. When the service stores history, an existing conversation id is expected. @@ -1197,12 +1250,13 @@ async def _prepare_run_context( # Resolve conversation_id from the same combined view so an agent-level default is honored # when the runtime omits it (a live session id still takes precedence below). effective_conversation_id = effective_options.get("conversation_id") + session_conversation_id = self._get_chat_conversation_id(session) # Auto-inject InMemoryHistoryProvider when session is provided, no context providers # registered, and no service-side storage indicators if ( session is not None and not self.context_providers - and not session.service_session_id + and not session_conversation_id and not effective_conversation_id and not service_stores_history ): @@ -1289,7 +1343,7 @@ async def _prepare_run_context( # Build options dict from run() options merged with provided options run_opts: dict[str, Any] = { - "conversation_id": active_session.service_session_id + "conversation_id": self._get_chat_conversation_id(active_session) if active_session else opts.pop("conversation_id", None), "allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None), diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 64f30083fd..19052cb69a 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -44,6 +44,7 @@ JsonDumps: TypeAlias = Callable[[Any], str | bytes] JsonLoads: TypeAlias = Callable[[str | bytes], Any] +ServiceSessionId: TypeAlias = Mapping[str, Any] def _default_json_dumps(value: Any) -> str: @@ -176,7 +177,7 @@ def __init__( self, *, session_id: str | None = None, - service_session_id: str | None = None, + service_session_id: str | ServiceSessionId | None = None, input_messages: list[Message], context_messages: dict[str, list[Message]] | None = None, instructions: list[str] | None = None, @@ -759,7 +760,7 @@ def __init__( self, *, session_id: str | None = None, - service_session_id: str | None = None, + service_session_id: str | ServiceSessionId | None = None, ): """Initialize the session. diff --git a/python/packages/core/agent_framework/a2a/__init__.py b/python/packages/core/agent_framework/a2a/__init__.py index 4c07dab5f2..dd2bac6595 100644 --- a/python/packages/core/agent_framework/a2a/__init__.py +++ b/python/packages/core/agent_framework/a2a/__init__.py @@ -8,6 +8,7 @@ Supported classes: - A2AAgent - A2AAgentSession +- A2AServiceSessionId - A2AExecutor """ @@ -16,7 +17,7 @@ IMPORT_PATH = "agent_framework_a2a" PACKAGE_NAME = "agent-framework-a2a" -_IMPORTS = ["A2AAgent", "A2AAgentSession", "A2AExecutor"] +_IMPORTS = ["A2AAgent", "A2AAgentSession", "A2AServiceSessionId", "A2AExecutor"] def __getattr__(name: str) -> Any: diff --git a/python/packages/core/agent_framework/a2a/__init__.pyi b/python/packages/core/agent_framework/a2a/__init__.pyi index 7d3fa97804..fa81294621 100644 --- a/python/packages/core/agent_framework/a2a/__init__.pyi +++ b/python/packages/core/agent_framework/a2a/__init__.pyi @@ -1,5 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -from agent_framework_a2a import A2AAgent, A2AAgentSession, A2AExecutor +from agent_framework_a2a import A2AAgent, A2AAgentSession, A2AExecutor, A2AServiceSessionId -__all__ = ["A2AAgent", "A2AAgentSession", "A2AExecutor"] +__all__ = ["A2AAgent", "A2AAgentSession", "A2AExecutor", "A2AServiceSessionId"] diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 15e83289f7..30142d7d85 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1767,13 +1767,22 @@ def _trace_agent_invocation( provider_name = str(self.otel_provider_name) merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + get_otel_conversation_id = cast( + "Callable[[AgentSession | None], str | None] | None", + getattr(self, "_get_otel_conversation_id", None), + ) + conversation_id = ( + get_otel_conversation_id(session) + if callable(get_otel_conversation_id) + else (session.service_session_id if (session and isinstance(session.service_session_id, str)) else None) + ) attributes = _get_span_attributes( operation_name=OtelAttr.AGENT_INVOKE_OPERATION, provider_name=provider_name, agent_id=getattr(self, "id", "unknown"), agent_name=getattr(self, "name", None) or getattr(self, "id", "unknown"), agent_description=getattr(self, "description", None), - thread_id=session.service_session_id if session else None, + thread_id=conversation_id, all_options=dict(merged_options), **merged_client_kwargs, ) diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 4de14fa807..fea78034b3 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -30,6 +30,7 @@ InMemoryHistoryProvider, Message, ResponseStream, + ServiceSessionId, SessionContext, SlidingWindowStrategy, SupportsAgentRun, @@ -1132,7 +1133,7 @@ def __init__(self, messages: list[Message] | None = None) -> None: self.before_run_called = False self.after_run_called = False self.new_messages: list[Message] = [] - self.last_service_session_id: str | None = None + self.last_service_session_id: str | ServiceSessionId | None = None async def before_run(self, *, agent: Any, session: Any, context: Any, state: Any) -> None: self.before_run_called = True @@ -2316,6 +2317,35 @@ async def test_agent_get_session_with_service_session_id( assert session.service_session_id == "test-thread-123" +@pytest.mark.asyncio +async def test_agent_get_session_with_structured_service_session_id( + chat_client_base: SupportsChatGetResponse, tool_tool: FunctionTool +): + """Test that get_session accepts structured service_session_id.""" + agent = Agent(client=chat_client_base, tools=[tool_tool]) + structured_service_session_id = {"context_id": "ctx-123", "task_id": "task-456", "task_state": "working"} + + session = agent.get_session(service_session_id=structured_service_session_id) + + assert session is not None + assert session.service_session_id == structured_service_session_id + + +@pytest.mark.asyncio +async def test_agent_run_rejects_structured_service_session_id_for_generic_chat_clients( + chat_client_base: SupportsChatGetResponse, +): + """Structured service_session_id must fail before generic chat client calls.""" + agent = Agent(client=chat_client_base) + session = agent.get_session(service_session_id={"context_id": "ctx-123"}) + + with pytest.raises( + AgentInvalidRequestException, + match="expects a string service_session_id", + ): + await agent.run("Hello", session=session) + + def test_agent_session_from_dict(chat_client_base: SupportsChatGetResponse, tool_tool: FunctionTool): """Test AgentSession.from_dict restores a session from serialized state.""" # Create serialized session state diff --git a/python/packages/core/tests/core/test_harness_agent.py b/python/packages/core/tests/core/test_harness_agent.py index 8213df5888..11305d983c 100644 --- a/python/packages/core/tests/core/test_harness_agent.py +++ b/python/packages/core/tests/core/test_harness_agent.py @@ -25,6 +25,7 @@ InMemoryHistoryProvider, Message, ResponseStream, + ServiceSessionId, SkillsProvider, TodoProvider, create_harness_agent, @@ -509,7 +510,12 @@ def __init__(self, name: str, description: str | None = None): def create_session(self, *, session_id: str | None = None) -> AgentSession: return AgentSession(session_id=session_id) - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + def get_session( + self, + service_session_id: str | ServiceSessionId, + *, + session_id: str | None = None, + ) -> AgentSession: return AgentSession(service_session_id=service_session_id, session_id=session_id) async def run(self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any) -> Any: diff --git a/python/packages/core/tests/core/test_harness_background_agents.py b/python/packages/core/tests/core/test_harness_background_agents.py index 0c2b97f3f6..98e6fa2a67 100644 --- a/python/packages/core/tests/core/test_harness_background_agents.py +++ b/python/packages/core/tests/core/test_harness_background_agents.py @@ -14,6 +14,7 @@ BackgroundTaskInfo, BackgroundTaskStatus, Message, + ServiceSessionId, ) from agent_framework._sessions import SessionContext @@ -46,7 +47,12 @@ def __init__( def create_session(self, *, session_id: str | None = None) -> AgentSession: return AgentSession(session_id=session_id) - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + def get_session( + self, + service_session_id: str | ServiceSessionId, + *, + session_id: str | None = None, + ) -> AgentSession: return AgentSession(service_session_id=service_session_id, session_id=session_id) async def run( diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index a15f4793ee..2fcf2f94bc 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -437,6 +437,11 @@ def test_service_session_id(self) -> None: session = AgentSession(service_session_id="svc-456") assert session.service_session_id == "svc-456" + def test_service_session_id_accepts_structured_mapping(self) -> None: + service_session_id = {"context_id": "ctx-123", "task_id": "task-456", "task_state": "working"} + session = AgentSession(service_session_id=service_session_id) + assert session.service_session_id == service_session_id + def test_to_dict(self) -> None: session = AgentSession(session_id="s1", service_session_id="svc1") session.state = {"key": "value"} @@ -466,6 +471,14 @@ def test_roundtrip(self) -> None: assert restored.session_id == "rt-1" assert restored.state == {"messages": ["a", "b"], "count": 42} + def test_roundtrip_with_structured_service_session_id(self) -> None: + service_session_id = {"context_id": "ctx-123", "task_id": "task-456", "task_state": "working"} + session = AgentSession(session_id="rt-2", service_session_id=service_session_id) + json_str = json.dumps(session.to_dict()) + restored = AgentSession.from_dict(json.loads(json_str)) + assert restored.session_id == "rt-2" + assert restored.service_session_id == service_session_id + def test_from_dict_missing_state(self) -> None: data = {"session_id": "s1"} session = AgentSession.from_dict(data) diff --git a/python/packages/core/tests/workflow/test_agent_utils.py b/python/packages/core/tests/workflow/test_agent_utils.py index 8ac0744809..90ebfb67e8 100644 --- a/python/packages/core/tests/workflow/test_agent_utils.py +++ b/python/packages/core/tests/workflow/test_agent_utils.py @@ -3,7 +3,14 @@ from collections.abc import Awaitable from typing import Any, Literal, overload -from agent_framework import AgentResponse, AgentResponseUpdate, AgentRunInputs, AgentSession, ResponseStream +from agent_framework import ( + AgentResponse, + AgentResponseUpdate, + AgentRunInputs, + AgentSession, + ResponseStream, + ServiceSessionId, +) from agent_framework._workflows._agent_utils import resolve_agent_id @@ -46,7 +53,7 @@ def create_session(self, **kwargs: Any) -> AgentSession: # type: ignore[empty-b """Creates a new conversation session for the agent.""" ... - def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + def get_session(self, *, service_session_id: str | ServiceSessionId, **kwargs: Any) -> AgentSession: return AgentSession() diff --git a/python/packages/core/tests/workflow/test_full_conversation.py b/python/packages/core/tests/workflow/test_full_conversation.py index d7f9c3b245..fbb9ac4eee 100644 --- a/python/packages/core/tests/workflow/test_full_conversation.py +++ b/python/packages/core/tests/workflow/test_full_conversation.py @@ -20,6 +20,7 @@ Executor, Message, ResponseStream, + ServiceSessionId, WorkflowBuilder, WorkflowContext, WorkflowRunState, @@ -368,7 +369,7 @@ async def test_agent_executor_full_conversation_round_trip_does_not_duplicate_hi class _SessionIdCapturingAgent(BaseAgent): """Records service_session_id of the session at run() time.""" - _captured_service_session_id: str | None = PrivateAttr(default="NOT_CAPTURED") + _captured_service_session_id: str | ServiceSessionId | None = PrivateAttr(default="NOT_CAPTURED") @overload def run( diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index 9b3b8799b9..0cf69aa941 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -19,6 +19,7 @@ InMemoryHistoryProvider, Message, ResponseStream, + ServiceSessionId, SupportsAgentRun, UsageDetails, WorkflowAgent, @@ -1234,7 +1235,7 @@ def __init__(self, name: str, response_text: str) -> None: def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() - def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] + def get_session(self, *, service_session_id: str | ServiceSessionId, **kwargs: Any) -> AgentSession: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] return AgentSession() @overload @@ -1344,7 +1345,7 @@ def __init__(self, name: str, response_text: str) -> None: def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() - def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] + def get_session(self, *, service_session_id: str | ServiceSessionId, **kwargs: Any) -> AgentSession: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] return AgentSession() @overload @@ -1980,7 +1981,7 @@ def __init__( def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() - def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] + def get_session(self, *, service_session_id: str | ServiceSessionId, **kwargs: Any) -> AgentSession: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] return AgentSession() def _next_request_id(self) -> str: diff --git a/python/packages/durabletask/agent_framework_durabletask/_models.py b/python/packages/durabletask/agent_framework_durabletask/_models.py index 19d5804bc2..e8eabca078 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_models.py +++ b/python/packages/durabletask/agent_framework_durabletask/_models.py @@ -315,6 +315,9 @@ def from_dict(cls, data: dict[str, Any]) -> DurableAgentSession: data = dict(data) # defensive copy — avoid mutating caller's dict session_id_value = data.pop(cls._SERIALIZED_SESSION_ID_KEY, None) session = super().from_dict(data) + service_session_id = session.service_session_id + if service_session_id is not None and not isinstance(service_session_id, str): + raise ValueError("durable sessions require service_session_id to be a string when present") durable_session_id: AgentSessionId | None = None # We need to create a DurableAgentSession from the base AgentSession if session_id_value is not None: @@ -325,7 +328,7 @@ def from_dict(cls, data: dict[str, Any]) -> DurableAgentSession: durable_session = cls( durable_session_id=durable_session_id, session_id=session.session_id, - service_session_id=session.service_session_id, + service_session_id=service_session_id, ) durable_session.state.update(session.state) return durable_session diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index b21cac6831..e6e9f5d027 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -12,7 +12,7 @@ from abc import ABC, abstractmethod from typing import Any, Generic, Literal, TypeVar -from agent_framework import AgentSession, SupportsAgentRun, normalize_messages +from agent_framework import AgentSession, ServiceSessionId, SupportsAgentRun, normalize_messages from agent_framework._types import AgentRunInputs from ._executors import DurableAgentExecutor @@ -137,8 +137,10 @@ def create_session(self, *, session_id: str | None = None) -> DurableAgentSessio """Create a new agent session via the provider.""" return self._executor.get_new_session(self.name) - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + def get_session(self, service_session_id: str | ServiceSessionId, *, session_id: str | None = None) -> AgentSession: """Retrieve an existing session via the provider.""" + if not isinstance(service_session_id, str): + raise ValueError("DurableAIAgent requires service_session_id to be a string") return self._executor.get_new_session(self.name, service_session_id=service_session_id, session_id=session_id) def _normalize_messages(self, messages: AgentRunInputs | None) -> str: diff --git a/python/packages/foundry_hosting/tests/test_responses.py b/python/packages/foundry_hosting/tests/test_responses.py index eabf4231a7..218d132103 100644 --- a/python/packages/foundry_hosting/tests/test_responses.py +++ b/python/packages/foundry_hosting/tests/test_responses.py @@ -30,6 +30,7 @@ Message, RawAgent, ResponseStream, + ServiceSessionId, SupportsAgentRun, WorkflowAgent, WorkflowBuilder, @@ -3692,7 +3693,7 @@ def __init__( def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + def get_session(self, service_session_id: str | ServiceSessionId, *, session_id: str | None = None) -> AgentSession: return AgentSession() def _next_request_id(self) -> str: @@ -3826,7 +3827,9 @@ def __init__(self, name: str, text: str) -> None: def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + def get_session( + self, service_session_id: str | ServiceSessionId, *, session_id: str | None = None + ) -> AgentSession: return AgentSession() @overload diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 921231a973..6661d5b057 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -874,7 +874,12 @@ async def _get_or_create_session( try: if agent_session.service_session_id: - return await self._resume_session(agent_session.service_session_id, streaming, runtime_options) + service_session_id = agent_session.service_session_id + if not isinstance(service_session_id, str): + raise AgentException( + "GitHubCopilotAgent expects a string service_session_id for session resumption." + ) + return await self._resume_session(service_session_id, streaming, runtime_options) session = await self._create_session(streaming, runtime_options) agent_session.service_session_id = session.session_id diff --git a/python/packages/github_copilot/tests/test_github_copilot_agent.py b/python/packages/github_copilot/tests/test_github_copilot_agent.py index 39a7eef387..ec0095a5ef 100644 --- a/python/packages/github_copilot/tests/test_github_copilot_agent.py +++ b/python/packages/github_copilot/tests/test_github_copilot_agent.py @@ -2771,7 +2771,7 @@ async def test_integration_run_with_simple_prompt_returns_response() -> None: assert len(response.messages) > 0 assert "4" in response.text - if session.service_session_id and agent._client: + if isinstance(session.service_session_id, str) and agent._client: await agent._client.delete_session(session.service_session_id) @@ -2795,7 +2795,7 @@ async def test_integration_run_streaming_returns_updates() -> None: full_text = "".join(u.text for u in updates if u.text) assert len(full_text) > 0 - if session.service_session_id and agent._client: + if isinstance(session.service_session_id, str) and agent._client: await agent._client.delete_session(session.service_session_id) @@ -2818,7 +2818,7 @@ async def test_integration_run_with_function_tool_invokes_tool() -> None: assert len(response.messages) > 0 assert any(word in response.text.lower() for word in ["sunny", "25", "weather", "seattle"]) - if session.service_session_id and agent._client: + if isinstance(session.service_session_id, str) and agent._client: await agent._client.delete_session(session.service_session_id) @@ -2843,7 +2843,7 @@ async def test_integration_run_with_session_maintains_context() -> None: assert response2 is not None assert "alice" in response2.text.lower() - if session.service_session_id and agent._client: + if isinstance(session.service_session_id, str) and agent._client: await agent._client.delete_session(session.service_session_id) @@ -2862,7 +2862,7 @@ async def test_integration_run_with_session_resume_continues_conversation() -> N await agent.run("Remember this number: 42.", session=session1) session_id = session1.service_session_id - assert session_id is not None + assert isinstance(session_id, str) session2 = AgentSession() session2.service_session_id = session_id @@ -2893,5 +2893,5 @@ async def test_integration_run_with_shell_permissions_executes_command() -> None assert response is not None assert "hello" in response.text.lower() - if session.service_session_id and agent._client: + if isinstance(session.service_session_id, str) and agent._client: await agent._client.delete_session(session.service_session_id) diff --git a/python/packages/hosting-responses/tests/hosting_responses/test_channel.py b/python/packages/hosting-responses/tests/hosting_responses/test_channel.py index 274e797ad9..f46d108beb 100644 --- a/python/packages/hosting-responses/tests/hosting_responses/test_channel.py +++ b/python/packages/hosting-responses/tests/hosting_responses/test_channel.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from typing import Any -from agent_framework import AgentResponse, AgentResponseUpdate, Content, Message +from agent_framework import AgentResponse, AgentResponseUpdate, Content, Message, ServiceSessionId from agent_framework_hosting import ( AgentFrameworkHost, HostedRunResult, @@ -62,7 +62,7 @@ def __init__(self, reply: Any = "hello", chunks: list[str] | None = None) -> Non def create_session(self, *, session_id: str | None = None) -> Any: return {"session_id": session_id} - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> Any: + def get_session(self, service_session_id: str | ServiceSessionId, *, session_id: str | None = None) -> Any: return {"service_session_id": service_session_id, "session_id": session_id} def run(self, messages: Any = None, *, stream: bool = False, **kwargs: Any) -> Any: diff --git a/python/packages/hosting/tests/hosting/test_host.py b/python/packages/hosting/tests/hosting/test_host.py index 7aa6ef7955..90c2f51db0 100644 --- a/python/packages/hosting/tests/hosting/test_host.py +++ b/python/packages/hosting/tests/hosting/test_host.py @@ -10,7 +10,15 @@ from typing import Any, cast import pytest -from agent_framework import AgentResponse, AgentResponseUpdate, AgentSession, Content, Message, ResponseStream +from agent_framework import ( + AgentResponse, + AgentResponseUpdate, + AgentSession, + Content, + Message, + ResponseStream, + ServiceSessionId, +) from agent_framework._workflows._events import WorkflowEvent from starlette.requests import Request from starlette.responses import JSONResponse @@ -74,7 +82,7 @@ def create_session(self, *, session_id: str | None = None) -> AgentSession: self.created_sessions.append(s) return s - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + def get_session(self, service_session_id: str | ServiceSessionId, *, session_id: str | None = None) -> AgentSession: return AgentSession(session_id=session_id, service_session_id=service_session_id) def run(self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any) -> Any: @@ -782,7 +790,9 @@ class _MultiModalAgent: def create_session(self, *, session_id: str | None = None) -> AgentSession: return AgentSession(session_id=session_id) - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + def get_session( + self, service_session_id: str | ServiceSessionId, *, session_id: str | None = None + ) -> AgentSession: return AgentSession(session_id=session_id, service_session_id=service_session_id) def run(self, *_args: Any, **_kwargs: Any) -> Any: @@ -864,7 +874,7 @@ def __init__(self, providers: Sequence[Any], *, reply: str = "ok") -> None: def create_session(self, *, session_id: str | None = None) -> AgentSession: return AgentSession(session_id=session_id) - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + def get_session(self, service_session_id: str | ServiceSessionId, *, session_id: str | None = None) -> AgentSession: return AgentSession(session_id=session_id, service_session_id=service_session_id) def run( diff --git a/python/packages/hosting/tests/hosting/test_host_disk.py b/python/packages/hosting/tests/hosting/test_host_disk.py index e6567b6617..607de72cbc 100644 --- a/python/packages/hosting/tests/hosting/test_host_disk.py +++ b/python/packages/hosting/tests/hosting/test_host_disk.py @@ -9,7 +9,7 @@ from typing import Any, cast import pytest -from agent_framework import AgentSession +from agent_framework import AgentSession, ServiceSessionId from agent_framework_hosting import AgentFrameworkHost, ChannelContext, ChannelContribution @@ -26,7 +26,7 @@ class _AgentStub: def create_session(self, *, session_id: str | None = None) -> AgentSession: return AgentSession(session_id=session_id) - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + def get_session(self, service_session_id: str | ServiceSessionId, *, session_id: str | None = None) -> AgentSession: return AgentSession(service_session_id=service_session_id, session_id=session_id) def run(self, *_args: Any, **_kwargs: Any) -> Any: # pragma: no cover - unused diff --git a/python/packages/hosting/tests/hosting/test_isolation.py b/python/packages/hosting/tests/hosting/test_isolation.py index 571f275d93..c1c2a77655 100644 --- a/python/packages/hosting/tests/hosting/test_isolation.py +++ b/python/packages/hosting/tests/hosting/test_isolation.py @@ -18,7 +18,7 @@ from typing import Any import pytest -from agent_framework import AgentSession +from agent_framework import AgentSession, ServiceSessionId from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import BaseRoute, Route @@ -169,7 +169,9 @@ class _NoopAgent: def create_session(self, *, session_id: str | None = None) -> AgentSession: return AgentSession(session_id=session_id) - def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + def get_session( + self, service_session_id: str | ServiceSessionId, *, session_id: str | None = None + ) -> AgentSession: return AgentSession(service_session_id=service_session_id, session_id=session_id) def run(self, *_args: object, **_kwargs: object) -> Any: # pragma: no cover - never called diff --git a/python/packages/purview/agent_framework_purview/_middleware.py b/python/packages/purview/agent_framework_purview/_middleware.py index b605f4c4b2..5cdd2649b1 100644 --- a/python/packages/purview/agent_framework_purview/_middleware.py +++ b/python/packages/purview/agent_framework_purview/_middleware.py @@ -56,8 +56,10 @@ def _get_agent_session_id(context: AgentContext) -> str | None: 2. First message whose additional_properties contains 'conversation_id' 3. None: the downstream processor will generate a new UUID """ - if context.session and context.session.service_session_id: - return context.session.service_session_id + if context.session: + service_session_id = context.session.service_session_id + if isinstance(service_session_id, str) and service_session_id: + return service_session_id for message in context.messages: conversation_id = message.additional_properties.get("conversation_id") diff --git a/python/samples/04-hosting/foundry-hosted-agents/responses/using_deployed_agent.py b/python/samples/04-hosting/foundry-hosted-agents/responses/using_deployed_agent.py index 35bd1745e0..5561974b09 100644 --- a/python/samples/04-hosting/foundry-hosted-agents/responses/using_deployed_agent.py +++ b/python/samples/04-hosting/foundry-hosted-agents/responses/using_deployed_agent.py @@ -124,7 +124,7 @@ async def main() -> None: if chunk.text: print(chunk.text, end="", flush=True) finally: - if session.service_session_id is not None: + if isinstance(session.service_session_id, str): await project_client.beta.agents.delete_session( agent_name=agent_name, session_id=session.service_session_id,