From 13fe4b2f0bf381230ac1223ec9f30d8e2e01e035 Mon Sep 17 00:00:00 2001 From: Asher Fink Date: Thu, 21 May 2026 15:05:41 -0400 Subject: [PATCH] feat(AGX1-274): dual-write tasks to spark-authz --- ..._task_creator_and_zedtoken_a1f73ada66c5.py | 48 ++++ .../database/migrations/migration_history.txt | 4 +- agentex/src/adapters/orm.py | 3 + agentex/src/domain/entities/tasks.py | 12 + agentex/src/domain/services/task_service.py | 96 ++++++- .../domain/use_cases/agents_acp_use_case.py | 11 + .../src/domain/use_cases/tasks_use_case.py | 8 + agentex/src/utils/feature_flags.py | 29 +++ agentex/tests/fixtures/services.py | 17 +- .../fixtures/integration_client.py | 11 +- .../tests/integration/services/__init__.py | 0 .../services/test_task_service_dual_write.py | 242 ++++++++++++++++++ agentex/tests/integration/test_task_stream.py | 14 + .../tests/unit/services/test_task_service.py | 9 +- ...p_type_backwards_compatibility_use_case.py | 17 +- .../use_cases/test_agents_acp_use_case.py | 9 +- 16 files changed, 521 insertions(+), 9 deletions(-) create mode 100644 agentex/database/migrations/alembic/versions/2026_05_21_1508_add_task_creator_and_zedtoken_a1f73ada66c5.py create mode 100644 agentex/src/utils/feature_flags.py create mode 100644 agentex/tests/integration/services/__init__.py create mode 100644 agentex/tests/integration/services/test_task_service_dual_write.py diff --git a/agentex/database/migrations/alembic/versions/2026_05_21_1508_add_task_creator_and_zedtoken_a1f73ada66c5.py b/agentex/database/migrations/alembic/versions/2026_05_21_1508_add_task_creator_and_zedtoken_a1f73ada66c5.py new file mode 100644 index 00000000..d278b32a --- /dev/null +++ b/agentex/database/migrations/alembic/versions/2026_05_21_1508_add_task_creator_and_zedtoken_a1f73ada66c5.py @@ -0,0 +1,48 @@ +"""add_task_creator_and_zedtoken + +Revision ID: a1f73ada66c5 +Revises: a9959ebcbe98 +Create Date: 2026-05-21 15:08:51.441535 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'a1f73ada66c5' +down_revision: Union[str, None] = 'a9959ebcbe98' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('tasks', sa.Column('creator_user_id', sa.String(), nullable=True)) + op.add_column('tasks', sa.Column('creator_service_account_id', sa.String(), nullable=True)) + op.add_column('tasks', sa.Column('spark_authz_zedtoken', sa.Text(), nullable=True)) + with op.get_context().autocommit_block(): + op.execute( + "CREATE INDEX CONCURRENTLY IF NOT EXISTS ix_tasks_creator_user_id " + "ON tasks (creator_user_id)" + ) + op.execute( + "CREATE INDEX CONCURRENTLY IF NOT EXISTS ix_tasks_creator_service_account_id " + "ON tasks (creator_service_account_id)" + ) + op.create_check_constraint( + 'ck_tasks_one_creator', + 'tasks', + '(creator_user_id IS NULL) OR (creator_service_account_id IS NULL)', + ) + + +def downgrade() -> None: + op.drop_constraint('ck_tasks_one_creator', 'tasks', type_='check') + with op.get_context().autocommit_block(): + op.execute("DROP INDEX CONCURRENTLY IF EXISTS ix_tasks_creator_service_account_id") + op.execute("DROP INDEX CONCURRENTLY IF EXISTS ix_tasks_creator_user_id") + op.drop_column('tasks', 'spark_authz_zedtoken') + op.drop_column('tasks', 'creator_service_account_id') + op.drop_column('tasks', 'creator_user_id') diff --git a/agentex/database/migrations/migration_history.txt b/agentex/database/migrations/migration_history.txt index 25e97ddb..08a93189 100644 --- a/agentex/database/migrations/migration_history.txt +++ b/agentex/database/migrations/migration_history.txt @@ -1,4 +1,6 @@ -9ff3ee32c81b -> e9c4ff9e6542 (head), add_tasks_metadata_gin_index +a9959ebcbe98 -> a1f73ada66c5 (head), add_task_creator_and_zedtoken +e9c4ff9e6542 -> a9959ebcbe98, finalize_spans_task_id +9ff3ee32c81b -> e9c4ff9e6542, add_tasks_metadata_gin_index 57c5ed4f59ae -> 9ff3ee32c81b, uppercase deployment status enum labels 4a9b7787ccd7 -> 57c5ed4f59ae, add_task_id_to_spans d1a6cde41b3f -> 4a9b7787ccd7, deployments diff --git a/agentex/src/adapters/orm.py b/agentex/src/adapters/orm.py index ac5ee39a..29223654 100644 --- a/agentex/src/adapters/orm.py +++ b/agentex/src/adapters/orm.py @@ -74,6 +74,9 @@ class TaskORM(BaseORM): ) params = Column(JSONB, nullable=True) task_metadata = Column(JSONB, nullable=True) + creator_user_id = Column(String, nullable=True, index=True) + creator_service_account_id = Column(String, nullable=True, index=True) + spark_authz_zedtoken = Column(Text, nullable=True) # Many-to-Many relationship with agents agents = relationship("AgentORM", secondary="task_agents", back_populates="tasks") diff --git a/agentex/src/domain/entities/tasks.py b/agentex/src/domain/entities/tasks.py index bdccc23e..328eec5e 100644 --- a/agentex/src/domain/entities/tasks.py +++ b/agentex/src/domain/entities/tasks.py @@ -58,6 +58,18 @@ class TaskEntity(BaseModel): None, title="Task metadata", ) + creator_user_id: str | None = Field( + None, + title="Identity ID of the user who created this task (granted as FGAC owner)", + ) + creator_service_account_id: str | None = Field( + None, + title="Service identity ID of the service account that created this task", + ) + spark_authz_zedtoken: str | None = Field( + None, + title="ZedToken from the Spark AuthZ grant for new-write isolation", + ) # allow extra fields for agents relationships model_config = ConfigDict(extra="allow") diff --git a/agentex/src/domain/services/task_service.py b/agentex/src/domain/services/task_service.py index 013c6903..ddcac586 100644 --- a/agentex/src/domain/services/task_service.py +++ b/agentex/src/domain/services/task_service.py @@ -4,6 +4,7 @@ from fastapi import Depends from src.adapters.streams.adapter_redis import DRedisStreamRepository +from src.api.schemas.authorization_types import AgentexResource from src.domain.entities.agents import ACPType, AgentEntity from src.domain.entities.events import EventEntity from src.domain.entities.task_message_updates import TaskMessageUpdateEntity @@ -14,6 +15,8 @@ from src.domain.repositories.task_repository import DTaskRepository from src.domain.repositories.task_state_repository import DTaskStateRepository from src.domain.services.agent_acp_service import DAgentACPService +from src.domain.services.authorization_service import DAuthorizationService +from src.utils.feature_flags import DFeatureFlagProvider, FeatureFlagName from src.utils.ids import orm_id from src.utils.logging import make_logger from src.utils.stream_topics import get_task_event_stream_topic @@ -33,12 +36,16 @@ def __init__( task_repository: DTaskRepository, event_repository: DEventRepository, stream_repository: DRedisStreamRepository, + authorization_service: DAuthorizationService, + feature_flags: DFeatureFlagProvider, ): self.acp_client = acp_client self.task_state_repository = task_state_repository self.task_repository = task_repository self.event_repository = event_repository self.stream_repository = stream_repository + self.authorization_service = authorization_service + self.feature_flags = feature_flags async def create_task( self, @@ -46,6 +53,7 @@ async def create_task( task_name: str | None = None, task_params: dict[str, Any] | None = None, task_metadata: dict[str, Any] | None = None, + account_id: str | None = None, ) -> TaskEntity: """ Create a new task record in the repository with single agent (maintains existing interface). @@ -56,28 +64,107 @@ async def create_task( task_params: The parameters for the task task_metadata: Caller-provided metadata to persist on the task row. Not forwarded to the agent. + account_id: Caller-resolved account scope. When provided and the + FGAC_TASKS_DUAL_WRITE flag is enabled for it, the task is also + registered in Spark AuthZ. Returns: Task containing the created task info """ + principal_context = self.authorization_service.principal_context + creator_user_id = getattr(principal_context, "user_id", None) + creator_service_account_id = getattr( + principal_context, "service_account_id", None + ) + + task_id = orm_id() + zedtoken: str | None = None + + if self.feature_flags.is_enabled( + FeatureFlagName.FGAC_TASKS_DUAL_WRITE, account_id + ): + zedtoken = await self._register_task_in_spark_authz( + task_id=task_id, + account_id=account_id, + creator_user_id=creator_user_id, + creator_service_account_id=creator_service_account_id, + ) task_entity = await self.task_repository.create( agent_id=agent.id, task=TaskEntity( - id=orm_id(), + id=task_id, name=task_name, status=TaskStatus.RUNNING, status_reason="Task created, forwarding to ACP server", params=task_params, task_metadata=task_metadata, + creator_user_id=creator_user_id, + creator_service_account_id=creator_service_account_id, + spark_authz_zedtoken=zedtoken, ), ) return task_entity + async def _register_task_in_spark_authz( + self, + *, + task_id: str, + account_id: str | None, + creator_user_id: str | None, + creator_service_account_id: str | None, + ) -> str | None: + """Register a new task in Spark AuthZ with creator as owner. + + Called BEFORE the Postgres write — a failure raises and prevents the + row from being persisted, so there is no compensating action to take. + Mirrors the KB FGAC pattern at + ``packages/egp-api-backend/.../knowledge_base_v2_use_case.py:374-388``. + + The current ``Provider.spark`` adapter returns ``{}`` from ``grant``; + no ZedToken is surfaced today, so we always return ``None`` for the + new-write-isolation column. A follow-up will plumb the token through + once the adapter exposes it. + """ + if creator_user_id is None and creator_service_account_id is None: + logger.warning( + "Skipping Spark AuthZ task registration: no creator resolvable", + extra={"task_id": task_id, "account_id": account_id}, + ) + return None + await self.authorization_service.grant( + resource=AgentexResource.task(task_id), + ) + return None + + async def deregister_task_from_spark_authz( + self, *, task_id: str, account_id: str | None + ) -> None: + """Best-effort revocation of a task's Spark AuthZ tuples on delete. + + Only invoked when the FGAC_TASKS_DUAL_WRITE flag is enabled for the + caller's account. Failures are logged but do not block the delete. + """ + if not self.feature_flags.is_enabled( + FeatureFlagName.FGAC_TASKS_DUAL_WRITE, account_id + ): + return + try: + await self.authorization_service.revoke( + resource=AgentexResource.task(task_id), + ) + except Exception: + logger.warning( + "Spark AuthZ revoke failed for task", + extra={"task_id": task_id, "account_id": account_id}, + exc_info=True, + ) + async def create_task_and_forward_to_acp( self, agent: AgentEntity, task_name: str | None = None, task_params: dict[str, Any] | None = None, + account_id: str | None = None, ) -> TaskEntity: """ Create a new task record in the repository with single agent (maintains existing interface). @@ -86,12 +173,17 @@ async def create_task_and_forward_to_acp( Args: agent: The agent to create the task for task_params: The parameters for the task to be sent to the ACP server + account_id: Caller-resolved account scope; threaded through to + :meth:`create_task` for FGAC dual-write gating. Returns: Task containing the created task info """ task_entity = await self.create_task( - agent=agent, task_name=task_name, task_params=task_params + agent=agent, + task_name=task_name, + task_params=task_params, + account_id=account_id, ) if agent.acp_type == ACPType.SYNC: diff --git a/agentex/src/domain/use_cases/agents_acp_use_case.py b/agentex/src/domain/use_cases/agents_acp_use_case.py index fc727d88..d77cc43e 100644 --- a/agentex/src/domain/use_cases/agents_acp_use_case.py +++ b/agentex/src/domain/use_cases/agents_acp_use_case.py @@ -268,6 +268,7 @@ async def _get_or_create_task( task_name: str | None = None, task_params: dict[str, Any] | None = None, task_metadata: dict[str, Any] | None = None, + account_id: str | None = None, ) -> TaskEntity: """Return the existing task if *task_id* is provided, otherwise create a new one. @@ -308,6 +309,7 @@ async def _get_or_create_task( task_name=task_name, task_params=task_params, task_metadata=task_metadata, + account_id=account_id, ) logger.info(f"[agent_id={agent.id}] Created task {task.id}") await self.grant_with_retry(task) @@ -419,6 +421,9 @@ async def _handle_task_create( task_name=params.name, task_params=params.params, task_metadata=params.task_metadata, + account_id=getattr( + self.authorization_service.principal_context, "account_id", None + ), ) if agent.acp_type in [ACPType.AGENTIC, ACPType.ASYNC]: @@ -457,6 +462,9 @@ async def _handle_message_send_sync( task_id=params.task_id, task_name=params.task_name, task_params=params.task_params, + account_id=getattr( + self.authorization_service.principal_context, "account_id", None + ), ) # Step 1: Insert the message in the messages table @@ -642,6 +650,9 @@ async def flush_aggregated_deltas(task_message_index: int) -> TaskMessageEntity: task_id=params.task_id, task_name=params.task_name, task_params=params.task_params, + account_id=getattr( + self.authorization_service.principal_context, "account_id", None + ), ) # Append the input client message diff --git a/agentex/src/domain/use_cases/tasks_use_case.py b/agentex/src/domain/use_cases/tasks_use_case.py index 4ad1a61a..8d0407ed 100644 --- a/agentex/src/domain/use_cases/tasks_use_case.py +++ b/agentex/src/domain/use_cases/tasks_use_case.py @@ -60,6 +60,14 @@ async def delete_task(self, id: str | None = None, name: str | None = None) -> N task.status = TaskStatus.DELETED task.status_reason = "Task deleted successfully" await self.task_service.update_task(task=task) + account_id = getattr( + self.task_service.authorization_service.principal_context, + "account_id", + None, + ) + await self.task_service.deregister_task_from_spark_authz( + task_id=task.id, account_id=account_id + ) async def list_tasks( self, diff --git a/agentex/src/utils/feature_flags.py b/agentex/src/utils/feature_flags.py new file mode 100644 index 00000000..b7cf66b2 --- /dev/null +++ b/agentex/src/utils/feature_flags.py @@ -0,0 +1,29 @@ +import os +from enum import StrEnum +from typing import Annotated + +from fastapi import Depends + + +class FeatureFlagName(StrEnum): + FGAC_TASKS = "fgac-tasks" + FGAC_TASKS_DUAL_WRITE = "fgac-tasks-dual-write" + + +class FeatureFlagProvider: + """Per-account feature flag provider. + + v1: env-var allowlist (per-account, comma-separated). The env var name is + derived from the flag name, e.g. ``FGAC_TASKS_DUAL_WRITE_ACCOUNTS``. A + follow-up will swap this for LaunchDarkly with an account_id context. + """ + + def is_enabled(self, name: FeatureFlagName, account_id: str | None) -> bool: + if not account_id: + return False + env_key = f"{name.value.upper().replace('-', '_')}_ACCOUNTS" + allowed = os.environ.get(env_key, "") + return account_id in {a.strip() for a in allowed.split(",") if a.strip()} + + +DFeatureFlagProvider = Annotated[FeatureFlagProvider, Depends(FeatureFlagProvider)] diff --git a/agentex/tests/fixtures/services.py b/agentex/tests/fixtures/services.py index c30c06c8..7bd44041 100644 --- a/agentex/tests/fixtures/services.py +++ b/agentex/tests/fixtures/services.py @@ -3,7 +3,7 @@ Provides factory functions and specific fixtures for creating services with test repositories. """ -from unittest.mock import MagicMock, Mock +from unittest.mock import AsyncMock, MagicMock, Mock import pytest @@ -52,9 +52,20 @@ def create_task_service( event_repository, agent_acp_service, redis_stream_repository, + authorization_service=None, + feature_flags=None, ): - """Factory function to create AgentTaskService with given repositories and services""" + """Factory function to create AgentTaskService with given repositories and services.""" from src.domain.services.task_service import AgentTaskService + from src.utils.feature_flags import FeatureFlagProvider + + if authorization_service is None: + authorization_service = Mock() + authorization_service.principal_context = None + authorization_service.grant = AsyncMock(return_value=None) + authorization_service.revoke = AsyncMock(return_value=None) + if feature_flags is None: + feature_flags = FeatureFlagProvider() return AgentTaskService( task_repository=task_repository, @@ -62,6 +73,8 @@ def create_task_service( event_repository=event_repository, acp_client=agent_acp_service, stream_repository=redis_stream_repository, + authorization_service=authorization_service, + feature_flags=feature_flags, ) diff --git a/agentex/tests/integration/fixtures/integration_client.py b/agentex/tests/integration/fixtures/integration_client.py index b715d223..0b148ba6 100644 --- a/agentex/tests/integration/fixtures/integration_client.py +++ b/agentex/tests/integration/fixtures/integration_client.py @@ -6,7 +6,7 @@ import asyncio import os import uuid -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock import pymongo import pytest @@ -441,12 +441,21 @@ async def send_event(self, *args, **kwargs): async def send_message(self, *args, **kwargs): pass + from src.utils.feature_flags import FeatureFlagProvider + + noop_authorization_service = Mock() + noop_authorization_service.principal_context = None + noop_authorization_service.grant = AsyncMock(return_value={}) + noop_authorization_service.revoke = AsyncMock(return_value=None) + task_service = AgentTaskService( acp_client=MockAgentACPService(), task_state_repository=isolated_repositories["task_state_repository"], task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], stream_repository=isolated_repositories["redis_stream_repository"], + authorization_service=noop_authorization_service, + feature_flags=FeatureFlagProvider(), ) return TasksUseCase(task_service=task_service) diff --git a/agentex/tests/integration/services/__init__.py b/agentex/tests/integration/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agentex/tests/integration/services/test_task_service_dual_write.py b/agentex/tests/integration/services/test_task_service_dual_write.py new file mode 100644 index 00000000..60f2d695 --- /dev/null +++ b/agentex/tests/integration/services/test_task_service_dual_write.py @@ -0,0 +1,242 @@ +"""Integration tests for AgentTaskService dual-write to Spark AuthZ. + +These cover the AGX1-274 dual-write path: + +- Flag OFF: ``authorization_service.grant`` is NOT called and the task is + written to the repository with creator metadata populated from the + principal context. +- Flag ON: ``grant`` is called with ``AgentexResource.task()`` and the + task row is written. +- Delete deregisters: ``revoke`` is called when ``delete_task`` runs under + the flag. +- Idempotency: when the spark adapter returns success on the second + invocation (sgp-authz uses TOUCH semantics under the hood), the row is + still inserted. +- Spark failure prevents row: when ``grant`` raises, the task is NOT + persisted. + +The tests intentionally mock the repository, authorization service, and +ACP client. The behaviour under test is the call sequencing inside +``AgentTaskService`` — not Postgres or Spark itself. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock + +import pytest +from src.api.schemas.authorization_types import AgentexResource, AgentexResourceType +from src.domain.entities.agents import ACPType, AgentEntity, AgentStatus +from src.domain.entities.tasks import TaskEntity, TaskStatus +from src.domain.services.task_service import AgentTaskService +from src.domain.use_cases.tasks_use_case import TasksUseCase +from src.utils.feature_flags import FeatureFlagProvider +from src.utils.ids import orm_id + + +def _principal(user_id: str | None, account_id: str | None) -> SimpleNamespace: + """Minimal stand-in for AgentexAuthPrincipalContext. + + Uses ``SimpleNamespace`` so the ``getattr(..., "user_id", None)`` + pattern in the service works identically whether AGX1-240's typed + union has landed or the alias is still ``Any``. + """ + return SimpleNamespace( + user_id=user_id, service_account_id=None, account_id=account_id + ) + + +def _agent() -> AgentEntity: + agent_id = orm_id() + return AgentEntity( + id=agent_id, + name=f"agent-{agent_id[:8]}", + description="dual-write test agent", + status=AgentStatus.READY, + acp_type=ACPType.SYNC, + acp_url="http://test-acp", + ) + + +def _build_service( + *, + flag_accounts: str, + principal: SimpleNamespace | None, + grant: AsyncMock | None = None, + revoke: AsyncMock | None = None, + create_raises: Exception | None = None, + monkeypatch: pytest.MonkeyPatch, +) -> tuple[AgentTaskService, Mock, AsyncMock, AsyncMock]: + monkeypatch.setenv("FGAC_TASKS_DUAL_WRITE_ACCOUNTS", flag_accounts) + + task_repository = Mock() + if create_raises is None: + task_repository.create = AsyncMock(side_effect=lambda agent_id, task: task) + else: + task_repository.create = AsyncMock(side_effect=create_raises) + task_repository.get = AsyncMock() + task_repository.update = AsyncMock(side_effect=lambda task: task) + + authorization_service = Mock() + authorization_service.principal_context = principal + authorization_service.grant = grant or AsyncMock(return_value={}) + authorization_service.revoke = revoke or AsyncMock(return_value=None) + + feature_flags = FeatureFlagProvider() + + service = AgentTaskService( + acp_client=Mock(), + task_state_repository=Mock(), + task_repository=task_repository, + event_repository=Mock(), + stream_repository=Mock(), + authorization_service=authorization_service, + feature_flags=feature_flags, + ) + return ( + service, + task_repository, + authorization_service.grant, + authorization_service.revoke, + ) + + +@pytest.mark.asyncio +async def test_create_task_skips_grant_when_flag_off( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service, task_repository, grant, _ = _build_service( + flag_accounts="", + principal=_principal(user_id="user-A", account_id="acct-1"), + monkeypatch=monkeypatch, + ) + + task = await service.create_task( + agent=_agent(), task_name="t", account_id="acct-1" + ) + + grant.assert_not_called() + task_repository.create.assert_awaited_once() + assert task.creator_user_id == "user-A" + assert task.creator_service_account_id is None + assert task.spark_authz_zedtoken is None + + +@pytest.mark.asyncio +async def test_create_task_calls_grant_when_flag_on( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service, task_repository, grant, _ = _build_service( + flag_accounts="acct-1", + principal=_principal(user_id="user-A", account_id="acct-1"), + monkeypatch=monkeypatch, + ) + + task = await service.create_task( + agent=_agent(), task_name="t", account_id="acct-1" + ) + + grant.assert_awaited_once() + granted_resource: AgentexResource = grant.await_args.kwargs["resource"] + assert granted_resource.type == AgentexResourceType.task + assert granted_resource.selector == task.id + task_repository.create.assert_awaited_once() + assert task.creator_user_id == "user-A" + # Provider.spark.grant returns {} today — no zedtoken yet. + assert task.spark_authz_zedtoken is None + + +@pytest.mark.asyncio +async def test_delete_task_calls_revoke_when_flag_on( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service, task_repository, _, revoke = _build_service( + flag_accounts="acct-1", + principal=_principal(user_id="user-A", account_id="acct-1"), + monkeypatch=monkeypatch, + ) + + existing = TaskEntity( + id=orm_id(), + name="t", + status=TaskStatus.RUNNING, + creator_user_id="user-A", + ) + task_repository.get = AsyncMock(return_value=existing) + + use_case = TasksUseCase(task_service=service) + await use_case.delete_task(id=existing.id) + + revoke.assert_awaited_once() + revoked_resource: AgentexResource = revoke.await_args.kwargs["resource"] + assert revoked_resource.type == AgentexResourceType.task + assert revoked_resource.selector == existing.id + + +@pytest.mark.asyncio +async def test_create_task_idempotent_grant_returns_empty_dict( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """sgp-authz uses TOUCH semantics — a second grant on the same tuple is a no-op + and the adapter returns ``{}``. The service must persist the row either way.""" + grant = AsyncMock(side_effect=[{}, {}]) + service, task_repository, grant_ref, _ = _build_service( + flag_accounts="acct-1", + principal=_principal(user_id="user-A", account_id="acct-1"), + grant=grant, + monkeypatch=monkeypatch, + ) + + first = await service.create_task( + agent=_agent(), task_name="t1", account_id="acct-1" + ) + second = await service.create_task( + agent=_agent(), task_name="t2", account_id="acct-1" + ) + + assert grant_ref.await_count == 2 + assert task_repository.create.await_count == 2 + assert first.id != second.id + + +@pytest.mark.asyncio +async def test_create_task_grant_failure_prevents_db_row( + monkeypatch: pytest.MonkeyPatch, +) -> None: + grant = AsyncMock(side_effect=RuntimeError("spark unavailable")) + service, task_repository, _, _ = _build_service( + flag_accounts="acct-1", + principal=_principal(user_id="user-A", account_id="acct-1"), + grant=grant, + monkeypatch=monkeypatch, + ) + + with pytest.raises(RuntimeError, match="spark unavailable"): + await service.create_task( + agent=_agent(), task_name="t", account_id="acct-1" + ) + + task_repository.create.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_create_task_skips_grant_when_no_creator_resolvable( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """If neither user_id nor service_account_id is available on the principal, + the dual-write is a no-op (logged) and the row still lands without a tuple.""" + service, task_repository, grant, _ = _build_service( + flag_accounts="acct-1", + principal=_principal(user_id=None, account_id="acct-1"), + monkeypatch=monkeypatch, + ) + + task = await service.create_task( + agent=_agent(), task_name="t", account_id="acct-1" + ) + + grant.assert_not_called() + task_repository.create.assert_awaited_once() + assert task.creator_user_id is None + assert task.creator_service_account_id is None diff --git a/agentex/tests/integration/test_task_stream.py b/agentex/tests/integration/test_task_stream.py index 289010ee..9b8a50ea 100644 --- a/agentex/tests/integration/test_task_stream.py +++ b/agentex/tests/integration/test_task_stream.py @@ -1,13 +1,23 @@ import asyncio +from unittest.mock import AsyncMock, Mock import pytest from src.domain.entities.agents import ACPType, AgentEntity, AgentStatus from src.domain.entities.tasks import TaskEntity, TaskStatus from src.domain.use_cases.streams_use_case import StreamsUseCase from src.domain.use_cases.tasks_use_case import TasksUseCase +from src.utils.feature_flags import FeatureFlagProvider from src.utils.ids import orm_id +def _make_noop_authorization_service() -> Mock: + svc = Mock() + svc.principal_context = None + svc.grant = AsyncMock(return_value=None) + svc.revoke = AsyncMock(return_value=None) + return svc + + @pytest.mark.asyncio @pytest.mark.integration class TestTaskEventStream: @@ -76,6 +86,8 @@ async def send_message(self, *args, **kwargs): task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], stream_repository=isolated_repositories["redis_stream_repository"], + authorization_service=_make_noop_authorization_service(), + feature_flags=FeatureFlagProvider(), ) return TasksUseCase(task_service=task_service) @@ -103,6 +115,8 @@ async def send_message(self, *args, **kwargs): task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], stream_repository=isolated_repositories["redis_stream_repository"], + authorization_service=_make_noop_authorization_service(), + feature_flags=FeatureFlagProvider(), ) environment_variables = EnvironmentVariables.refresh() diff --git a/agentex/tests/unit/services/test_task_service.py b/agentex/tests/unit/services/test_task_service.py index eb096eb1..70ad64d2 100644 --- a/agentex/tests/unit/services/test_task_service.py +++ b/agentex/tests/unit/services/test_task_service.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock from uuid import uuid4 import pytest @@ -18,6 +18,7 @@ from src.domain.repositories.task_repository import TaskRepository from src.domain.repositories.task_state_repository import TaskStateRepository from src.domain.services.task_service import AgentTaskService +from src.utils.feature_flags import FeatureFlagProvider async def create_or_get_agent(agent_repository, agent): @@ -78,12 +79,18 @@ def task_service( redis_stream_repository, ): """Create TaskService instance with real repositories and mocked ACP client""" + authorization_service = Mock() + authorization_service.principal_context = None + authorization_service.grant = AsyncMock(return_value={}) + authorization_service.revoke = AsyncMock(return_value=None) return AgentTaskService( acp_client=mock_acp_client, task_repository=task_repository, task_state_repository=task_state_repository, event_repository=event_repository, stream_repository=redis_stream_repository, + authorization_service=authorization_service, + feature_flags=FeatureFlagProvider(), ) diff --git a/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py b/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py index 60914adc..7fe7c2d7 100644 --- a/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py +++ b/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py @@ -3,7 +3,7 @@ Ensures legacy "agentic" agents continue to work alongside new "async" agents. """ -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock from uuid import uuid4 import pytest @@ -23,6 +23,15 @@ from src.domain.services.task_service import AgentTaskService from src.domain.use_cases.agents_acp_use_case import AgentsACPUseCase from src.domain.use_cases.agents_use_case import AgentsUseCase +from src.utils.feature_flags import FeatureFlagProvider + + +def _noop_authorization_service() -> Mock: + svc = Mock() + svc.principal_context = None + svc.grant = AsyncMock(return_value={}) + svc.revoke = AsyncMock(return_value=None) + return svc @pytest.mark.unit @@ -95,6 +104,8 @@ async def test_agentic_agent_forwards_task_to_acp(self): task_repository=task_repo, event_repository=event_repo, stream_repository=stream_repo, + authorization_service=_noop_authorization_service(), + feature_flags=FeatureFlagProvider(), ) # Create AGENTIC agent @@ -148,6 +159,8 @@ async def test_sync_agent_does_not_forward_task_to_acp(self): task_repository=task_repo, event_repository=event_repo, stream_repository=stream_repo, + authorization_service=_noop_authorization_service(), + feature_flags=FeatureFlagProvider(), ) # Create SYNC agent @@ -195,6 +208,8 @@ async def test_async_agent_forwards_task_to_acp(self): task_repository=task_repo, event_repository=event_repo, stream_repository=stream_repo, + authorization_service=_noop_authorization_service(), + feature_flags=FeatureFlagProvider(), ) # Create ASYNC agent diff --git a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py index b48751a4..73c1314d 100644 --- a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py +++ b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock from uuid import uuid4 from zoneinfo import ZoneInfo @@ -36,6 +36,7 @@ from src.domain.services.task_message_service import TaskMessageService from src.domain.services.task_service import AgentTaskService from src.domain.use_cases.agents_acp_use_case import AgentsACPUseCase +from src.utils.feature_flags import FeatureFlagProvider # UTC timezone constant UTC = ZoneInfo("UTC") @@ -129,12 +130,18 @@ def task_service( redis_stream_repository, ): """Real AgentTaskService instance""" + authorization_service = Mock() + authorization_service.principal_context = None + authorization_service.grant = AsyncMock(return_value={}) + authorization_service.revoke = AsyncMock(return_value=None) return AgentTaskService( task_repository=task_repository, task_state_repository=task_state_repository, event_repository=event_repository, acp_client=agent_acp_service, stream_repository=redis_stream_repository, + authorization_service=authorization_service, + feature_flags=FeatureFlagProvider(), )