diff --git a/docs/guide.md b/docs/guide.md index 509c8e1..d16c112 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -51,8 +51,8 @@ Key variables to understand protocol behavior: - `A2A_CLIENT_BEARER_TOKEN`: optional bearer token attached to outbound peer calls made by the embedded A2A client and `a2a_call` tool path. - `A2A_CLIENT_BASIC_AUTH`: optional Basic auth credential attached to outbound peer calls made by the embedded A2A client and `a2a_call` tool path. - `A2A_CLIENT_SUPPORTED_TRANSPORTS`: ordered outbound transport preference list. -- `A2A_TASK_STORE_BACKEND`: runtime state backend. Supported values: `database`, `memory`. Default: `database`. -- `A2A_TASK_STORE_DATABASE_URL`: database URL used by the default durable backend. Default: `sqlite+aiosqlite:///./opencode-a2a.db`. +- `A2A_TASK_STORE_BACKEND`: unified lightweight persistence backend for SDK task rows plus adapter-managed session / interrupt state. Supported values: `database`, `memory`. Default: `database`. +- `A2A_TASK_STORE_DATABASE_URL`: database URL used by the unified durable backend when `A2A_TASK_STORE_BACKEND=database`. Default: `sqlite+aiosqlite:///./opencode-a2a.db`. - Runtime authentication is bearer-token only via `A2A_BEARER_TOKEN`. - Runtime authentication also applies to `/health`; the public unauthenticated discovery surface remains `/.well-known/agent-card.json` and `/.well-known/agent.json`. - The authenticated extended card endpoint `/agent/authenticatedExtendedCard` is bearer-token protected. @@ -139,15 +139,22 @@ A2A_TASK_STORE_DATABASE_URL=sqlite+aiosqlite:///./opencode-a2a.db \ opencode-a2a ``` -With the default `database` backend, the service persists: +With the default `database` backend, the unified lightweight persistence layer persists: - task records - session binding / ownership state +- pending preferred-session claims - interrupt request bindings and tombstones -The runtime automatically applies lightweight schema migrations for its custom state tables and records the applied version in `a2a_schema_version`. This built-in path currently targets the local SQLite deployment profile and does not require Alembic. +This project is SQLite-first for local single-instance deployments. The runtime configures local durability-oriented SQLite connection settings (`WAL`, `busy_timeout`, `synchronous=NORMAL`) and creates missing parent directories for file-backed database paths. -The A2A SDK task table remains managed by the SDK's own `DatabaseTaskStore` initialization path. The internal migration runner only owns the additional `opencode-a2a` state tables listed above. +The runtime automatically applies lightweight schema migrations for its custom state tables and records the applied version in `a2a_schema_version`. Schema-version writes are idempotent across concurrent first-start races, pending preferred-session claims now persist absolute `expires_at` timestamps while remaining backward-compatible with legacy `updated_at` rows, and the built-in path currently targets the local SQLite deployment profile without requiring Alembic. + +Database-backed task persistence also keeps the existing first-terminal-state-wins contract while tightening the SQLite path with an atomic terminal-write guard instead of relying only on process-local read-before-write checks. Any wider SQLAlchemy dialect compatibility should be treated as incidental implementation latitude rather than a documented deployment target. + +At startup, the runtime logs a concise persistence summary covering the active backend, the redacted database URL when applicable, the shared persistence scope, and whether the SQLite local durability profile is active. + +The A2A SDK task table remains managed by the SDK's own `DatabaseTaskStore` initialization path. The internal migration runner only owns the additional `opencode-a2a` state tables listed above, but both layers still share the same configured lightweight persistence backend. In-flight asyncio locks, outbound A2A client caches, and stream-local aggregation buffers remain process-local runtime state. diff --git a/src/opencode_a2a/server/application.py b/src/opencode_a2a/server/application.py index e0157fa..c512446 100644 --- a/src/opencode_a2a/server/application.py +++ b/src/opencode_a2a/server/application.py @@ -109,6 +109,7 @@ TaskStoreOperationError, build_database_engine, build_task_store, + describe_lightweight_persistence_backend, ) logger = logging.getLogger(__name__) @@ -595,6 +596,7 @@ def create_app(settings: Settings) -> FastAPI: ) public_card_etag = build_agent_card_etag(agent_card) extended_card_etag = build_agent_card_etag(extended_agent_card) + persistence_summary = describe_lightweight_persistence_backend(settings) lifespan = build_lifespan( database_engine=database_engine, task_store=task_store, @@ -602,6 +604,7 @@ def create_app(settings: Settings) -> FastAPI: interrupt_request_repository=interrupt_request_repository, client_manager=client_manager, upstream_client=upstream_client, + persistence_summary=persistence_summary, ) app = A2AFastAPI( @@ -615,6 +618,7 @@ def create_app(settings: Settings) -> FastAPI: app.add_api_route(route[0], callback, methods=[route[1]]) app.state._jsonrpc_app = jsonrpc_app app.state.task_store = task_store + app.state.persistence_summary = persistence_summary app.state.agent_executor = executor app.state.upstream_client = upstream_client app.state.a2a_client_manager = client_manager diff --git a/src/opencode_a2a/server/lifespan.py b/src/opencode_a2a/server/lifespan.py index 334fcbd..7cb78ca 100644 --- a/src/opencode_a2a/server/lifespan.py +++ b/src/opencode_a2a/server/lifespan.py @@ -1,10 +1,14 @@ from __future__ import annotations +import logging +from collections.abc import Mapping from contextlib import asynccontextmanager from .state_store import initialize_state_repository from .task_store import initialize_task_store +logger = logging.getLogger(__name__) + def build_lifespan( *, @@ -14,9 +18,19 @@ def build_lifespan( interrupt_request_repository, client_manager, upstream_client, + persistence_summary: Mapping[str, object] | None = None, ): @asynccontextmanager async def lifespan(_app): + if persistence_summary is not None: + logger.info( + "Lightweight persistence configured backend=%s scope=%s " + "database_url=%s sqlite_tuning=%s", + persistence_summary.get("backend", "unknown"), + persistence_summary.get("scope", "unknown"), + persistence_summary.get("database_url", "n/a"), + persistence_summary.get("sqlite_tuning", "not_applicable"), + ) await initialize_task_store(task_store) await initialize_state_repository(session_state_repository) await initialize_state_repository(interrupt_request_repository) diff --git a/src/opencode_a2a/server/migrations.py b/src/opencode_a2a/server/migrations.py index fbad028..07c5bf6 100644 --- a/src/opencode_a2a/server/migrations.py +++ b/src/opencode_a2a/server/migrations.py @@ -1,10 +1,11 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Mapping from typing import TYPE_CHECKING from sqlalchemy import ( Column, + Index, Integer, MetaData, String, @@ -15,12 +16,13 @@ text, update, ) +from sqlalchemy.exc import IntegrityError if TYPE_CHECKING: from sqlalchemy.engine import Connection STATE_STORE_SCHEMA_NAME = "state_store" -CURRENT_STATE_STORE_SCHEMA_VERSION = 1 +CURRENT_STATE_STORE_SCHEMA_VERSION = 3 _SCHEMA_VERSION_METADATA = MetaData() @@ -65,54 +67,158 @@ def _migration_1_add_interrupt_details_json( ) -def _read_schema_version(connection: Connection, *, name: str) -> int: +def _migration_2_add_pending_claim_expires_at( + connection: Connection, + *, + pending_session_claims_table: Table, +) -> None: + _add_missing_nullable_column( + connection, + table=pending_session_claims_table, + column_name="expires_at", + ) + + +def _create_missing_index( + connection: Connection, + *, + index: Index, +) -> None: + table = index.table + if table is None: + raise RuntimeError("State-store index is missing table metadata") + existing_indexes = { + existing_index["name"] for existing_index in inspect(connection).get_indexes(table.name) + } + if index.name in existing_indexes: + return + index.create(connection) + + +def _migration_3_add_lightweight_state_indexes( + connection: Connection, + *, + pending_session_claims_table: Table, + interrupt_requests_table: Table, +) -> None: + indexes = sorted( + [ + *pending_session_claims_table.indexes, + *interrupt_requests_table.indexes, + ], + key=lambda index: index.name or "", + ) + for index in indexes: + _create_missing_index(connection, index=index) + + +def _read_schema_version( + connection: Connection, + *, + version_table: Table, + scope: str, +) -> int | None: result = connection.execute( - select(_SCHEMA_VERSIONS.c.version).where(_SCHEMA_VERSIONS.c.name == name) + select(version_table.c.version).where(version_table.c.name == scope) ) version = result.scalar_one_or_none() - return int(version) if version is not None else 0 + return int(version) if version is not None else None -def _write_schema_version(connection: Connection, *, name: str, version: int) -> None: - exists = connection.execute( - select(_SCHEMA_VERSIONS.c.name).where(_SCHEMA_VERSIONS.c.name == name) - ).scalar_one_or_none() - if exists is None: - connection.execute(insert(_SCHEMA_VERSIONS).values(name=name, version=version)) +def _write_schema_version( + connection: Connection, + *, + version_table: Table, + scope: str, + version: int, +) -> None: + existing_version = _read_schema_version( + connection, + version_table=version_table, + scope=scope, + ) + if existing_version is not None: + connection.execute( + update(version_table).where(version_table.c.name == scope).values(version=version) + ) return - connection.execute( - update(_SCHEMA_VERSIONS).where(_SCHEMA_VERSIONS.c.name == name).values(version=version) + try: + connection.execute(insert(version_table).values(name=scope, version=version)) + except IntegrityError: + connection.execute( + update(version_table).where(version_table.c.name == scope).values(version=version) + ) + + +def _apply_schema_migrations( + connection: Connection, + *, + version_table: Table, + scope: str, + current_version: int, + migrations: Mapping[int, Callable[[Connection], None]], +) -> int: + if current_version < 0: + raise ValueError("current_version must be non-negative") + + stored_version = _read_schema_version( + connection, + version_table=version_table, + scope=scope, ) + if stored_version is not None and stored_version > current_version: + raise RuntimeError( + f"Database schema scope {scope!r} is newer than this application supports" + ) + + starting_version = stored_version or 0 + for next_version in range(starting_version + 1, current_version + 1): + migration = migrations.get(next_version) + if migration is None: + raise RuntimeError( + f"Missing migration for schema scope {scope!r} version {next_version}" + ) + migration(connection) + _write_schema_version( + connection, + version_table=version_table, + scope=scope, + version=next_version, + ) + + return current_version def migrate_state_store_schema( connection: Connection, *, state_metadata: MetaData, + pending_session_claims_table: Table, interrupt_requests_table: Table, current_version: int = CURRENT_STATE_STORE_SCHEMA_VERSION, ) -> int: _SCHEMA_VERSION_METADATA.create_all(connection) state_metadata.create_all(connection) - stored_version = _read_schema_version(connection, name=STATE_STORE_SCHEMA_NAME) - if stored_version > current_version: - raise RuntimeError( - "Database state-store schema version is newer than this application supports" - ) - migrations: dict[int, Callable[[Connection], None]] = { 1: lambda conn: _migration_1_add_interrupt_details_json( conn, interrupt_requests_table=interrupt_requests_table, ), + 2: lambda conn: _migration_2_add_pending_claim_expires_at( + conn, + pending_session_claims_table=pending_session_claims_table, + ), + 3: lambda conn: _migration_3_add_lightweight_state_indexes( + conn, + pending_session_claims_table=pending_session_claims_table, + interrupt_requests_table=interrupt_requests_table, + ), } - - for next_version in range(stored_version + 1, current_version + 1): - migration = migrations.get(next_version) - if migration is None: - raise RuntimeError(f"Missing state-store migration for version {next_version}") - migration(connection) - _write_schema_version(connection, name=STATE_STORE_SCHEMA_NAME, version=next_version) - - return current_version + return _apply_schema_migrations( + connection, + version_table=_SCHEMA_VERSIONS, + scope=STATE_STORE_SCHEMA_NAME, + current_version=current_version, + migrations=migrations, + ) diff --git a/src/opencode_a2a/server/state_store.py b/src/opencode_a2a/server/state_store.py index 1208015..a26168f 100644 --- a/src/opencode_a2a/server/state_store.py +++ b/src/opencode_a2a/server/state_store.py @@ -3,21 +3,24 @@ import json import time from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, cast from sqlalchemy import ( Column, Float, + Index, MetaData, String, Table, and_, delete, insert, + or_, select, update, ) +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from ..config import Settings @@ -50,7 +53,8 @@ _STATE_METADATA, Column("session_id", String, primary_key=True), Column("identity", String, nullable=False), - Column("updated_at", Float, nullable=False), + Column("updated_at", Float, nullable=True), + Column("expires_at", Float, nullable=True), ) _INTERRUPT_REQUESTS = Table( @@ -67,18 +71,70 @@ Column("tombstone_expires_at", Float, nullable=True), ) +Index( + "ix_a2a_pending_session_claims_expires_at", + _PENDING_SESSION_CLAIMS.c.expires_at, +) +Index( + "ix_a2a_interrupt_requests_identity_expires_at", + _INTERRUPT_REQUESTS.c.identity, + _INTERRUPT_REQUESTS.c.expires_at, +) +Index( + "ix_a2a_interrupt_requests_identity_type_expires_at", + _INTERRUPT_REQUESTS.c.identity, + _INTERRUPT_REQUESTS.c.interrupt_type, + _INTERRUPT_REQUESTS.c.expires_at, +) +Index( + "ix_a2a_interrupt_requests_tombstone_expires_at", + _INTERRUPT_REQUESTS.c.tombstone_expires_at, +) + _MEMORY_SESSION_BINDING_TTL_SECONDS = 3600 _MEMORY_SESSION_BINDING_MAXSIZE = 10_000 +async def _insert_then_update_on_conflict( + session: AsyncSession, + *, + table: Table, + key_values: Mapping[str, Any], + update_values: Mapping[str, Any], +) -> None: + values = {**key_values, **update_values} + try: + await session.execute(insert(table).values(**values)) + except IntegrityError: + stmt = update(table) + for key, value in key_values.items(): + stmt = stmt.where(table.c[key] == value) + await session.execute(stmt.values(**update_values)) + + def _initialize_state_store_schema(connection) -> None: # noqa: ANN001 migrate_state_store_schema( connection, state_metadata=_STATE_METADATA, + pending_session_claims_table=_PENDING_SESSION_CLAIMS, interrupt_requests_table=_INTERRUPT_REQUESTS, ) +def _pending_claim_expires_at( + row: Mapping[str, Any], + *, + legacy_ttl_seconds: float, +) -> float | None: + expires_at = row.get("expires_at") + if expires_at is not None: + return float(expires_at) + updated_at = row.get("updated_at") + if updated_at is None: + return None + return float(updated_at) + max(0.0, legacy_ttl_seconds) + + class SessionStateRepository(ABC): @abstractmethod async def get_session(self, *, identity: str, context_id: str) -> str | None: ... @@ -244,10 +300,24 @@ async def _prune_expired_pending_claims( if self._pending_claim_ttl_seconds <= 0: await session.execute(delete(_PENDING_SESSION_CLAIMS)) return - expires_before = now - self._pending_claim_ttl_seconds await session.execute( delete(_PENDING_SESSION_CLAIMS).where( - _PENDING_SESSION_CLAIMS.c.updated_at <= expires_before + or_( + and_( + _PENDING_SESSION_CLAIMS.c.expires_at.is_not(None), + _PENDING_SESSION_CLAIMS.c.expires_at <= now, + ), + and_( + _PENDING_SESSION_CLAIMS.c.expires_at.is_(None), + _PENDING_SESSION_CLAIMS.c.updated_at.is_not(None), + _PENDING_SESSION_CLAIMS.c.updated_at + <= now - self._pending_claim_ttl_seconds, + ), + and_( + _PENDING_SESSION_CLAIMS.c.expires_at.is_(None), + _PENDING_SESSION_CLAIMS.c.updated_at.is_(None), + ), + ) ) ) @@ -267,33 +337,15 @@ async def get_session(self, *, identity: str, context_id: str) -> str | None: async def set_session(self, *, identity: str, context_id: str, session_id: str) -> None: await self._ensure_initialized() async with self._session_maker.begin() as session: - exists = await session.execute( - select(_SESSION_BINDINGS.c.session_id).where( - and_( - _SESSION_BINDINGS.c.identity == identity, - _SESSION_BINDINGS.c.context_id == context_id, - ) - ) + await _insert_then_update_on_conflict( + session, + table=_SESSION_BINDINGS, + key_values={ + "identity": identity, + "context_id": context_id, + }, + update_values={"session_id": session_id}, ) - if exists.scalar_one_or_none() is None: - await session.execute( - insert(_SESSION_BINDINGS).values( - identity=identity, - context_id=context_id, - session_id=session_id, - ) - ) - else: - await session.execute( - update(_SESSION_BINDINGS) - .where( - and_( - _SESSION_BINDINGS.c.identity == identity, - _SESSION_BINDINGS.c.context_id == context_id, - ) - ) - .values(session_id=session_id) - ) async def pop_session(self, *, identity: str, context_id: str) -> None: await self._ensure_initialized() @@ -318,24 +370,12 @@ async def get_owner(self, *, session_id: str) -> str | None: async def set_owner(self, *, session_id: str, identity: str) -> None: await self._ensure_initialized() async with self._session_maker.begin() as session: - exists = await session.execute( - select(_SESSION_OWNERS.c.session_id).where( - _SESSION_OWNERS.c.session_id == session_id - ) + await _insert_then_update_on_conflict( + session, + table=_SESSION_OWNERS, + key_values={"session_id": session_id}, + update_values={"identity": identity}, ) - if exists.scalar_one_or_none() is None: - await session.execute( - insert(_SESSION_OWNERS).values( - session_id=session_id, - identity=identity, - ) - ) - else: - await session.execute( - update(_SESSION_OWNERS) - .where(_SESSION_OWNERS.c.session_id == session_id) - .values(identity=identity) - ) async def get_pending_claim(self, *, session_id: str) -> str | None: await self._ensure_initialized() @@ -343,11 +383,25 @@ async def get_pending_claim(self, *, session_id: str) -> str | None: async with self._session_maker.begin() as session: await self._prune_expired_pending_claims(session, now=now) result = await session.execute( - select(_PENDING_SESSION_CLAIMS.c.identity).where( + select(_PENDING_SESSION_CLAIMS).where( _PENDING_SESSION_CLAIMS.c.session_id == session_id ) ) - return cast("str | None", result.scalar_one_or_none()) + row = cast("Mapping[str, Any] | None", result.mappings().one_or_none()) + if row is None: + return None + expires_at = _pending_claim_expires_at( + row, + legacy_ttl_seconds=self._pending_claim_ttl_seconds, + ) + if expires_at is None or expires_at <= now: + await session.execute( + delete(_PENDING_SESSION_CLAIMS).where( + _PENDING_SESSION_CLAIMS.c.session_id == session_id + ) + ) + return None + return cast("str", row["identity"]) async def set_pending_claim(self, *, session_id: str, identity: str) -> None: await self._ensure_initialized() @@ -361,22 +415,16 @@ async def set_pending_claim(self, *, session_id: str, identity: str) -> None: ) ) return - exists = await session.execute( - select(_PENDING_SESSION_CLAIMS.c.session_id).where( - _PENDING_SESSION_CLAIMS.c.session_id == session_id - ) + await _insert_then_update_on_conflict( + session, + table=_PENDING_SESSION_CLAIMS, + key_values={"session_id": session_id}, + update_values={ + "identity": identity, + "updated_at": now, + "expires_at": now + self._pending_claim_ttl_seconds, + }, ) - values = {"identity": identity, "updated_at": now} - if exists.scalar_one_or_none() is None: - await session.execute( - insert(_PENDING_SESSION_CLAIMS).values(session_id=session_id, **values) - ) - else: - await session.execute( - update(_PENDING_SESSION_CLAIMS) - .where(_PENDING_SESSION_CLAIMS.c.session_id == session_id) - .values(**values) - ) async def clear_pending_claim( self, @@ -603,31 +651,21 @@ async def remember( expires_at = now + max(0.0, float(ttl)) async with self._session_maker.begin() as session: await self._prune_tombstones(session, now=now) - exists = await session.execute( - select(_INTERRUPT_REQUESTS.c.request_id).where( - _INTERRUPT_REQUESTS.c.request_id == request_id - ) + await _insert_then_update_on_conflict( + session, + table=_INTERRUPT_REQUESTS, + key_values={"request_id": request_id}, + update_values={ + "session_id": session_id, + "interrupt_type": interrupt_type, + "identity": identity, + "task_id": task_id, + "context_id": context_id, + "details_json": self._encode_details(details), + "expires_at": expires_at, + "tombstone_expires_at": None, + }, ) - values = { - "session_id": session_id, - "interrupt_type": interrupt_type, - "identity": identity, - "task_id": task_id, - "context_id": context_id, - "details_json": self._encode_details(details), - "expires_at": expires_at, - "tombstone_expires_at": None, - } - if exists.scalar_one_or_none() is None: - await session.execute( - insert(_INTERRUPT_REQUESTS).values(request_id=request_id, **values) - ) - else: - await session.execute( - update(_INTERRUPT_REQUESTS) - .where(_INTERRUPT_REQUESTS.c.request_id == request_id) - .values(**values) - ) async def resolve( self, diff --git a/src/opencode_a2a/server/task_store.py b/src/opencode_a2a/server/task_store.py index 1713495..8d00f56 100644 --- a/src/opencode_a2a/server/task_store.py +++ b/src/opencode_a2a/server/task_store.py @@ -1,13 +1,20 @@ from __future__ import annotations +import asyncio import logging from abc import ABC, abstractmethod from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING, Any, cast +from a2a.server.tasks.database_task_store import DatabaseTaskStore from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_store import TaskStore from a2a.types import Task, TaskState +from sqlalchemy import event, or_, select +from sqlalchemy.dialects.postgresql import insert as postgresql_insert +from sqlalchemy.dialects.sqlite import insert as sqlite_insert +from sqlalchemy.engine import make_url from ..config import Settings @@ -25,6 +32,11 @@ TaskState.rejected, } ) +_TERMINAL_TASK_STATE_VALUES = tuple(state.value for state in _TERMINAL_TASK_STATES) +_ATOMIC_TERMINAL_GUARD_DIALECTS = frozenset({"postgresql", "sqlite"}) +_SQLITE_JOURNAL_MODE = "WAL" +_SQLITE_BUSY_TIMEOUT_MS = 30_000 +_SQLITE_SYNCHRONOUS_MODE = "NORMAL" class TaskStoreOperationError(RuntimeError): @@ -110,6 +122,8 @@ async def save( ) -> None: try: await self._inner.save(task, context) + except TaskStoreOperationError: + raise except Exception as exc: raise TaskStoreOperationError("save", task.id) from exc @@ -120,6 +134,8 @@ async def get( ) -> Task | None: try: return await self._inner.get(task_id, context) + except TaskStoreOperationError: + raise except Exception as exc: raise TaskStoreOperationError("get", task_id) from exc @@ -130,6 +146,8 @@ async def delete( ) -> None: try: await self._inner.delete(task_id, context) + except TaskStoreOperationError: + raise except Exception as exc: raise TaskStoreOperationError("delete", task_id) from exc @@ -143,27 +161,128 @@ def __init__( ) -> None: super().__init__(inner) self._write_policy = write_policy or FirstTerminalStateWinsPolicy() + self._save_lock = asyncio.Lock() + self._atomic_guard_fallback_logged = False async def save( self, task: Task, context: ServerCallContext | None = None, ) -> None: - existing = await self._inner.get(task.id, context) - decision = self._write_policy.evaluate(existing=existing, incoming=task) - if existing is not None and existing.status.state in _TERMINAL_TASK_STATES: - logger.warning( - "Received task persistence after terminal state task_id=%s existing_state=%s " - "incoming_state=%s persist=%s reason=%s", - task.id, - existing.status.state, - task.status.state, - decision.persist, - decision.reason or "accepted_duplicate", + raw_task_store = unwrap_task_store(self._inner) + if isinstance(raw_task_store, DatabaseTaskStore): + await self._save_database_task(raw_task_store, task, context) + return + await self._save_with_read_before_write(task, context) + + async def _save_with_read_before_write( + self, + task: Task, + context: ServerCallContext | None = None, + ) -> None: + async with self._save_lock: + existing = await self._inner.get(task.id, context) + decision = self._write_policy.evaluate(existing=existing, incoming=task) + self._log_terminal_persistence_decision( + existing=existing, + incoming=task, + decision=decision, ) - if not decision.persist: + if not decision.persist: + return + await self._inner.save(task, context) + + async def _save_database_task( + self, + task_store: DatabaseTaskStore, + task: Task, + context: ServerCallContext | None = None, + ) -> None: + dialect_name = task_store.engine.dialect.name + if dialect_name not in _ATOMIC_TERMINAL_GUARD_DIALECTS: + if not self._atomic_guard_fallback_logged: + logger.warning( + "Database-backed task store dialect does not support atomic terminal guard; " + "falling back to read-before-write policy dialect=%s", + dialect_name, + ) + self._atomic_guard_fallback_logged = True + await self._save_with_read_before_write(task, context) return - await self._inner.save(task, context) + + try: + if await self._persist_with_atomic_terminal_guard(task_store, task): + return + existing = await self._load_task_from_database(task_store, task.id) + decision = self._write_policy.evaluate(existing=existing, incoming=task) + self._log_terminal_persistence_decision( + existing=existing, + incoming=task, + decision=decision, + ) + if not decision.persist: + return + if ( + existing is not None + and existing.status.state in _TERMINAL_TASK_STATES + and existing.model_dump(mode="json") == task.model_dump(mode="json") + ): + return + raise RuntimeError( + "Atomic task persistence was skipped without an authoritative terminal task." + ) + except TaskStoreOperationError: + raise + except Exception as exc: + raise TaskStoreOperationError("save", task.id) from exc + + async def _persist_with_atomic_terminal_guard( + self, + task_store: DatabaseTaskStore, + task: Task, + ) -> bool: + await task_store._ensure_initialized() + statement = _build_atomic_task_save_statement( + task=task, + task_table=task_store.task_model.__table__, + dialect_name=task_store.engine.dialect.name, + ) + async with task_store.async_session_maker.begin() as session: + result = await session.execute(statement) + return result.scalar_one_or_none() is not None + + async def _load_task_from_database( + self, + task_store: DatabaseTaskStore, + task_id: str, + ) -> Task | None: + await task_store._ensure_initialized() + async with task_store.async_session_maker() as session: + stmt = select(task_store.task_model).where(task_store.task_model.id == task_id) + result = await session.execute(stmt) + task_model = result.scalar_one_or_none() + if task_model is None: + return None + return task_store._from_orm(task_model) + + def _log_terminal_persistence_decision( + self, + *, + existing: Task | None, + incoming: Task, + decision: TaskPersistenceDecision, + ) -> None: + if existing is None or existing.status.state not in _TERMINAL_TASK_STATES: + return + logger.warning( + "Received task persistence after terminal state task_id=%s existing_state=%s " + "incoming_state=%s persist=%s reason=%s", + incoming.id, + existing.status.state, + incoming.status.state, + decision.persist, + decision.reason or "accepted_duplicate", + ) class GuardedTaskStore(PolicyAwareTaskStore): @@ -184,8 +303,6 @@ def build_task_store( *, engine: AsyncEngine | None = None, ) -> TaskStore: - from a2a.server.tasks.database_task_store import DatabaseTaskStore - if settings.a2a_task_store_backend == "memory": return GuardedTaskStore(InMemoryTaskStore()) @@ -197,11 +314,104 @@ def build_task_store( ) +def describe_lightweight_persistence_backend(settings: Settings) -> dict[str, str]: + summary = { + "backend": settings.a2a_task_store_backend, + "scope": "sdk_tasks_and_adapter_state", + } + if settings.a2a_task_store_backend != "database": + return summary + url = make_url(cast(str, settings.a2a_task_store_database_url)) + summary["database_url"] = url.render_as_string(hide_password=True) + summary["sqlite_tuning"] = ( + "local_durability_defaults" if url.drivername.startswith("sqlite") else "not_applicable" + ) + return summary + + def build_database_engine(settings: Settings) -> AsyncEngine: from sqlalchemy.ext.asyncio import create_async_engine database_url = cast(str, settings.a2a_task_store_database_url) - return create_async_engine(database_url) + url = make_url(database_url) + if url.drivername.startswith("sqlite"): + database_path = url.database + if database_path and database_path != ":memory:" and not database_path.startswith("file:"): + path = Path(database_path) + if not path.is_absolute(): + path = (Path.cwd() / path).resolve() + path.parent.mkdir(parents=True, exist_ok=True) + + engine = create_async_engine( + database_url, + pool_pre_ping=not url.drivername.startswith("sqlite"), + ) + if url.drivername.startswith("sqlite"): + event.listen(engine.sync_engine, "connect", _configure_sqlite_connection) + return engine + + +def unwrap_task_store(task_store: TaskStore) -> TaskStore: + inner = getattr(task_store, "_inner", None) + if isinstance(inner, TaskStore): + return unwrap_task_store(inner) + return task_store + + +def _configure_sqlite_connection(dbapi_connection: Any, _connection_record: Any) -> None: + cursor = dbapi_connection.cursor() + try: + cursor.execute(f"PRAGMA journal_mode={_SQLITE_JOURNAL_MODE}") + cursor.execute(f"PRAGMA busy_timeout={_SQLITE_BUSY_TIMEOUT_MS}") + cursor.execute(f"PRAGMA synchronous={_SQLITE_SYNCHRONOUS_MODE}") + finally: + cursor.close() + + +def _build_atomic_task_save_statement( + *, + task: Task, + task_table: Any, + dialect_name: str, +): + insert = _resolve_atomic_insert_factory(dialect_name) + values = _task_row_values(task) + status_state = task_table.c.status["state"].as_string() + persist_guard = or_( + task_table.c.status.is_(None), + status_state.is_(None), + status_state.not_in(_TERMINAL_TASK_STATE_VALUES), + ) + return ( + insert(task_table) + .values(**values) + .on_conflict_do_update( + index_elements=[task_table.c.id], + set_={key: value for key, value in values.items() if key != "id"}, + where=persist_guard, + ) + .returning(task_table.c.id) + ) + + +def _resolve_atomic_insert_factory(dialect_name: str): + if dialect_name == "sqlite": + return sqlite_insert + if dialect_name == "postgresql": + return postgresql_insert + raise ValueError(f"Unsupported atomic task persistence dialect: {dialect_name}") + + +def _task_row_values(task: Task) -> dict[str, Any]: + return { + "id": task.id, + "context_id": task.context_id, + "kind": task.kind, + "status": task.status, + "artifacts": task.artifacts, + "history": task.history, + "metadata": task.metadata, + } async def initialize_task_store(task_store: TaskStore) -> None: diff --git a/tests/server/test_app_behaviors.py b/tests/server/test_app_behaviors.py index 2f2e0ca..20a1b99 100644 --- a/tests/server/test_app_behaviors.py +++ b/tests/server/test_app_behaviors.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging import types from unittest.mock import AsyncMock, MagicMock @@ -238,7 +239,7 @@ def test_agent_card_helper_builders_cover_optional_branches() -> None: @pytest.mark.asyncio -async def test_auth_health_lifespan_and_openapi_cache(monkeypatch) -> None: +async def test_auth_health_lifespan_and_openapi_cache(monkeypatch, caplog) -> None: class _ClosableClient(DummyChatOpencodeUpstreamClient): def __init__(self, settings=None) -> None: super().__init__(settings) @@ -327,9 +328,16 @@ async def close(self) -> None: }, } - async with app.router.lifespan_context(app): - pass + with caplog.at_level(logging.INFO, logger="opencode_a2a.server.lifespan"): + async with app.router.lifespan_context(app): + pass assert closable.closed is True + assert any( + "Lightweight persistence configured" in record.message + and "backend=database" in record.message + and "scope=sdk_tasks_and_adapter_state" in record.message + for record in caplog.records + ) openapi_first = app.openapi() openapi_second = app.openapi() diff --git a/tests/server/test_state_store.py b/tests/server/test_state_store.py index a98223a..cfe4b16 100644 --- a/tests/server/test_state_store.py +++ b/tests/server/test_state_store.py @@ -3,10 +3,12 @@ from pathlib import Path import pytest -from sqlalchemy import text +from sqlalchemy import inspect, text from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect +from sqlalchemy.exc import IntegrityError import opencode_a2a.server.migrations as migrations_module +import opencode_a2a.server.state_store as state_store_module from opencode_a2a.server.migrations import CURRENT_STATE_STORE_SCHEMA_VERSION from opencode_a2a.server.state_store import ( _INTERRUPT_REQUESTS, @@ -37,6 +39,31 @@ async def _read_state_store_schema_row_count(engine) -> int: # noqa: ANN001 return int(result.scalar_one()) +async def _read_table_index_names(engine, table_name: str) -> set[str]: # noqa: ANN001 + async with engine.begin() as conn: + return await conn.run_sync( + lambda sync_conn: { + index["name"] for index in inspect(sync_conn).get_indexes(table_name) + } + ) + + +async def _read_pending_claim_row(engine, session_id: str) -> dict[str, object] | None: # noqa: ANN001 + async with engine.begin() as conn: + result = await conn.execute( + text( + """ + SELECT session_id, identity, updated_at, expires_at + FROM a2a_pending_session_claims + WHERE session_id = :session_id + """ + ), + {"session_id": session_id}, + ) + row = result.mappings().one_or_none() + return None if row is None else dict(row) + + def test_add_missing_nullable_column_supports_non_sqlite_dialects(monkeypatch) -> None: executed: list[str] = [] @@ -62,6 +89,69 @@ def execute(self, clause) -> None: # noqa: ANN001 assert executed == ["ALTER TABLE a2a_interrupt_requests ADD COLUMN details_json VARCHAR"] +def test_write_schema_version_recovers_from_concurrent_first_insert_race() -> None: + executed: list[str] = [] + + class _FakeResult: + def __init__(self, value: int | None) -> None: + self._value = value + + def scalar_one_or_none(self) -> int | None: + return self._value + + class _FakeConnection: + def execute(self, clause): # noqa: ANN001 + executed.append(clause.__visit_name__) + if clause.__visit_name__ == "select": + return _FakeResult(None) + if clause.__visit_name__ == "insert": + raise IntegrityError("insert", {}, Exception("duplicate key")) + if clause.__visit_name__ == "update": + return None + raise AssertionError(f"Unexpected clause type: {clause.__visit_name__}") + + migrations_module._write_schema_version( + _FakeConnection(), + version_table=migrations_module._SCHEMA_VERSIONS, + scope=migrations_module.STATE_STORE_SCHEMA_NAME, + version=1, + ) + + assert executed == ["select", "insert", "update"] + + +@pytest.mark.asyncio +async def test_state_store_write_helper_recovers_from_concurrent_first_insert_race() -> None: + executed: list[str] = [] + + class _FakeSession: + async def execute(self, clause): # noqa: ANN001 + executed.append(clause.__visit_name__) + if clause.__visit_name__ == "insert": + raise IntegrityError("insert", {}, Exception("duplicate key")) + if clause.__visit_name__ == "update": + return None + raise AssertionError(f"Unexpected clause type: {clause.__visit_name__}") + + await state_store_module._insert_then_update_on_conflict( + _FakeSession(), + table=_INTERRUPT_REQUESTS, + key_values={"request_id": "perm-1"}, + update_values={ + "session_id": "ses-1", + "interrupt_type": "permission", + "identity": "user-1", + "task_id": "task-1", + "context_id": "ctx-1", + "details_json": None, + "expires_at": 1.0, + "tombstone_expires_at": None, + }, + ) + + assert executed == ["insert", "update"] + + @pytest.mark.asyncio async def test_database_session_state_repository_persists_bindings(tmp_path: Path) -> None: database_url = f"sqlite+aiosqlite:///{tmp_path / 'state.db'}" @@ -140,6 +230,49 @@ def _now() -> float: await engine.dispose() +@pytest.mark.asyncio +async def test_database_pending_session_claim_keeps_absolute_expiry_when_runtime_ttl_changes( + tmp_path: Path, +) -> None: + now = 100.0 + + def _now() -> float: + return now + + database_url = f"sqlite+aiosqlite:///{tmp_path / 'pending-claim-expires-at.db'}" + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_database_url=database_url, + ) + engine = build_database_engine(settings) + writer = DatabaseSessionStateRepository( + engine=engine, + pending_claim_ttl_seconds=5.0, + clock=_now, + ) + await initialize_state_repository(writer) + await writer.set_pending_claim(session_id="ses-1", identity="user-1") + + stored_row = await _read_pending_claim_row(engine, "ses-1") + assert stored_row is not None + assert stored_row["expires_at"] == pytest.approx(105.0) + + reader = DatabaseSessionStateRepository( + engine=engine, + pending_claim_ttl_seconds=1.0, + clock=_now, + ) + await initialize_state_repository(reader) + + now = 104.0 + assert await reader.get_pending_claim(session_id="ses-1") == "user-1" + + now = 106.0 + assert await reader.get_pending_claim(session_id="ses-1") is None + + await engine.dispose() + + @pytest.mark.asyncio async def test_database_session_binding_and_owner_do_not_expire_with_time(tmp_path: Path) -> None: now = 100.0 @@ -411,6 +544,14 @@ async def test_database_state_store_records_schema_version_for_existing_current_ await initialize_state_repository(session_repository) assert await _read_state_store_schema_version(engine) == CURRENT_STATE_STORE_SCHEMA_VERSION + assert await _read_table_index_names(engine, "a2a_pending_session_claims") == { + "ix_a2a_pending_session_claims_expires_at" + } + assert await _read_table_index_names(engine, "a2a_interrupt_requests") == { + "ix_a2a_interrupt_requests_identity_expires_at", + "ix_a2a_interrupt_requests_identity_type_expires_at", + "ix_a2a_interrupt_requests_tombstone_expires_at", + } await engine.dispose() diff --git a/tests/server/test_task_store_factory.py b/tests/server/test_task_store_factory.py index f34770c..d71c67b 100644 --- a/tests/server/test_task_store_factory.py +++ b/tests/server/test_task_store_factory.py @@ -3,8 +3,10 @@ import logging import warnings from pathlib import Path +from unittest.mock import AsyncMock import pytest +from a2a.server.tasks.database_task_store import DatabaseTaskStore from a2a.types import Task, TaskState, TaskStatus from opencode_a2a.server.task_store import ( @@ -15,8 +17,11 @@ TaskStoreOperationError, TaskStoreOperationWrappingDecorator, TaskWritePolicy, + build_database_engine, build_task_store, + describe_lightweight_persistence_backend, initialize_task_store, + unwrap_task_store, ) from tests.support.helpers import make_settings @@ -52,6 +57,32 @@ def test_build_task_store_allows_explicit_memory_backend() -> None: assert isinstance(store._inner._inner, InMemoryTaskStore) +def test_describe_lightweight_persistence_backend_marks_sqlite_first_scope() -> None: + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_database_url="sqlite+aiosqlite:///./opencode-a2a.db", + ) + + assert describe_lightweight_persistence_backend(settings) == { + "backend": "database", + "scope": "sdk_tasks_and_adapter_state", + "database_url": "sqlite+aiosqlite:///./opencode-a2a.db", + "sqlite_tuning": "local_durability_defaults", + } + + +def test_describe_lightweight_persistence_backend_supports_memory_backend() -> None: + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_backend="memory", + ) + + assert describe_lightweight_persistence_backend(settings) == { + "backend": "memory", + "scope": "sdk_tasks_and_adapter_state", + } + + @pytest.mark.asyncio async def test_database_task_store_persists_tasks_across_rebuilds(tmp_path: Path) -> None: database_path = tmp_path / "tasks.db" @@ -103,6 +134,50 @@ async def test_database_task_store_can_build_multiple_instances_without_warnings await second.engine.dispose() +@pytest.mark.asyncio +async def test_build_database_engine_configures_sqlite_pragmas_and_parent_dir( + tmp_path: Path, +) -> None: + database_path = tmp_path / "nested" / "runtime.db" + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_database_url=f"sqlite+aiosqlite:///{database_path}", + ) + engine = build_database_engine(settings) + + try: + async with engine.connect() as conn: + journal_mode = (await conn.exec_driver_sql("PRAGMA journal_mode")).scalar_one() + busy_timeout = (await conn.exec_driver_sql("PRAGMA busy_timeout")).scalar_one() + synchronous = (await conn.exec_driver_sql("PRAGMA synchronous")).scalar_one() + finally: + await engine.dispose() + + assert database_path.parent.exists() + assert str(journal_mode).lower() == "wal" + assert int(busy_timeout) == 30_000 + assert int(synchronous) == 1 + + +@pytest.mark.asyncio +async def test_build_task_store_does_not_dispose_shared_engine( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_database_url=f"sqlite+aiosqlite:///{tmp_path / 'shared-engine.db'}", + ) + engine = build_database_engine(settings) + dispose_spy = AsyncMock() + monkeypatch.setattr(type(engine), "dispose", dispose_spy) + + store = build_task_store(settings, engine=engine) + await initialize_task_store(store) + + dispose_spy.assert_not_awaited() + + @pytest.mark.asyncio @pytest.mark.parametrize("backend", ["memory", "database"]) async def test_task_store_preserves_first_terminal_state( @@ -134,6 +209,40 @@ async def test_task_store_preserves_first_terminal_state( await engine.dispose() +@pytest.mark.asyncio +async def test_database_task_store_keeps_first_terminal_state_across_independent_instances( + tmp_path: Path, +) -> None: + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_database_url=f"sqlite+aiosqlite:///{tmp_path / 'terminal-guard.db'}", + ) + first = build_task_store(settings) + second = build_task_store(settings) + await initialize_task_store(first) + await initialize_task_store(second) + + try: + working = _task("task-1") + await first.save(working) + + completed = _task("task-1") + completed.status = TaskStatus(state=TaskState.completed) + await first.save(completed) + + late_failed = _task("task-1") + late_failed.status = TaskStatus(state=TaskState.failed) + await second.save(late_failed) + + restored = await first.get("task-1") + finally: + await first.engine.dispose() + await second.engine.dispose() + + assert restored is not None + assert restored.status.state == TaskState.completed + + @pytest.mark.asyncio @pytest.mark.parametrize("backend", ["memory", "database"]) async def test_task_store_rejects_late_mutation_after_terminal_state( @@ -167,6 +276,56 @@ async def test_task_store_rejects_late_mutation_after_terminal_state( await engine.dispose() +@pytest.mark.asyncio +async def test_database_task_store_atomic_guard_does_not_depend_on_stale_read( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_database_url=f"sqlite+aiosqlite:///{tmp_path / 'stale-read.db'}", + ) + first = build_task_store(settings) + second = build_task_store(settings) + await initialize_task_store(first) + await initialize_task_store(second) + + try: + working = _task("task-1") + await first.save(working) + + completed = _task("task-1") + completed.status = TaskStatus(state=TaskState.completed) + await first.save(completed) + + late_completed = _task("task-1") + late_completed.status = TaskStatus(state=TaskState.completed) + late_completed.metadata = {"opencode": {"late_mutation": True}} + + raw_second = unwrap_task_store(second) + assert isinstance(raw_second, DatabaseTaskStore) + original_get = DatabaseTaskStore.get.__get__(raw_second, DatabaseTaskStore) + + async def _stale_get(task_id: str, context=None) -> Task | None: # noqa: ANN001 + del context + if task_id == "task-1": + return working + return None + + monkeypatch.setattr(raw_second, "get", _stale_get) + await second.save(late_completed) + monkeypatch.setattr(raw_second, "get", original_get) + + restored = await first.get("task-1") + finally: + await first.engine.dispose() + await second.engine.dispose() + + assert restored is not None + assert restored.status.state == TaskState.completed + assert restored.metadata is None + + @pytest.mark.asyncio async def test_task_store_wraps_backend_failures() -> None: class _BrokenGetStore: