Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/packages/a2a/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Agent-to-Agent (A2A) protocol support for inter-agent communication.

- **`A2AAgent`** - Client to connect to remote A2A-compliant agents.
- **`A2AExecutor`** - Bridge to expose Agent Framework agents via the A2A protocol.
- **`A2AServiceSessionId`** - Typed durable A2A continuation state shape stored in `AgentSession.service_session_id`.
- **`A2AAgentSession`** - Deprecated compatibility session wrapper; prefer `AgentSession` + `A2AServiceSessionId`.

## Usage

Expand Down Expand Up @@ -50,7 +52,7 @@ app = Starlette(
## Import Path

```python
from agent_framework.a2a import A2AAgent, A2AExecutor
from agent_framework.a2a import A2AAgent, A2AExecutor, A2AServiceSessionId
# or directly:
from agent_framework_a2a import A2AAgent, A2AExecutor
from agent_framework_a2a import A2AAgent, A2AExecutor, A2AServiceSessionId
```
3 changes: 2 additions & 1 deletion python/packages/a2a/agent_framework_a2a/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib.metadata

from ._a2a_executor import A2AExecutor
from ._agent import A2AAgent, A2AAgentSession, A2AContinuationToken
from ._agent import A2AAgent, A2AAgentSession, A2AContinuationToken, A2AServiceSessionId

try:
__version__ = importlib.metadata.version(__name__)
Expand All @@ -15,5 +15,6 @@
"A2AAgentSession",
"A2AContinuationToken",
"A2AExecutor",
"A2AServiceSessionId",
"__version__",
]
177 changes: 152 additions & 25 deletions python/packages/a2a/agent_framework_a2a/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from __future__ import annotations

import base64
import sys
import uuid
import warnings
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
from typing import Any, Final, Literal, TypeAlias, overload
from typing import Any, Final, Literal, TypeAlias, cast, overload

import httpx
from a2a.client import Client, ClientConfig, ClientFactory, minimal_agent_card
Expand Down Expand Up @@ -35,6 +37,7 @@
HistoryProvider,
Message,
ResponseStream,
ServiceSessionId,
SessionContext,
normalize_messages,
prepend_agent_framework_to_user_agent,
Expand All @@ -43,11 +46,24 @@
from agent_framework.observability import AgentTelemetryLayer
from google.protobuf.json_format import MessageToDict

__all__ = ["A2AAgent", "A2AAgentSession", "A2AContinuationToken"]
if sys.version_info >= (3, 11):
from typing import TypedDict # pragma: no cover
else:
from typing_extensions import TypedDict # pragma: no cover

__all__ = ["A2AAgent", "A2AAgentSession", "A2AContinuationToken", "A2AServiceSessionId"]

from agent_framework_a2a._utils import get_uri_data


class A2AServiceSessionId(TypedDict):
"""Durable A2A continuation state stored in ``AgentSession.service_session_id``."""

context_id: str
task_id: str | None
task_state: TaskState | None


class A2AAgentSession(AgentSession):
"""Session for A2A-based agents.

Expand Down Expand Up @@ -79,13 +95,27 @@ def __init__(
task_id: Optional task ID from a previous interaction.
task_state: Optional state of the most recent task.
"""
super().__init__(service_session_id=context_id)
warnings.warn(
"A2AAgentSession is deprecated and will be removed in a future version. "
"Use AgentSession with service_session_id=A2AServiceSessionId(...) instead.",
DeprecationWarning,
stacklevel=2,
)
self.context_id: str | None = context_id
self.task_id: str | None = task_id
self.task_state: TaskState | None = task_state
service_session_id: str | ServiceSessionId | None = None
if context_id is not None:
service_session_id = A2AServiceSessionId(
context_id=context_id,
task_id=task_id,
task_state=task_state,
)
super().__init__(service_session_id=service_session_id)

def to_dict(self) -> dict[str, Any]:
"""Serialize session to a plain dict for storage/transfer."""
self._sync_service_session_id()
data = super().to_dict()
if self.context_id is not None:
data[self._CONTEXT_ID_KEY] = self.context_id
Expand Down Expand Up @@ -115,16 +145,40 @@ def from_dict(cls, data: dict[str, Any]) -> A2AAgentSession:

# Delegate state deserialization to the base class
base_session = AgentSession.from_dict(data)
base_service_session = base_session.service_session_id
if isinstance(base_service_session, Mapping):
mapped_context_id = base_service_session.get("context_id")
if isinstance(mapped_context_id, str):
context_id = context_id or mapped_context_id
mapped_task_id = base_service_session.get("task_id")
if isinstance(mapped_task_id, str):
task_id = task_id or mapped_task_id
mapped_task_state = base_service_session.get("task_state")
if isinstance(mapped_task_state, int):
task_state = task_state if task_state is not None else cast(TaskState, mapped_task_state)
elif isinstance(base_service_session, str):
context_id = context_id or base_service_session

session = cls(
context_id=context_id or base_session.service_session_id,
context_id=context_id,
task_id=task_id,
task_state=task_state,
)
session._session_id = base_session.session_id
session.state.update(base_session.state)
return session

def _sync_service_session_id(self) -> None:
"""Keep compatibility fields and ``service_session_id`` aligned."""
if self.context_id is None:
self.service_session_id = None
return
self.service_session_id = A2AServiceSessionId(
context_id=self.context_id,
task_id=self.task_id,
task_state=self.task_state,
)


class A2AContinuationToken(ContinuationToken):
"""Continuation token for A2A protocol long-running tasks."""
Expand Down Expand Up @@ -316,6 +370,83 @@ async def __aexit__(
if self._http_client is not None and self._close_http_client:
await self._http_client.aclose()

@staticmethod
def _build_service_session_id(
*,
context_id: str,
task_id: str | None,
task_state: TaskState | None,
) -> A2AServiceSessionId:
return A2AServiceSessionId(
context_id=context_id,
task_id=task_id,
task_state=task_state,
)

def _extract_a2a_session_state(
self,
session: AgentSession | None,
) -> tuple[str | None, str | None, TaskState | None]:
"""Extract A2A continuation state from supported session shapes."""
if session is None:
return None, None, None
if isinstance(session, A2AAgentSession):
return session.context_id, session.task_id, session.task_state
service_session_id = session.service_session_id
if service_session_id is None:
return None, None, None
if isinstance(service_session_id, str):
return service_session_id, None, None
if not isinstance(service_session_id, Mapping):
raise ValueError(
"A2AAgent requires service_session_id to be a string or mapping with context_id/task_id/task_state."
)

context_id = service_session_id.get("context_id")
if not isinstance(context_id, str) or not context_id:
raise ValueError("A2A service_session_id mapping must include a non-empty string 'context_id'.")

task_id_value = service_session_id.get("task_id")
if task_id_value is None:
task_id = None
elif isinstance(task_id_value, str):
task_id = task_id_value
else:
raise ValueError("A2A service_session_id mapping field 'task_id' must be a string or None.")

task_state_value = service_session_id.get("task_state")
if task_state_value is None:
task_state = None
elif isinstance(task_state_value, int):
task_state = cast(TaskState, task_state_value)
else:
raise ValueError("A2A service_session_id mapping field 'task_state' must be an integer enum value or None.")
return context_id, task_id, task_state

def _apply_a2a_session_state(
self,
*,
session: AgentSession,
context_id: str,
task_id: str | None,
task_state: TaskState | None,
) -> None:
"""Persist durable A2A continuation state on the session."""
session.service_session_id = self._build_service_session_id(
context_id=context_id,
task_id=task_id,
task_state=task_state,
)
if isinstance(session, A2AAgentSession):
session.context_id = context_id
session.task_id = task_id
session.task_state = task_state

def _get_otel_conversation_id(self, session: AgentSession | None) -> str | None:
"""Return A2A context_id as OpenTelemetry conversation id."""
context_id, _, _ = self._extract_a2a_session_state(session)
return context_id

@overload
def run(
self,
Expand Down Expand Up @@ -577,20 +708,24 @@ async def _map_a2a_stream(
session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment]

# Persist A2A protocol state on the session for follow-up message linking.
if isinstance(session, A2AAgentSession) and (last_task_id or last_context_id):
if session is not None and (last_task_id or last_context_id):
existing_context_id, existing_task_id, existing_task_state = self._extract_a2a_session_state(session)

# Validate context_id consistency
if session.context_id is not None and last_context_id and session.context_id != last_context_id:
if existing_context_id is not None and last_context_id and existing_context_id != last_context_id:
raise RuntimeError(
f"The context_id returned from the A2A agent ('{last_context_id}') "
f"differs from the session's context_id ('{session.context_id}')."
f"differs from the session's context_id ('{existing_context_id}')."
)

persisted_context_id = existing_context_id or last_context_id
if persisted_context_id is not None:
self._apply_a2a_session_state(
session=session,
context_id=persisted_context_id,
task_id=last_task_id or existing_task_id,
task_state=last_task_state if last_task_state is not None else existing_task_state,
)
# Assign server-generated context_id if not already set
if session.context_id is None and last_context_id:
session.context_id = last_context_id
session.service_session_id = last_context_id
if last_task_id:
session.task_id = last_task_id
session.task_state = last_task_state

await self._run_after_providers(session=session, context=session_context)

Expand Down Expand Up @@ -789,19 +924,11 @@ def _prepare_message_for_a2a(self, message: Message, *, session: AgentSession |
Keyword Args:
session: Optional session to read A2A state from. If an
``A2AAgentSession``, context_id/task_id/task_state are used for
linking. A plain ``AgentSession`` provides service_session_id as
a fallback context_id.
linking. A plain ``AgentSession`` may provide either a string
or structured A2A ``service_session_id`` mapping.
"""
# Extract A2A state from the session
context_id: str | None = None
previous_task_id: str | None = None
task_state: TaskState | None = None
if isinstance(session, A2AAgentSession):
context_id = session.context_id
previous_task_id = session.task_id
task_state = session.task_state
elif session is not None:
context_id = session.service_session_id
context_id, previous_task_id, task_state = self._extract_a2a_session_state(session)

parts: list[A2APart] = []
if not message.contents:
Expand Down
Loading
Loading