diff --git a/src/opencode_a2a/execution/executor.py b/src/opencode_a2a/execution/executor.py index 5972550..59e9662 100644 --- a/src/opencode_a2a/execution/executor.py +++ b/src/opencode_a2a/execution/executor.py @@ -34,6 +34,7 @@ UpstreamConcurrencyLimitError, UpstreamContractError, ) +from ..output_modes import accepts_output_mode, normalize_accepted_output_modes from ..parts.mapping import ( UnsupportedA2AInputError, extract_text_from_a2a_parts, @@ -90,6 +91,8 @@ ) logger = logging.getLogger(__name__) +_TEXT_PLAIN_MEDIA_TYPE = "text/plain" +_APPLICATION_JSON_MEDIA_TYPE = "application/json" __all__ = [ "_build_output_metadata", @@ -148,6 +151,7 @@ class _PreparedExecution: directory: str | None workspace_id: str | None session_binding_context_id: str + allow_structured_output: bool def _build_session_binding_context_id( @@ -368,6 +372,7 @@ async def _bind_session(self) -> None: directory=self._prepared.directory, workspace_id=self._prepared.workspace_id, terminal_signal=self._stream_terminal_signal, + allow_structured_output=self._prepared.allow_structured_output, ) ) @@ -775,6 +780,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non identity = (call_context.state.get("identity") if call_context else None) or "anonymous" streaming_request = self._should_stream(context) + accepted_output_modes = normalize_accepted_output_modes(context.configuration) message_parts = ( getattr(context.message, "parts", None) if context.message is not None else None ) @@ -853,6 +859,22 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non ) return + if not accepts_output_mode(accepted_output_modes, _TEXT_PLAIN_MEDIA_TYPE): + await self._emit_error( + event_queue, + task_id=task_id, + context_id=context_id, + message="acceptedOutputModes must include text/plain for OpenCode chat responses.", + state=TaskState.failed, + streaming_request=streaming_request, + ) + return + + allow_structured_output = accepts_output_mode( + accepted_output_modes, + _APPLICATION_JSON_MEDIA_TYPE, + ) + logger.debug( ( "Received message identity=%s task_id=%s context_id=%s " @@ -877,6 +899,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non directory=directory, workspace_id=workspace_id, session_binding_context_id=session_binding_context_id, + allow_structured_output=allow_structured_output, ) coordinator = _ExecutionCoordinator( self, @@ -1097,6 +1120,7 @@ async def _consume_opencode_stream( terminal_signal: asyncio.Future[_StreamTerminalSignal], directory: str | None = None, workspace_id: str | None = None, + allow_structured_output: bool = True, ) -> None: await self._stream_runtime.consume( session_id=session_id, @@ -1110,6 +1134,7 @@ async def _consume_opencode_stream( terminal_signal=terminal_signal, directory=directory, workspace_id=workspace_id, + allow_structured_output=allow_structured_output, ) diff --git a/src/opencode_a2a/execution/stream_runtime.py b/src/opencode_a2a/execution/stream_runtime.py index 63ca1a1..721b2cf 100644 --- a/src/opencode_a2a/execution/stream_runtime.py +++ b/src/opencode_a2a/execution/stream_runtime.py @@ -77,6 +77,7 @@ async def consume( terminal_signal: asyncio.Future[_StreamTerminalSignal], directory: str | None = None, workspace_id: str | None = None, + allow_structured_output: bool = True, ) -> None: part_states: dict[str, _StreamPartState] = {} pending_deltas: defaultdict[str, list[_PendingDelta]] = defaultdict(list) @@ -85,6 +86,8 @@ async def consume( async def _emit_chunks(chunks: list[_NormalizedStreamChunk]) -> None: for chunk in chunks: + if not allow_structured_output and getattr(chunk.part.root, "kind", None) == "data": + continue resolved_message_id = stream_state.resolve_message_id(chunk.message_id) chunk_text = getattr(chunk.part.root, "text", "") if stream_state.should_drop_initial_user_echo( diff --git a/src/opencode_a2a/output_modes.py b/src/opencode_a2a/output_modes.py new file mode 100644 index 0000000..323cd01 --- /dev/null +++ b/src/opencode_a2a/output_modes.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from collections.abc import Collection +from typing import Any + + +def normalize_accepted_output_modes(source: Any) -> tuple[str, ...] | None: + accepted = getattr(source, "accepted_output_modes", None) or getattr( + source, "acceptedOutputModes", None + ) + if not isinstance(accepted, list): + return None + + normalized: list[str] = [] + for value in accepted: + if not isinstance(value, str): + continue + mode = value.strip().lower() + if not mode or mode in normalized: + continue + normalized.append(mode) + return tuple(normalized) or None + + +def accepts_output_mode( + accepted_output_modes: Collection[str] | None, + media_type: str, +) -> bool: + return accepted_output_modes is None or media_type in accepted_output_modes diff --git a/src/opencode_a2a/server/agent_card.py b/src/opencode_a2a/server/agent_card.py index f0dcc16..26cfd65 100644 --- a/src/opencode_a2a/server/agent_card.py +++ b/src/opencode_a2a/server/agent_card.py @@ -42,6 +42,10 @@ from ..jsonrpc.application import SESSION_CONTEXT_PREFIX from ..profile.runtime import RuntimeProfile, build_runtime_profile +_CHAT_INPUT_MODES = ["text/plain", "application/octet-stream"] +_CHAT_OUTPUT_MODES = ["text/plain", "application/json"] +_JSON_RPC_MODES = ["application/json"] + def _select_public_extension_params( params: dict[str, Any], @@ -381,6 +385,8 @@ def _build_agent_skills( "Handle core A2A chat turns with shared session binding and optional " "request-scoped model selection." ), + input_modes=list(_CHAT_INPUT_MODES), + output_modes=list(_CHAT_OUTPUT_MODES), tags=["assistant", "coding", "opencode", "core-a2a", "portable"], ), AgentSkill( @@ -390,6 +396,8 @@ def _build_agent_skills( "Inspect OpenCode session status, history, and low-risk lifecycle actions " "through provider-private JSON-RPC extensions." ), + input_modes=list(_JSON_RPC_MODES), + output_modes=list(_JSON_RPC_MODES), tags=["opencode", "sessions", "history", "provider-private"], ), AgentSkill( @@ -399,6 +407,8 @@ def _build_agent_skills( "Discover available upstream providers and models through provider-private " "JSON-RPC extensions." ), + input_modes=list(_JSON_RPC_MODES), + output_modes=list(_JSON_RPC_MODES), tags=["opencode", "providers", "models", "provider-private"], ), AgentSkill( @@ -408,6 +418,8 @@ def _build_agent_skills( "Manage OpenCode projects, workspaces, and worktrees through " "provider-private JSON-RPC extensions." ), + input_modes=list(_JSON_RPC_MODES), + output_modes=list(_JSON_RPC_MODES), tags=["opencode", "project", "workspace", "worktree", "provider-private"], ), AgentSkill( @@ -417,6 +429,8 @@ def _build_agent_skills( "Recover pending permission and question interrupts through " "provider-private JSON-RPC extensions." ), + input_modes=list(_JSON_RPC_MODES), + output_modes=list(_JSON_RPC_MODES), tags=["interrupt", "permission", "question", "provider-private"], ), AgentSkill( @@ -426,6 +440,8 @@ def _build_agent_skills( "Reply to streaming permission and question interrupts through shared " "JSON-RPC callbacks." ), + input_modes=list(_JSON_RPC_MODES), + output_modes=list(_JSON_RPC_MODES), tags=["interrupt", "permission", "question", "shared"], ), ] @@ -439,6 +455,8 @@ def _build_agent_skills( "TextPart and FilePart inputs to OpenCode sessions with shared session " "binding and optional request-scoped model selection." ), + input_modes=list(_CHAT_INPUT_MODES), + output_modes=list(_CHAT_OUTPUT_MODES), tags=["assistant", "coding", "opencode", "core-a2a", "portable"], examples=_build_chat_examples(settings.a2a_project), ), @@ -449,6 +467,8 @@ def _build_agent_skills( "provider-private OpenCode session/history and session-control surface " "exposed through JSON-RPC extensions." ), + input_modes=list(_JSON_RPC_MODES), + output_modes=list(_JSON_RPC_MODES), tags=["opencode", "sessions", "history", "provider-private"], examples=_build_session_query_skill_examples( capability_snapshot=capability_snapshot, @@ -461,6 +481,8 @@ def _build_agent_skills( "provider-private OpenCode provider/model discovery surface exposed " "through JSON-RPC extensions." ), + input_modes=list(_JSON_RPC_MODES), + output_modes=list(_JSON_RPC_MODES), tags=["opencode", "providers", "models", "provider-private"], examples=[ "List available providers (method opencode.providers.list).", @@ -474,6 +496,8 @@ def _build_agent_skills( "provider-private OpenCode project/workspace/worktree control surface " "exposed through JSON-RPC extensions." ), + input_modes=list(_JSON_RPC_MODES), + output_modes=list(_JSON_RPC_MODES), tags=["opencode", "project", "workspace", "worktree", "provider-private"], examples=_build_workspace_control_skill_examples(), ), @@ -484,6 +508,8 @@ def _build_agent_skills( "provider-private OpenCode interrupt recovery surface exposed through " "JSON-RPC extensions." ), + input_modes=list(_JSON_RPC_MODES), + output_modes=list(_JSON_RPC_MODES), tags=["interrupt", "permission", "question", "provider-private"], examples=_build_interrupt_recovery_skill_examples(), ), @@ -495,6 +521,8 @@ def _build_agent_skills( "JSON-RPC methods a2a.interrupt.permission.reply, " "a2a.interrupt.question.reply, and a2a.interrupt.question.reject." ), + input_modes=list(_JSON_RPC_MODES), + output_modes=list(_JSON_RPC_MODES), tags=["interrupt", "permission", "question", "shared"], examples=[ "Reply once/always/reject to a permission request by request_id.", @@ -535,8 +563,8 @@ def _build_agent_card( version=settings.a2a_version, protocol_version=settings.a2a_protocol_version, preferred_transport=TransportProtocol.http_json, - default_input_modes=["text/plain", "application/octet-stream"], - default_output_modes=["text/plain"], + default_input_modes=list(_CHAT_INPUT_MODES), + default_output_modes=list(_CHAT_OUTPUT_MODES), capabilities=AgentCapabilities( streaming=True, extensions=_build_agent_extensions( diff --git a/src/opencode_a2a/server/application.py b/src/opencode_a2a/server/application.py index 3219795..bd8f1bf 100644 --- a/src/opencode_a2a/server/application.py +++ b/src/opencode_a2a/server/application.py @@ -2,7 +2,6 @@ import asyncio import hashlib -import inspect import json import logging import secrets @@ -36,7 +35,9 @@ TaskStatus, TaskStatusUpdateEvent, TextPart, + UnsupportedOperationError, ) +from a2a.utils import are_modalities_compatible from a2a.utils.constants import ( AGENT_CARD_WELL_KNOWN_PATH, EXTENDED_AGENT_CARD_PATH, @@ -71,12 +72,15 @@ build_capability_snapshot, ) from ..execution.executor import OpencodeAgentExecutor, _emit_metric +from ..invocation import call_with_supported_kwargs from ..jsonrpc.application import ( OpencodeSessionQueryJSONRPCApplication, ) from ..opencode_upstream_client import OpencodeUpstreamClient +from ..output_modes import normalize_accepted_output_modes from ..profile.runtime import build_runtime_profile from .agent_card import ( + _CHAT_OUTPUT_MODES, _build_agent_card_description, _build_chat_examples, _build_session_query_skill_examples, @@ -256,6 +260,44 @@ def _resolve_context_id_from_params(params, task_id: str) -> str: # noqa: ANN00 getattr(message, "contextId", None) or getattr(message, "context_id", None) or task_id ) + @staticmethod + def _extract_accepted_output_modes(params) -> list[str] | None: # noqa: ANN001 + configuration = getattr(params, "configuration", None) + normalized = normalize_accepted_output_modes(configuration) + return list(normalized) if normalized is not None else None + + @classmethod + def _validate_chat_output_modes(cls, params) -> None: # noqa: ANN001 + accepted_output_modes = cls._extract_accepted_output_modes(params) + if not accepted_output_modes: + return + + if not are_modalities_compatible(list(_CHAT_OUTPUT_MODES), accepted_output_modes): + raise ServerError( + error=UnsupportedOperationError( + message=( + "Requested acceptedOutputModes are not compatible " + "with OpenCode chat responses." + ), + data={ + "accepted_output_modes": accepted_output_modes, + "supported_output_modes": list(_CHAT_OUTPUT_MODES), + }, + ) + ) + + if "text/plain" not in accepted_output_modes: + raise ServerError( + error=UnsupportedOperationError( + message="OpenCode chat responses require text/plain in acceptedOutputModes.", + data={ + "accepted_output_modes": accepted_output_modes, + "required_output_modes": ["text/plain"], + "supported_output_modes": list(_CHAT_OUTPUT_MODES), + }, + ) + ) + async def on_get_task( self, params: TaskQueryParams, @@ -320,6 +362,7 @@ async def on_resubscribe_to_task( raise self._task_store_server_error(exc) from exc async def on_message_send_stream(self, params, context=None): + self._validate_chat_output_modes(params) ( _task_manager, task_id, @@ -369,6 +412,7 @@ async def on_message_send_stream(self, params, context=None): self._track_background_task(cleanup_task) async def on_message_send(self, params, context=None): + self._validate_chat_output_modes(params) ( _task_manager, task_id, @@ -690,22 +734,6 @@ def __init__( self.pending_eviction = pending_eviction -def _call_with_optional_kwargs(factory, /, *args, **kwargs): # noqa: ANN001 - try: - return factory(*args, **kwargs) - except TypeError as exc: - signature = inspect.signature(factory) - supported_kwargs = { - name: value for name, value in kwargs.items() if name in signature.parameters - } - if supported_kwargs == kwargs: - raise - try: - return factory(*args, **supported_kwargs) - except TypeError: - raise exc from None - - def create_app(settings: Settings) -> FastAPI: database_engine = ( build_database_engine(settings) if settings.a2a_task_store_backend == "database" else None @@ -715,13 +743,13 @@ def create_app(settings: Settings) -> FastAPI: settings, engine=database_engine, ) - upstream_client = _call_with_optional_kwargs( + upstream_client = call_with_supported_kwargs( OpencodeUpstreamClient, settings, interrupt_request_repository=interrupt_request_repository, ) client_manager = A2AClientManager(settings) - executor = _call_with_optional_kwargs( + executor = call_with_supported_kwargs( OpencodeAgentExecutor, upstream_client, streaming_enabled=True, @@ -730,7 +758,7 @@ def create_app(settings: Settings) -> FastAPI: a2a_client_manager=client_manager, session_state_repository=session_state_repository, ) - task_store = _call_with_optional_kwargs( + task_store = call_with_supported_kwargs( build_task_store, settings, engine=database_engine, diff --git a/tests/execution/test_agent_errors.py b/tests/execution/test_agent_errors.py index 5c1889d..372737a 100644 --- a/tests/execution/test_agent_errors.py +++ b/tests/execution/test_agent_errors.py @@ -12,7 +12,11 @@ UpstreamConcurrencyLimitError, UpstreamContractError, ) -from tests.support.helpers import configure_mock_client_runtime, make_request_context_mock +from tests.support.helpers import ( + configure_mock_client_runtime, + make_request_context, + make_request_context_mock, +) @pytest.mark.asyncio @@ -89,6 +93,29 @@ async def test_execute_invalid_metadata_type(): assert "Invalid metadata" in str(event.status.message) +@pytest.mark.asyncio +async def test_execute_rejects_output_modes_without_text_plain() -> None: + client = MagicMock() + configure_mock_client_runtime(client) + executor = OpencodeAgentExecutor(client, streaming_enabled=False) + + context = make_request_context( + task_id="task-1", + context_id="ctx-1", + text="hello", + accepted_output_modes=["application/json"], + ) + + event_queue = AsyncMock(spec=EventQueue) + await executor.execute(context, event_queue) + + event = event_queue.enqueue_event.call_args[0][0] + assert isinstance(event, Task) + assert event.status.state.name == "failed" + assert "acceptedOutputModes must include text/plain" in event.status.message.parts[0].root.text + assert not client.create_session.called + + @pytest.mark.asyncio @pytest.mark.parametrize( ("status", "expected_type", "expected_state"), diff --git a/tests/execution/test_streaming_output_contract_blocks.py b/tests/execution/test_streaming_output_contract_blocks.py index 53daabe..d2d8cac 100644 --- a/tests/execution/test_streaming_output_contract_blocks.py +++ b/tests/execution/test_streaming_output_contract_blocks.py @@ -122,6 +122,45 @@ async def test_streaming_emits_structured_tool_part_updates() -> None: assert all(getattr(ev.artifact.parts[0].root, "kind", None) == "data" for ev in tool_updates) +@pytest.mark.asyncio +async def test_streaming_suppresses_structured_tool_updates_when_json_output_not_accepted() -> None: + client = DummyStreamingClient( + stream_events_payload=[ + _event( + session_id="ses-1", + role="assistant", + part_type="tool", + delta="", + part_id="prt-tool-1", + part_overrides={ + "callID": "call-1", + "tool": "bash", + "state": {"status": "running"}, + }, + ), + ], + response_text="done", + ) + executor = OpencodeAgentExecutor(client, streaming_enabled=True) + executor._should_stream = lambda context: True # type: ignore[method-assign] + queue = DummyEventQueue() + + await executor.execute( + make_request_context( + task_id="task-tool-text-only", + context_id="ctx-tool-text-only", + text="go", + accepted_output_modes=["text/plain"], + ), + queue, + ) + + updates = _artifact_updates(queue) + tool_updates = [ev for ev in updates if _artifact_stream_meta(ev)["block_type"] == "tool_call"] + assert tool_updates == [] + assert any(_artifact_stream_meta(ev)["block_type"] == "text" for ev in updates) + + @pytest.mark.asyncio async def test_streaming_flushes_partial_marker_on_eof_as_current_block_type() -> None: client = DummyStreamingClient( diff --git a/tests/server/test_agent_card.py b/tests/server/test_agent_card.py index a47d5e6..cb0e773 100644 --- a/tests/server/test_agent_card.py +++ b/tests/server/test_agent_card.py @@ -25,6 +25,7 @@ def test_agent_card_description_reflects_actual_transport_capabilities() -> None: card = build_agent_card(make_settings(a2a_bearer_token="test-token")) + skills_by_id = {skill.id: skill for skill in card.skills} assert "HTTP+JSON and JSON-RPC transports" in card.description assert "authenticated extended Agent Card discovery" in card.description @@ -35,8 +36,15 @@ def test_agent_card_description_reflects_actual_transport_capabilities() -> None assert card.capabilities.streaming is True assert card.supports_authenticated_extended_card is True assert card.default_input_modes == ["text/plain", "application/octet-stream"] + assert card.default_output_modes == ["text/plain", "application/json"] assert list(card.security_schemes.keys()) == ["bearerAuth"] assert card.security == [{"bearerAuth": []}] + assert skills_by_id["opencode.chat"].input_modes == ["text/plain", "application/octet-stream"] + assert skills_by_id["opencode.chat"].output_modes == ["text/plain", "application/json"] + assert skills_by_id["opencode.sessions.query"].input_modes == ["application/json"] + assert skills_by_id["opencode.sessions.query"].output_modes == ["application/json"] + assert skills_by_id["opencode.interrupt.callback"].input_modes == ["application/json"] + assert skills_by_id["opencode.interrupt.callback"].output_modes == ["application/json"] def test_public_agent_card_is_slimmed_but_keeps_core_shared_contract_hints() -> None: diff --git a/tests/server/test_app_behaviors.py b/tests/server/test_app_behaviors.py index 105ec22..6726707 100644 --- a/tests/server/test_app_behaviors.py +++ b/tests/server/test_app_behaviors.py @@ -15,6 +15,7 @@ TaskQueryParams, TaskState, TaskStatus, + UnsupportedOperationError, ) from a2a.utils.errors import ServerError from fastapi import Request @@ -711,6 +712,31 @@ def _apply_history_length(task: Task, history_length: int) -> Task: await asyncio.gather(*handler.background_tasks, return_exceptions=True) +@pytest.mark.asyncio +async def test_on_message_send_rejects_output_modes_without_text_plain() -> None: + class _Handler(OpencodeRequestHandler): + def __init__(self) -> None: + super().__init__(agent_executor=MagicMock(), task_store=MagicMock()) + self.setup_called = False + + async def _setup_message_execution(self, params, context=None): # noqa: ANN001 + del params, context + self.setup_called = True + raise AssertionError("_setup_message_execution should not be called") + + handler = _Handler() + params = types.SimpleNamespace( + configuration=types.SimpleNamespace(accepted_output_modes=["application/json"]) + ) + + with pytest.raises(ServerError) as exc_info: + await handler.on_message_send(params) + + assert isinstance(exc_info.value.error, UnsupportedOperationError) + assert "require text/plain" in exc_info.value.error.message + assert handler.setup_called is False + + def test_normalize_log_level_configure_logging_and_main(monkeypatch) -> None: assert _normalize_log_level("debug") == "DEBUG" diff --git a/tests/support/helpers.py b/tests/support/helpers.py index a7bed1d..e2a997c 100644 --- a/tests/support/helpers.py +++ b/tests/support/helpers.py @@ -7,7 +7,7 @@ from a2a.server.agent_execution import RequestContext from a2a.server.context import ServerCallContext -from a2a.types import Message, MessageSendParams, Part, Role, TextPart +from a2a.types import Message, MessageSendConfiguration, MessageSendParams, Part, Role, TextPart from opencode_a2a.config import Settings from opencode_a2a.opencode_upstream_client import OpencodeMessage, OpencodeMessagePage @@ -87,14 +87,26 @@ def make_request_context( text: str, metadata: dict[str, Any] | None = None, message_id: str = "req-1", + accepted_output_modes: list[str] | None = None, + call_context: Any = None, ) -> RequestContext: message = Message( message_id=message_id, role=Role.user, parts=[TextPart(text=text)], ) - params = MessageSendParams(message=message, metadata=metadata) - return RequestContext(request=params, task_id=task_id, context_id=context_id) + configuration = ( + MessageSendConfiguration(acceptedOutputModes=accepted_output_modes) + if accepted_output_modes is not None + else None + ) + params = MessageSendParams(message=message, metadata=metadata, configuration=configuration) + return RequestContext( + request=params, + task_id=task_id, + context_id=context_id, + call_context=call_context, + ) def make_request_context_with_parts( @@ -105,13 +117,19 @@ def make_request_context_with_parts( metadata: dict[str, Any] | None = None, message_id: str = "req-1", call_context: Any = None, + accepted_output_modes: list[str] | None = None, ) -> RequestContext: message = Message( message_id=message_id, role=Role.user, parts=parts, ) - params = MessageSendParams(message=message, metadata=metadata) + configuration = ( + MessageSendConfiguration(acceptedOutputModes=accepted_output_modes) + if accepted_output_modes is not None + else None + ) + params = MessageSendParams(message=message, metadata=metadata, configuration=configuration) return RequestContext( request=params, task_id=task_id,