diff --git a/docs/guide.md b/docs/guide.md index 85268a0..0a2361e 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -145,6 +145,10 @@ With the default `database` backend, the service persists: - session binding / ownership state - interrupt request bindings and tombstones +The runtime automatically applies lightweight schema migrations for its custom state tables and records the applied version in `a2a_schema_version`. This built-in path currently targets the local SQLite deployment profile and does not require Alembic. + +The A2A SDK task table remains managed by the SDK's own `DatabaseTaskStore` initialization path. The internal migration runner only owns the additional `opencode-a2a` state tables listed above. + In-flight asyncio locks, outbound A2A client caches, and stream-local aggregation buffers remain process-local runtime state. To opt into an ephemeral development profile, set: diff --git a/src/opencode_a2a/py.typed b/src/opencode_a2a/py.typed index 8b13789..e69de29 100644 --- a/src/opencode_a2a/py.typed +++ b/src/opencode_a2a/py.typed @@ -1 +0,0 @@ - diff --git a/src/opencode_a2a/server/migrations.py b/src/opencode_a2a/server/migrations.py new file mode 100644 index 0000000..fbad028 --- /dev/null +++ b/src/opencode_a2a/server/migrations.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING + +from sqlalchemy import ( + Column, + Integer, + MetaData, + String, + Table, + insert, + inspect, + select, + text, + update, +) + +if TYPE_CHECKING: + from sqlalchemy.engine import Connection + +STATE_STORE_SCHEMA_NAME = "state_store" +CURRENT_STATE_STORE_SCHEMA_VERSION = 1 + +_SCHEMA_VERSION_METADATA = MetaData() + +_SCHEMA_VERSIONS = Table( + "a2a_schema_version", + _SCHEMA_VERSION_METADATA, + Column("name", String, primary_key=True), + Column("version", Integer, nullable=False), +) + + +def _add_missing_nullable_column( + connection: Connection, + *, + table: Table, + column_name: str, +) -> None: + existing_columns = {column["name"] for column in inspect(connection).get_columns(table.name)} + if column_name in existing_columns: + return + column = table.c[column_name] + if column.primary_key or not column.nullable: + raise RuntimeError(f"Unsupported state-store migration for {table.name}.{column_name}") + preparer = connection.dialect.identifier_preparer + table_name_sql = preparer.quote(table.name) + column_name_sql = preparer.quote(column_name) + column_type_sql = column.type.compile(dialect=connection.dialect) + connection.execute( + text(f"ALTER TABLE {table_name_sql} ADD COLUMN {column_name_sql} {column_type_sql}") + ) + + +def _migration_1_add_interrupt_details_json( + connection: Connection, + *, + interrupt_requests_table: Table, +) -> None: + _add_missing_nullable_column( + connection, + table=interrupt_requests_table, + column_name="details_json", + ) + + +def _read_schema_version(connection: Connection, *, name: str) -> int: + result = connection.execute( + select(_SCHEMA_VERSIONS.c.version).where(_SCHEMA_VERSIONS.c.name == name) + ) + version = result.scalar_one_or_none() + return int(version) if version is not None else 0 + + +def _write_schema_version(connection: Connection, *, name: str, version: int) -> None: + exists = connection.execute( + select(_SCHEMA_VERSIONS.c.name).where(_SCHEMA_VERSIONS.c.name == name) + ).scalar_one_or_none() + if exists is None: + connection.execute(insert(_SCHEMA_VERSIONS).values(name=name, version=version)) + return + connection.execute( + update(_SCHEMA_VERSIONS).where(_SCHEMA_VERSIONS.c.name == name).values(version=version) + ) + + +def migrate_state_store_schema( + connection: Connection, + *, + state_metadata: MetaData, + interrupt_requests_table: Table, + current_version: int = CURRENT_STATE_STORE_SCHEMA_VERSION, +) -> int: + _SCHEMA_VERSION_METADATA.create_all(connection) + state_metadata.create_all(connection) + + stored_version = _read_schema_version(connection, name=STATE_STORE_SCHEMA_NAME) + if stored_version > current_version: + raise RuntimeError( + "Database state-store schema version is newer than this application supports" + ) + + migrations: dict[int, Callable[[Connection], None]] = { + 1: lambda conn: _migration_1_add_interrupt_details_json( + conn, + interrupt_requests_table=interrupt_requests_table, + ), + } + + for next_version in range(stored_version + 1, current_version + 1): + migration = migrations.get(next_version) + if migration is None: + raise RuntimeError(f"Missing state-store migration for version {next_version}") + migration(connection) + _write_schema_version(connection, name=STATE_STORE_SCHEMA_NAME, version=next_version) + + return current_version diff --git a/src/opencode_a2a/server/state_store.py b/src/opencode_a2a/server/state_store.py index 9f9f01a..1208015 100644 --- a/src/opencode_a2a/server/state_store.py +++ b/src/opencode_a2a/server/state_store.py @@ -15,9 +15,7 @@ and_, delete, insert, - inspect, select, - text, update, ) from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker @@ -25,9 +23,9 @@ from ..config import Settings from ..execution.stream_state import _TTLCache from ..runtime_state import InterruptRequestBinding, InterruptRequestTombstone +from .migrations import migrate_state_store_schema if TYPE_CHECKING: - from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import AsyncEngine _STATE_METADATA = MetaData() @@ -73,31 +71,11 @@ _MEMORY_SESSION_BINDING_MAXSIZE = 10_000 -def _add_missing_sqlite_column( - connection: Connection, - *, - table: Table, - column_name: str, -) -> None: - if connection.dialect.name != "sqlite": - return - existing_columns = {column["name"] for column in inspect(connection).get_columns(table.name)} - if column_name in existing_columns: - return - column = table.c[column_name] - if column.primary_key or not column.nullable: - raise RuntimeError( - f"Unsupported SQLite state-store migration for {table.name}.{column_name}" - ) - column_type = column.type.compile(dialect=connection.dialect) - connection.execute(text(f'ALTER TABLE "{table.name}" ADD COLUMN "{column_name}" {column_type}')) - - -def _ensure_state_store_schema(connection: Connection) -> None: - _add_missing_sqlite_column( +def _initialize_state_store_schema(connection) -> None: # noqa: ANN001 + migrate_state_store_schema( connection, - table=_INTERRUPT_REQUESTS, - column_name="details_json", + state_metadata=_STATE_METADATA, + interrupt_requests_table=_INTERRUPT_REQUESTS, ) @@ -250,8 +228,7 @@ async def initialize(self) -> None: if self._initialized: return async with self.engine.begin() as conn: - await conn.run_sync(_STATE_METADATA.create_all) - await conn.run_sync(_ensure_state_store_schema) + await conn.run_sync(_initialize_state_store_schema) self._initialized = True async def _ensure_initialized(self) -> None: @@ -572,8 +549,7 @@ async def initialize(self) -> None: if self._initialized: return async with self.engine.begin() as conn: - await conn.run_sync(_STATE_METADATA.create_all) - await conn.run_sync(_ensure_state_store_schema) + await conn.run_sync(_initialize_state_store_schema) self._initialized = True async def _ensure_initialized(self) -> None: diff --git a/tests/server/test_state_store.py b/tests/server/test_state_store.py index be4bf9d..a98223a 100644 --- a/tests/server/test_state_store.py +++ b/tests/server/test_state_store.py @@ -4,8 +4,12 @@ import pytest from sqlalchemy import text +from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect +import opencode_a2a.server.migrations as migrations_module +from opencode_a2a.server.migrations import CURRENT_STATE_STORE_SCHEMA_VERSION from opencode_a2a.server.state_store import ( + _INTERRUPT_REQUESTS, DatabaseSessionStateRepository, MemorySessionStateRepository, build_interrupt_request_repository, @@ -16,6 +20,48 @@ from tests.support.helpers import make_settings +async def _read_state_store_schema_version(engine) -> int | None: # noqa: ANN001 + async with engine.begin() as conn: + result = await conn.execute( + text("SELECT version FROM a2a_schema_version WHERE name = 'state_store'") + ) + value = result.scalar_one_or_none() + return int(value) if value is not None else None + + +async def _read_state_store_schema_row_count(engine) -> int: # noqa: ANN001 + async with engine.begin() as conn: + result = await conn.execute( + text("SELECT COUNT(*) FROM a2a_schema_version WHERE name = 'state_store'") + ) + return int(result.scalar_one()) + + +def test_add_missing_nullable_column_supports_non_sqlite_dialects(monkeypatch) -> None: + executed: list[str] = [] + + class _FakeInspector: + def get_columns(self, _table_name: str) -> list[dict[str, str]]: + return [] + + class _FakeConnection: + def __init__(self) -> None: + self.dialect = postgresql_dialect() + + def execute(self, clause) -> None: # noqa: ANN001 + executed.append(str(clause)) + + monkeypatch.setattr(migrations_module, "inspect", lambda _connection: _FakeInspector()) + + migrations_module._add_missing_nullable_column( + _FakeConnection(), + table=_INTERRUPT_REQUESTS, + column_name="details_json", + ) + + assert executed == ["ALTER TABLE a2a_interrupt_requests ADD COLUMN details_json VARCHAR"] + + @pytest.mark.asyncio async def test_database_session_state_repository_persists_bindings(tmp_path: Path) -> None: database_url = f"sqlite+aiosqlite:///{tmp_path / 'state.db'}" @@ -39,6 +85,7 @@ async def test_database_session_state_repository_persists_bindings(tmp_path: Pat assert await reader.get_session(identity="user-1", context_id="ctx-1") == "ses-1" assert await reader.get_owner(session_id="ses-1") == "user-1" assert await reader.get_pending_claim(session_id="ses-2") == "user-2" + assert await _read_state_store_schema_version(engine) == CURRENT_STATE_STORE_SCHEMA_VERSION await engine.dispose() @@ -292,5 +339,101 @@ async def test_database_interrupt_request_repository_upgrades_legacy_interrupt_t assert binding.details is None assert [item.request_id for item in pending] == ["perm-legacy"] assert pending[0].details is None + assert await _read_state_store_schema_version(engine) == CURRENT_STATE_STORE_SCHEMA_VERSION + + await engine.dispose() + + +@pytest.mark.asyncio +async def test_database_state_store_records_schema_version_for_existing_current_schema( + tmp_path: Path, +) -> None: + database_url = f"sqlite+aiosqlite:///{tmp_path / 'current-schema.db'}" + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_database_url=database_url, + ) + engine = build_database_engine(settings) + + async with engine.begin() as conn: + await conn.execute( + text( + """ + CREATE TABLE a2a_session_bindings ( + identity VARCHAR NOT NULL, + context_id VARCHAR NOT NULL, + session_id VARCHAR NOT NULL, + PRIMARY KEY (identity, context_id) + ) + """ + ) + ) + await conn.execute( + text( + """ + CREATE TABLE a2a_session_owners ( + session_id VARCHAR NOT NULL PRIMARY KEY, + identity VARCHAR NOT NULL + ) + """ + ) + ) + await conn.execute( + text( + """ + CREATE TABLE a2a_pending_session_claims ( + session_id VARCHAR NOT NULL PRIMARY KEY, + identity VARCHAR NOT NULL, + updated_at FLOAT NOT NULL + ) + """ + ) + ) + await conn.execute( + text( + """ + CREATE TABLE a2a_interrupt_requests ( + request_id VARCHAR NOT NULL PRIMARY KEY, + session_id VARCHAR, + interrupt_type VARCHAR, + identity VARCHAR, + task_id VARCHAR, + context_id VARCHAR, + details_json VARCHAR, + expires_at FLOAT, + tombstone_expires_at FLOAT + ) + """ + ) + ) + + session_repository = build_session_state_repository(settings, engine=engine) + await initialize_state_repository(session_repository) + + assert await _read_state_store_schema_version(engine) == CURRENT_STATE_STORE_SCHEMA_VERSION + + await engine.dispose() + + +@pytest.mark.asyncio +async def test_database_state_store_initialization_is_idempotent_across_repositories( + tmp_path: Path, +) -> None: + database_url = f"sqlite+aiosqlite:///{tmp_path / 'idempotent-state.db'}" + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_database_url=database_url, + ) + engine = build_database_engine(settings) + + session_repository = build_session_state_repository(settings, engine=engine) + interrupt_repository = build_interrupt_request_repository(settings, engine=engine) + + await initialize_state_repository(session_repository) + await initialize_state_repository(interrupt_repository) + await initialize_state_repository(session_repository) + + assert await _read_state_store_schema_version(engine) == CURRENT_STATE_STORE_SCHEMA_VERSION + assert await _read_state_store_schema_row_count(engine) == 1 await engine.dispose()