diff --git a/src/opencode_a2a/execution/coordinator.py b/src/opencode_a2a/execution/coordinator.py new file mode 100644 index 0000000..ecfd5cc --- /dev/null +++ b/src/opencode_a2a/execution/coordinator.py @@ -0,0 +1,499 @@ +from __future__ import annotations + +import asyncio +import logging +import uuid +from collections.abc import Mapping +from contextlib import suppress +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import httpx +from a2a.server.agent_execution import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import ( + Artifact, + Message, + Part, + Role, + Task, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, + TextPart, +) + +from ..invocation import call_with_supported_kwargs +from ..opencode_upstream_client import UpstreamConcurrencyLimitError, UpstreamContractError +from .event_helpers import _enqueue_artifact_update +from .stream_events import _extract_token_usage, _extract_upstream_error_from_response +from .stream_state import ( + _build_output_metadata, + _build_stream_artifact_metadata, + _merge_token_usage, + _StreamOutputState, +) +from .tool_orchestration import maybe_handle_tools +from .upstream_error_translator import ( + _await_stream_terminal_signal, + _format_upstream_error, + _StreamTerminalSignal, +) + +if TYPE_CHECKING: + from .executor import OpencodeAgentExecutor + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class PreparedExecution: + identity: str + streaming_request: bool + request_parts: list[Any] + user_text: str + session_title: str + use_structured_parts: bool + bound_session_id: str | None + model_override: dict[str, str] | None + directory: str | None + workspace_id: str | None + session_binding_context_id: str + allow_structured_output: bool + + +def build_session_binding_context_id( + *, + context_id: str, + directory: str | None, + workspace_id: str | None, + use_directory_binding: bool, +) -> str: + if isinstance(workspace_id, str) and workspace_id.strip(): + return f"{context_id}::workspace:{workspace_id.strip()}" + if use_directory_binding and isinstance(directory, str) and directory.strip(): + return f"{context_id}::directory:{directory.strip()}" + return context_id + + +class ExecutionCoordinator: + def __init__( + self, + executor: OpencodeAgentExecutor, + *, + context: RequestContext, + event_queue: EventQueue, + task_id: str, + context_id: str, + prepared: PreparedExecution, + ) -> None: + self._executor = executor + self._context = context + self._event_queue = event_queue + self._task_id = task_id + self._context_id = context_id + self._prepared = prepared + self._stream_artifact_id = f"{task_id}:stream" + self._stream_state = _StreamOutputState( + user_text=prepared.user_text, + stable_message_id=f"{task_id}:{context_id}:assistant", + event_id_namespace=f"{task_id}:{context_id}:{self._stream_artifact_id}", + ) + self._stream_terminal_signal: asyncio.Future[_StreamTerminalSignal] | None = None + self._stop_event = asyncio.Event() + self._stream_task: asyncio.Task[None] | None = None + self._pending_preferred_claim = False + self._session_lock: asyncio.Lock | None = None + self._session_id = "" + self._execution_key = (task_id, context_id) + + async def run(self) -> None: + current_task = asyncio.current_task() + if current_task is not None: + await self._register_running_request(current_task) + + try: + await self._bind_session() + await self._enqueue_working_status() + + turn_request_parts = list(self._prepared.request_parts) + user_text = self._prepared.user_text + + while True: + send_kwargs: dict[str, Any] = { + "directory": self._prepared.directory, + "workspace_id": self._prepared.workspace_id, + "model_override": self._prepared.model_override, + } + if self._prepared.streaming_request: + send_kwargs["timeout_override"] = self._executor._client.stream_timeout + + if not self._prepared.use_structured_parts and not turn_request_parts: + response = await call_with_supported_kwargs( + self._executor._client.send_message, + self._session_id, + user_text, + **send_kwargs, + ) + else: + response = await call_with_supported_kwargs( + self._executor._client.send_message, + self._session_id, + user_text or None, + parts=turn_request_parts, + **send_kwargs, + ) + + if self._pending_preferred_claim: + await self._executor._session_manager.finalize_preferred_session_binding( + identity=self._prepared.identity, + context_id=self._prepared.session_binding_context_id, + session_id=self._session_id, + ) + self._pending_preferred_claim = False + + tool_results = await maybe_handle_tools( + response.raw, + a2a_client_manager=self._executor._a2a_client_manager, + ) + if tool_results: + user_text = "" + turn_request_parts = [ + { + "type": "tool", + "tool": res["tool"], + "call_id": res["call_id"], + "output": res.get("output"), + "error": res.get("error"), + } + for res in tool_results + ] + continue + + await self._handle_response(response) + break + + except httpx.HTTPStatusError as exc: + logger.exception("OpenCode request failed with HTTP error") + error_type, state, message = _format_upstream_error( + exc, + request="send_message", + ) + await self._executor._emit_error( + self._event_queue, + task_id=self._task_id, + context_id=self._context_id, + message=message, + state=state, + error_type=error_type, + upstream_status=exc.response.status_code, + streaming_request=self._prepared.streaming_request, + ) + except httpx.TimeoutException as exc: + logger.exception("OpenCode request timed out") + await self._executor._emit_error( + self._event_queue, + task_id=self._task_id, + context_id=self._context_id, + message=f"OpenCode request timed out: {exc}", + state=TaskState.failed, + error_type="UPSTREAM_TIMEOUT", + streaming_request=self._prepared.streaming_request, + ) + except UpstreamContractError as exc: + logger.warning("OpenCode request failed with payload mismatch: %s", exc) + await self._executor._emit_error( + self._event_queue, + task_id=self._task_id, + context_id=self._context_id, + message=f"OpenCode payload mismatch: {exc}", + state=TaskState.failed, + error_type="UPSTREAM_PAYLOAD_ERROR", + streaming_request=self._prepared.streaming_request, + ) + except UpstreamConcurrencyLimitError as exc: + logger.warning("OpenCode request rejected by concurrency budget: %s", exc) + await self._executor._emit_error( + self._event_queue, + task_id=self._task_id, + context_id=self._context_id, + message=str(exc), + state=TaskState.failed, + error_type="UPSTREAM_BACKPRESSURE", + streaming_request=self._prepared.streaming_request, + ) + except Exception as exc: + logger.exception("OpenCode request failed") + await self._executor._emit_error( + self._event_queue, + task_id=self._task_id, + context_id=self._context_id, + message=f"OpenCode error: {exc}", + state=TaskState.failed, + streaming_request=self._prepared.streaming_request, + ) + finally: + await self._cleanup() + + async def _register_running_request(self, current_task: asyncio.Task[Any]) -> None: + async with self._executor._lock: + self._executor._running_requests[self._execution_key] = current_task + self._executor._running_stop_events[self._execution_key] = self._stop_event + self._executor._running_identities[self._execution_key] = self._prepared.identity + + async def _bind_session(self) -> None: + ( + self._session_id, + self._pending_preferred_claim, + ) = await self._executor._session_manager.get_or_create_session( + self._prepared.identity, + self._prepared.session_binding_context_id, + self._prepared.session_title or self._prepared.user_text, + preferred_session_id=self._prepared.bound_session_id, + directory=self._prepared.directory, + workspace_id=self._prepared.workspace_id, + ) + self._session_lock = await self._executor._session_manager.get_session_lock( + self._session_id + ) + await self._session_lock.acquire() + async with self._executor._lock: + self._executor._running_session_ids[self._execution_key] = self._session_id + self._executor._running_directories[self._execution_key] = self._prepared.directory + self._executor._running_workspace_ids[self._execution_key] = self._prepared.workspace_id + self._executor._running_binding_context_ids[self._execution_key] = ( + self._prepared.session_binding_context_id + ) + + if self._prepared.streaming_request: + self._stream_terminal_signal = asyncio.get_running_loop().create_future() + self._stream_task = asyncio.create_task( + self._executor._consume_opencode_stream( + session_id=self._session_id, + identity=self._prepared.identity, + task_id=self._task_id, + context_id=self._context_id, + artifact_id=self._stream_artifact_id, + stream_state=self._stream_state, + event_queue=self._event_queue, + stop_event=self._stop_event, + directory=self._prepared.directory, + workspace_id=self._prepared.workspace_id, + terminal_signal=self._stream_terminal_signal, + allow_structured_output=self._prepared.allow_structured_output, + ) + ) + + async def _enqueue_working_status(self) -> None: + await self._event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=self._task_id, + context_id=self._context_id, + status=TaskStatus(state=TaskState.working), + final=False, + ) + ) + + async def _handle_response(self, response: Any) -> None: + response_text = response.text or "" + resolved_message_id = self._stream_state.resolve_message_id(response.message_id) + response_error = _extract_upstream_error_from_response(response.raw) + resolved_token_usage = _merge_token_usage( + _extract_token_usage(response.raw), + self._stream_state.token_usage, + ) + + logger.debug( + "OpenCode response task_id=%s session_id=%s message_id=%s text=%s", + self._task_id, + response.session_id, + resolved_message_id, + response_text, + ) + + if response_error is not None: + await self._executor._emit_error( + self._event_queue, + task_id=self._task_id, + context_id=self._context_id, + message=response_error.message, + state=response_error.state, + error_type=response_error.error_type, + upstream_status=response_error.upstream_status, + streaming_request=self._prepared.streaming_request, + ) + return + + if self._prepared.streaming_request: + await self._handle_streaming_response( + response=response, + response_text=response_text, + resolved_message_id=resolved_message_id, + resolved_token_usage=resolved_token_usage, + ) + return + + await self._handle_non_streaming_response( + response=response, + response_text=response_text, + resolved_message_id=resolved_message_id, + resolved_token_usage=resolved_token_usage, + ) + + async def _handle_streaming_response( + self, + *, + response: Any, + response_text: str, + resolved_message_id: str, + resolved_token_usage: Mapping[str, Any] | None, + ) -> None: + from .stream_events import BlockType + + del response + if self._stream_terminal_signal is None: + raise RuntimeError("Streaming terminal signal was not initialized") + + terminal_signal = await _await_stream_terminal_signal( + stream_task=self._stream_task, + terminal_signal=self._stream_terminal_signal, + session_id=self._session_id, + ) + if terminal_signal.state != TaskState.completed: + await self._executor._emit_error( + self._event_queue, + task_id=self._task_id, + context_id=self._context_id, + message=terminal_signal.message or "OpenCode execution failed.", + state=terminal_signal.state, + error_type=terminal_signal.error_type, + upstream_status=terminal_signal.upstream_status, + streaming_request=True, + ) + return + + if self._stream_state.upstream_error is not None: + await self._executor._emit_error( + self._event_queue, + task_id=self._task_id, + context_id=self._context_id, + message=self._stream_state.upstream_error.message, + state=self._stream_state.upstream_error.state, + error_type=self._stream_state.upstream_error.error_type, + upstream_status=self._stream_state.upstream_error.upstream_status, + streaming_request=True, + ) + return + + if self._stream_state.should_emit_final_snapshot(response_text): + sequence = self._stream_state.next_sequence() + await _enqueue_artifact_update( + event_queue=self._event_queue, + task_id=self._task_id, + context_id=self._context_id, + artifact_id=self._stream_artifact_id, + part=Part(root=TextPart(text=response_text)), + append=self._stream_state.emitted_stream_chunk, + last_chunk=True, + artifact_metadata=_build_stream_artifact_metadata( + block_type=BlockType.TEXT, + shared_source="final_snapshot", + message_id=resolved_message_id, + event_id=self._stream_state.build_event_id(sequence), + sequence=sequence, + ), + ) + + await self._event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=self._task_id, + context_id=self._context_id, + status=TaskStatus(state=TaskState.completed), + final=True, + metadata=_build_output_metadata( + session_id=self._session_id, + usage=resolved_token_usage, + stream={ + "message_id": resolved_message_id, + "event_id": f"{self._stream_state.event_id_namespace}:status", + "source": "status", + }, + ), + ) + ) + + async def _handle_non_streaming_response( + self, + *, + response: Any, + response_text: str, + resolved_message_id: str, + resolved_token_usage: Mapping[str, Any] | None, + ) -> None: + response_text = response_text or "(No text content returned by OpenCode.)" + assistant_message = build_assistant_message( + task_id=self._task_id, + context_id=self._context_id, + text=response_text, + message_id=resolved_message_id, + ) + artifact = Artifact( + artifact_id=str(uuid.uuid4()), + name="response", + parts=[Part(root=TextPart(text=response_text))], + ) + from .request_context import _build_history + + history = _build_history(self._context) + task = Task( + id=self._task_id, + context_id=self._context_id, + status=TaskStatus(state=TaskState.completed), + history=history, + artifacts=[artifact], + metadata=_build_output_metadata( + session_id=response.session_id, + usage=resolved_token_usage, + ), + ) + task.status.message = assistant_message + await self._event_queue.enqueue_event(task) + + async def _cleanup(self) -> None: + if self._pending_preferred_claim and self._session_id: + with suppress(Exception): + await self._executor._session_manager.release_preferred_session_claim( + identity=self._prepared.identity, + session_id=self._session_id, + ) + self._stop_event.set() + if self._stream_task: + self._stream_task.cancel() + with suppress(asyncio.CancelledError): + await self._stream_task + if self._session_lock and self._session_lock.locked(): + self._session_lock.release() + async with self._executor._lock: + self._executor._running_requests.pop(self._execution_key, None) + self._executor._running_stop_events.pop(self._execution_key, None) + self._executor._running_identities.pop(self._execution_key, None) + self._executor._running_session_ids.pop(self._execution_key, None) + self._executor._running_directories.pop(self._execution_key, None) + self._executor._running_workspace_ids.pop(self._execution_key, None) + self._executor._running_binding_context_ids.pop(self._execution_key, None) + + +def build_assistant_message( + task_id: str, + context_id: str, + text: str, + *, + message_id: str | None = None, +) -> Message: + return Message( + message_id=message_id or str(uuid.uuid4()), + role=Role.agent, + parts=[Part(root=TextPart(text=text))], + task_id=task_id, + context_id=context_id, + ) diff --git a/src/opencode_a2a/execution/executor.py b/src/opencode_a2a/execution/executor.py index 59e9662..00ace18 100644 --- a/src/opencode_a2a/execution/executor.py +++ b/src/opencode_a2a/execution/executor.py @@ -6,18 +6,16 @@ import uuid from collections.abc import Mapping from contextlib import suppress -from dataclasses import dataclass from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from ..server.application import A2AClientManager + from ..server.client_manager import A2AClientManager from ..server.state_store import SessionStateRepository import httpx from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events.event_queue import EventQueue from a2a.types import ( - Artifact, Message, Part, Role, @@ -29,11 +27,7 @@ ) from ..invocation import call_with_supported_kwargs -from ..opencode_upstream_client import ( - OpencodeUpstreamClient, - UpstreamConcurrencyLimitError, - UpstreamContractError, -) +from ..opencode_upstream_client import OpencodeUpstreamClient from ..output_modes import accepts_output_mode, normalize_accepted_output_modes from ..parts.mapping import ( UnsupportedA2AInputError, @@ -42,9 +36,10 @@ summarize_a2a_parts, ) from ..sandbox_policy import SandboxPolicy +from .coordinator import ExecutionCoordinator, PreparedExecution, build_session_binding_context_id from .event_helpers import _enqueue_artifact_update +from .metrics import emit_metric from .request_context import ( - _build_history, _extract_opencode_directory, _extract_opencode_workspace_id, _extract_shared_model, @@ -74,12 +69,11 @@ from .stream_runtime import StreamRuntime from .stream_state import ( _build_output_metadata, - _build_stream_artifact_metadata, _merge_token_usage, _StreamOutputState, _TTLCache, ) -from .tool_error_mapping import build_tool_error, map_a2a_tool_exception +from .tool_orchestration import handle_a2a_call_tool, maybe_handle_tools, merge_streamed_tool_output from .upstream_error_translator import ( _await_stream_terminal_signal, _extract_upstream_error_detail, @@ -95,6 +89,7 @@ _APPLICATION_JSON_MEDIA_TYPE = "application/json" __all__ = [ + "BlockType", "_build_output_metadata", "_build_progress_identity", "_coerce_number", @@ -122,453 +117,15 @@ "_TTLCache", ] +_EXPORTED_COMPAT_SYMBOLS = (BlockType, _await_stream_terminal_signal) + def _emit_metric( name: str, value: float = 1.0, **labels: str | int | float | bool, ) -> None: - if labels: - labels_text = ",".join( - f"{key}={str(label).lower() if isinstance(label, bool) else label}" - for key, label in sorted(labels.items()) - ) - logger.debug("metric=%s value=%s labels=%s", name, value, labels_text) - return - logger.debug("metric=%s value=%s", name, value) - - -@dataclass(frozen=True) -class _PreparedExecution: - identity: str - streaming_request: bool - request_parts: list[Any] - user_text: str - session_title: str - use_structured_parts: bool - bound_session_id: str | None - model_override: dict[str, str] | None - directory: str | None - workspace_id: str | None - session_binding_context_id: str - allow_structured_output: bool - - -def _build_session_binding_context_id( - *, - context_id: str, - directory: str | None, - workspace_id: str | None, - use_directory_binding: bool, -) -> str: - if isinstance(workspace_id, str) and workspace_id.strip(): - return f"{context_id}::workspace:{workspace_id.strip()}" - if use_directory_binding and isinstance(directory, str) and directory.strip(): - return f"{context_id}::directory:{directory.strip()}" - return context_id - - -class _ExecutionCoordinator: - def __init__( - self, - executor: OpencodeAgentExecutor, - *, - context: RequestContext, - event_queue: EventQueue, - task_id: str, - context_id: str, - prepared: _PreparedExecution, - ) -> None: - self._executor = executor - self._context = context - self._event_queue = event_queue - self._task_id = task_id - self._context_id = context_id - self._prepared = prepared - self._stream_artifact_id = f"{task_id}:stream" - self._stream_state = _StreamOutputState( - user_text=prepared.user_text, - stable_message_id=f"{task_id}:{context_id}:assistant", - event_id_namespace=f"{task_id}:{context_id}:{self._stream_artifact_id}", - ) - self._stream_terminal_signal: asyncio.Future[_StreamTerminalSignal] | None = None - self._stop_event = asyncio.Event() - self._stream_task: asyncio.Task[None] | None = None - self._pending_preferred_claim = False - self._session_lock: asyncio.Lock | None = None - self._session_id = "" - self._execution_key = (task_id, context_id) - - async def run(self) -> None: - current_task = asyncio.current_task() - if current_task is not None: - await self._register_running_request(current_task) - - try: - await self._bind_session() - await self._enqueue_working_status() - - turn_request_parts = list(self._prepared.request_parts) - user_text = self._prepared.user_text - - while True: - send_kwargs: dict[str, Any] = { - "directory": self._prepared.directory, - "workspace_id": self._prepared.workspace_id, - "model_override": self._prepared.model_override, - } - if self._prepared.streaming_request: - send_kwargs["timeout_override"] = self._executor._client.stream_timeout - - if not self._prepared.use_structured_parts and not turn_request_parts: - response = await call_with_supported_kwargs( - self._executor._client.send_message, - self._session_id, - user_text, - **send_kwargs, - ) - else: - response = await call_with_supported_kwargs( - self._executor._client.send_message, - self._session_id, - user_text or None, - parts=turn_request_parts, - **send_kwargs, - ) - - if self._pending_preferred_claim: - await self._executor._session_manager.finalize_preferred_session_binding( - identity=self._prepared.identity, - context_id=self._prepared.session_binding_context_id, - session_id=self._session_id, - ) - self._pending_preferred_claim = False - - # Check for tool calls that we should handle - tool_results = await self._executor._maybe_handle_tools(response.raw) - if tool_results: - # Clear user_text/parts for the next turn with tool results. - user_text = "" - turn_request_parts = [ - { - "type": "tool", - "tool": res["tool"], - "call_id": res["call_id"], - "output": res.get("output"), - "error": res.get("error"), - } - for res in tool_results - ] - # Loop back to send tool results - continue - - await self._handle_response(response) - break - - except httpx.HTTPStatusError as exc: - logger.exception("OpenCode request failed with HTTP error") - error_type, state, message = _format_upstream_error( - exc, - request="send_message", - ) - await self._executor._emit_error( - self._event_queue, - task_id=self._task_id, - context_id=self._context_id, - message=message, - state=state, - error_type=error_type, - upstream_status=exc.response.status_code, - streaming_request=self._prepared.streaming_request, - ) - except httpx.TimeoutException as exc: - logger.exception("OpenCode request timed out") - await self._executor._emit_error( - self._event_queue, - task_id=self._task_id, - context_id=self._context_id, - message=f"OpenCode request timed out: {exc}", - state=TaskState.failed, - error_type="UPSTREAM_TIMEOUT", - streaming_request=self._prepared.streaming_request, - ) - except UpstreamContractError as exc: - logger.warning("OpenCode request failed with payload mismatch: %s", exc) - await self._executor._emit_error( - self._event_queue, - task_id=self._task_id, - context_id=self._context_id, - message=f"OpenCode payload mismatch: {exc}", - state=TaskState.failed, - error_type="UPSTREAM_PAYLOAD_ERROR", - streaming_request=self._prepared.streaming_request, - ) - except UpstreamConcurrencyLimitError as exc: - logger.warning("OpenCode request rejected by concurrency budget: %s", exc) - await self._executor._emit_error( - self._event_queue, - task_id=self._task_id, - context_id=self._context_id, - message=str(exc), - state=TaskState.failed, - error_type="UPSTREAM_BACKPRESSURE", - streaming_request=self._prepared.streaming_request, - ) - except Exception as exc: - logger.exception("OpenCode request failed") - await self._executor._emit_error( - self._event_queue, - task_id=self._task_id, - context_id=self._context_id, - message=f"OpenCode error: {exc}", - state=TaskState.failed, - streaming_request=self._prepared.streaming_request, - ) - finally: - await self._cleanup() - - async def _register_running_request(self, current_task: asyncio.Task[Any]) -> None: - async with self._executor._lock: - self._executor._running_requests[self._execution_key] = current_task - self._executor._running_stop_events[self._execution_key] = self._stop_event - self._executor._running_identities[self._execution_key] = self._prepared.identity - - async def _bind_session(self) -> None: - ( - self._session_id, - self._pending_preferred_claim, - ) = await self._executor._session_manager.get_or_create_session( - self._prepared.identity, - self._prepared.session_binding_context_id, - self._prepared.session_title or self._prepared.user_text, - preferred_session_id=self._prepared.bound_session_id, - directory=self._prepared.directory, - workspace_id=self._prepared.workspace_id, - ) - self._session_lock = await self._executor._session_manager.get_session_lock( - self._session_id - ) - await self._session_lock.acquire() - async with self._executor._lock: - self._executor._running_session_ids[self._execution_key] = self._session_id - self._executor._running_directories[self._execution_key] = self._prepared.directory - self._executor._running_workspace_ids[self._execution_key] = self._prepared.workspace_id - self._executor._running_binding_context_ids[self._execution_key] = ( - self._prepared.session_binding_context_id - ) - - if self._prepared.streaming_request: - self._stream_terminal_signal = asyncio.get_running_loop().create_future() - self._stream_task = asyncio.create_task( - self._executor._consume_opencode_stream( - session_id=self._session_id, - identity=self._prepared.identity, - task_id=self._task_id, - context_id=self._context_id, - artifact_id=self._stream_artifact_id, - stream_state=self._stream_state, - event_queue=self._event_queue, - stop_event=self._stop_event, - directory=self._prepared.directory, - workspace_id=self._prepared.workspace_id, - terminal_signal=self._stream_terminal_signal, - allow_structured_output=self._prepared.allow_structured_output, - ) - ) - - async def _enqueue_working_status(self) -> None: - await self._event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=self._task_id, - context_id=self._context_id, - status=TaskStatus(state=TaskState.working), - final=False, - ) - ) - - async def _handle_response(self, response: Any) -> None: - response_text = response.text or "" - resolved_message_id = self._stream_state.resolve_message_id(response.message_id) - response_error = _extract_upstream_error_from_response(response.raw) - resolved_token_usage = _merge_token_usage( - _extract_token_usage(response.raw), - self._stream_state.token_usage, - ) - - logger.debug( - "OpenCode response task_id=%s session_id=%s message_id=%s text=%s", - self._task_id, - response.session_id, - resolved_message_id, - response_text, - ) - - if response_error is not None: - await self._executor._emit_error( - self._event_queue, - task_id=self._task_id, - context_id=self._context_id, - message=response_error.message, - state=response_error.state, - error_type=response_error.error_type, - upstream_status=response_error.upstream_status, - streaming_request=self._prepared.streaming_request, - ) - return - - if self._prepared.streaming_request: - await self._handle_streaming_response( - response=response, - response_text=response_text, - resolved_message_id=resolved_message_id, - resolved_token_usage=resolved_token_usage, - ) - return - - await self._handle_non_streaming_response( - response=response, - response_text=response_text, - resolved_message_id=resolved_message_id, - resolved_token_usage=resolved_token_usage, - ) - - async def _handle_streaming_response( - self, - *, - response: Any, - response_text: str, - resolved_message_id: str, - resolved_token_usage: Mapping[str, Any] | None, - ) -> None: - del response - if self._stream_terminal_signal is None: - raise RuntimeError("Streaming terminal signal was not initialized") - - terminal_signal = await _await_stream_terminal_signal( - stream_task=self._stream_task, - terminal_signal=self._stream_terminal_signal, - session_id=self._session_id, - ) - if terminal_signal.state != TaskState.completed: - await self._executor._emit_error( - self._event_queue, - task_id=self._task_id, - context_id=self._context_id, - message=terminal_signal.message or "OpenCode execution failed.", - state=terminal_signal.state, - error_type=terminal_signal.error_type, - upstream_status=terminal_signal.upstream_status, - streaming_request=True, - ) - return - - if self._stream_state.upstream_error is not None: - await self._executor._emit_error( - self._event_queue, - task_id=self._task_id, - context_id=self._context_id, - message=self._stream_state.upstream_error.message, - state=self._stream_state.upstream_error.state, - error_type=self._stream_state.upstream_error.error_type, - upstream_status=self._stream_state.upstream_error.upstream_status, - streaming_request=True, - ) - return - - if self._stream_state.should_emit_final_snapshot(response_text): - sequence = self._stream_state.next_sequence() - await _enqueue_artifact_update( - event_queue=self._event_queue, - task_id=self._task_id, - context_id=self._context_id, - artifact_id=self._stream_artifact_id, - part=Part(root=TextPart(text=response_text)), - append=self._stream_state.emitted_stream_chunk, - last_chunk=True, - artifact_metadata=_build_stream_artifact_metadata( - block_type=BlockType.TEXT, - shared_source="final_snapshot", - message_id=resolved_message_id, - event_id=self._stream_state.build_event_id(sequence), - sequence=sequence, - ), - ) - - await self._event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=self._task_id, - context_id=self._context_id, - status=TaskStatus(state=TaskState.completed), - final=True, - metadata=_build_output_metadata( - session_id=self._session_id, - usage=resolved_token_usage, - stream={ - "message_id": resolved_message_id, - "event_id": f"{self._stream_state.event_id_namespace}:status", - "source": "status", - }, - ), - ) - ) - - async def _handle_non_streaming_response( - self, - *, - response: Any, - response_text: str, - resolved_message_id: str, - resolved_token_usage: Mapping[str, Any] | None, - ) -> None: - response_text = response_text or "(No text content returned by OpenCode.)" - assistant_message = _build_assistant_message( - task_id=self._task_id, - context_id=self._context_id, - text=response_text, - message_id=resolved_message_id, - ) - artifact = Artifact( - artifact_id=str(uuid.uuid4()), - name="response", - parts=[Part(root=TextPart(text=response_text))], - ) - history = _build_history(self._context) - task = Task( - id=self._task_id, - context_id=self._context_id, - status=TaskStatus(state=TaskState.completed), - history=history, - artifacts=[artifact], - metadata=_build_output_metadata( - session_id=response.session_id, - usage=resolved_token_usage, - ), - ) - task.status.message = assistant_message - await self._event_queue.enqueue_event(task) - - async def _cleanup(self) -> None: - if self._pending_preferred_claim and self._session_id: - with suppress(Exception): - await self._executor._session_manager.release_preferred_session_claim( - identity=self._prepared.identity, - session_id=self._session_id, - ) - self._stop_event.set() - if self._stream_task: - self._stream_task.cancel() - with suppress(asyncio.CancelledError): - await self._stream_task - if self._session_lock and self._session_lock.locked(): - self._session_lock.release() - async with self._executor._lock: - self._executor._running_requests.pop(self._execution_key, None) - self._executor._running_stop_events.pop(self._execution_key, None) - self._executor._running_identities.pop(self._execution_key, None) - self._executor._running_session_ids.pop(self._execution_key, None) - self._executor._running_directories.pop(self._execution_key, None) - self._executor._running_workspace_ids.pop(self._execution_key, None) - self._executor._running_binding_context_ids.pop(self._execution_key, None) + emit_metric(name, value, **labels) class OpencodeAgentExecutor(AgentExecutor): @@ -624,143 +181,20 @@ def _emit_metric( async def _maybe_handle_tools( self, raw_response: dict[str, Any] ) -> list[dict[str, Any]] | None: - """Heuristically detect and execute A2A tool calls from upstream OpenCode.""" - parts = raw_response.get("parts", []) - if not isinstance(parts, list): - return None - - results: list[dict[str, Any]] = [] - for part in parts: - if not isinstance(part, dict) or part.get("type") != "tool": - continue - - state = part.get("state") - if not isinstance(state, dict) or state.get("status") != "calling": - continue - - tool_name = part.get("tool") - if tool_name == "a2a_call": - result = await self._handle_a2a_call_tool(part) - if result: - results.append(result) - - return results if results else None + return await maybe_handle_tools( + raw_response, + a2a_client_manager=self._a2a_client_manager, + ) async def _handle_a2a_call_tool(self, part: dict[str, Any]) -> dict[str, Any]: - call_id = part.get("callID") or str(uuid.uuid4()) - tool_name = part.get("tool") or "a2a_call" - state = part.get("state", {}) - inputs = state.get("input", {}) - - if not isinstance(inputs, dict): - return { - "call_id": call_id, - "tool": tool_name, - **build_tool_error( - error_code="a2a_invalid_input", - error="Invalid a2a_call input payload", - ), - } - - agent_url = inputs.get("url") - message = inputs.get("message") - if not agent_url or not message: - return { - "call_id": call_id, - "tool": tool_name, - **build_tool_error( - error_code="a2a_missing_required_input", - error="Missing required a2a_call url or message", - ), - } - - mgr = self._a2a_client_manager - if mgr is None: - return { - "call_id": call_id, - "tool": tool_name, - **build_tool_error( - error_code="a2a_client_manager_unavailable", - error="A2A client manager is not available", - ), - } - - try: - event = None - result_text = "" - async with mgr.borrow_client(agent_url) as client: - async for current_event in client.send_message(message): - event = current_event - extracted = client.extract_text(current_event) - if extracted: - result_text = self._merge_streamed_tool_output(result_text, extracted) - - from a2a.types import Task - - if result_text: - return { - "call_id": call_id, - "tool": tool_name, - "output": result_text, - } - - if isinstance(event, Task): - result_text = "" - # Extract text from Task's assistant message if available - if event.status and event.status.message: - for part_obj in event.status.message.parts: - # Use dict-style access if available or getattr to satisfy type checkers - root = getattr(part_obj, "root", part_obj) - text_val = getattr(root, "text", "") - if text_val: - result_text += str(text_val) - return { - "call_id": call_id, - "tool": tool_name, - "output": result_text or "Task completed.", - } - - # Handle case where event is a tuple (Task, Update) - if isinstance(event, tuple) and len(event) > 0 and isinstance(event[0], Task): - return { - "call_id": call_id, - "tool": tool_name, - "output": "Task completed (streaming).", - } - - return { - "call_id": call_id, - "tool": tool_name, - **build_tool_error( - error_code="a2a_unexpected_response", - error="Remote A2A peer returned an unexpected response type", - error_meta={"response_type": type(event).__name__}, - ), - } - except Exception as exc: - logger.exception("A2A tool call failed") - return { - "call_id": call_id, - "tool": tool_name, - **map_a2a_tool_exception(exc), - } + return await handle_a2a_call_tool( + part, + a2a_client_manager=self._a2a_client_manager, + ) @staticmethod def _merge_streamed_tool_output(current: str, incoming: str) -> str: - if not current: - return incoming - if incoming == current or incoming in current: - return current - if incoming.startswith(current): - return incoming - if current.startswith(incoming): - return current - separator = ( - "" - if current.endswith(("\n", " ", "\t")) or incoming.startswith(("\n", " ", "\t")) - else "\n" - ) - return f"{current}{separator}{incoming}" + return merge_streamed_tool_output(current, incoming) async def execute(self, context: RequestContext, event_queue: EventQueue) -> None: task_id = context.task_id @@ -841,7 +275,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non ) return - session_binding_context_id = _build_session_binding_context_id( + session_binding_context_id = build_session_binding_context_id( context_id=context_id, directory=directory, workspace_id=workspace_id, @@ -887,7 +321,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non user_text, len(request_parts), ) - prepared = _PreparedExecution( + prepared = PreparedExecution( identity=identity, streaming_request=streaming_request, request_parts=request_parts, @@ -901,7 +335,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non session_binding_context_id=session_binding_context_id, allow_structured_output=allow_structured_output, ) - coordinator = _ExecutionCoordinator( + coordinator = ExecutionCoordinator( self, context=context, event_queue=event_queue, @@ -1136,19 +570,3 @@ async def _consume_opencode_stream( workspace_id=workspace_id, allow_structured_output=allow_structured_output, ) - - -def _build_assistant_message( - task_id: str, - context_id: str, - text: str, - *, - message_id: str | None = None, -) -> Message: - return Message( - message_id=message_id or str(uuid.uuid4()), - role=Role.agent, - parts=[Part(root=TextPart(text=text))], - task_id=task_id, - context_id=context_id, - ) diff --git a/src/opencode_a2a/execution/metrics.py b/src/opencode_a2a/execution/metrics.py new file mode 100644 index 0000000..095bc71 --- /dev/null +++ b/src/opencode_a2a/execution/metrics.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import logging + +logger = logging.getLogger("opencode_a2a.execution.executor") + + +def emit_metric( + name: str, + value: float = 1.0, + **labels: str | int | float | bool, +) -> None: + if labels: + labels_text = ",".join( + f"{key}={str(label).lower() if isinstance(label, bool) else label}" + for key, label in sorted(labels.items()) + ) + logger.debug("metric=%s value=%s labels=%s", name, value, labels_text) + return + logger.debug("metric=%s value=%s", name, value) diff --git a/src/opencode_a2a/execution/tool_orchestration.py b/src/opencode_a2a/execution/tool_orchestration.py new file mode 100644 index 0000000..7fe3f61 --- /dev/null +++ b/src/opencode_a2a/execution/tool_orchestration.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import logging +import uuid +from typing import Any + +from .tool_error_mapping import build_tool_error, map_a2a_tool_exception + +logger = logging.getLogger(__name__) + + +async def maybe_handle_tools( + raw_response: dict[str, Any], + *, + a2a_client_manager, +) -> list[dict[str, Any]] | None: + parts = raw_response.get("parts", []) + if not isinstance(parts, list): + return None + + results: list[dict[str, Any]] = [] + for part in parts: + if not isinstance(part, dict) or part.get("type") != "tool": + continue + + state = part.get("state") + if not isinstance(state, dict) or state.get("status") != "calling": + continue + + tool_name = part.get("tool") + if tool_name == "a2a_call": + result = await handle_a2a_call_tool(part, a2a_client_manager=a2a_client_manager) + if result: + results.append(result) + + return results if results else None + + +async def handle_a2a_call_tool( + part: dict[str, Any], + *, + a2a_client_manager, +) -> dict[str, Any]: + call_id = part.get("callID") or str(uuid.uuid4()) + tool_name = part.get("tool") or "a2a_call" + state = part.get("state", {}) + inputs = state.get("input", {}) + + if not isinstance(inputs, dict): + return { + "call_id": call_id, + "tool": tool_name, + **build_tool_error( + error_code="a2a_invalid_input", + error="Invalid a2a_call input payload", + ), + } + + agent_url = inputs.get("url") + message = inputs.get("message") + if not agent_url or not message: + return { + "call_id": call_id, + "tool": tool_name, + **build_tool_error( + error_code="a2a_missing_required_input", + error="Missing required a2a_call url or message", + ), + } + + if a2a_client_manager is None: + return { + "call_id": call_id, + "tool": tool_name, + **build_tool_error( + error_code="a2a_client_manager_unavailable", + error="A2A client manager is not available", + ), + } + + try: + event = None + result_text = "" + async with a2a_client_manager.borrow_client(agent_url) as client: + async for current_event in client.send_message(message): + event = current_event + extracted = client.extract_text(current_event) + if extracted: + result_text = merge_streamed_tool_output(result_text, extracted) + + from a2a.types import Task + + if result_text: + return { + "call_id": call_id, + "tool": tool_name, + "output": result_text, + } + + if isinstance(event, Task): + result_text = "" + if event.status and event.status.message: + for part_obj in event.status.message.parts: + root = getattr(part_obj, "root", part_obj) + text_val = getattr(root, "text", "") + if text_val: + result_text += str(text_val) + return { + "call_id": call_id, + "tool": tool_name, + "output": result_text or "Task completed.", + } + + if isinstance(event, tuple) and len(event) > 0 and isinstance(event[0], Task): + return { + "call_id": call_id, + "tool": tool_name, + "output": "Task completed (streaming).", + } + + return { + "call_id": call_id, + "tool": tool_name, + **build_tool_error( + error_code="a2a_unexpected_response", + error="Remote A2A peer returned an unexpected response type", + error_meta={"response_type": type(event).__name__}, + ), + } + except Exception as exc: + logger.exception("A2A tool call failed") + return { + "call_id": call_id, + "tool": tool_name, + **map_a2a_tool_exception(exc), + } + + +def merge_streamed_tool_output(current: str, incoming: str) -> str: + if not current: + return incoming + if incoming == current or incoming in current: + return current + if incoming.startswith(current): + return incoming + if current.startswith(incoming): + return current + separator = ( + "" + if current.endswith(("\n", " ", "\t")) or incoming.startswith(("\n", " ", "\t")) + else "\n" + ) + return f"{current}{separator}{incoming}" diff --git a/src/opencode_a2a/server/application.py b/src/opencode_a2a/server/application.py index bd8f1bf..9460594 100644 --- a/src/opencode_a2a/server/application.py +++ b/src/opencode_a2a/server/application.py @@ -1,12 +1,7 @@ from __future__ import annotations import asyncio -import hashlib -import json import logging -import secrets -from contextlib import asynccontextmanager -from contextvars import ContextVar, Token from functools import partial from typing import TYPE_CHECKING, cast @@ -38,19 +33,11 @@ UnsupportedOperationError, ) from a2a.utils import are_modalities_compatible -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, - EXTENDED_AGENT_CARD_PATH, - PREV_AGENT_CARD_WELL_KNOWN_PATH, -) from a2a.utils.errors import ServerError from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse, Response from pydantic_settings import BaseSettings from starlette.middleware.gzip import GZipMiddleware -from starlette.responses import StreamingResponse -from ..client import A2AClient from ..config import Settings from ..contracts.extensions import ( COMPATIBILITY_PROFILE_EXTENSION_URI, @@ -71,7 +58,7 @@ WORKSPACE_CONTROL_METHODS, build_capability_snapshot, ) -from ..execution.executor import OpencodeAgentExecutor, _emit_metric +from ..execution.executor import OpencodeAgentExecutor from ..invocation import call_with_supported_kwargs from ..jsonrpc.application import ( OpencodeSessionQueryJSONRPCApplication, @@ -87,6 +74,15 @@ build_agent_card, build_authenticated_extended_agent_card, ) +from .client_manager import A2AClientManager +from .lifespan import build_lifespan +from .middleware import ( + AUTHENTICATED_EXTENDED_CARD_CACHE_CONTROL, + PUBLIC_AGENT_CARD_CACHE_CONTROL, + build_agent_card_etag, + emit_stream_request_metrics, + install_runtime_middlewares, +) from .openapi import ( _build_jsonrpc_extension_openapi_description, _build_jsonrpc_extension_openapi_examples, @@ -108,19 +104,15 @@ from .state_store import ( build_interrupt_request_repository, build_session_state_repository, - initialize_state_repository, ) from .task_store import ( TaskStoreOperationError, build_database_engine, build_task_store, - initialize_task_store, ) logger = logging.getLogger(__name__) TASK_STORE_ERROR_TYPE = "TASK_STORE_UNAVAILABLE" -PUBLIC_AGENT_CARD_CACHE_CONTROL = "public, max-age=300" -AUTHENTICATED_EXTENDED_CARD_CACHE_CONTROL = "private, max-age=300" __all__ = [ "_RequestBodyTooLargeError", @@ -130,6 +122,8 @@ "INTERRUPT_RECOVERY_EXTENSION_URI", "INTERRUPT_RECOVERY_METHODS", "MODEL_SELECTION_EXTENSION_URI", + "PUBLIC_AGENT_CARD_CACHE_CONTROL", + "AUTHENTICATED_EXTENDED_CARD_CACHE_CONTROL", "PROVIDER_DISCOVERY_EXTENSION_URI", "PROVIDER_DISCOVERY_METHODS", "SESSION_BINDING_EXTENSION_URI", @@ -161,11 +155,6 @@ "build_agent_card", ] -_REQUEST_BODY_BYTES: ContextVar[bytes | None] = ContextVar( - "_REQUEST_BODY_BYTES", - default=None, -) - if TYPE_CHECKING: from a2a.server.context import ServerCallContext @@ -370,8 +359,8 @@ async def on_message_send_stream(self, params, context=None): result_aggregator, producer_task, ) = await self._setup_message_execution(params, context) - _emit_metric("a2a_stream_requests_total") - _emit_metric("a2a_stream_active") + emit_stream_request_metrics() + emit_stream_request_metrics(active_delta=1.0) consumer = EventConsumer(queue) producer_task.add_done_callback(consumer.agent_task_callback) stream_completed = False @@ -401,7 +390,7 @@ async def on_message_send_stream(self, params, context=None): await queue.close(immediate=True) raise finally: - _emit_metric("a2a_stream_active", -1) + emit_stream_request_metrics(active_delta=-1.0) logger.debug( "A2A stream request closed task_id=%s completed=%s", task_id, @@ -519,221 +508,6 @@ def build(self, request: Request) -> ServerCallContext: return context -def add_auth_middleware(app: FastAPI, settings: Settings) -> None: - token = settings.a2a_bearer_token - - def _unauthorized_response() -> JSONResponse: - return JSONResponse( - {"error": "Unauthorized"}, - status_code=401, - headers={"WWW-Authenticate": "Bearer"}, - ) - - @app.middleware("http") - async def bearer_auth(request: Request, call_next): - if request.method == "OPTIONS" or request.url.path in { - AGENT_CARD_WELL_KNOWN_PATH, - PREV_AGENT_CARD_WELL_KNOWN_PATH, - }: - return await call_next(request) - - auth_header = request.headers.get("authorization", "") - if not auth_header.lower().startswith("bearer "): - return _unauthorized_response() - provided = auth_header.split(" ", 1)[1].strip() - if not secrets.compare_digest(provided, token): - return _unauthorized_response() - request.state.user_identity = f"bearer:{hashlib.sha256(provided.encode()).hexdigest()[:12]}" - - return await call_next(request) - - -def _agent_card_response_bytes(card) -> bytes: - payload = card.model_dump(mode="json", by_alias=True, exclude_none=True) - return json.dumps( - payload, - ensure_ascii=False, - separators=(",", ":"), - sort_keys=True, - ).encode("utf-8") - - -def _build_agent_card_etag(card) -> str: - return f'W/"{hashlib.sha256(_agent_card_response_bytes(card)).hexdigest()}"' - - -def _etag_matches(if_none_match: str | None, etag: str) -> bool: - if not if_none_match: - return False - candidates = {item.strip() for item in if_none_match.split(",") if item.strip()} - return "*" in candidates or etag in candidates - - -def _merge_vary(*values: str) -> str: - ordered: list[str] = [] - seen: set[str] = set() - for value in values: - for item in value.split(","): - normalized = item.strip() - if not normalized: - continue - key = normalized.lower() - if key in seen: - continue - seen.add(key) - ordered.append(normalized) - return ", ".join(ordered) - - -class A2AClientManager: - def __init__(self, settings: Settings) -> None: - import time - - from ..client.config import load_settings as load_client_settings - - self.client_settings = load_client_settings( - { - "A2A_CLIENT_TIMEOUT_SECONDS": settings.a2a_client_timeout_seconds, - "A2A_CLIENT_CARD_FETCH_TIMEOUT_SECONDS": ( - settings.a2a_client_card_fetch_timeout_seconds - ), - "A2A_CLIENT_USE_CLIENT_PREFERENCE": settings.a2a_client_use_client_preference, - "A2A_CLIENT_BEARER_TOKEN": settings.a2a_client_bearer_token, - "A2A_CLIENT_BASIC_AUTH": settings.a2a_client_basic_auth, - "A2A_CLIENT_SUPPORTED_TRANSPORTS": settings.a2a_client_supported_transports, - } - ) - self._cache_ttl_seconds = float(settings.a2a_client_cache_ttl_seconds) - self._cache_maxsize = int(settings.a2a_client_cache_maxsize) - self._now = time.monotonic - self.clients: dict[str, _ClientCacheEntry] = {} - self._lock = asyncio.Lock() - - @asynccontextmanager - async def borrow_client(self, agent_url: str): - url = agent_url.rstrip("/") - if self._cache_maxsize <= 0: - client = A2AClient(url, settings=self.client_settings) - try: - yield client - finally: - await client.close() - return - - to_close: list[A2AClient] = [] - async with self._lock: - now = self._now() - entry = self.clients.get(url) - if entry is not None and entry.expires_at is not None and entry.expires_at <= now: - if entry.borrow_count > 0 or entry.client.is_busy(): - entry.pending_eviction = True - else: - self.clients.pop(url, None) - to_close.append(entry.client) - entry = None - to_close.extend(self._evict_locked(now=now, protected_keys={url})) - if entry is None: - entry = _ClientCacheEntry( - client=A2AClient(url, settings=self.client_settings), - last_used=now, - expires_at=None - if self._cache_ttl_seconds <= 0 - else now + self._cache_ttl_seconds, - ) - self.clients[url] = entry - else: - entry.last_used = now - entry.expires_at = ( - None if self._cache_ttl_seconds <= 0 else now + self._cache_ttl_seconds - ) - entry.pending_eviction = False - entry.borrow_count += 1 - to_close.extend(self._evict_locked(now=now, protected_keys={url})) - await self._close_clients(to_close) - - try: - yield entry.client - finally: - async with self._lock: - now = self._now() - current = self.clients.get(url) - if current is entry: - if current.borrow_count > 0: - current.borrow_count -= 1 - current.last_used = now - current.expires_at = ( - None if self._cache_ttl_seconds <= 0 else now + self._cache_ttl_seconds - ) - to_close = self._evict_locked(now=now) - await self._close_clients(to_close) - - async def close_all(self) -> None: - async with self._lock: - clients = [entry.client for entry in self.clients.values()] - self.clients.clear() - for client in clients: - await client.close() - - def _evict_locked( - self, - *, - now: float, - protected_keys: set[str] | None = None, - ) -> list[A2AClient]: - protected = protected_keys or set() - to_close: list[A2AClient] = [] - - for key, entry in list(self.clients.items()): - expired = entry.expires_at is not None and entry.expires_at <= now - if not expired and not entry.pending_eviction: - continue - if key in protected or entry.borrow_count > 0 or entry.client.is_busy(): - entry.pending_eviction = True - continue - self.clients.pop(key, None) - to_close.append(entry.client) - - if self._cache_maxsize <= 0 or len(self.clients) <= self._cache_maxsize: - return to_close - - if any(entry.pending_eviction for entry in self.clients.values()): - return to_close - - for key, entry in sorted(self.clients.items(), key=lambda item: item[1].last_used): - if len(self.clients) <= self._cache_maxsize: - break - if key in protected: - continue - if entry.borrow_count > 0 or entry.client.is_busy(): - entry.pending_eviction = True - continue - self.clients.pop(key, None) - to_close.append(entry.client) - - return to_close - - async def _close_clients(self, clients: list[A2AClient]) -> None: - for client in clients: - await client.close() - - -class _ClientCacheEntry: - def __init__( - self, - *, - client: A2AClient, - last_used: float, - expires_at: float | None, - borrow_count: int = 0, - pending_eviction: bool = False, - ) -> None: - self.client = client - self.last_used = last_used - self.expires_at = expires_at - self.borrow_count = borrow_count - self.pending_eviction = pending_eviction - - def create_app(settings: Settings) -> FastAPI: database_engine = ( build_database_engine(settings) if settings.a2a_task_store_backend == "database" else None @@ -813,19 +587,16 @@ def create_app(settings: Settings) -> FastAPI: http_handler=handler, context_builder=context_builder, ) - public_card_etag = _build_agent_card_etag(agent_card) - extended_card_etag = _build_agent_card_etag(extended_agent_card) - - @asynccontextmanager - async def lifespan(_app: FastAPI): - await initialize_task_store(task_store) - await initialize_state_repository(session_state_repository) - await initialize_state_repository(interrupt_request_repository) - yield - if database_engine is not None: - await database_engine.dispose() - await client_manager.close_all() - await upstream_client.close() + public_card_etag = build_agent_card_etag(agent_card) + extended_card_etag = build_agent_card_etag(extended_agent_card) + lifespan = build_lifespan( + database_engine=database_engine, + task_store=task_store, + session_state_repository=session_state_repository, + interrupt_request_repository=interrupt_request_repository, + client_manager=client_manager, + upstream_client=upstream_client, + ) app = A2AFastAPI( title=settings.a2a_title, @@ -842,6 +613,12 @@ async def lifespan(_app: FastAPI): app.state.upstream_client = upstream_client app.state.a2a_client_manager = client_manager _patch_jsonrpc_openapi_contract(app, settings, runtime_profile=runtime_profile) + install_runtime_middlewares( + app, + settings, + public_card_etag=public_card_etag, + extended_card_etag=extended_card_etag, + ) @app.get("/health") async def health_check(): @@ -851,250 +628,6 @@ async def health_check(): protocol_version=settings.a2a_protocol_version, ) - async def _get_request_body(request: Request) -> tuple[bytes, Token | None]: - cached = _REQUEST_BODY_BYTES.get() - if cached is not None: - return cached, None - - limit = settings.a2a_max_request_body_bytes - content_length = _parse_content_length(request.headers.get("content-length")) - if limit > 0 and content_length is not None and content_length > limit: - raise _RequestBodyTooLargeError(limit=limit, actual_size=content_length) - - if hasattr(request, "_body"): - body = request._body - if limit > 0 and len(body) > limit: - raise _RequestBodyTooLargeError(limit=limit, actual_size=len(body)) - elif limit <= 0: - body = await request.body() - else: - total = 0 - chunks: list[bytes] = [] - async for chunk in request.stream(): - if not chunk: - continue - total += len(chunk) - if total > limit: - raise _RequestBodyTooLargeError(limit=limit, actual_size=total) - chunks.append(chunk) - body = b"".join(chunks) - request._body = body - - token = _REQUEST_BODY_BYTES.set(body) - return body, token - - @app.middleware("http") - async def cache_agent_card_responses(request: Request, call_next): - if request.method != "GET": - return await call_next(request) - - path = request.url.path - is_public_card = path in { - AGENT_CARD_WELL_KNOWN_PATH, - PREV_AGENT_CARD_WELL_KNOWN_PATH, - } - is_extended_card = path == EXTENDED_AGENT_CARD_PATH - if not is_public_card and not is_extended_card: - return await call_next(request) - - if is_public_card and _etag_matches(request.headers.get("if-none-match"), public_card_etag): - return Response( - status_code=304, - headers={ - "ETag": public_card_etag, - "Cache-Control": PUBLIC_AGENT_CARD_CACHE_CONTROL, - "Vary": "Accept-Encoding", - }, - ) - - response = await call_next(request) - if response.status_code != 200: - return response - - if is_public_card: - response.headers["ETag"] = public_card_etag - response.headers["Cache-Control"] = PUBLIC_AGENT_CARD_CACHE_CONTROL - response.headers["Vary"] = _merge_vary( - response.headers.get("Vary", ""), - "Accept-Encoding", - ) - return response - - response.headers["ETag"] = extended_card_etag - response.headers["Cache-Control"] = AUTHENTICATED_EXTENDED_CARD_CACHE_CONTROL - response.headers["Vary"] = _merge_vary( - response.headers.get("Vary", ""), - "Authorization", - "Accept-Encoding", - ) - if _etag_matches(request.headers.get("if-none-match"), extended_card_etag): - return Response(status_code=304, headers=dict(response.headers)) - return response - - @app.middleware("http") - async def enforce_request_body_limit(request: Request, call_next): - token: Token | None = None - limit = settings.a2a_max_request_body_bytes - if limit <= 0 or request.method not in {"POST", "PUT", "PATCH"}: - return await call_next(request) - - try: - _, token = await _get_request_body(request) - return await call_next(request) - except _RequestBodyTooLargeError as error: - return _request_body_too_large_response( - path=request.url.path, - method=request.method, - error=error, - ) - finally: - if token is not None: - _REQUEST_BODY_BYTES.reset(token) - - @app.middleware("http") - async def guard_rest_payload_shape(request: Request, call_next): - token: Token | None = None - if request.method != "POST" or request.url.path not in { - "/v1/message:send", - "/v1/message:stream", - }: - return await call_next(request) - - try: - body, token = await _get_request_body(request) - payload = _parse_json_body(body) - if _looks_like_jsonrpc_envelope(payload) or _looks_like_jsonrpc_message_payload( - payload - ): - return JSONResponse( - { - "error": ( - "Invalid HTTP+JSON payload for REST endpoint. " - "Use message.content with ROLE_* role values, or call " - "POST / with method=message/send or method=message/stream." - ) - }, - status_code=400, - ) - return await call_next(request) - except _RequestBodyTooLargeError as error: - return _request_body_too_large_response( - path=request.url.path, - method=request.method, - error=error, - ) - finally: - if token is not None: - _REQUEST_BODY_BYTES.reset(token) - - @app.middleware("http") - async def log_payloads(request: Request, call_next): - token: Token | None = None - if not settings.a2a_log_payloads: - return await call_next(request) - - try: - path = request.url.path - limit = settings.a2a_log_body_limit - content_type = _normalize_content_type(request.headers.get("content-type")) - content_length = _parse_content_length(request.headers.get("content-length")) - - sensitive_method: str | None = None - request_omit_reason: str | None = None - - if not _is_json_content_type(content_type): - request_omit_reason = f"non-json content-type={content_type or 'unknown'}" - elif limit > 0 and content_length is None: - request_omit_reason = f"missing content-length with limit={limit}" - elif limit > 0 and content_length is not None and content_length > limit: - request_omit_reason = f"content-length={content_length} exceeds limit={limit}" - else: - body, token = await _get_request_body(request) - # Detect session-query JSON-RPC methods regardless of deployment prefixes/root_path. - payload = _parse_json_body(body) - sensitive_method = _detect_sensitive_extension_method(payload) - - if sensitive_method: - logger.debug( - "A2A request %s %s method=%s", - request.method, - path, - sensitive_method, - ) - else: - logger.debug( - "A2A request %s %s body=%s", - request.method, - path, - _decode_payload_preview(body, limit=limit), - ) - - if request_omit_reason: - logger.debug( - "A2A request %s %s body=[omitted %s]", - request.method, - path, - request_omit_reason, - ) - - response = await call_next(request) - if isinstance(response, StreamingResponse): - if sensitive_method: - logger.debug("A2A response %s streaming method=%s", path, sensitive_method) - else: - logger.debug("A2A response %s streaming", path) - return response - - response_body = getattr(response, "body", b"") or b"" - if sensitive_method: - logger.debug( - "A2A response %s status=%s bytes=%s method=%s", - path, - response.status_code, - len(response_body), - sensitive_method, - ) - return response - - if request_omit_reason: - logger.debug( - "A2A response %s status=%s bytes=%s body=[omitted request_%s]", - path, - response.status_code, - len(response_body), - request_omit_reason, - ) - return response - response_content_type = _normalize_content_type(response.headers.get("content-type")) - if not _is_json_content_type(response_content_type): - logger.debug( - "A2A response %s status=%s bytes=%s body=[omitted non-json content-type=%s]", - path, - response.status_code, - len(response_body), - response_content_type or "unknown", - ) - return response - - logger.debug( - "A2A response %s status=%s body=%s", - path, - response.status_code, - _decode_payload_preview(response_body, limit=limit), - ) - return response - except _RequestBodyTooLargeError as error: - return _request_body_too_large_response( - path=request.url.path, - method=request.method, - error=error, - ) - finally: - if token is not None: - _REQUEST_BODY_BYTES.reset(token) - - add_auth_middleware(app, settings) - return app diff --git a/src/opencode_a2a/server/client_manager.py b/src/opencode_a2a/server/client_manager.py new file mode 100644 index 0000000..d496e74 --- /dev/null +++ b/src/opencode_a2a/server/client_manager.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager + +from ..client import A2AClient + + +class A2AClientManager: + def __init__(self, settings) -> None: # noqa: ANN001 + import time + + from ..client.config import load_settings as load_client_settings + + self.client_settings = load_client_settings( + { + "A2A_CLIENT_TIMEOUT_SECONDS": settings.a2a_client_timeout_seconds, + "A2A_CLIENT_CARD_FETCH_TIMEOUT_SECONDS": ( + settings.a2a_client_card_fetch_timeout_seconds + ), + "A2A_CLIENT_USE_CLIENT_PREFERENCE": settings.a2a_client_use_client_preference, + "A2A_CLIENT_BEARER_TOKEN": settings.a2a_client_bearer_token, + "A2A_CLIENT_BASIC_AUTH": settings.a2a_client_basic_auth, + "A2A_CLIENT_SUPPORTED_TRANSPORTS": settings.a2a_client_supported_transports, + } + ) + self._cache_ttl_seconds = float(settings.a2a_client_cache_ttl_seconds) + self._cache_maxsize = int(settings.a2a_client_cache_maxsize) + self._now = time.monotonic + self.clients: dict[str, _ClientCacheEntry] = {} + self._lock = asyncio.Lock() + + @asynccontextmanager + async def borrow_client(self, agent_url: str): + url = agent_url.rstrip("/") + if self._cache_maxsize <= 0: + client = A2AClient(url, settings=self.client_settings) + try: + yield client + finally: + await client.close() + return + + to_close: list[A2AClient] = [] + async with self._lock: + now = self._now() + entry = self.clients.get(url) + if entry is not None and entry.expires_at is not None and entry.expires_at <= now: + if entry.borrow_count > 0 or entry.client.is_busy(): + entry.pending_eviction = True + else: + self.clients.pop(url, None) + to_close.append(entry.client) + entry = None + to_close.extend(self._evict_locked(now=now, protected_keys={url})) + if entry is None: + entry = _ClientCacheEntry( + client=A2AClient(url, settings=self.client_settings), + last_used=now, + expires_at=None + if self._cache_ttl_seconds <= 0 + else now + self._cache_ttl_seconds, + ) + self.clients[url] = entry + else: + entry.last_used = now + entry.expires_at = ( + None if self._cache_ttl_seconds <= 0 else now + self._cache_ttl_seconds + ) + entry.pending_eviction = False + entry.borrow_count += 1 + to_close.extend(self._evict_locked(now=now, protected_keys={url})) + await self._close_clients(to_close) + + try: + yield entry.client + finally: + async with self._lock: + now = self._now() + current = self.clients.get(url) + if current is entry: + if current.borrow_count > 0: + current.borrow_count -= 1 + current.last_used = now + current.expires_at = ( + None if self._cache_ttl_seconds <= 0 else now + self._cache_ttl_seconds + ) + to_close = self._evict_locked(now=now) + await self._close_clients(to_close) + + async def close_all(self) -> None: + async with self._lock: + clients = [entry.client for entry in self.clients.values()] + self.clients.clear() + for client in clients: + await client.close() + + def _evict_locked( + self, + *, + now: float, + protected_keys: set[str] | None = None, + ) -> list[A2AClient]: + protected = protected_keys or set() + to_close: list[A2AClient] = [] + + for key, entry in list(self.clients.items()): + expired = entry.expires_at is not None and entry.expires_at <= now + if not expired and not entry.pending_eviction: + continue + if key in protected or entry.borrow_count > 0 or entry.client.is_busy(): + entry.pending_eviction = True + continue + self.clients.pop(key, None) + to_close.append(entry.client) + + if self._cache_maxsize <= 0 or len(self.clients) <= self._cache_maxsize: + return to_close + + if any(entry.pending_eviction for entry in self.clients.values()): + return to_close + + for key, entry in sorted(self.clients.items(), key=lambda item: item[1].last_used): + if len(self.clients) <= self._cache_maxsize: + break + if key in protected: + continue + if entry.borrow_count > 0 or entry.client.is_busy(): + entry.pending_eviction = True + continue + self.clients.pop(key, None) + to_close.append(entry.client) + + return to_close + + async def _close_clients(self, clients: list[A2AClient]) -> None: + for client in clients: + await client.close() + + +class _ClientCacheEntry: + def __init__( + self, + *, + client: A2AClient, + last_used: float, + expires_at: float | None, + borrow_count: int = 0, + pending_eviction: bool = False, + ) -> None: + self.client = client + self.last_used = last_used + self.expires_at = expires_at + self.borrow_count = borrow_count + self.pending_eviction = pending_eviction diff --git a/src/opencode_a2a/server/lifespan.py b/src/opencode_a2a/server/lifespan.py new file mode 100644 index 0000000..334fcbd --- /dev/null +++ b/src/opencode_a2a/server/lifespan.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager + +from .state_store import initialize_state_repository +from .task_store import initialize_task_store + + +def build_lifespan( + *, + database_engine, + task_store, + session_state_repository, + interrupt_request_repository, + client_manager, + upstream_client, +): + @asynccontextmanager + async def lifespan(_app): + await initialize_task_store(task_store) + await initialize_state_repository(session_state_repository) + await initialize_state_repository(interrupt_request_repository) + yield + if database_engine is not None: + await database_engine.dispose() + await client_manager.close_all() + await upstream_client.close() + + return lifespan diff --git a/src/opencode_a2a/server/middleware.py b/src/opencode_a2a/server/middleware.py new file mode 100644 index 0000000..d3afc4f --- /dev/null +++ b/src/opencode_a2a/server/middleware.py @@ -0,0 +1,357 @@ +from __future__ import annotations + +import hashlib +import json +import logging +import secrets +from contextvars import ContextVar, Token + +from a2a.utils.constants import ( + AGENT_CARD_WELL_KNOWN_PATH, + EXTENDED_AGENT_CARD_PATH, + PREV_AGENT_CARD_WELL_KNOWN_PATH, +) +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response +from starlette.responses import StreamingResponse + +from ..execution.metrics import emit_metric +from .request_parsing import ( + _decode_payload_preview, + _detect_sensitive_extension_method, + _is_json_content_type, + _looks_like_jsonrpc_envelope, + _looks_like_jsonrpc_message_payload, + _normalize_content_type, + _parse_content_length, + _parse_json_body, + _request_body_too_large_response, + _RequestBodyTooLargeError, +) + +logger = logging.getLogger("opencode_a2a.server.application") +PUBLIC_AGENT_CARD_CACHE_CONTROL = "public, max-age=300" +AUTHENTICATED_EXTENDED_CARD_CACHE_CONTROL = "private, max-age=300" +_REQUEST_BODY_BYTES: ContextVar[bytes | None] = ContextVar( + "_REQUEST_BODY_BYTES", + default=None, +) + + +def add_auth_middleware(app: FastAPI, settings) -> None: # noqa: ANN001 + token = settings.a2a_bearer_token + + def _unauthorized_response() -> JSONResponse: + return JSONResponse( + {"error": "Unauthorized"}, + status_code=401, + headers={"WWW-Authenticate": "Bearer"}, + ) + + @app.middleware("http") + async def bearer_auth(request: Request, call_next): + if request.method == "OPTIONS" or request.url.path in { + AGENT_CARD_WELL_KNOWN_PATH, + PREV_AGENT_CARD_WELL_KNOWN_PATH, + }: + return await call_next(request) + + auth_header = request.headers.get("authorization", "") + if not auth_header.lower().startswith("bearer "): + return _unauthorized_response() + provided = auth_header.split(" ", 1)[1].strip() + if not secrets.compare_digest(provided, token): + return _unauthorized_response() + request.state.user_identity = f"bearer:{hashlib.sha256(provided.encode()).hexdigest()[:12]}" + + return await call_next(request) + + +def build_agent_card_etag(card) -> str: # noqa: ANN001 + payload = card.model_dump(mode="json", by_alias=True, exclude_none=True) + content = json.dumps( + payload, + ensure_ascii=False, + separators=(",", ":"), + sort_keys=True, + ).encode("utf-8") + return f'W/"{hashlib.sha256(content).hexdigest()}"' + + +def install_runtime_middlewares( + app: FastAPI, + settings, + *, + public_card_etag: str, + extended_card_etag: str, +) -> None: + async def _get_request_body(request: Request) -> tuple[bytes, Token | None]: + cached = _REQUEST_BODY_BYTES.get() + if cached is not None: + return cached, None + + limit = settings.a2a_max_request_body_bytes + content_length = _parse_content_length(request.headers.get("content-length")) + if limit > 0 and content_length is not None and content_length > limit: + raise _RequestBodyTooLargeError(limit=limit, actual_size=content_length) + + if hasattr(request, "_body"): + body = request._body + if limit > 0 and len(body) > limit: + raise _RequestBodyTooLargeError(limit=limit, actual_size=len(body)) + elif limit <= 0: + body = await request.body() + else: + total = 0 + chunks: list[bytes] = [] + async for chunk in request.stream(): + if not chunk: + continue + total += len(chunk) + if total > limit: + raise _RequestBodyTooLargeError(limit=limit, actual_size=total) + chunks.append(chunk) + body = b"".join(chunks) + request._body = body + + token = _REQUEST_BODY_BYTES.set(body) + return body, token + + def _etag_matches(if_none_match: str | None, etag: str) -> bool: + if not if_none_match: + return False + candidates = {item.strip() for item in if_none_match.split(",") if item.strip()} + return "*" in candidates or etag in candidates + + def _merge_vary(*values: str) -> str: + ordered: list[str] = [] + seen: set[str] = set() + for value in values: + for item in value.split(","): + normalized = item.strip() + if not normalized: + continue + key = normalized.lower() + if key in seen: + continue + seen.add(key) + ordered.append(normalized) + return ", ".join(ordered) + + @app.middleware("http") + async def cache_agent_card_responses(request: Request, call_next): + if request.method != "GET": + return await call_next(request) + + path = request.url.path + is_public_card = path in { + AGENT_CARD_WELL_KNOWN_PATH, + PREV_AGENT_CARD_WELL_KNOWN_PATH, + } + is_extended_card = path == EXTENDED_AGENT_CARD_PATH + if not is_public_card and not is_extended_card: + return await call_next(request) + + if is_public_card and _etag_matches(request.headers.get("if-none-match"), public_card_etag): + return Response( + status_code=304, + headers={ + "ETag": public_card_etag, + "Cache-Control": PUBLIC_AGENT_CARD_CACHE_CONTROL, + "Vary": "Accept-Encoding", + }, + ) + + response = await call_next(request) + if response.status_code != 200: + return response + + if is_public_card: + response.headers["ETag"] = public_card_etag + response.headers["Cache-Control"] = PUBLIC_AGENT_CARD_CACHE_CONTROL + response.headers["Vary"] = _merge_vary( + response.headers.get("Vary", ""), + "Accept-Encoding", + ) + return response + + response.headers["ETag"] = extended_card_etag + response.headers["Cache-Control"] = AUTHENTICATED_EXTENDED_CARD_CACHE_CONTROL + response.headers["Vary"] = _merge_vary( + response.headers.get("Vary", ""), + "Authorization", + "Accept-Encoding", + ) + if _etag_matches(request.headers.get("if-none-match"), extended_card_etag): + return Response(status_code=304, headers=dict(response.headers)) + return response + + @app.middleware("http") + async def enforce_request_body_limit(request: Request, call_next): + token: Token | None = None + limit = settings.a2a_max_request_body_bytes + if limit <= 0 or request.method not in {"POST", "PUT", "PATCH"}: + return await call_next(request) + + try: + _, token = await _get_request_body(request) + return await call_next(request) + except _RequestBodyTooLargeError as error: + return _request_body_too_large_response( + path=request.url.path, + method=request.method, + error=error, + ) + finally: + if token is not None: + _REQUEST_BODY_BYTES.reset(token) + + @app.middleware("http") + async def guard_rest_payload_shape(request: Request, call_next): + token: Token | None = None + if request.method != "POST" or request.url.path not in { + "/v1/message:send", + "/v1/message:stream", + }: + return await call_next(request) + + try: + body, token = await _get_request_body(request) + payload = _parse_json_body(body) + if _looks_like_jsonrpc_envelope(payload) or _looks_like_jsonrpc_message_payload( + payload + ): + return JSONResponse( + { + "error": ( + "Invalid HTTP+JSON payload for REST endpoint. " + "Use message.content with ROLE_* role values, or call " + "POST / with method=message/send or method=message/stream." + ) + }, + status_code=400, + ) + return await call_next(request) + except _RequestBodyTooLargeError as error: + return _request_body_too_large_response( + path=request.url.path, + method=request.method, + error=error, + ) + finally: + if token is not None: + _REQUEST_BODY_BYTES.reset(token) + + @app.middleware("http") + async def log_payloads(request: Request, call_next): + token: Token | None = None + if not settings.a2a_log_payloads: + return await call_next(request) + + try: + path = request.url.path + limit = settings.a2a_log_body_limit + content_type = _normalize_content_type(request.headers.get("content-type")) + content_length = _parse_content_length(request.headers.get("content-length")) + + sensitive_method: str | None = None + request_omit_reason: str | None = None + + if not _is_json_content_type(content_type): + request_omit_reason = f"non-json content-type={content_type or 'unknown'}" + elif limit > 0 and content_length is None: + request_omit_reason = f"missing content-length with limit={limit}" + elif limit > 0 and content_length is not None and content_length > limit: + request_omit_reason = f"content-length={content_length} exceeds limit={limit}" + else: + body, token = await _get_request_body(request) + payload = _parse_json_body(body) + sensitive_method = _detect_sensitive_extension_method(payload) + + if sensitive_method: + logger.debug( + "A2A request %s %s method=%s", + request.method, + path, + sensitive_method, + ) + else: + logger.debug( + "A2A request %s %s body=%s", + request.method, + path, + _decode_payload_preview(body, limit=limit), + ) + + if request_omit_reason: + logger.debug( + "A2A request %s %s body=[omitted %s]", + request.method, + path, + request_omit_reason, + ) + + response = await call_next(request) + if isinstance(response, StreamingResponse): + if sensitive_method: + logger.debug("A2A response %s streaming method=%s", path, sensitive_method) + else: + logger.debug("A2A response %s streaming", path) + return response + + response_body = getattr(response, "body", b"") or b"" + if sensitive_method: + logger.debug( + "A2A response %s status=%s bytes=%s method=%s", + path, + response.status_code, + len(response_body), + sensitive_method, + ) + return response + + if request_omit_reason: + logger.debug( + "A2A response %s status=%s bytes=%s body=[omitted request_%s]", + path, + response.status_code, + len(response_body), + request_omit_reason, + ) + return response + response_content_type = _normalize_content_type(response.headers.get("content-type")) + if not _is_json_content_type(response_content_type): + logger.debug( + "A2A response %s status=%s bytes=%s body=[omitted non-json content-type=%s]", + path, + response.status_code, + len(response_body), + response_content_type or "unknown", + ) + return response + + logger.debug( + "A2A response %s status=%s body=%s", + path, + response.status_code, + _decode_payload_preview(response_body, limit=limit), + ) + return response + except _RequestBodyTooLargeError as error: + return _request_body_too_large_response( + path=request.url.path, + method=request.method, + error=error, + ) + finally: + if token is not None: + _REQUEST_BODY_BYTES.reset(token) + + add_auth_middleware(app, settings) + + +def emit_stream_request_metrics(*, active_delta: float | None = None) -> None: + if active_delta is None: + emit_metric("a2a_stream_requests_total") + return + emit_metric("a2a_stream_active", active_delta) diff --git a/tests/execution/test_opencode_agent_session_binding.py b/tests/execution/test_opencode_agent_session_binding.py index 2175157..4d08c79 100644 --- a/tests/execution/test_opencode_agent_session_binding.py +++ b/tests/execution/test_opencode_agent_session_binding.py @@ -30,7 +30,7 @@ from opencode_a2a.execution.executor import OpencodeAgentExecutor from opencode_a2a.execution.tool_error_mapping import map_a2a_tool_exception from opencode_a2a.opencode_upstream_client import OpencodeMessage -from opencode_a2a.server import application as app_module +from opencode_a2a.server.client_manager import A2AClientManager from tests.support.helpers import ( DummyChatOpencodeUpstreamClient, DummyEventQueue, @@ -575,7 +575,7 @@ async def test_agent_a2a_call_uses_server_side_basic_auth_headers( ) monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_sdk_client)) - manager = app_module.A2AClientManager( + manager = A2AClientManager( SimpleNamespace( a2a_client_timeout_seconds=30.0, a2a_client_card_fetch_timeout_seconds=5.0, diff --git a/tests/server/test_a2a_client_manager.py b/tests/server/test_a2a_client_manager.py index be9938e..7ba8cb7 100644 --- a/tests/server/test_a2a_client_manager.py +++ b/tests/server/test_a2a_client_manager.py @@ -4,7 +4,7 @@ import pytest -from opencode_a2a.server import application as app_module +from opencode_a2a.server import client_manager as client_manager_module def _make_settings(**overrides: object) -> SimpleNamespace: @@ -40,9 +40,9 @@ def is_busy(self) -> bool: async def close(self) -> None: self.closed = True - monkeypatch.setattr(app_module, "A2AClient", _FakeClient) + monkeypatch.setattr(client_manager_module, "A2AClient", _FakeClient) - manager = app_module.A2AClientManager(_make_settings(a2a_client_cache_maxsize=2)) + manager = client_manager_module.A2AClientManager(_make_settings(a2a_client_cache_maxsize=2)) async with manager.borrow_client("http://peer-1"): pass @@ -75,9 +75,9 @@ def is_busy(self) -> bool: async def close(self) -> None: self.closed = True - monkeypatch.setattr(app_module, "A2AClient", _FakeClient) + monkeypatch.setattr(client_manager_module, "A2AClient", _FakeClient) - manager = app_module.A2AClientManager(_make_settings(a2a_client_cache_maxsize=1)) + manager = client_manager_module.A2AClientManager(_make_settings(a2a_client_cache_maxsize=1)) async with manager.borrow_client("http://peer-1") as first_client: first_client.busy = True @@ -111,9 +111,9 @@ def is_busy(self) -> bool: async def close(self) -> None: self.closed = True - monkeypatch.setattr(app_module, "A2AClient", _FakeClient) + monkeypatch.setattr(client_manager_module, "A2AClient", _FakeClient) - manager = app_module.A2AClientManager(_make_settings(a2a_client_cache_maxsize=1)) + manager = client_manager_module.A2AClientManager(_make_settings(a2a_client_cache_maxsize=1)) async with manager.borrow_client("http://peer-1"): async with manager.borrow_client("http://peer-2"): @@ -144,10 +144,12 @@ def is_busy(self) -> bool: async def close(self) -> None: self.closed = True - monkeypatch.setattr(app_module, "A2AClient", _FakeClient) + monkeypatch.setattr(client_manager_module, "A2AClient", _FakeClient) now = 100.0 - manager = app_module.A2AClientManager(_make_settings(a2a_client_cache_ttl_seconds=10.0)) + manager = client_manager_module.A2AClientManager( + _make_settings(a2a_client_cache_ttl_seconds=10.0) + ) manager._now = lambda: now async with manager.borrow_client("http://peer-1"): @@ -182,10 +184,12 @@ def is_busy(self) -> bool: async def close(self) -> None: self.closed = True - monkeypatch.setattr(app_module, "A2AClient", _FakeClient) + monkeypatch.setattr(client_manager_module, "A2AClient", _FakeClient) now = 100.0 - manager = app_module.A2AClientManager(_make_settings(a2a_client_cache_ttl_seconds=10.0)) + manager = client_manager_module.A2AClientManager( + _make_settings(a2a_client_cache_ttl_seconds=10.0) + ) manager._now = lambda: now async with manager.borrow_client("http://peer-1") as first_client: @@ -201,6 +205,8 @@ async def close(self) -> None: def test_client_manager_loads_basic_auth_into_client_settings() -> None: - manager = app_module.A2AClientManager(_make_settings(a2a_client_basic_auth="user:pass")) + manager = client_manager_module.A2AClientManager( + _make_settings(a2a_client_basic_auth="user:pass") + ) assert manager.client_settings.basic_auth == "user:pass"