Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ With the default `database` backend, the service persists:
- session binding / ownership state
- interrupt request bindings and tombstones

The runtime automatically applies lightweight schema migrations for its custom state tables and records the applied version in `a2a_schema_version`. This built-in path currently targets the local SQLite deployment profile and does not require Alembic.

The A2A SDK task table remains managed by the SDK's own `DatabaseTaskStore` initialization path. The internal migration runner only owns the additional `opencode-a2a` state tables listed above.

In-flight asyncio locks, outbound A2A client caches, and stream-local aggregation buffers remain process-local runtime state.

To opt into an ephemeral development profile, set:
Expand Down
1 change: 0 additions & 1 deletion src/opencode_a2a/py.typed
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

118 changes: 118 additions & 0 deletions src/opencode_a2a/server/migrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING

from sqlalchemy import (
Column,
Integer,
MetaData,
String,
Table,
insert,
inspect,
select,
text,
update,
)

if TYPE_CHECKING:
from sqlalchemy.engine import Connection

STATE_STORE_SCHEMA_NAME = "state_store"
CURRENT_STATE_STORE_SCHEMA_VERSION = 1

_SCHEMA_VERSION_METADATA = MetaData()

_SCHEMA_VERSIONS = Table(
"a2a_schema_version",
_SCHEMA_VERSION_METADATA,
Column("name", String, primary_key=True),
Column("version", Integer, nullable=False),
)


def _add_missing_nullable_column(
connection: Connection,
*,
table: Table,
column_name: str,
) -> None:
existing_columns = {column["name"] for column in inspect(connection).get_columns(table.name)}
if column_name in existing_columns:
return
column = table.c[column_name]
if column.primary_key or not column.nullable:
raise RuntimeError(f"Unsupported state-store migration for {table.name}.{column_name}")
preparer = connection.dialect.identifier_preparer
table_name_sql = preparer.quote(table.name)
column_name_sql = preparer.quote(column_name)
column_type_sql = column.type.compile(dialect=connection.dialect)
connection.execute(
text(f"ALTER TABLE {table_name_sql} ADD COLUMN {column_name_sql} {column_type_sql}")
)


def _migration_1_add_interrupt_details_json(
connection: Connection,
*,
interrupt_requests_table: Table,
) -> None:
_add_missing_nullable_column(
connection,
table=interrupt_requests_table,
column_name="details_json",
)


def _read_schema_version(connection: Connection, *, name: str) -> int:
result = connection.execute(
select(_SCHEMA_VERSIONS.c.version).where(_SCHEMA_VERSIONS.c.name == name)
)
version = result.scalar_one_or_none()
return int(version) if version is not None else 0


def _write_schema_version(connection: Connection, *, name: str, version: int) -> None:
exists = connection.execute(
select(_SCHEMA_VERSIONS.c.name).where(_SCHEMA_VERSIONS.c.name == name)
).scalar_one_or_none()
if exists is None:
connection.execute(insert(_SCHEMA_VERSIONS).values(name=name, version=version))
return
connection.execute(
update(_SCHEMA_VERSIONS).where(_SCHEMA_VERSIONS.c.name == name).values(version=version)
)


def migrate_state_store_schema(
connection: Connection,
*,
state_metadata: MetaData,
interrupt_requests_table: Table,
current_version: int = CURRENT_STATE_STORE_SCHEMA_VERSION,
) -> int:
_SCHEMA_VERSION_METADATA.create_all(connection)
state_metadata.create_all(connection)

stored_version = _read_schema_version(connection, name=STATE_STORE_SCHEMA_NAME)
if stored_version > current_version:
raise RuntimeError(
"Database state-store schema version is newer than this application supports"
)

migrations: dict[int, Callable[[Connection], None]] = {
1: lambda conn: _migration_1_add_interrupt_details_json(
conn,
interrupt_requests_table=interrupt_requests_table,
),
}

for next_version in range(stored_version + 1, current_version + 1):
migration = migrations.get(next_version)
if migration is None:
raise RuntimeError(f"Missing state-store migration for version {next_version}")
migration(connection)
_write_schema_version(connection, name=STATE_STORE_SCHEMA_NAME, version=next_version)

return current_version
38 changes: 7 additions & 31 deletions src/opencode_a2a/server/state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,17 @@
and_,
delete,
insert,
inspect,
select,
text,
update,
)
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker

from ..config import Settings
from ..execution.stream_state import _TTLCache
from ..runtime_state import InterruptRequestBinding, InterruptRequestTombstone
from .migrations import migrate_state_store_schema

if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncEngine

_STATE_METADATA = MetaData()
Expand Down Expand Up @@ -73,31 +71,11 @@
_MEMORY_SESSION_BINDING_MAXSIZE = 10_000


def _add_missing_sqlite_column(
connection: Connection,
*,
table: Table,
column_name: str,
) -> None:
if connection.dialect.name != "sqlite":
return
existing_columns = {column["name"] for column in inspect(connection).get_columns(table.name)}
if column_name in existing_columns:
return
column = table.c[column_name]
if column.primary_key or not column.nullable:
raise RuntimeError(
f"Unsupported SQLite state-store migration for {table.name}.{column_name}"
)
column_type = column.type.compile(dialect=connection.dialect)
connection.execute(text(f'ALTER TABLE "{table.name}" ADD COLUMN "{column_name}" {column_type}'))


def _ensure_state_store_schema(connection: Connection) -> None:
_add_missing_sqlite_column(
def _initialize_state_store_schema(connection) -> None: # noqa: ANN001
migrate_state_store_schema(
connection,
table=_INTERRUPT_REQUESTS,
column_name="details_json",
state_metadata=_STATE_METADATA,
interrupt_requests_table=_INTERRUPT_REQUESTS,
)


Expand Down Expand Up @@ -250,8 +228,7 @@ async def initialize(self) -> None:
if self._initialized:
return
async with self.engine.begin() as conn:
await conn.run_sync(_STATE_METADATA.create_all)
await conn.run_sync(_ensure_state_store_schema)
await conn.run_sync(_initialize_state_store_schema)
self._initialized = True

async def _ensure_initialized(self) -> None:
Expand Down Expand Up @@ -572,8 +549,7 @@ async def initialize(self) -> None:
if self._initialized:
return
async with self.engine.begin() as conn:
await conn.run_sync(_STATE_METADATA.create_all)
await conn.run_sync(_ensure_state_store_schema)
await conn.run_sync(_initialize_state_store_schema)
self._initialized = True

async def _ensure_initialized(self) -> None:
Expand Down
143 changes: 143 additions & 0 deletions tests/server/test_state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@

import pytest
from sqlalchemy import text
from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect

import opencode_a2a.server.migrations as migrations_module
from opencode_a2a.server.migrations import CURRENT_STATE_STORE_SCHEMA_VERSION
from opencode_a2a.server.state_store import (
_INTERRUPT_REQUESTS,
DatabaseSessionStateRepository,
MemorySessionStateRepository,
build_interrupt_request_repository,
Expand All @@ -16,6 +20,48 @@
from tests.support.helpers import make_settings


async def _read_state_store_schema_version(engine) -> int | None: # noqa: ANN001
async with engine.begin() as conn:
result = await conn.execute(
text("SELECT version FROM a2a_schema_version WHERE name = 'state_store'")
)
value = result.scalar_one_or_none()
return int(value) if value is not None else None


async def _read_state_store_schema_row_count(engine) -> int: # noqa: ANN001
async with engine.begin() as conn:
result = await conn.execute(
text("SELECT COUNT(*) FROM a2a_schema_version WHERE name = 'state_store'")
)
return int(result.scalar_one())


def test_add_missing_nullable_column_supports_non_sqlite_dialects(monkeypatch) -> None:
executed: list[str] = []

class _FakeInspector:
def get_columns(self, _table_name: str) -> list[dict[str, str]]:
return []

class _FakeConnection:
def __init__(self) -> None:
self.dialect = postgresql_dialect()

def execute(self, clause) -> None: # noqa: ANN001
executed.append(str(clause))

monkeypatch.setattr(migrations_module, "inspect", lambda _connection: _FakeInspector())

migrations_module._add_missing_nullable_column(
_FakeConnection(),
table=_INTERRUPT_REQUESTS,
column_name="details_json",
)

assert executed == ["ALTER TABLE a2a_interrupt_requests ADD COLUMN details_json VARCHAR"]


@pytest.mark.asyncio
async def test_database_session_state_repository_persists_bindings(tmp_path: Path) -> None:
database_url = f"sqlite+aiosqlite:///{tmp_path / 'state.db'}"
Expand All @@ -39,6 +85,7 @@ async def test_database_session_state_repository_persists_bindings(tmp_path: Pat
assert await reader.get_session(identity="user-1", context_id="ctx-1") == "ses-1"
assert await reader.get_owner(session_id="ses-1") == "user-1"
assert await reader.get_pending_claim(session_id="ses-2") == "user-2"
assert await _read_state_store_schema_version(engine) == CURRENT_STATE_STORE_SCHEMA_VERSION

await engine.dispose()

Expand Down Expand Up @@ -292,5 +339,101 @@ async def test_database_interrupt_request_repository_upgrades_legacy_interrupt_t
assert binding.details is None
assert [item.request_id for item in pending] == ["perm-legacy"]
assert pending[0].details is None
assert await _read_state_store_schema_version(engine) == CURRENT_STATE_STORE_SCHEMA_VERSION

await engine.dispose()


@pytest.mark.asyncio
async def test_database_state_store_records_schema_version_for_existing_current_schema(
tmp_path: Path,
) -> None:
database_url = f"sqlite+aiosqlite:///{tmp_path / 'current-schema.db'}"
settings = make_settings(
a2a_bearer_token="test-token",
a2a_task_store_database_url=database_url,
)
engine = build_database_engine(settings)

async with engine.begin() as conn:
await conn.execute(
text(
"""
CREATE TABLE a2a_session_bindings (
identity VARCHAR NOT NULL,
context_id VARCHAR NOT NULL,
session_id VARCHAR NOT NULL,
PRIMARY KEY (identity, context_id)
)
"""
)
)
await conn.execute(
text(
"""
CREATE TABLE a2a_session_owners (
session_id VARCHAR NOT NULL PRIMARY KEY,
identity VARCHAR NOT NULL
)
"""
)
)
await conn.execute(
text(
"""
CREATE TABLE a2a_pending_session_claims (
session_id VARCHAR NOT NULL PRIMARY KEY,
identity VARCHAR NOT NULL,
updated_at FLOAT NOT NULL
)
"""
)
)
await conn.execute(
text(
"""
CREATE TABLE a2a_interrupt_requests (
request_id VARCHAR NOT NULL PRIMARY KEY,
session_id VARCHAR,
interrupt_type VARCHAR,
identity VARCHAR,
task_id VARCHAR,
context_id VARCHAR,
details_json VARCHAR,
expires_at FLOAT,
tombstone_expires_at FLOAT
)
"""
)
)

session_repository = build_session_state_repository(settings, engine=engine)
await initialize_state_repository(session_repository)

assert await _read_state_store_schema_version(engine) == CURRENT_STATE_STORE_SCHEMA_VERSION

await engine.dispose()


@pytest.mark.asyncio
async def test_database_state_store_initialization_is_idempotent_across_repositories(
tmp_path: Path,
) -> None:
database_url = f"sqlite+aiosqlite:///{tmp_path / 'idempotent-state.db'}"
settings = make_settings(
a2a_bearer_token="test-token",
a2a_task_store_database_url=database_url,
)
engine = build_database_engine(settings)

session_repository = build_session_state_repository(settings, engine=engine)
interrupt_repository = build_interrupt_request_repository(settings, engine=engine)

await initialize_state_repository(session_repository)
await initialize_state_repository(interrupt_repository)
await initialize_state_repository(session_repository)

assert await _read_state_store_schema_version(engine) == CURRENT_STATE_STORE_SCHEMA_VERSION
assert await _read_state_store_schema_row_count(engine) == 1

await engine.dispose()