diff --git a/README.md b/README.md index 0b14cc4..0049342 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,14 @@ curl http://127.0.0.1:8000/.well-known/agent-card.json - Request-scoped model selection through `metadata.shared.model` - OpenCode-oriented JSON-RPC extensions for session and model/provider queries +## A2A Protocol Support + +- Default protocol line: `0.3` +- Declared supported protocol lines: `0.3`, `1.0` +- `0.3` is the stable interoperability baseline for the current runtime surface. +- `1.0` currently covers version negotiation plus protocol-aware JSON-RPC and REST error shaping, while transport payloads, enums, pagination, signatures, and interface-level protocol declarations still follow the shipped SDK baseline. +- The detailed compatibility matrix and machine-readable support boundary are documented in [`docs/guide.md`](docs/guide.md). + ## Peering Node / Outbound Access `opencode-a2a` supports a "Peering Node" architecture where a single process handles both inbound (Server) and outbound (Client) A2A traffic. diff --git a/docs/extension-specifications.md b/docs/extension-specifications.md index b55413e..862009a 100644 --- a/docs/extension-specifications.md +++ b/docs/extension-specifications.md @@ -87,6 +87,7 @@ URI: `https://github.com/Intelligent-Internet/opencode-a2a/blob/main/docs/extens URI: `https://github.com/Intelligent-Internet/opencode-a2a/blob/main/docs/extension-specifications.md#a2a-compatibility-profile-v1` - Scope: compatibility profile describing core baselines, extension retention, and service behaviors +- Includes machine-readable protocol compatibility summary for the currently declared `0.3` / `1.0` support boundary - Public Agent Card: capability declaration only - Authenticated extended card: full compatibility profile payload - Transport: Agent Card extension params @@ -96,6 +97,7 @@ URI: `https://github.com/Intelligent-Internet/opencode-a2a/blob/main/docs/extens URI: `https://github.com/Intelligent-Internet/opencode-a2a/blob/main/docs/extension-specifications.md#a2a-wire-contract-v1` - Scope: wire-level contract for supported methods, endpoints, and error semantics +- Includes the same machine-readable protocol compatibility summary published by the compatibility profile - Public Agent Card: capability declaration only - Authenticated extended card: full wire contract payload - Transport: Agent Card extension params diff --git a/docs/guide.md b/docs/guide.md index 1a5925e..509c8e1 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -257,6 +257,31 @@ Consumer guidance: - Discover custom JSON-RPC methods from Agent Card / OpenAPI before calling them. - Treat `supported_methods` in `error.data` as the runtime truth for the current deployment, especially when a deployment-conditional method is disabled. +## Protocol Version Negotiation + +- The runtime accepts `A2A-Version` from either the HTTP header or the query parameter of A2A transport requests. +- If both are omitted, the runtime falls back to the configured default protocol version. +- Current defaults declare `default_protocol_version=0.3` and `supported_protocol_versions=["0.3", "1.0"]`. +- Unsupported or invalid versions are rejected before request routing: + - JSON-RPC returns a unified `VERSION_NOT_SUPPORTED` error envelope. + - REST returns HTTP `400` with the same contract fields. +- Error shaping now follows the negotiated major line: + - `0.3` keeps the existing legacy `error.data={...}` and flat REST error payloads. + - `1.0` keeps standard JSON-RPC error codes for standard failures, but moves A2A-specific JSON-RPC errors to `google.rpc.ErrorInfo`-style `error.data[]` details and REST errors to AIP-193 `error.details[]`. +- The current transport payloads still follow the SDK-owned request/response shapes; version negotiation is introduced first so later issues can evolve error and payload compatibility without scattering version checks across handlers. + +Current compatibility matrix: + +| Area | `0.3` | `1.0` | Current note | +| --- | --- | --- | --- | +| Version negotiation | Supported | Supported | The runtime accepts `A2A-Version` and routes requests before handler dispatch. | +| Agent Card / interface version discovery | Default card protocol only | Partial | The service publishes `default_protocol_version` and `supported_protocol_versions`, but `AgentInterface.protocolVersion` cannot yet be declared with `a2a-sdk==0.3.25`. | +| Transport payloads and enums | Supported | Partial | Request/response payloads, enums, and schema details still follow the SDK-owned `0.3` baseline. | +| Error model | Supported | Partial | `0.3` keeps legacy `error.data={...}` / flat REST payloads; `1.0` uses protocol-aware JSON-RPC details and AIP-193-style REST errors. | +| Pagination and list semantics | Supported | Partial | Cursor/list behavior is stable, but the declared shape still follows the `0.3` SDK baseline. | +| Push notification surfaces | Supported | Partial | Core task push-notification routes are available, but no extra `1.0`-specific compatibility layer is declared yet. | +| Signatures and authenticated data | Supported | Partial | Security schemes and authenticated extended card discovery follow the shipped SDK schema rather than a dedicated `1.0` compatibility layer. | + ## Compatibility Profile The service also publishes a machine-readable compatibility profile through Agent Card and OpenAPI metadata. @@ -271,6 +296,13 @@ Its purpose is to declare: Current profile shape: - `profile_id=opencode-a2a-single-tenant-coding-v1` +- `default_protocol_version` +- `supported_protocol_versions` +- `protocol_compatibility` + - `versions["0.3"].status=supported` + - `versions["1.0"].status=partial` + - `versions[*].supported_features[]` + - `versions[*].known_gaps[]` - Deployment semantics are declared under `deployment`: - `id=single_tenant_shared_workspace` - `single_tenant=true` @@ -306,6 +338,7 @@ Retention guidance: - Treat `a2a.interrupt.*` methods as shared extensions. - Treat `opencode.sessions.*`, `opencode.providers.*`, and `opencode.models.*` as provider-private OpenCode extensions rather than portable A2A baseline capabilities. - Treat `opencode.sessions.shell` as deployment-conditional and discover it from the declared profile and current wire contract before calling it. +- Treat `protocol_compatibility` as the runtime truth for which protocol line is fully supported versus only partially adapted. ## Multipart Input Example diff --git a/src/opencode_a2a/client/client.py b/src/opencode_a2a/client/client.py index fbf5c06..8928c77 100644 --- a/src/opencode_a2a/client/client.py +++ b/src/opencode_a2a/client/client.py @@ -131,7 +131,10 @@ async def send_message( async for event in client.send_message( request, context=build_call_context( - self._settings.bearer_token, extra_headers, self._settings.basic_auth + self._settings.bearer_token, + extra_headers, + self._settings.basic_auth, + self._settings.protocol_version, ), request_metadata=request_metadata, extensions=extensions, @@ -203,7 +206,10 @@ async def get_task( metadata=request_metadata or {}, ), context=build_call_context( - self._settings.bearer_token, extra_headers, self._settings.basic_auth + self._settings.bearer_token, + extra_headers, + self._settings.basic_auth, + self._settings.protocol_version, ), ) except ( @@ -231,7 +237,10 @@ async def cancel_task( return await client.cancel_task( TaskIdParams(id=task_id, metadata=request_metadata or {}), context=build_call_context( - self._settings.bearer_token, extra_headers, self._settings.basic_auth + self._settings.bearer_token, + extra_headers, + self._settings.basic_auth, + self._settings.protocol_version, ), ) except ( @@ -259,7 +268,10 @@ async def resubscribe_task( async for event in client.resubscribe( TaskIdParams(id=task_id, metadata=request_metadata or {}), context=build_call_context( - self._settings.bearer_token, extra_headers, self._settings.basic_auth + self._settings.bearer_token, + extra_headers, + self._settings.basic_auth, + self._settings.protocol_version, ), ): yield event @@ -293,7 +305,9 @@ async def _build_client(self) -> Client: client = factory.create( card, interceptors=build_client_interceptors( - self._settings.bearer_token, self._settings.basic_auth + self._settings.bearer_token, + self._settings.basic_auth, + self._settings.protocol_version, ), ) except ValueError as exc: diff --git a/src/opencode_a2a/client/config.py b/src/opencode_a2a/client/config.py index b202674..3b9e70a 100644 --- a/src/opencode_a2a/client/config.py +++ b/src/opencode_a2a/client/config.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Any +from ..protocol_versions import normalize_protocol_version from .auth import validate_basic_auth from .polling import PollingFallbackPolicy, validate_polling_fallback_policy @@ -70,6 +71,13 @@ def _coerce_optional_str(name: str, value: Any) -> str | None: raise ValueError(f"{name} must be a string, got {value!r}") +def _coerce_optional_protocol_version(name: str, value: Any) -> str | None: + normalized = _coerce_optional_str(name, value) + if normalized is None: + return None + return normalize_protocol_version(normalized) + + def _normalize_transport(value: str) -> str: normalized = value.strip().lower() if normalized in {"jsonrpc", "json-rpc", "json_rpc"}: @@ -110,6 +118,7 @@ class A2AClientSettings: card_fetch_timeout: float = 5.0 bearer_token: str | None = None basic_auth: str | None = None + protocol_version: str | None = None supported_transports: tuple[str, ...] = ( "JSONRPC", "HTTP+JSON", @@ -172,6 +181,19 @@ def load_settings(raw_settings: Any) -> A2AClientSettings: ) if basic_auth is not None: validate_basic_auth(basic_auth) + protocol_version = _coerce_optional_protocol_version( + "A2A_CLIENT_PROTOCOL_VERSION", + _read_setting( + raw_settings, + keys=( + "A2A_CLIENT_PROTOCOL_VERSION", + "a2a_client_protocol_version", + "A2A_PROTOCOL_VERSION", + "a2a_protocol_version", + ), + default=None, + ), + ) supported_transports = _parse_transports( _read_setting( raw_settings, @@ -260,6 +282,7 @@ def load_settings(raw_settings: Any) -> A2AClientSettings: card_fetch_timeout=card_fetch_timeout, bearer_token=bearer_token, basic_auth=basic_auth, + protocol_version=protocol_version, supported_transports=supported_transports, polling_fallback_enabled=polling_fallback_enabled, polling_fallback_initial_interval_seconds=polling_fallback_initial_interval_seconds, diff --git a/src/opencode_a2a/client/request_context.py b/src/opencode_a2a/client/request_context.py index 44ef93c..bf9586d 100644 --- a/src/opencode_a2a/client/request_context.py +++ b/src/opencode_a2a/client/request_context.py @@ -7,6 +7,7 @@ from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from ..protocol_versions import normalize_protocol_version from .auth import encode_basic_auth @@ -41,12 +42,16 @@ async def intercept( def build_default_headers( bearer_token: str | None, basic_auth: str | None = None, + protocol_version: str | None = None, ) -> dict[str, str]: + headers: dict[str, str] = {} if bearer_token: - return {"Authorization": f"Bearer {bearer_token}"} - if basic_auth: - return {"Authorization": f"Basic {encode_basic_auth(basic_auth)}"} - return {} + headers["Authorization"] = f"Bearer {bearer_token}" + elif basic_auth: + headers["Authorization"] = f"Basic {encode_basic_auth(basic_auth)}" + if protocol_version: + headers["A2A-Version"] = normalize_protocol_version(protocol_version) + return headers def split_request_metadata( @@ -59,6 +64,10 @@ def split_request_metadata( if value is not None: extra_headers["Authorization"] = str(value) continue + if isinstance(key, str) and key.lower() == "a2a-version": + if value is not None: + extra_headers["A2A-Version"] = normalize_protocol_version(str(value)) + continue request_metadata[key] = value return request_metadata or None, extra_headers or None @@ -67,8 +76,9 @@ def build_call_context( bearer_token: str | None, extra_headers: Mapping[str, str] | None, basic_auth: str | None = None, + protocol_version: str | None = None, ) -> ClientCallContext | None: - merged_headers = build_default_headers(bearer_token, basic_auth) + merged_headers = build_default_headers(bearer_token, basic_auth, protocol_version) if extra_headers: merged_headers.update(extra_headers) if not merged_headers: @@ -84,8 +94,9 @@ def build_call_context( def build_client_interceptors( bearer_token: str | None, basic_auth: str | None = None, + protocol_version: str | None = None, ) -> list[ClientCallInterceptor]: - return [HeaderInterceptor(build_default_headers(bearer_token, basic_auth))] + return [HeaderInterceptor(build_default_headers(bearer_token, basic_auth, protocol_version))] __all__ = [ diff --git a/src/opencode_a2a/config.py b/src/opencode_a2a/config.py index 98c19a1..f7300b6 100644 --- a/src/opencode_a2a/config.py +++ b/src/opencode_a2a/config.py @@ -3,10 +3,14 @@ import json from typing import Annotated, Any, Literal -from pydantic import BeforeValidator, Field, model_validator +from pydantic import BeforeValidator, Field, field_validator, model_validator from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict from opencode_a2a import __version__ +from opencode_a2a.protocol_versions import ( + normalize_protocol_version, + normalize_protocol_versions, +) from opencode_a2a.sandbox_policy import SandboxPolicy SandboxMode = Literal[ @@ -97,7 +101,11 @@ class Settings(BaseSettings): a2a_title: str = Field(default="OpenCode A2A", alias="A2A_TITLE") a2a_description: str = Field(default="OpenCode A2A runtime", alias="A2A_DESCRIPTION") a2a_version: str = Field(default=__version__, alias="A2A_VERSION") - a2a_protocol_version: str = Field(default="0.3.0", alias="A2A_PROTOCOL_VERSION") + a2a_protocol_version: str = Field(default="0.3", alias="A2A_PROTOCOL_VERSION") + a2a_supported_protocol_versions: DeclaredStringList = Field( + default=("0.3", "1.0"), + alias="A2A_SUPPORTED_PROTOCOL_VERSIONS", + ) a2a_log_level: str = Field(default="WARNING", alias="A2A_LOG_LEVEL") a2a_log_payloads: bool = Field(default=False, alias="A2A_LOG_PAYLOADS") a2a_log_body_limit: int = Field(default=0, alias="A2A_LOG_BODY_LIMIT") @@ -180,6 +188,10 @@ class Settings(BaseSettings): ) a2a_client_bearer_token: str | None = Field(default=None, alias="A2A_CLIENT_BEARER_TOKEN") a2a_client_basic_auth: str | None = Field(default=None, alias="A2A_CLIENT_BASIC_AUTH") + a2a_client_protocol_version: str | None = Field( + default=None, + alias="A2A_CLIENT_PROTOCOL_VERSION", + ) a2a_client_cache_ttl_seconds: float = Field( default=900.0, ge=0.0, @@ -212,4 +224,37 @@ def _validate_sandbox_policy(self) -> Settings: raise ValueError( "A2A_TASK_STORE_DATABASE_URL is required when A2A_TASK_STORE_BACKEND=database" ) + if self.a2a_protocol_version not in self.a2a_supported_protocol_versions: + supported_display = ", ".join(self.a2a_supported_protocol_versions) + raise ValueError( + "A2A_PROTOCOL_VERSION must be present in A2A_SUPPORTED_PROTOCOL_VERSIONS. " + f"Declared supported versions: {supported_display}" + ) return self + + @field_validator("a2a_protocol_version", mode="before") + @classmethod + def _normalize_a2a_protocol_version(cls, value: Any) -> str: + if not isinstance(value, str): + raise TypeError("A2A_PROTOCOL_VERSION must be a string.") + return normalize_protocol_version(value) + + @field_validator("a2a_client_protocol_version", mode="before") + @classmethod + def _normalize_a2a_client_protocol_version(cls, value: Any) -> str | None: + if value is None: + return None + if not isinstance(value, str): + raise TypeError("A2A_CLIENT_PROTOCOL_VERSION must be a string.") + normalized = value.strip() + if not normalized: + return None + return normalize_protocol_version(normalized) + + @field_validator("a2a_supported_protocol_versions") + @classmethod + def _normalize_supported_protocol_versions( + cls, + value: tuple[str, ...], + ) -> tuple[str, ...]: + return normalize_protocol_versions(value) diff --git a/src/opencode_a2a/contracts/extensions.py b/src/opencode_a2a/contracts/extensions.py index ad367b9..32a4f03 100644 --- a/src/opencode_a2a/contracts/extensions.py +++ b/src/opencode_a2a/contracts/extensions.py @@ -40,6 +40,13 @@ def _extension_spec_uri(fragment: str) -> str: SERVICE_BEHAVIOR_CLASSIFICATION = "service-level-semantic-enhancement" CANCEL_IDEMPOTENCY_BEHAVIOR = "return_current_terminal_task" TERMINAL_RESUBSCRIBE_BEHAVIOR = "replay_terminal_task_once_then_close" +V1_PARTIAL_COMPATIBILITY_GAPS: tuple[str, ...] = ( + "AgentInterface.protocolVersion cannot be declared with a2a-sdk==0.3.25.", + ( + "Transport payloads, enums, pagination, signatures, and push-notification " + "surfaces still follow the SDK-owned 0.3 baseline." + ), +) @dataclass(frozen=True) @@ -1225,7 +1232,17 @@ def build_compatibility_profile_params( *, protocol_version: str, runtime_profile: RuntimeProfile, + supported_protocol_versions: tuple[str, ...] | list[str] | None = None, + default_protocol_version: str | None = None, ) -> dict[str, Any]: + declared_default_protocol_version = default_protocol_version or protocol_version + declared_supported_protocol_versions = list( + supported_protocol_versions or (declared_default_protocol_version,) + ) + protocol_compatibility = build_protocol_compatibility_params( + supported_protocol_versions=declared_supported_protocol_versions, + default_protocol_version=declared_default_protocol_version, + ) capability_snapshot = build_capability_snapshot(runtime_profile=runtime_profile) service_behaviors = build_service_behavior_contract_params() method_retention: dict[str, dict[str, Any]] = { @@ -1295,6 +1312,9 @@ def build_compatibility_profile_params( ) return { **runtime_profile.summary_dict(protocol_version=protocol_version), + "default_protocol_version": declared_default_protocol_version, + "supported_protocol_versions": declared_supported_protocol_versions, + "protocol_compatibility": protocol_compatibility, "core": { "jsonrpc_methods": list(CORE_JSONRPC_METHODS), "http_endpoints": list(CORE_HTTP_ENDPOINTS), @@ -1372,20 +1392,95 @@ def build_compatibility_profile_params( "Treat declared service behaviors as stable server-level semantic " "enhancements layered on top of the core A2A method baseline." ), + ( + "Treat protocol_compatibility as the runtime truth for which major line " + "is fully supported versus partially adapted." + ), ], } +def build_protocol_compatibility_params( + *, + supported_protocol_versions: tuple[str, ...] | list[str], + default_protocol_version: str, +) -> dict[str, Any]: + declared_supported_versions = list(supported_protocol_versions) + versions: dict[str, dict[str, Any]] = { + "0.3": { + "enabled": "0.3" in declared_supported_versions, + "default": default_protocol_version == "0.3", + "status": "supported", + "supported_features": [ + "Default compatibility line for the current deployment.", + "A2A-Version negotiation fallback and explicit 0.3 routing.", + "Legacy JSON-RPC and REST error envelopes.", + ( + "SDK-owned transport payloads, enums, pagination, signatures, and " + "push-notification surfaces." + ), + ], + "known_gaps": [], + }, + "1.0": { + "enabled": "1.0" in declared_supported_versions, + "default": default_protocol_version == "1.0", + "status": "partial", + "supported_features": [ + "A2A-Version negotiation and request routing.", + "Protocol-aware JSON-RPC error shaping.", + "Protocol-aware REST error shaping.", + ], + "known_gaps": list(V1_PARTIAL_COMPATIBILITY_GAPS), + }, + } + + for version in declared_supported_versions: + if version in versions: + continue + versions[version] = { + "enabled": True, + "default": default_protocol_version == version, + "status": "custom", + "supported_features": [ + "Supported by deployment configuration.", + "Version-specific compatibility details are not yet declared.", + ], + "known_gaps": [ + "This protocol line does not yet have a dedicated compatibility summary.", + ], + } + + return { + "default_protocol_version": default_protocol_version, + "supported_protocol_versions": declared_supported_versions, + "versions": versions, + } + + def build_wire_contract_params( *, protocol_version: str, runtime_profile: RuntimeProfile, + supported_protocol_versions: tuple[str, ...] | list[str] | None = None, + default_protocol_version: str | None = None, ) -> dict[str, Any]: + declared_default_protocol_version = default_protocol_version or protocol_version + declared_supported_protocol_versions = list( + supported_protocol_versions or (declared_default_protocol_version,) + ) + protocol_compatibility = build_protocol_compatibility_params( + supported_protocol_versions=declared_supported_protocol_versions, + default_protocol_version=declared_default_protocol_version, + ) capability_snapshot = build_capability_snapshot(runtime_profile=runtime_profile) service_behaviors = build_service_behavior_contract_params() return { "protocol_version": protocol_version, + "default_protocol_version": declared_default_protocol_version, + "supported_protocol_versions": declared_supported_protocol_versions, + "protocol_compatibility": protocol_compatibility, "profile": runtime_profile.summary_dict(protocol_version=protocol_version), "preferred_transport": "HTTP+JSON", "additional_transports": ["JSON-RPC"], diff --git a/src/opencode_a2a/jsonrpc/application.py b/src/opencode_a2a/jsonrpc/application.py index bad2124..f02c9bd 100644 --- a/src/opencode_a2a/jsonrpc/application.py +++ b/src/opencode_a2a/jsonrpc/application.py @@ -2,12 +2,15 @@ import logging from collections.abc import Awaitable, Callable +from dataclasses import replace +from functools import partial from typing import Any, cast from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication from a2a.types import ( A2AError, InvalidRequestError, + JSONRPCError, JSONRPCRequest, ) from fastapi.responses import JSONResponse @@ -21,6 +24,7 @@ build_extension_method_registry, ) from .error_responses import ( + adapt_jsonrpc_error_for_protocol, invalid_params_error, method_not_supported_error, ) @@ -178,9 +182,26 @@ def __init__( self._extension_handler_context ) + def _generate_protocol_error_response( + self, + request_id: str | int | None, + error: JSONRPCError | A2AError, + *, + protocol_version: str, + ) -> JSONResponse: + return self._generate_error_response( + request_id, + adapt_jsonrpc_error_for_protocol(protocol_version, error), + ) + async def _handle_requests(self, request: Request) -> Response: # Fast path: sniff method first then either handle here or delegate. request_id: str | int | None = None + negotiated_protocol_version = getattr( + request.state, + "a2a_protocol_version", + self._protocol_version, + ) try: body = await request.json() if isinstance(body, dict): @@ -189,9 +210,10 @@ async def _handle_requests(self, request: Request) -> Response: request_id = None if not self._allowed_content_length(request): - return self._generate_error_response( + return self._generate_protocol_error_response( request_id, A2AError(root=InvalidRequestError(message="Payload too large")), + protocol_version=negotiated_protocol_version, ) base_request = JSONRPCRequest.model_validate(body) @@ -205,24 +227,33 @@ async def _handle_requests(self, request: Request) -> Response: return await super()._handle_requests(request) if base_request.id is None: return Response(status_code=204) - - return self._generate_error_response( + return self._generate_protocol_error_response( base_request.id, method_not_supported_error( method=base_request.method, supported_methods=self._supported_methods, - protocol_version=self._protocol_version, + protocol_version=negotiated_protocol_version, ), + protocol_version=negotiated_protocol_version, ) params = base_request.params or {} if not isinstance(params, dict): - return self._generate_error_response( + return self._generate_protocol_error_response( base_request.id, invalid_params_error("params must be an object"), + protocol_version=negotiated_protocol_version, ) - return await extension_spec.handler( + request_context = replace( self._extension_handler_context, + protocol_version=negotiated_protocol_version, + error_response=partial( + self._generate_protocol_error_response, + protocol_version=negotiated_protocol_version, + ), + ) + return await extension_spec.handler( + request_context, base_request, params, request, diff --git a/src/opencode_a2a/jsonrpc/error_responses.py b/src/opencode_a2a/jsonrpc/error_responses.py index b170b0d..4b5ab7c 100644 --- a/src/opencode_a2a/jsonrpc/error_responses.py +++ b/src/opencode_a2a/jsonrpc/error_responses.py @@ -1,9 +1,185 @@ from __future__ import annotations +import json +from collections.abc import Mapping from typing import Any from a2a.types import A2AError, InvalidParamsError, JSONRPCError +from ..protocol_versions import normalize_protocol_version + +A2A_ERROR_DOMAIN = "a2a-protocol.org" +GOOGLE_RPC_ERROR_INFO_TYPE = "type.googleapis.com/google.rpc.ErrorInfo" +STANDARD_JSONRPC_ERROR_MESSAGES = { + -32700: "Invalid JSON payload", + -32600: "Request payload validation error", + -32601: "Method not found", + -32602: "Invalid parameters", + -32603: "Internal error", +} +STANDARD_JSONRPC_ERROR_CODES = frozenset(STANDARD_JSONRPC_ERROR_MESSAGES) + + +def protocol_uses_v1_error_format(protocol_version: str | None) -> bool: + if protocol_version is None: + return False + return normalize_protocol_version(protocol_version).startswith("1.") + + +def _to_upper_snake_case(name: str) -> str: + normalized: list[str] = [] + previous_was_lower = False + for char in name: + if char.isupper() and previous_was_lower: + normalized.append("_") + if char in {" ", "-"}: + normalized.append("_") + previous_was_lower = False + continue + normalized.append(char.upper()) + previous_was_lower = char.islower() + return "".join(normalized).strip("_") + + +def _to_lower_camel_case(name: str) -> str: + if "_" not in name: + return name + head, *tail = [part for part in name.split("_") if part] + return head + "".join(part[:1].upper() + part[1:] for part in tail) + + +def _camelize(value: Any) -> Any: + if isinstance(value, Mapping): + return {_to_lower_camel_case(str(key)): _camelize(item) for key, item in value.items()} + if isinstance(value, list): + return [_camelize(item) for item in value] + return value + + +def _stringify_metadata_value(value: Any) -> str: + if isinstance(value, str): + return value + if isinstance(value, bool | int | float): + return str(value) + return json.dumps(value, ensure_ascii=False, separators=(",", ":"), sort_keys=True) + + +def _build_error_info_detail( + *, + reason: str, + metadata: Mapping[str, Any] | None = None, +) -> dict[str, Any]: + payload: dict[str, Any] = { + "@type": GOOGLE_RPC_ERROR_INFO_TYPE, + "reason": _to_upper_snake_case(reason), + "domain": A2A_ERROR_DOMAIN, + } + if metadata: + payload["metadata"] = { + _to_lower_camel_case(str(key)): _stringify_metadata_value(value) + for key, value in metadata.items() + if value is not None + } + return payload + + +def _build_context_detail(type_name: str, payload: Mapping[str, Any]) -> dict[str, Any]: + return { + "@type": f"type.googleapis.com/opencode_a2a.{type_name}", + **_camelize(dict(payload)), + } + + +def _reason_from_error(error: object) -> str | None: + data = getattr(error, "data", None) + if isinstance(data, Mapping): + data_type = data.get("type") + if isinstance(data_type, str) and data_type.strip(): + return data_type + class_name = type(error).__name__ + if class_name.endswith("Error") and class_name != "JSONRPCError": + return class_name[:-5] + return None + + +def _metadata_from_error(error: object) -> dict[str, Any]: + data = getattr(error, "data", None) + if not isinstance(data, Mapping): + return {} + return {str(key): value for key, value in data.items() if key != "type"} + + +def adapt_jsonrpc_error_for_protocol( + protocol_version: str, + error: JSONRPCError | A2AError, +) -> JSONRPCError | A2AError: + if not protocol_uses_v1_error_format(protocol_version): + return error + + root_error = error.root if isinstance(error, A2AError) else error + root_data = getattr(root_error, "data", None) + + if root_error.code in STANDARD_JSONRPC_ERROR_CODES: + adapted_data = None + if isinstance(root_data, Mapping): + adapted_data = _camelize( + {str(key): value for key, value in root_data.items() if key != "type"} + ) + elif root_data is not None: + adapted_data = root_data + return JSONRPCError( + code=root_error.code, + message=STANDARD_JSONRPC_ERROR_MESSAGES[root_error.code], + data=adapted_data, + ) + + reason = _reason_from_error(root_error) + metadata = _metadata_from_error(root_error) + details: list[dict[str, Any]] = [] + if reason is not None: + details.append(_build_error_info_detail(reason=reason, metadata=metadata)) + if metadata: + details.append(_build_context_detail("ErrorContext", metadata)) + + message = root_error.message + if message is None: + message = STANDARD_JSONRPC_ERROR_MESSAGES.get(root_error.code, "Internal error") + + return JSONRPCError( + code=root_error.code, + message=message, + data=details or None, + ) + + +def build_http_error_body( + *, + protocol_version: str, + status_code: int, + status: str, + message: str, + legacy_payload: dict[str, Any], + reason: str | None = None, + metadata: Mapping[str, Any] | None = None, +) -> dict[str, Any]: + if not protocol_uses_v1_error_format(protocol_version): + return legacy_payload + + details: list[dict[str, Any]] = [] + if reason is not None: + details.append(_build_error_info_detail(reason=reason, metadata=metadata)) + if metadata: + details.append(_build_context_detail("HttpErrorContext", dict(metadata))) + + error_payload: dict[str, Any] = { + "code": status_code, + "status": status, + "message": message, + } + if details: + error_payload["details"] = details + return {"error": error_payload} + def invalid_params_error( message: str, @@ -31,6 +207,24 @@ def method_not_supported_error( ) +def version_not_supported_error( + *, + requested_version: str, + supported_protocol_versions: list[str], + default_protocol_version: str, +) -> JSONRPCError: + return JSONRPCError( + code=-32001, + message=f"Unsupported A2A version: {requested_version}", + data={ + "type": "VERSION_NOT_SUPPORTED", + "requested_version": requested_version, + "supported_protocol_versions": supported_protocol_versions, + "default_protocol_version": default_protocol_version, + }, + ) + + def session_forbidden_error(code: int, *, session_id: str) -> JSONRPCError: return JSONRPCError( code=code, @@ -148,13 +342,19 @@ def upstream_payload_error( __all__ = [ + "A2A_ERROR_DOMAIN", + "GOOGLE_RPC_ERROR_INFO_TYPE", + "adapt_jsonrpc_error_for_protocol", + "build_http_error_body", "interrupt_not_found_error", "interrupt_type_mismatch_error", "invalid_params_error", "method_not_supported_error", + "protocol_uses_v1_error_format", "session_forbidden_error", "session_not_found_error", "upstream_http_error", "upstream_payload_error", "upstream_unreachable_error", + "version_not_supported_error", ] diff --git a/src/opencode_a2a/protocol_versions.py b/src/opencode_a2a/protocol_versions.py new file mode 100644 index 0000000..9991d2e --- /dev/null +++ b/src/opencode_a2a/protocol_versions.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import re +from collections.abc import Iterable +from dataclasses import dataclass + +_PROTOCOL_VERSION_PATTERN = re.compile(r"^(?P\d+)\.(?P\d+)(?:\.\d+)?$") + + +class UnsupportedProtocolVersionError(ValueError): + def __init__( + self, + requested_version: str, + *, + supported_protocol_versions: tuple[str, ...], + default_protocol_version: str, + ) -> None: + self.requested_version = requested_version + self.supported_protocol_versions = supported_protocol_versions + self.default_protocol_version = default_protocol_version + supported_display = ", ".join(supported_protocol_versions) + super().__init__( + f"Unsupported A2A protocol version {requested_version!r}. " + f"Supported versions: {supported_display}." + ) + + +@dataclass(frozen=True) +class NegotiatedProtocolVersion: + requested_version: str + negotiated_version: str + explicit: bool + + +def normalize_protocol_version(value: str) -> str: + normalized = value.strip() + if not normalized: + raise ValueError("Protocol version must be a non-empty string.") + match = _PROTOCOL_VERSION_PATTERN.fullmatch(normalized) + if match is None: + raise ValueError("Protocol version must use Major.Minor or Major.Minor.Patch format.") + return f"{match.group('major')}.{match.group('minor')}" + + +def normalize_protocol_versions(values: Iterable[str]) -> tuple[str, ...]: + normalized_versions: list[str] = [] + seen: set[str] = set() + for value in values: + normalized = normalize_protocol_version(str(value)) + if normalized in seen: + continue + seen.add(normalized) + normalized_versions.append(normalized) + if not normalized_versions: + raise ValueError("At least one supported protocol version must be declared.") + return tuple(normalized_versions) + + +def negotiate_protocol_version( + *, + header_value: str | None, + query_value: str | None, + default_protocol_version: str, + supported_protocol_versions: Iterable[str], +) -> NegotiatedProtocolVersion: + normalized_default = normalize_protocol_version(default_protocol_version) + normalized_supported = normalize_protocol_versions(supported_protocol_versions) + + raw_header = (header_value or "").strip() + raw_query = (query_value or "").strip() + explicit = bool(raw_header or raw_query) + raw_requested = raw_header or raw_query or normalized_default + + try: + normalized_requested = normalize_protocol_version(raw_requested) + except ValueError as exc: + raise UnsupportedProtocolVersionError( + raw_requested, + supported_protocol_versions=normalized_supported, + default_protocol_version=normalized_default, + ) from exc + + if normalized_requested not in normalized_supported: + raise UnsupportedProtocolVersionError( + normalized_requested, + supported_protocol_versions=normalized_supported, + default_protocol_version=normalized_default, + ) + + return NegotiatedProtocolVersion( + requested_version=normalized_requested, + negotiated_version=normalized_requested, + explicit=explicit, + ) + + +__all__ = [ + "NegotiatedProtocolVersion", + "UnsupportedProtocolVersionError", + "negotiate_protocol_version", + "normalize_protocol_version", + "normalize_protocol_versions", +] diff --git a/src/opencode_a2a/server/agent_card.py b/src/opencode_a2a/server/agent_card.py index 824988d..7b4ad86 100644 --- a/src/opencode_a2a/server/agent_card.py +++ b/src/opencode_a2a/server/agent_card.py @@ -226,10 +226,14 @@ def _build_agent_extensions( compatibility_profile_params = build_compatibility_profile_params( protocol_version=settings.a2a_protocol_version, runtime_profile=runtime_profile, + supported_protocol_versions=settings.a2a_supported_protocol_versions, + default_protocol_version=settings.a2a_protocol_version, ) wire_contract_params = build_wire_contract_params( protocol_version=settings.a2a_protocol_version, runtime_profile=runtime_profile, + supported_protocol_versions=settings.a2a_supported_protocol_versions, + default_protocol_version=settings.a2a_protocol_version, ) return [ diff --git a/src/opencode_a2a/server/application.py b/src/opencode_a2a/server/application.py index 9460594..e0157fa 100644 --- a/src/opencode_a2a/server/application.py +++ b/src/opencode_a2a/server/application.py @@ -504,6 +504,12 @@ def build(self, request: Request) -> ServerCallContext: identity = getattr(request.state, "user_identity", None) if identity: context.state["identity"] = identity + negotiated_protocol_version = getattr(request.state, "a2a_protocol_version", None) + if negotiated_protocol_version: + context.state["a2a_protocol_version"] = negotiated_protocol_version + requested_protocol_version = getattr(request.state, "a2a_requested_protocol_version", None) + if requested_protocol_version: + context.state["a2a_requested_protocol_version"] = requested_protocol_version return context diff --git a/src/opencode_a2a/server/client_manager.py b/src/opencode_a2a/server/client_manager.py index d496e74..f987f93 100644 --- a/src/opencode_a2a/server/client_manager.py +++ b/src/opencode_a2a/server/client_manager.py @@ -21,6 +21,9 @@ def __init__(self, settings) -> None: # noqa: ANN001 "A2A_CLIENT_USE_CLIENT_PREFERENCE": settings.a2a_client_use_client_preference, "A2A_CLIENT_BEARER_TOKEN": settings.a2a_client_bearer_token, "A2A_CLIENT_BASIC_AUTH": settings.a2a_client_basic_auth, + "A2A_CLIENT_PROTOCOL_VERSION": ( + settings.a2a_client_protocol_version or settings.a2a_protocol_version + ), "A2A_CLIENT_SUPPORTED_TRANSPORTS": settings.a2a_client_supported_transports, } ) diff --git a/src/opencode_a2a/server/middleware.py b/src/opencode_a2a/server/middleware.py index d3afc4f..7d8a1d6 100644 --- a/src/opencode_a2a/server/middleware.py +++ b/src/opencode_a2a/server/middleware.py @@ -5,6 +5,7 @@ import logging import secrets from contextvars import ContextVar, Token +from typing import cast from a2a.utils.constants import ( AGENT_CARD_WELL_KNOWN_PATH, @@ -16,6 +17,16 @@ from starlette.responses import StreamingResponse from ..execution.metrics import emit_metric +from ..jsonrpc.error_responses import ( + adapt_jsonrpc_error_for_protocol, + build_http_error_body, + version_not_supported_error, +) +from ..protocol_versions import ( + UnsupportedProtocolVersionError, + negotiate_protocol_version, + normalize_protocol_version, +) from .request_parsing import ( _decode_payload_preview, _detect_sensitive_extension_method, @@ -85,6 +96,106 @@ def install_runtime_middlewares( public_card_etag: str, extended_card_etag: str, ) -> None: + def _requires_protocol_negotiation(request: Request) -> bool: + if request.url.path == "/" and request.method == "POST": + return True + if request.url.path.startswith("/v1/"): + return True + return False + + def _extract_jsonrpc_request_id(payload: object) -> str | int | None: + if not isinstance(payload, dict): + return None + request_id = payload.get("id") + if isinstance(request_id, str | int): + return request_id + return None + + def _error_protocol_version(request: Request) -> str: + negotiated = getattr(request.state, "a2a_protocol_version", None) + if isinstance(negotiated, str) and negotiated.strip(): + return negotiated + raw_value = request.headers.get("A2A-Version") or request.query_params.get("A2A-Version") + if isinstance(raw_value, str) and raw_value.strip(): + try: + return normalize_protocol_version(raw_value) + except ValueError: + return raw_value.strip() + return cast(str, settings.a2a_protocol_version) + + @app.middleware("http") + async def negotiate_a2a_protocol_version(request: Request, call_next): + token: Token | None = None + if not _requires_protocol_negotiation(request): + return await call_next(request) + + try: + negotiated = negotiate_protocol_version( + header_value=request.headers.get("A2A-Version"), + query_value=request.query_params.get("A2A-Version"), + default_protocol_version=settings.a2a_protocol_version, + supported_protocol_versions=settings.a2a_supported_protocol_versions, + ) + except UnsupportedProtocolVersionError as error: + if request.url.path == "/" and request.method == "POST": + try: + body, token = await _get_request_body(request) + payload = _parse_json_body(body) + except _RequestBodyTooLargeError as request_error: + return _request_body_too_large_response( + path=request.url.path, + method=request.method, + error=request_error, + protocol_version=_error_protocol_version(request), + ) + return JSONResponse( + { + "jsonrpc": "2.0", + "id": _extract_jsonrpc_request_id(payload), + "error": adapt_jsonrpc_error_for_protocol( + error.requested_version, + version_not_supported_error( + requested_version=error.requested_version, + supported_protocol_versions=list(error.supported_protocol_versions), + default_protocol_version=error.default_protocol_version, + ), + ).model_dump(mode="json", exclude_none=True), + }, + status_code=200, + ) + return JSONResponse( + build_http_error_body( + protocol_version=error.requested_version, + status_code=400, + status="INVALID_ARGUMENT", + message="Unsupported A2A version", + legacy_payload={ + "error": "Unsupported A2A version", + "type": "VERSION_NOT_SUPPORTED", + "requested_version": error.requested_version, + "supported_protocol_versions": list(error.supported_protocol_versions), + "default_protocol_version": error.default_protocol_version, + }, + reason="VERSION_NOT_SUPPORTED", + metadata={ + "requested_version": error.requested_version, + "supported_protocol_versions": list(error.supported_protocol_versions), + "default_protocol_version": error.default_protocol_version, + }, + ), + status_code=400, + ) + finally: + if token is not None: + _REQUEST_BODY_BYTES.reset(token) + + request.state.a2a_protocol_version = negotiated.negotiated_version + request.state.a2a_requested_protocol_version = negotiated.requested_version + request.state.a2a_protocol_version_explicit = negotiated.explicit + response = await call_next(request) + response.headers["A2A-Version"] = negotiated.negotiated_version + return response + async def _get_request_body(request: Request) -> tuple[bytes, Token | None]: cached = _REQUEST_BODY_BYTES.get() if cached is not None: @@ -201,6 +312,7 @@ async def enforce_request_body_limit(request: Request, call_next): path=request.url.path, method=request.method, error=error, + protocol_version=_error_protocol_version(request), ) finally: if token is not None: @@ -222,13 +334,25 @@ async def guard_rest_payload_shape(request: Request, call_next): payload ): return JSONResponse( - { - "error": ( + build_http_error_body( + protocol_version=_error_protocol_version(request), + status_code=400, + status="INVALID_ARGUMENT", + message=( "Invalid HTTP+JSON payload for REST endpoint. " "Use message.content with ROLE_* role values, or call " "POST / with method=message/send or method=message/stream." - ) - }, + ), + legacy_payload={ + "error": ( + "Invalid HTTP+JSON payload for REST endpoint. " + "Use message.content with ROLE_* role values, or call " + "POST / with method=message/send or method=message/stream." + ) + }, + reason="INVALID_HTTP_JSON_PAYLOAD", + metadata={"path": request.url.path}, + ), status_code=400, ) return await call_next(request) @@ -237,6 +361,7 @@ async def guard_rest_payload_shape(request: Request, call_next): path=request.url.path, method=request.method, error=error, + protocol_version=_error_protocol_version(request), ) finally: if token is not None: @@ -342,6 +467,7 @@ async def log_payloads(request: Request, call_next): path=request.url.path, method=request.method, error=error, + protocol_version=_error_protocol_version(request), ) finally: if token is not None: diff --git a/src/opencode_a2a/server/openapi.py b/src/opencode_a2a/server/openapi.py index 0ac3f59..a9a944a 100644 --- a/src/opencode_a2a/server/openapi.py +++ b/src/opencode_a2a/server/openapi.py @@ -588,10 +588,14 @@ def _patch_jsonrpc_openapi_contract( compatibility_profile = build_compatibility_profile_params( protocol_version=settings.a2a_protocol_version, runtime_profile=runtime_profile, + supported_protocol_versions=settings.a2a_supported_protocol_versions, + default_protocol_version=settings.a2a_protocol_version, ) wire_contract = build_wire_contract_params( protocol_version=settings.a2a_protocol_version, runtime_profile=runtime_profile, + supported_protocol_versions=settings.a2a_supported_protocol_versions, + default_protocol_version=settings.a2a_protocol_version, ) capability_snapshot = build_capability_snapshot(runtime_profile=runtime_profile) original_openapi = app.openapi diff --git a/src/opencode_a2a/server/request_parsing.py b/src/opencode_a2a/server/request_parsing.py index 28f2c93..4d76559 100644 --- a/src/opencode_a2a/server/request_parsing.py +++ b/src/opencode_a2a/server/request_parsing.py @@ -11,6 +11,7 @@ SESSION_QUERY_METHODS, WORKSPACE_CONTROL_METHODS, ) +from ..jsonrpc.error_responses import build_http_error_body logger = logging.getLogger(__name__) @@ -103,6 +104,7 @@ def _request_body_too_large_response( path: str, method: str, error: _RequestBodyTooLargeError, + protocol_version: str = "0.3", ) -> JSONResponse: logger.warning( "A2A request %s %s rejected: body_size=%s exceeds max_request_body_bytes=%s", @@ -112,6 +114,14 @@ def _request_body_too_large_response( error.limit, ) return JSONResponse( - {"error": "Request body too large", "max_bytes": error.limit}, + build_http_error_body( + protocol_version=protocol_version, + status_code=413, + status="RESOURCE_EXHAUSTED", + message="Request body too large", + legacy_payload={"error": "Request body too large", "max_bytes": error.limit}, + reason="REQUEST_BODY_TOO_LARGE", + metadata={"max_bytes": error.limit, "actual_size": error.actual_size}, + ), status_code=413, ) diff --git a/tests/client/test_client_config.py b/tests/client/test_client_config.py index 04daaf6..5e9963b 100644 --- a/tests/client/test_client_config.py +++ b/tests/client/test_client_config.py @@ -20,6 +20,7 @@ def test_load_settings_from_mapping() -> None: "A2A_CLIENT_USE_CLIENT_PREFERENCE": "true", "A2A_CLIENT_BEARER_TOKEN": "peer-token", "A2A_CLIENT_BASIC_AUTH": "user:pass", + "A2A_CLIENT_PROTOCOL_VERSION": "1.0.0", "A2A_CLIENT_SUPPORTED_TRANSPORTS": "json-rpc,http-json", "A2A_CLIENT_POLLING_FALLBACK_ENABLED": "true", "A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS": "0.75", @@ -35,6 +36,7 @@ def test_load_settings_from_mapping() -> None: assert settings.use_client_preference is True assert settings.bearer_token == "peer-token" assert settings.basic_auth == "user:pass" + assert settings.protocol_version == "1.0" assert settings.supported_transports == ("JSONRPC", "HTTP+JSON") assert settings.polling_fallback_enabled is True assert settings.polling_fallback_initial_interval_seconds == 0.75 @@ -66,6 +68,12 @@ def test_load_settings_accepts_base64_basic_auth() -> None: assert settings == A2AClientSettings(basic_auth=encoded) +def test_load_settings_can_fallback_to_general_protocol_version() -> None: + settings = load_settings({"A2A_PROTOCOL_VERSION": "0.3.0"}) + + assert settings.protocol_version == "0.3" + + def test_load_settings_invalid_basic_auth_raises() -> None: with pytest.raises(ValueError, match="username:password"): load_settings({"A2A_CLIENT_BASIC_AUTH": "not-basic-auth"}) diff --git a/tests/client/test_request_context.py b/tests/client/test_request_context.py index 3bbc37e..5ee1863 100644 --- a/tests/client/test_request_context.py +++ b/tests/client/test_request_context.py @@ -16,11 +16,18 @@ def test_split_request_metadata_and_default_headers() -> None: request_metadata, extra_headers = split_request_metadata( - {"authorization": "Bearer explicit-token", "trace_id": "trace-1"} + { + "authorization": "Bearer explicit-token", + "A2A-Version": "1.0.0", + "trace_id": "trace-1", + } ) assert request_metadata == {"trace_id": "trace-1"} - assert extra_headers == {"Authorization": "Bearer explicit-token"} + assert extra_headers == { + "Authorization": "Bearer explicit-token", + "A2A-Version": "1.0", + } assert build_default_headers("peer-token") == {"Authorization": "Bearer peer-token"} @@ -47,6 +54,13 @@ def test_build_default_headers_prefers_bearer_over_basic_auth() -> None: } +def test_build_default_headers_includes_protocol_version() -> None: + assert build_default_headers("peer-token", protocol_version="1.0.0") == { + "Authorization": "Bearer peer-token", + "A2A-Version": "1.0", + } + + def test_build_call_context_without_headers_returns_none() -> None: assert build_call_context(None, None) is None diff --git a/tests/config/test_settings.py b/tests/config/test_settings.py index d837777..303c270 100644 --- a/tests/config/test_settings.py +++ b/tests/config/test_settings.py @@ -70,6 +70,23 @@ def test_settings_valid(): assert settings.a2a_task_store_backend == "database" assert settings.a2a_task_store_database_url == "sqlite+aiosqlite:///./opencode-a2a.db" assert settings.a2a_version == __version__ + assert settings.a2a_protocol_version == "0.3" + assert settings.a2a_supported_protocol_versions == ("0.3", "1.0") + + +def test_settings_normalize_protocol_versions() -> None: + env = { + "A2A_BEARER_TOKEN": "test-token", + "A2A_PROTOCOL_VERSION": "0.3.0", + "A2A_SUPPORTED_PROTOCOL_VERSIONS": "0.3.0,1.0.0,1.0", + "A2A_CLIENT_PROTOCOL_VERSION": "1.0.0", + } + with mock.patch.dict(os.environ, env, clear=True): + settings = Settings() + + assert settings.a2a_protocol_version == "0.3" + assert settings.a2a_supported_protocol_versions == ("0.3", "1.0") + assert settings.a2a_client_protocol_version == "1.0" def test_settings_allow_explicit_memory_backend() -> None: diff --git a/tests/contracts/test_extension_contract_consistency.py b/tests/contracts/test_extension_contract_consistency.py index 55e65d3..70d74fc 100644 --- a/tests/contracts/test_extension_contract_consistency.py +++ b/tests/contracts/test_extension_contract_consistency.py @@ -84,10 +84,14 @@ def test_extension_ssot_matches_agent_card_contracts() -> None: expected_compatibility_profile = build_compatibility_profile_params( protocol_version=settings.a2a_protocol_version, runtime_profile=runtime_profile, + supported_protocol_versions=settings.a2a_supported_protocol_versions, + default_protocol_version=settings.a2a_protocol_version, ) expected_wire_contract = build_wire_contract_params( protocol_version=settings.a2a_protocol_version, runtime_profile=runtime_profile, + supported_protocol_versions=settings.a2a_supported_protocol_versions, + default_protocol_version=settings.a2a_protocol_version, ) assert session_binding.params == expected_session_binding, ( @@ -120,6 +124,10 @@ def test_extension_ssot_matches_agent_card_contracts() -> None: assert wire_contract.params == expected_wire_contract, ( "Wire contract extension drifted from contracts.extensions SSOT." ) + assert ( + compatibility_profile.params["protocol_compatibility"] + == wire_contract.params["protocol_compatibility"] + ), "Protocol compatibility summary drifted between compatibility profile and wire contract." def test_openapi_jsonrpc_contract_extension_matches_ssot() -> None: @@ -170,10 +178,14 @@ def test_openapi_jsonrpc_contract_extension_matches_ssot() -> None: expected_compatibility_profile = build_compatibility_profile_params( protocol_version=settings.a2a_protocol_version, runtime_profile=runtime_profile, + supported_protocol_versions=settings.a2a_supported_protocol_versions, + default_protocol_version=settings.a2a_protocol_version, ) expected_wire_contract = build_wire_contract_params( protocol_version=settings.a2a_protocol_version, runtime_profile=runtime_profile, + supported_protocol_versions=settings.a2a_supported_protocol_versions, + default_protocol_version=settings.a2a_protocol_version, ) assert session_binding == expected_session_binding, ( @@ -206,6 +218,9 @@ def test_openapi_jsonrpc_contract_extension_matches_ssot() -> None: assert wire_contract == expected_wire_contract, ( "OpenAPI wire contract drifted from contracts.extensions SSOT." ) + assert ( + compatibility_profile["protocol_compatibility"] == wire_contract["protocol_compatibility"] + ), "OpenAPI protocol compatibility summary drifted between profile and wire contract." json_request_schema = ( post.get("requestBody", {}).get("content", {}).get("application/json", {}).get("schema", {}) diff --git a/tests/execution/test_opencode_agent_session_binding.py b/tests/execution/test_opencode_agent_session_binding.py index 4d08c79..891af96 100644 --- a/tests/execution/test_opencode_agent_session_binding.py +++ b/tests/execution/test_opencode_agent_session_binding.py @@ -582,6 +582,8 @@ async def test_agent_a2a_call_uses_server_side_basic_auth_headers( a2a_client_use_client_preference=False, a2a_client_bearer_token=None, a2a_client_basic_auth="user:pass", + a2a_client_protocol_version=None, + a2a_protocol_version="0.3", a2a_client_supported_transports=("JSONRPC", "HTTP+JSON"), a2a_client_cache_ttl_seconds=60.0, a2a_client_cache_maxsize=1, diff --git a/tests/jsonrpc/test_error_responses.py b/tests/jsonrpc/test_error_responses.py index e864f53..8206899 100644 --- a/tests/jsonrpc/test_error_responses.py +++ b/tests/jsonrpc/test_error_responses.py @@ -1,8 +1,10 @@ from __future__ import annotations -from a2a.types import InvalidParamsError +from a2a.types import A2AError, InvalidParamsError, UnsupportedOperationError from opencode_a2a.jsonrpc.error_responses import ( + GOOGLE_RPC_ERROR_INFO_TYPE, + adapt_jsonrpc_error_for_protocol, interrupt_not_found_error, interrupt_type_mismatch_error, invalid_params_error, @@ -12,6 +14,7 @@ upstream_http_error, upstream_payload_error, upstream_unreachable_error, + version_not_supported_error, ) @@ -19,7 +22,7 @@ def test_jsonrpc_error_mapping_helpers_preserve_business_contract_fields() -> No unsupported = method_not_supported_error( method="unsupported.method", supported_methods=["message/send", "tasks/get"], - protocol_version="0.3.0", + protocol_version="0.3", ) assert unsupported.code == -32601 assert unsupported.data["type"] == "METHOD_NOT_SUPPORTED" @@ -99,3 +102,84 @@ def test_invalid_error_helper_wraps_a2a_error() -> None: assert isinstance(invalid.root, InvalidParamsError) assert invalid.root.message == "bad field" assert invalid.root.data == {"type": "INVALID_FIELD", "field": "request"} + + +def test_version_not_supported_error_includes_supported_versions() -> None: + error = version_not_supported_error( + requested_version="2.0", + supported_protocol_versions=["0.3", "1.0"], + default_protocol_version="0.3", + ) + + assert error.code == -32001 + assert error.message == "Unsupported A2A version: 2.0" + assert error.data == { + "type": "VERSION_NOT_SUPPORTED", + "requested_version": "2.0", + "supported_protocol_versions": ["0.3", "1.0"], + "default_protocol_version": "0.3", + } + + +def test_adapt_standard_jsonrpc_error_for_v1_uses_standard_message_and_camel_case_data() -> None: + adapted = adapt_jsonrpc_error_for_protocol( + "1.0", + method_not_supported_error( + method="unsupported.method", + supported_methods=["message/send", "tasks/get"], + protocol_version="1.0", + ), + ) + + assert adapted.message == "Method not found" + assert adapted.data == { + "method": "unsupported.method", + "supportedMethods": ["message/send", "tasks/get"], + "protocolVersion": "1.0", + } + + +def test_adapt_a2a_specific_error_for_v1_uses_error_info_details() -> None: + adapted = adapt_jsonrpc_error_for_protocol( + "1.0", + version_not_supported_error( + requested_version="1.1", + supported_protocol_versions=["0.3", "1.0"], + default_protocol_version="0.3", + ), + ) + + assert adapted.code == -32001 + assert adapted.data[0] == { + "@type": GOOGLE_RPC_ERROR_INFO_TYPE, + "reason": "VERSION_NOT_SUPPORTED", + "domain": "a2a-protocol.org", + "metadata": { + "requestedVersion": "1.1", + "supportedProtocolVersions": '["0.3","1.0"]', + "defaultProtocolVersion": "0.3", + }, + } + assert adapted.data[1] == { + "@type": "type.googleapis.com/opencode_a2a.ErrorContext", + "requestedVersion": "1.1", + "supportedProtocolVersions": ["0.3", "1.0"], + "defaultProtocolVersion": "0.3", + } + + +def test_adapt_a2a_root_error_for_v1_uses_error_type_reason() -> None: + adapted = adapt_jsonrpc_error_for_protocol( + "1.0", + A2AError(root=UnsupportedOperationError()), + ) + + assert adapted.code == -32004 + assert adapted.message == "This operation is not supported" + assert adapted.data == [ + { + "@type": GOOGLE_RPC_ERROR_INFO_TYPE, + "reason": "UNSUPPORTED_OPERATION", + "domain": "a2a-protocol.org", + } + ] diff --git a/tests/jsonrpc/test_jsonrpc_unsupported_method.py b/tests/jsonrpc/test_jsonrpc_unsupported_method.py index 0641e4f..e7bd411 100644 --- a/tests/jsonrpc/test_jsonrpc_unsupported_method.py +++ b/tests/jsonrpc/test_jsonrpc_unsupported_method.py @@ -36,6 +36,88 @@ async def test_unsupported_method_returns_unified_error() -> None: assert data["protocol_version"] == settings.a2a_protocol_version +@pytest.mark.asyncio +async def test_unsupported_method_uses_requested_protocol_version() -> None: + settings = make_settings(a2a_bearer_token="test-token") + app = create_app(settings) + transport = httpx.ASGITransport(app=app) + + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/", + headers={ + "Authorization": "Bearer test-token", + "A2A-Version": "1.0", + }, + json={"jsonrpc": "2.0", "id": 123, "method": "unsupported.method", "params": {}}, + ) + + assert response.status_code == 200 + assert response.headers["A2A-Version"] == "1.0" + body = response.json() + assert body["error"]["message"] == "Method not found" + assert body["error"]["data"] == { + "method": "unsupported.method", + "supportedMethods": body["error"]["data"]["supportedMethods"], + "protocolVersion": "1.0", + } + assert "message/send" in body["error"]["data"]["supportedMethods"] + + +@pytest.mark.asyncio +async def test_unsupported_v1_minor_version_returns_v1_error_details() -> None: + settings = make_settings(a2a_bearer_token="test-token") + app = create_app(settings) + transport = httpx.ASGITransport(app=app) + + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/?A2A-Version=1.1", + headers={"Authorization": "Bearer test-token"}, + json={"jsonrpc": "2.0", "id": 124, "method": "message/send", "params": {}}, + ) + + assert response.status_code == 200 + body = response.json() + assert body["error"]["code"] == -32001 + assert body["error"]["data"][0] == { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "VERSION_NOT_SUPPORTED", + "domain": "a2a-protocol.org", + "metadata": { + "requestedVersion": "1.1", + "supportedProtocolVersions": '["0.3","1.0"]', + "defaultProtocolVersion": "0.3", + }, + } + + +@pytest.mark.asyncio +async def test_unsupported_version_returns_version_error() -> None: + settings = make_settings(a2a_bearer_token="test-token") + app = create_app(settings) + transport = httpx.ASGITransport(app=app) + + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/?A2A-Version=2.0", + headers={"Authorization": "Bearer test-token"}, + json={"jsonrpc": "2.0", "id": 123, "method": "message/send", "params": {}}, + ) + + assert response.status_code == 200 + body = response.json() + assert body["jsonrpc"] == "2.0" + assert body["id"] == 123 + assert body["error"]["code"] == -32001 + assert body["error"]["data"] == { + "type": "VERSION_NOT_SUPPORTED", + "requested_version": "2.0", + "supported_protocol_versions": ["0.3", "1.0"], + "default_protocol_version": "0.3", + } + + @pytest.mark.asyncio async def test_unsupported_method_notification_returns_204() -> None: settings = make_settings(a2a_bearer_token="test-token") diff --git a/tests/profile/test_profile_runtime.py b/tests/profile/test_profile_runtime.py index 238f08a..d64586e 100644 --- a/tests/profile/test_profile_runtime.py +++ b/tests/profile/test_profile_runtime.py @@ -26,7 +26,7 @@ def test_profile_runtime_splits_deployment_runtime_features_and_health_payload() assert profile.summary_dict(protocol_version=settings.a2a_protocol_version) == { "profile_id": "opencode-a2a-single-tenant-coding-v1", - "protocol_version": "0.3.0", + "protocol_version": "0.3", "deployment": { "id": "single_tenant_shared_workspace", "single_tenant": True, diff --git a/tests/server/test_a2a_client_manager.py b/tests/server/test_a2a_client_manager.py index 7ba8cb7..89b953c 100644 --- a/tests/server/test_a2a_client_manager.py +++ b/tests/server/test_a2a_client_manager.py @@ -14,6 +14,8 @@ def _make_settings(**overrides: object) -> SimpleNamespace: "a2a_client_use_client_preference": False, "a2a_client_bearer_token": None, "a2a_client_basic_auth": None, + "a2a_client_protocol_version": None, + "a2a_protocol_version": "0.3", "a2a_client_supported_transports": ("JSONRPC", "HTTP+JSON"), "a2a_client_cache_ttl_seconds": 60.0, "a2a_client_cache_maxsize": 2, @@ -210,3 +212,9 @@ def test_client_manager_loads_basic_auth_into_client_settings() -> None: ) assert manager.client_settings.basic_auth == "user:pass" + + +def test_client_manager_defaults_protocol_version_from_runtime_setting() -> None: + manager = client_manager_module.A2AClientManager(_make_settings(a2a_protocol_version="1.0")) + + assert manager.client_settings.protocol_version == "1.0" diff --git a/tests/server/test_agent_card.py b/tests/server/test_agent_card.py index 78c8f73..0b149a5 100644 --- a/tests/server/test_agent_card.py +++ b/tests/server/test_agent_card.py @@ -3,6 +3,7 @@ from opencode_a2a.contracts.extensions import ( SESSION_QUERY_DEFAULT_LIMIT, SESSION_QUERY_MAX_LIMIT, + build_protocol_compatibility_params, build_service_behavior_contract_params, ) from opencode_a2a.jsonrpc.application import SESSION_CONTEXT_PREFIX @@ -35,6 +36,7 @@ def test_agent_card_description_reflects_actual_transport_capabilities() -> None assert "Single-tenant deployment" in card.description assert card.capabilities.streaming is True assert card.supports_authenticated_extended_card is True + assert card.protocol_version == "0.3" assert card.default_input_modes == ["text/plain", "application/octet-stream"] assert card.default_output_modes == ["text/plain", "application/json"] assert list(card.security_schemes.keys()) == ["bearerAuth"] @@ -577,6 +579,10 @@ def test_agent_card_injects_profile_into_extensions() -> None: compatibility = ext_by_uri[COMPATIBILITY_PROFILE_EXTENSION_URI] expected_service_behaviors = build_service_behavior_contract_params() + expected_protocol_compatibility = build_protocol_compatibility_params( + supported_protocol_versions=["0.3", "1.0"], + default_protocol_version="0.3", + ) assert compatibility.params["extension_retention"][MODEL_SELECTION_EXTENSION_URI] == { "surface": "core-runtime-metadata", "availability": "always", @@ -621,10 +627,14 @@ def test_agent_card_injects_profile_into_extensions() -> None: "delivery": "single_task_snapshot", "closes_stream": True, } + assert compatibility.params["protocol_compatibility"] == expected_protocol_compatibility assert compatibility.description.endswith("deployment-conditional methods.") wire_contract = ext_by_uri[WIRE_CONTRACT_EXTENSION_URI] assert wire_contract.params["profile"]["profile_id"] == "opencode-a2a-single-tenant-coding-v1" + assert wire_contract.params["default_protocol_version"] == "0.3" + assert wire_contract.params["supported_protocol_versions"] == ["0.3", "1.0"] + assert wire_contract.params["protocol_compatibility"] == expected_protocol_compatibility assert MODEL_SELECTION_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] assert PROVIDER_DISCOVERY_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] assert WORKSPACE_CONTROL_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] diff --git a/tests/server/test_app_behaviors.py b/tests/server/test_app_behaviors.py index 1ce3852..2f2e0ca 100644 --- a/tests/server/test_app_behaviors.py +++ b/tests/server/test_app_behaviors.py @@ -272,7 +272,7 @@ async def close(self) -> None: "version": settings.a2a_version, "profile": { "profile_id": "opencode-a2a-single-tenant-coding-v1", - "protocol_version": "0.3.0", + "protocol_version": settings.a2a_protocol_version, "deployment": { "id": "single_tenant_shared_workspace", "single_tenant": True, diff --git a/tests/server/test_transport_contract.py b/tests/server/test_transport_contract.py index 15c5160..ca4c1e5 100644 --- a/tests/server/test_transport_contract.py +++ b/tests/server/test_transport_contract.py @@ -165,6 +165,83 @@ async def test_agent_card_routes_split_public_and_authenticated_extended_contrac ) +@pytest.mark.asyncio +async def test_rest_endpoints_reject_unsupported_protocol_version() -> None: + app = create_app(make_settings(a2a_bearer_token="test-token")) + transport = httpx.ASGITransport(app=app) + + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/v1/message:send", + headers={ + "Authorization": "Bearer test-token", + "A2A-Version": "2.0", + }, + json={ + "message": { + "messageId": "req-1", + "role": "ROLE_USER", + "content": [{"text": "hello"}], + } + }, + ) + + assert response.status_code == 400 + assert response.json() == { + "error": "Unsupported A2A version", + "type": "VERSION_NOT_SUPPORTED", + "requested_version": "2.0", + "supported_protocol_versions": ["0.3", "1.0"], + "default_protocol_version": "0.3", + } + + +@pytest.mark.asyncio +async def test_rest_endpoints_return_v1_status_body_for_v1_protocol_errors() -> None: + app = create_app(make_settings(a2a_bearer_token="test-token")) + transport = httpx.ASGITransport(app=app) + + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/v1/message:send?A2A-Version=1.1", + headers={"Authorization": "Bearer test-token"}, + json={ + "message": { + "messageId": "req-2", + "role": "ROLE_USER", + "content": [{"text": "hello"}], + } + }, + ) + + assert response.status_code == 400 + assert response.json() == { + "error": { + "code": 400, + "status": "INVALID_ARGUMENT", + "message": "Unsupported A2A version", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "VERSION_NOT_SUPPORTED", + "domain": "a2a-protocol.org", + "metadata": { + "requestedVersion": "1.1", + "supportedProtocolVersions": '["0.3","1.0"]', + "defaultProtocolVersion": "0.3", + }, + }, + { + "@type": "type.googleapis.com/opencode_a2a.HttpErrorContext", + "requestedVersion": "1.1", + "supportedProtocolVersions": ["0.3", "1.0"], + "defaultProtocolVersion": "0.3", + }, + ], + } + } + + @pytest.mark.asyncio async def test_global_http_gzip_applies_to_eligible_non_streaming_responses(monkeypatch) -> None: import opencode_a2a.server.application as app_module @@ -358,6 +435,36 @@ async def test_dual_stack_send_rejects_cross_transport_payload_shapes(monkeypatc assert rest_envelope_resp.status_code == 400 assert "Invalid HTTP+JSON payload" in rest_envelope_resp.text + v1_rest_resp = await client.post( + "/v1/message:send", + headers={**headers, "A2A-Version": "1.0"}, + json=rest_with_jsonrpc_shape, + ) + assert v1_rest_resp.status_code == 400 + assert v1_rest_resp.json() == { + "error": { + "code": 400, + "status": "INVALID_ARGUMENT", + "message": ( + "Invalid HTTP+JSON payload for REST endpoint. " + "Use message.content with ROLE_* role values, or call " + "POST / with method=message/send or method=message/stream." + ), + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "INVALID_HTTP_JSON_PAYLOAD", + "domain": "a2a-protocol.org", + "metadata": {"path": "/v1/message:send"}, + }, + { + "@type": "type.googleapis.com/opencode_a2a.HttpErrorContext", + "path": "/v1/message:send", + }, + ], + } + } + rpc_resp = await client.post("/", headers=headers, json=rpc_with_rest_shape) assert rpc_resp.status_code == 200 payload = rpc_resp.json()