From 06f2da19d1053843503c56cf51a89e1f6d7a5952 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 7 May 2026 20:14:39 +0530 Subject: [PATCH 01/12] feat(server): add runtime auth namespace cutover Add explicit none, api_key, and jwt runtime auth modes, including a generic no-auth provider. Move controls, bindings, policies, agents, and evaluation storage lookups onto principal namespace scoping. Cover auth mode selection and principal namespace isolation with server tests. --- .../auth_framework/__init__.py | 7 +- .../auth_framework/config.py | 120 +++++++++++---- .../auth_framework/core.py | 16 +- .../auth_framework/providers/__init__.py | 2 + .../auth_framework/providers/header.py | 39 +++-- .../auth_framework/providers/http_upstream.py | 6 +- .../auth_framework/providers/local_jwt.py | 2 +- .../auth_framework/providers/no_auth.py | 29 ++++ .../agent_control_server/endpoints/agents.py | 92 +++++++----- .../agent_control_server/endpoints/auth.py | 11 +- .../endpoints/control_bindings.py | 47 +++--- .../endpoints/controls.py | 87 +++++++---- .../endpoints/evaluation.py | 26 +++- .../endpoints/policies.py | 77 +++++++--- server/src/agent_control_server/main.py | 19 ++- .../agent_control_server/services/controls.py | 140 +++++++++++++---- server/tests/test_auth_framework.py | 96 +++++++++++- server/tests/test_controls_additional.py | 15 +- server/tests/test_controls_auth.py | 28 ++-- server/tests/test_principal_namespace_flow.py | 141 ++++++++++++++++++ server/tests/test_target_merged_contract.py | 6 +- 21 files changed, 753 insertions(+), 253 deletions(-) create mode 100644 server/src/agent_control_server/auth_framework/providers/no_auth.py create mode 100644 server/tests/test_principal_namespace_flow.py diff --git a/server/src/agent_control_server/auth_framework/__init__.py b/server/src/agent_control_server/auth_framework/__init__.py index 57368d57..0333f2cc 100644 --- a/server/src/agent_control_server/auth_framework/__init__.py +++ b/server/src/agent_control_server/auth_framework/__init__.py @@ -2,10 +2,9 @@ Endpoints declare an :class:`Operation` they need; an installed :class:`RequestAuthorizer` decides whether the request is allowed and -returns the resulting :class:`Principal`. Two providers ship in-tree: -:class:`HeaderAuthProvider` (uses local credential checks) and -:class:`HttpUpstreamAuthProvider` (delegates to a configurable -upstream HTTP service). +returns the resulting :class:`Principal`. Providers ship in-tree for +disabled auth, local credential checks, upstream HTTP authorization, +and local runtime-JWT verification. """ from .core import ( diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 92107b0e..c8f428dc 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -8,15 +8,19 @@ - **Default flow** (everything except runtime). One authorizer handles every operation that does not have a specific override: - :class:`HeaderAuthProvider` (local credentials) or + :class:`NoAuthProvider` (no credentials), + :class:`HeaderAuthProvider` (local API keys), or :class:`HttpUpstreamAuthProvider` (forwards to a configurable URL). -- **Runtime flow.** When ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is - configured, :class:`LocalJwtVerifyProvider` is registered as the - override for :data:`Operation.RUNTIME_USE`; the - ``runtime.token_exchange`` operation continues to flow through the - default authorizer because the exchange itself is shaped like a - management call (forward credential, get grant). Without the secret, - no runtime override is installed. +- **Runtime flow.** ``AGENT_CONTROL_RUNTIME_AUTH_MODE`` selects the + override for :data:`Operation.RUNTIME_USE`: ``none`` uses + :class:`NoAuthProvider`, ``api_key`` uses + :class:`HeaderAuthProvider`, and ``jwt`` uses + :class:`LocalJwtVerifyProvider`. When the mode is unset, startup + preserves historical behavior by selecting ``jwt`` if + ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set, otherwise ``api_key``. + The ``runtime.token_exchange`` operation continues to flow through + the default authorizer because the exchange itself is shaped like a + management call (forward credential, get grant). """ from __future__ import annotations @@ -30,6 +34,7 @@ HeaderAuthProvider, HttpUpstreamAuthProvider, LocalJwtVerifyProvider, + NoAuthProvider, ) from .providers.http_upstream import HttpUpstreamConfig @@ -43,6 +48,7 @@ _UPSTREAM_TOKEN_HEADER_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER" # Runtime flow. +_RUNTIME_MODE_ENV = "AGENT_CONTROL_RUNTIME_AUTH_MODE" _RUNTIME_TOKEN_SECRET_ENV = "AGENT_CONTROL_RUNTIME_TOKEN_SECRET" _RUNTIME_TOKEN_TTL_ENV = "AGENT_CONTROL_RUNTIME_TOKEN_TTL_SECONDS" _DEFAULT_RUNTIME_TOKEN_TTL_SECONDS = 300 @@ -80,15 +86,19 @@ def configure_auth_from_env() -> None: Default flow: - - ``AGENT_CONTROL_AUTH_MODE=header`` (default): :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_AUTH_MODE=none``: :class:`NoAuthProvider`. + - ``AGENT_CONTROL_AUTH_MODE=api_key`` (default): :class:`HeaderAuthProvider`. + ``header`` remains accepted as a backwards-compatible alias. - ``AGENT_CONTROL_AUTH_MODE=http_upstream``: :class:`HttpUpstreamAuthProvider` pointed at ``AGENT_CONTROL_AUTH_UPSTREAM_URL``. Runtime flow: - - When ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set, register - :class:`LocalJwtVerifyProvider` as an override for - :data:`Operation.RUNTIME_USE`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=none``: :class:`NoAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key`` (default when no runtime + token secret is configured): :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=jwt`` (default when a runtime token + secret is configured): :class:`LocalJwtVerifyProvider`. Clears any previously-installed default and operation overrides before installing fresh ones, so reconfiguration cannot leave @@ -101,27 +111,27 @@ def configure_auth_from_env() -> None: global _runtime_auth_config clear_authorizers() _active_providers.clear() - _runtime_auth_config = _load_runtime_auth_config() + runtime_mode = _resolve_runtime_mode() + _runtime_auth_config = ( + _load_runtime_auth_config(require_secret=True) if runtime_mode == "jwt" else None + ) default = _build_default_provider() set_authorizer(default) _active_providers.append(default) - if _runtime_auth_config is not None: - runtime_provider = LocalJwtVerifyProvider(secret=_runtime_auth_config.secret) - set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) - _active_providers.append(runtime_provider) + runtime_provider = _build_runtime_provider(runtime_mode, _runtime_auth_config) + set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) + _active_providers.append(runtime_provider) + if runtime_mode == "jwt": _logger.info( - "Runtime auth enabled: LocalJwtVerifyProvider override installed for %s", + "Runtime auth provider: jwt override installed for %s", Operation.RUNTIME_USE.value, ) else: - _logger.warning( - "Runtime auth disabled (%s not set); %s falls through to the " - "default authorizer, which may grant any authenticated credential. " - "Set the runtime token secret to bind runtime calls to a " - "short-lived target-scoped JWT.", - _RUNTIME_TOKEN_SECRET_ENV, + _logger.info( + "Runtime auth provider: %s override installed for %s", + runtime_mode, Operation.RUNTIME_USE.value, ) @@ -172,9 +182,12 @@ def set_runtime_auth_config(config: RuntimeAuthConfig | None) -> None: def _build_default_provider() -> RequestAuthorizer: - mode = os.environ.get(_MODE_ENV, "header").strip().lower() - if mode == "header": - _logger.info("Default auth provider: header (local credentials)") + mode = os.environ.get(_MODE_ENV, "api_key").strip().lower() + if mode in {"none", "no_auth"}: + _logger.info("Default auth provider: none") + return NoAuthProvider() + if mode in {"api_key", "header"}: + _logger.info("Default auth provider: api_key (local credentials)") return HeaderAuthProvider() if mode == "http_upstream": url = os.environ.get(_UPSTREAM_URL_ENV) @@ -192,19 +205,60 @@ def _build_default_provider() -> RequestAuthorizer: service_token_header=token_header, ) ) - raise RuntimeError(f"Unknown {_MODE_ENV}={mode!r}; expected 'header' or 'http_upstream'.") + raise RuntimeError( + f"Unknown {_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'http_upstream'." + ) + + +def _resolve_runtime_mode() -> str: + raw = os.environ.get(_RUNTIME_MODE_ENV) + if raw is None or not raw.strip(): + return "jwt" if os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) else "api_key" + + mode = raw.strip().lower() + if mode in {"none", "no_auth"}: + return "none" + if mode in {"api_key", "header"}: + return "api_key" + if mode == "jwt": + return mode + raise RuntimeError( + f"Unknown {_RUNTIME_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'jwt'." + ) + + +def _build_runtime_provider( + mode: str, + config: RuntimeAuthConfig | None, +) -> RequestAuthorizer: + if mode == "none": + return NoAuthProvider() + if mode == "api_key": + return HeaderAuthProvider() + if mode == "jwt": + if config is None: + raise RuntimeError(f"{_RUNTIME_MODE_ENV}=jwt but runtime auth config is missing.") + return LocalJwtVerifyProvider(secret=config.secret) + raise RuntimeError( + f"Unknown runtime auth mode {mode!r}; expected 'none', 'api_key', or 'jwt'." + ) -def _load_runtime_auth_config() -> RuntimeAuthConfig | None: +def _load_runtime_auth_config(*, require_secret: bool = False) -> RuntimeAuthConfig | None: """Parse, validate, and return the runtime-auth config from env. - Returns ``None`` when no runtime secret is configured. Raises - ``RuntimeError`` when the secret is too short or the TTL is invalid - so misconfiguration surfaces at startup, not on the first - request-time mint. + Returns ``None`` when no runtime secret is configured and + ``require_secret`` is false. Raises ``RuntimeError`` when the + secret is required, too short, or the TTL is invalid so + misconfiguration surfaces at startup, not on the first request-time + mint. """ secret = os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) if not secret: + if require_secret: + raise RuntimeError( + f"{_RUNTIME_MODE_ENV}=jwt requires {_RUNTIME_TOKEN_SECRET_ENV} to be set." + ) return None if len(secret.encode("utf-8")) < _RUNTIME_TOKEN_SECRET_MIN_BYTES: raise RuntimeError( diff --git a/server/src/agent_control_server/auth_framework/core.py b/server/src/agent_control_server/auth_framework/core.py index 9299b441..e0ea6da7 100644 --- a/server/src/agent_control_server/auth_framework/core.py +++ b/server/src/agent_control_server/auth_framework/core.py @@ -42,14 +42,21 @@ class Operation(StrEnum): CONTROL_BINDINGS_READ = "control_bindings.read" CONTROL_BINDINGS_WRITE = "control_bindings.write" - # Runtime token exchange — wired on the exchange endpoint. + # Runtime token exchange - wired on the exchange endpoint. RUNTIME_TOKEN_EXCHANGE = "runtime.token_exchange" - # Reserved for follow-up migrations; not yet wired on endpoints. CONTROLS_READ = "controls.read" CONTROLS_CREATE = "controls.create" CONTROLS_UPDATE = "controls.update" CONTROLS_DELETE = "controls.delete" + POLICIES_READ = "policies.read" + POLICIES_CREATE = "policies.create" + POLICIES_UPDATE = "policies.update" + POLICIES_DELETE = "policies.delete" + AGENTS_READ = "agents.read" + AGENTS_CREATE = "agents.create" + AGENTS_UPDATE = "agents.update" + AGENTS_DELETE = "agents.delete" RUNTIME_USE = "runtime.use" @@ -61,8 +68,7 @@ class Principal: namespace_key: The namespace the request runs in. Endpoints use this to scope every read and write. is_admin: Whether the caller has admin privileges in the - current namespace. Mostly informational for endpoints that - still gate on the legacy admin-key contract. + current namespace. caller_id: Opaque, provider-supplied identifier for the caller (e.g., a key fingerprint or user id). Useful for audit logging; never echo back to clients. @@ -122,7 +128,7 @@ def set_authorizer( Without ``operation``, this becomes the default authorizer used by every operation that does not have a specific override. With - ``operation``, it overrides the default for that operation only — + ``operation``, it overrides the default for that operation only - used to route a different family (e.g., runtime) through a different provider. diff --git a/server/src/agent_control_server/auth_framework/providers/__init__.py b/server/src/agent_control_server/auth_framework/providers/__init__.py index e8a68486..ad5d6b38 100644 --- a/server/src/agent_control_server/auth_framework/providers/__init__.py +++ b/server/src/agent_control_server/auth_framework/providers/__init__.py @@ -3,10 +3,12 @@ from .header import AccessLevel, HeaderAuthProvider from .http_upstream import HttpUpstreamAuthProvider from .local_jwt import LocalJwtVerifyProvider +from .no_auth import NoAuthProvider __all__ = [ "AccessLevel", "HeaderAuthProvider", "HttpUpstreamAuthProvider", "LocalJwtVerifyProvider", + "NoAuthProvider", ] diff --git a/server/src/agent_control_server/auth_framework/providers/header.py b/server/src/agent_control_server/auth_framework/providers/header.py index f76936a1..228ec443 100644 --- a/server/src/agent_control_server/auth_framework/providers/header.py +++ b/server/src/agent_control_server/auth_framework/providers/header.py @@ -1,23 +1,14 @@ """Default :class:`RequestAuthorizer` that uses local credentials only. -Resolves the namespace from a header (or falls back to -``DEFAULT_NAMESPACE_KEY``) and enforces a per-operation access level -using the legacy API-key + session-cookie credential check from -:mod:`agent_control_server.auth`. Behavior matches the pre-framework -local auth path verbatim: +Returns ``DEFAULT_NAMESPACE_KEY`` and enforces a per-operation access +level using the local API-key + session-cookie credential check from +:mod:`agent_control_server.auth`: - ``ADMIN`` operations require an admin key (or admin session). - ``AUTHENTICATED`` operations require any valid credential. - ``PUBLIC`` operations are open. -- When ``api_key_enabled`` is ``False`` (no-auth mode), every - operation succeeds with a non-admin :class:`Principal` — preserved - by the underlying credential check. - -The header lookup is wired but currently inert: the provider always -returns the default namespace because non-binding write endpoints -still hardcode it. The header is kept here so a follow-up that -threads namespace resolution through the rest of the API can flip it -on without changing the provider contract. +- When the underlying local credential layer is disabled, every + operation succeeds with a non-admin :class:`Principal`. """ from __future__ import annotations @@ -51,6 +42,14 @@ class AccessLevel(Enum): Operation.CONTROLS_CREATE: AccessLevel.ADMIN, Operation.CONTROLS_UPDATE: AccessLevel.ADMIN, Operation.CONTROLS_DELETE: AccessLevel.ADMIN, + Operation.POLICIES_READ: AccessLevel.AUTHENTICATED, + Operation.POLICIES_CREATE: AccessLevel.ADMIN, + Operation.POLICIES_UPDATE: AccessLevel.ADMIN, + Operation.POLICIES_DELETE: AccessLevel.ADMIN, + Operation.AGENTS_READ: AccessLevel.AUTHENTICATED, + Operation.AGENTS_CREATE: AccessLevel.AUTHENTICATED, + Operation.AGENTS_UPDATE: AccessLevel.ADMIN, + Operation.AGENTS_DELETE: AccessLevel.ADMIN, Operation.RUNTIME_TOKEN_EXCHANGE: AccessLevel.AUTHENTICATED, Operation.RUNTIME_USE: AccessLevel.AUTHENTICATED, } @@ -60,7 +59,7 @@ class HeaderAuthProvider(RequestAuthorizer): """Default authorizer. For each operation's configured access level, validates the - request's credentials via the legacy local check; on success, + request's credentials via the local credential check; on success, returns a :class:`Principal` scoped to the resolved namespace. """ @@ -100,8 +99,7 @@ async def authorize( ) # Runtime token exchange returns a normalized scope grant so the # exchange endpoint can require ``runtime.use`` uniformly across - # providers; an upstream that explicitly grants no scopes ends - # up with an empty tuple and is rejected. + # providers. scopes: tuple[str, ...] = ( (Operation.RUNTIME_USE.value,) if operation is Operation.RUNTIME_TOKEN_EXCHANGE else () ) @@ -113,10 +111,7 @@ async def authorize( ) def _resolve_namespace_key(self, request: Request) -> str: - # The provider always returns the default namespace because - # non-binding write endpoints still hardcode it; serving - # anything else here would create rows the rest of the API - # cannot find. The branch is preserved so a future change can - # lift the lock without touching the provider contract. + # Local credentials do not carry namespace metadata. Providers + # that resolve a namespace can return a different principal. del request return self._default_namespace_key diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index a97a3de8..8d5c850c 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -67,8 +67,8 @@ class _UpstreamGrant(BaseModel): """Strict schema for the upstream authorization-service response. Unknown fields are tolerated (so the upstream can evolve), but every - *known* field is type-checked. A wrong type on any field — or a - half-supplied target binding — causes the provider to fail closed + *known* field is type-checked. A wrong type on any field - or a + half-supplied target binding - causes the provider to fail closed with a 502. """ @@ -108,7 +108,7 @@ def _target_must_be_paired(self) -> _UpstreamGrant: A target is meaningful only as a ``(target_type, target_id)`` pair; allowing one side without the other would let a malformed grant pass and the exchange endpoint mint a token for the - request's value of the missing half — outside the upstream's + request's value of the missing half - outside the upstream's intended authorization. """ if (self.target_type is None) != (self.target_id is None): diff --git a/server/src/agent_control_server/auth_framework/providers/local_jwt.py b/server/src/agent_control_server/auth_framework/providers/local_jwt.py index bb448503..8620d3b6 100644 --- a/server/src/agent_control_server/auth_framework/providers/local_jwt.py +++ b/server/src/agent_control_server/auth_framework/providers/local_jwt.py @@ -6,7 +6,7 @@ returns a :class:`Principal` carrying the bound target. When a ``context_builder`` on the dependency surfaces ``target_type`` / ``target_id``, the provider also enforces that they match the token's -binding — runtime endpoints get the request-target check for free. +binding - runtime endpoints get the request-target check for free. """ from __future__ import annotations diff --git a/server/src/agent_control_server/auth_framework/providers/no_auth.py b/server/src/agent_control_server/auth_framework/providers/no_auth.py new file mode 100644 index 00000000..509ca4f3 --- /dev/null +++ b/server/src/agent_control_server/auth_framework/providers/no_auth.py @@ -0,0 +1,29 @@ +"""Authorizer for deployments that intentionally disable authentication.""" + +from __future__ import annotations + +from typing import Any + +from fastapi import Request + +from ...models import DEFAULT_NAMESPACE_KEY +from ..core import Operation, Principal, RequestAuthorizer + + +class NoAuthProvider(RequestAuthorizer): + """Allows every operation and returns the default namespace.""" + + def __init__(self, *, default_namespace_key: str = DEFAULT_NAMESPACE_KEY) -> None: + self._default_namespace_key = default_namespace_key + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request, context + scopes: tuple[str, ...] = ( + (Operation.RUNTIME_USE.value,) if operation is Operation.RUNTIME_TOKEN_EXCHANGE else () + ) + return Principal(namespace_key=self._default_namespace_key, scopes=scopes) diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index 034ae35f..ac099911 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -36,7 +36,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import RequireAPIKey, require_admin_key +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import ( APIValidationError, @@ -53,7 +53,6 @@ Policy, agent_policies, ) -from ..namespace import get_namespace_key from ..services.agent_names import normalize_agent_name_or_422 from ..services.controls import ( AgentControlEnabledState, @@ -112,7 +111,7 @@ def _validate_controls_for_agent(agent: Agent, controls: list[Control]) -> list[ agent_evaluators = {e.name: e for e in (agent_data.evaluators or [])} for control in controls: - # Skip unrendered template controls — they have no evaluators to validate. + # Skip unrendered template controls - they have no evaluators to validate. if ( isinstance(control.data, dict) and control.data.get("template") is not None @@ -286,7 +285,7 @@ async def list_agents( limit: int = _DEFAULT_PAGINATION_LIMIT, name: str | None = None, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> ListAgentsResponse: """ List all registered agents with cursor-based pagination. @@ -300,11 +299,13 @@ async def list_agents( limit: Pagination limit (default 20, max 100) name: Optional name filter (case-insensitive partial match) db: Database session (injected) - namespace_key: Resolved namespace for the request + principal: Authorized request principal Returns: ListAgentsResponse with agent summaries and pagination info """ + namespace_key = principal.namespace_key + # Clamp limit limit = min(max(1, limit), _MAX_PAGINATION_LIMIT) @@ -377,14 +378,20 @@ async def list_agents( agent_policies.c.agent_name, agent_policies.c.policy_id, ) - .where(agent_policies.c.agent_name.in_(agent_names)) + .where( + agent_policies.c.namespace_key == namespace_key, + agent_policies.c.agent_name.in_(agent_names), + ) .order_by(agent_policies.c.agent_name, agent_policies.c.policy_id) ) policy_ids_result = await db.execute(policy_ids_query) for assoc_agent_name, policy_id in policy_ids_result.all(): policy_ids_map.setdefault(assoc_agent_name, []).append(policy_id) - control_counts_map = await control_service.list_active_control_counts_by_agent(agent_names) + control_counts_map = await control_service.list_active_control_counts_by_agent( + agent_names, + namespace_key=namespace_key, + ) # Build summaries summaries: list[AgentSummary] = [] @@ -436,9 +443,8 @@ async def list_agents( ) async def init_agent( request: InitAgentRequest, - client: RequireAPIKey, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_CREATE)), ) -> InitAgentResponse: """ Register a new agent or update an existing agent's steps and metadata. @@ -462,10 +468,13 @@ async def init_agent( Args: request: Agent metadata and step schemas db: Database session (injected) + principal: Authorized request principal Returns: InitAgentResponse with created flag and the effective controls """ + namespace_key = principal.namespace_key + # Check for evaluator name collisions with built-in evaluators builtin_names = _get_builtin_evaluator_names() for ev in request.evaluators: @@ -835,7 +844,7 @@ async def init_agent( async def get_agent( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetAgentResponse: """ Retrieve agent metadata and all registered steps. @@ -845,8 +854,7 @@ async def get_agent( Args: agent_name: Agent identifier db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: GetAgentResponse with agent metadata and step list @@ -855,6 +863,7 @@ async def get_agent( HTTPException 404: Agent not found HTTPException 422: Agent data is corrupted """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where(Agent.name == agent_name, Agent.namespace_key == namespace_key) @@ -917,7 +926,7 @@ async def _get_agent_or_404( The lookup is always namespace-scoped: an agent that exists only in another namespace surfaces as 404 (non-disclosing) so duplicate - names across namespaces — which the schema explicitly permits — + names across namespaces - which the schema explicitly permits - cannot be addressed across the namespace boundary. """ normalized_agent_name = normalize_agent_name_or_422(agent_name) @@ -940,7 +949,6 @@ async def _get_agent_or_404( @router.post( "/{agent_name}/policies/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Associate policy with agent", response_description="Success confirmation", @@ -949,9 +957,10 @@ async def add_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Associate a policy with an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1017,7 +1026,6 @@ async def add_agent_policy( @router.post( "/{agent_name}/policy/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=SetPolicyResponse, summary="Assign policy to agent (compatibility)", response_description="Success status with previous policy ID", @@ -1026,9 +1034,10 @@ async def set_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> SetPolicyResponse: """Compatibility endpoint that replaces all policy associations with one policy.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1117,9 +1126,10 @@ async def set_agent_policy( async def get_agent_policies( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetAgentPoliciesResponse: """List policy IDs associated with an agent.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) result = await db.execute( select(agent_policies.c.policy_id) @@ -1141,9 +1151,10 @@ async def get_agent_policies( async def get_agent_policy( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetPolicyResponse: """Compatibility endpoint that returns the first associated policy.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( select(Policy.id) @@ -1172,7 +1183,6 @@ async def get_agent_policy( @router.delete( "/{agent_name}/policies/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove policy association from agent", response_description="Success confirmation", @@ -1181,13 +1191,14 @@ async def remove_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Remove a policy association from an agent. Idempotent for existing resources: removing a non-associated link is a no-op. Missing agent/policy resources still return 404. """ + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1230,7 +1241,6 @@ async def remove_agent_policy( @router.delete( "/{agent_name}/policies", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove all policy associations from agent", response_description="Success confirmation", @@ -1238,9 +1248,10 @@ async def remove_agent_policy( async def remove_all_agent_policies( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Remove all policy associations from an agent.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) try: @@ -1271,7 +1282,6 @@ async def remove_all_agent_policies( @router.delete( "/{agent_name}/policy", - dependencies=[Depends(require_admin_key)], response_model=DeletePolicyResponse, summary="Remove agent's policy assignment (compatibility)", response_description="Success confirmation", @@ -1279,9 +1289,10 @@ async def remove_all_agent_policies( async def delete_agent_policy( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> DeletePolicyResponse: """Compatibility endpoint that removes all policy associations.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) existing_policy_result = await db.execute( @@ -1328,7 +1339,6 @@ async def delete_agent_policy( @router.post( "/{agent_name}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Associate control directly with agent", response_description="Success confirmation", @@ -1337,9 +1347,10 @@ async def add_agent_control( agent_name: str, control_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Associate a control directly with an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) control_service = ControlService(db) control = await control_service.get_active_control_or_404( @@ -1389,7 +1400,6 @@ async def add_agent_control( @router.delete( "/{agent_name}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=RemoveAgentControlResponse, summary="Remove direct control association from agent", response_description="Success confirmation", @@ -1398,9 +1408,10 @@ async def remove_agent_control( agent_name: str, control_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> RemoveAgentControlResponse: """Remove a direct control association from an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) control_service = ControlService(db) await control_service.get_active_control_or_404(control_id, namespace_key=namespace_key) @@ -1481,7 +1492,7 @@ async def list_agent_controls( description="Optional opaque target identifier. Required when target_type is supplied.", ), db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> AgentControlsResponse: """ List protection controls effective for an agent. @@ -1506,7 +1517,7 @@ async def list_agent_controls( target_type: Optional opaque target kind (paired with target_id) target_id: Optional opaque target identifier (paired with target_type) db: Database session (injected) - namespace_key: Namespace scoping for the resolution (injected) + principal: Authorized request principal Returns: AgentControlsResponse with controls matching the requested state filters @@ -1515,6 +1526,8 @@ async def list_agent_controls( HTTPException 400: target_type and target_id were not supplied together HTTPException 404: Agent not found """ + namespace_key = principal.namespace_key + if (target_type is None) != (target_id is None): raise BadRequestError( error_code=ErrorCode.VALIDATION_ERROR, @@ -1572,7 +1585,7 @@ async def list_agent_evaluators( cursor: str | None = None, limit: int = _DEFAULT_PAGINATION_LIMIT, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> ListEvaluatorsResponse: """ List all evaluator schemas registered with an agent. @@ -1586,8 +1599,7 @@ async def list_agent_evaluators( cursor: Optional cursor for pagination (name of last evaluator from previous page) limit: Pagination limit (default 20, max 100) db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: ListEvaluatorsResponse with evaluator schemas and pagination @@ -1595,6 +1607,7 @@ async def list_agent_evaluators( Raises: HTTPException 404: Agent not found """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) # Clamp limit limit = min(max(1, limit), _MAX_PAGINATION_LIMIT) @@ -1672,7 +1685,7 @@ async def get_agent_evaluator( agent_name: str, evaluator_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> EvaluatorSchemaItem: """ Get a specific evaluator schema registered with an agent. @@ -1681,8 +1694,7 @@ async def get_agent_evaluator( agent_name: Agent identifier evaluator_name: Name of the evaluator db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: EvaluatorSchemaItem with schema details @@ -1690,6 +1702,7 @@ async def get_agent_evaluator( Raises: HTTPException 404: Agent or evaluator not found """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where(Agent.name == agent_name, Agent.namespace_key == namespace_key) @@ -1734,7 +1747,6 @@ async def get_agent_evaluator( @router.patch( "/{agent_name}", - dependencies=[Depends(require_admin_key)], response_model=PatchAgentResponse, summary="Modify agent (remove steps/evaluators)", response_description="Lists of removed items", @@ -1743,7 +1755,7 @@ async def patch_agent( agent_name: str, request: PatchAgentRequest, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> PatchAgentResponse: """ Remove steps and/or evaluators from an agent. @@ -1755,6 +1767,7 @@ async def patch_agent( agent_name: Agent identifier request: Lists of step/evaluator identifiers to remove db: Database session (injected) + principal: Authorized request principal Returns: PatchAgentResponse with lists of actually removed items @@ -1763,6 +1776,7 @@ async def patch_agent( HTTPException 404: Agent not found HTTPException 500: Database error during update """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where( diff --git a/server/src/agent_control_server/endpoints/auth.py b/server/src/agent_control_server/endpoints/auth.py index 1a23baa8..f80cd2fa 100644 --- a/server/src/agent_control_server/endpoints/auth.py +++ b/server/src/agent_control_server/endpoints/auth.py @@ -2,9 +2,8 @@ The runtime auth flow is two-phase: this endpoint is phase one. The caller presents a long-lived credential plus ``(target_type, -target_id)``; the default authorizer (typically -:class:`HttpUpstreamAuthProvider` in production) authenticates the -credential and authorizes the implied +target_id)``; the default authorizer authenticates the credential and +authorizes the implied :data:`Operation.RUNTIME_TOKEN_EXCHANGE`. On success, this endpoint mints a short-lived local runtime token bound to the supplied target and returns it. Subsequent target-bearing runtime calls present the @@ -130,8 +129,8 @@ async def runtime_token_exchange( actor_id = principal.caller_id or "anonymous" # The exchange endpoint requires the authorizer to explicitly grant - # runtime.use. Providers that do not surface scopes (legacy local - # provider) supply a normalized grant for ``RUNTIME_TOKEN_EXCHANGE``; + # runtime.use. Local providers supply a normalized grant for + # ``RUNTIME_TOKEN_EXCHANGE``; # upstream providers that return an explicit empty scopes array fail # closed here rather than escalating to runtime.use. if Operation.RUNTIME_USE.value not in principal.scopes: @@ -155,7 +154,7 @@ async def runtime_token_exchange( ) except UpstreamGrantExpiredError as exc: # Upstream returned a grant whose ``expires_at`` is already in - # the past — minting would hand the caller a token that's dead + # the past - minting would hand the caller a token that's dead # on arrival. Distinguished from the misconfigured case so the # error code and status reflect "upstream returned bad data." raise APIError( diff --git a/server/src/agent_control_server/endpoints/control_bindings.py b/server/src/agent_control_server/endpoints/control_bindings.py index 92798ae1..d2fe4b44 100644 --- a/server/src/agent_control_server/endpoints/control_bindings.py +++ b/server/src/agent_control_server/endpoints/control_bindings.py @@ -26,7 +26,6 @@ from ..db import get_async_db from ..errors import BadRequestError from ..models import ControlBinding -from ..namespace import get_namespace_key from ..services.control_bindings import ControlBindingsService router = APIRouter(prefix="/control-bindings", tags=["control-bindings"]) @@ -94,26 +93,21 @@ def _to_response(binding: ControlBinding) -> GetControlBindingResponse: async def create_control_binding( request: CreateControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> CreateControlBindingResponse: """Attach a control to an opaque external target. Each binding row is scoped to the request namespace as resolved by - ``get_namespace_key``. The auth chain still runs via - ``require_operation`` for authentication and authorization, but the - storage namespace is taken from the same resolver the rest of the - server uses so binding writes and runtime reads stay in lockstep - until auth-derived namespace resolution lands across every endpoint. + the active authorizer. """ service = ControlBindingsService(db) binding = await service.create_binding( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, @@ -148,20 +142,18 @@ async def list_control_bindings( target_id: str | None = None, control_id: int | None = None, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_READ, context_builder=_binding_list_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> ListControlBindingsResponse: """Return bindings in the request namespace with optional filters and cursor-based pagination. Bindings are ordered by ID descending (newest first). The cursor is opaque to clients: pass back the ``next_cursor`` value verbatim to fetch the following page. The - storage namespace is resolved by ``get_namespace_key`` so this - listing stays in lockstep with the rest of the server's reads. + storage namespace is resolved by the active authorizer. """ parsed_cursor: int | None if cursor is None: @@ -177,7 +169,7 @@ async def list_control_bindings( ) from exc service = ControlBindingsService(db) page = await service.list_bindings( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, cursor=parsed_cursor, limit=limit, target_type=target_type, @@ -204,8 +196,7 @@ async def list_control_bindings( async def get_control_binding( binding_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_READ)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_READ)), ) -> GetControlBindingResponse: """Read a single control binding by surrogate ID. @@ -218,7 +209,9 @@ async def get_control_binding( of which forward ``(target_type, target_id)`` to the authorizer. """ service = ControlBindingsService(db) - binding = await service.get_binding_or_404(namespace_key=namespace_key, binding_id=binding_id) + binding = await service.get_binding_or_404( + namespace_key=principal.namespace_key, binding_id=binding_id + ) return _to_response(binding) @@ -232,8 +225,7 @@ async def patch_control_binding( binding_id: int, request: PatchControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), ) -> PatchControlBindingResponse: """Update the ``enabled`` flag on a control binding. @@ -244,7 +236,7 @@ async def patch_control_binding( """ service = ControlBindingsService(db) binding = await service.set_enabled( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, binding_id=binding_id, enabled=request.enabled, ) @@ -261,8 +253,7 @@ async def patch_control_binding( async def delete_control_binding( binding_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), ) -> DeleteControlBindingResponse: """Delete a control binding by surrogate ID. @@ -272,7 +263,7 @@ async def delete_control_binding( target-scoped detach that forwards the target to the authorizer. """ service = ControlBindingsService(db) - await service.delete_binding(namespace_key=namespace_key, binding_id=binding_id) + await service.delete_binding(namespace_key=principal.namespace_key, binding_id=binding_id) await db.commit() return DeleteControlBindingResponse(success=True) @@ -286,13 +277,12 @@ async def delete_control_binding( async def upsert_control_binding_by_key( request: UpsertControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> UpsertControlBindingResponse: """Idempotent attach using ``(target_type, target_id, control_id)`` as the natural key. Updates ``enabled`` on an existing match; creates a new row @@ -300,7 +290,7 @@ async def upsert_control_binding_by_key( """ service = ControlBindingsService(db) binding, created = await service.upsert_by_natural_key( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, @@ -324,20 +314,19 @@ async def upsert_control_binding_by_key( async def delete_control_binding_by_key( request: DeleteControlBindingByKeyRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> DeleteControlBindingByKeyResponse: """Idempotent detach by natural key. Returns ``deleted=False`` when no matching binding exists. """ service = ControlBindingsService(db) deleted = await service.delete_by_natural_key( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index fcb7cb18..5b01593c 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -229,7 +229,7 @@ async def _materialize_control_input( enabled=enabled, ) - # Incomplete values — only allowed for new controls or already-unrendered + # Incomplete values - only allowed for new controls or already-unrendered # templates. Updating a rendered control with incomplete values is # rejected to prevent silently stripping rendered fields. current_is_rendered = ( @@ -470,7 +470,7 @@ async def render_control_template( async def create_control( request: CreateControlRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> CreateControlResponse: """ Create a new control with a unique name. @@ -492,7 +492,10 @@ async def create_control( control_service = ControlService(db) # Uniqueness check - if await control_service.active_control_name_exists(request.name): + namespace_key = principal.namespace_key + if await control_service.active_control_name_exists( + request.name, namespace_key=namespace_key + ): raise ConflictError( error_code=ErrorCode.CONTROL_NAME_CONFLICT, detail=f"Control with name '{request.name}' already exists", @@ -504,7 +507,11 @@ async def create_control( control_def = await _materialize_control_input(request.data, db=db) control_data = _serialize_control_data(control_def) - control = control_service.create_control(name=request.name, data=control_data) + control = control_service.create_control( + namespace_key=namespace_key, + name=request.name, + data=control_data, + ) try: await control_service.create_version( control, @@ -569,7 +576,7 @@ async def get_control_schema() -> GetControlSchemaResponse: async def get_control( control_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlResponse: """ Retrieve a control by ID including its name and configuration data. @@ -584,7 +591,9 @@ async def get_control( Raises: HTTPException 404: Control not found """ - control = await ControlService(db).get_active_control_or_404(control_id) + control = await ControlService(db).get_active_control_or_404( + control_id, namespace_key=principal.namespace_key + ) control_data = _parse_stored_control_data( control.data, control_name=control.name, @@ -608,7 +617,7 @@ async def get_control( async def get_control_data( control_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlDataResponse: """ Retrieve the configuration data for a control. @@ -626,7 +635,9 @@ async def get_control_data( HTTPException 404: Control not found HTTPException 422: Control data is corrupted """ - control = await ControlService(db).get_active_control_or_404(control_id) + control = await ControlService(db).get_active_control_or_404( + control_id, namespace_key=principal.namespace_key + ) control_data = _parse_stored_control_data( control.data, control_name=control.name, @@ -648,10 +659,15 @@ async def list_control_versions( ), limit: int = Query(_DEFAULT_PAGINATION_LIMIT, ge=1, le=_MAX_PAGINATION_LIMIT), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlVersionsResponse: """List control versions ordered newest-first using cursor-based pagination.""" - page = await ControlService(db).list_versions(control_id, cursor=cursor, limit=limit) + page = await ControlService(db).list_versions( + control_id, + namespace_key=principal.namespace_key, + cursor=cursor, + limit=limit, + ) return ListControlVersionsResponse( versions=[ @@ -682,10 +698,12 @@ async def get_control_version( control_id: int, version_num: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlVersionResponse: """Return a specific control version, including its raw persisted snapshot.""" - version = await ControlService(db).get_version_or_404(control_id, version_num) + version = await ControlService(db).get_version_or_404( + control_id, version_num, namespace_key=principal.namespace_key + ) return GetControlVersionResponse( version_num=version.version_num, event_type=version.event_type, @@ -705,7 +723,7 @@ async def set_control_data( control_id: int, request: SetControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> SetControlDataResponse: """ Update the configuration data for a control. @@ -726,7 +744,9 @@ async def set_control_data( HTTPException 500: Database error during update """ control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=principal.namespace_key, for_update=True + ) control_def = await _materialize_control_input( request.data, @@ -767,11 +787,12 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) -# Validation uses the authoring path, so require create access. +# Authorized as CONTROLS_READ: validate exercises the materialization +# path but does not mutate stored control data. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ValidateControlDataResponse: """ Validate control configuration data without saving it. @@ -811,7 +832,7 @@ async def list_controls( execution: str | None = Query(None, description="Filter by execution ('server' or 'sdk')"), tag: str | None = Query(None, description="Filter by tag"), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlsResponse: """ List all controls with optional filtering and cursor-based pagination. @@ -837,7 +858,9 @@ async def list_controls( GET /controls?limit=10&enabled=true&step_type=tool """ control_service = ControlService(db) + namespace_key = principal.namespace_key page = await control_service.list_controls_page( + namespace_key=namespace_key, cursor=cursor, limit=limit, name=name, @@ -849,7 +872,8 @@ async def list_controls( tag=tag, ) usage_by_control_id = await control_service.list_control_usage( - [control.id for control in page.controls] + [control.id for control in page.controls], + namespace_key=namespace_key, ) # Build summaries (filtering already done at DB level) @@ -910,7 +934,7 @@ async def delete_control( "If false, fail if control is associated with any policy or agent.", ), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_DELETE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_DELETE)), ) -> DeleteControlResponse: """ Delete a control by ID. @@ -933,13 +957,18 @@ async def delete_control( """ control_service = ControlService(db) bindings_service = ControlBindingsService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + namespace_key = principal.namespace_key + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key, for_update=True + ) - associations = await control_service.list_control_associations(control_id) + associations = await control_service.list_control_associations( + control_id, namespace_key=namespace_key + ) associated_policy_ids = associations.policy_ids associated_agent_names = associations.agent_names target_binding_ids = await bindings_service.list_binding_ids_for_control( - namespace_key=control.namespace_key, control_id=control_id + namespace_key=namespace_key, control_id=control_id ) if ( @@ -996,13 +1025,15 @@ async def delete_control( dissociated_from_policies: list[int] = [] dissociated_from_agents: list[str] = [] if associated_policy_ids or associated_agent_names: - dissociated = await control_service.remove_all_control_associations(control_id) + dissociated = await control_service.remove_all_control_associations( + control_id, namespace_key=namespace_key + ) dissociated_from_policies = dissociated.policy_ids dissociated_from_agents = dissociated.agent_names detached_target_bindings: list[int] = [] if target_binding_ids: detached_target_bindings = await bindings_service.delete_bindings_for_control( - namespace_key=control.namespace_key, control_id=control_id + namespace_key=namespace_key, control_id=control_id ) if dissociated_from_policies or dissociated_from_agents or detached_target_bindings: _logger.info( @@ -1057,7 +1088,7 @@ async def patch_control( control_id: int, request: PatchControlRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> PatchControlResponse: """ Update control metadata (name and/or enabled status). @@ -1081,7 +1112,10 @@ async def patch_control( HTTPException 500: Database error during update """ control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + namespace_key = principal.namespace_key + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key, for_update=True + ) parsed_control = _parse_stored_control_data( control.data, control_name=control.name, @@ -1096,6 +1130,7 @@ async def patch_control( # Check for name collision if await control_service.active_control_name_exists( request.name, + namespace_key=namespace_key, exclude_control_id=control_id, ): raise ConflictError( diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index e018796e..437af8b5 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -10,16 +10,15 @@ EvaluationResponse, ) from agent_control_models.errors import ErrorCode, ValidationErrorItem -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import RequireAPIKey +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import APIValidationError, NotFoundError from ..logging_utils import get_logger from ..models import Agent -from ..namespace import get_namespace_key from ..services.controls import ControlService router = APIRouter(prefix="/evaluation", tags=["evaluation"]) @@ -118,6 +117,20 @@ def _sanitize_evaluation_response(response: EvaluationResponse) -> EvaluationRes ) +async def _evaluation_context(request: Request) -> dict[str, object]: + """Surface target identifiers to the runtime authorizer.""" + try: + body = await request.json() + except Exception: # noqa: BLE001 malformed JSON, defer to endpoint validation + return {} + if not isinstance(body, dict): + return {} + return { + "target_type": body.get("target_type"), + "target_id": body.get("target_id"), + } + + @router.post( "", response_model=EvaluationResponse, @@ -126,9 +139,10 @@ def _sanitize_evaluation_response(response: EvaluationResponse) -> EvaluationRes ) async def evaluate( request: EvaluationRequest, - client: RequireAPIKey, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends( + require_operation(Operation.RUNTIME_USE, context_builder=_evaluation_context) + ), ) -> EvaluationResponse: """Analyze content for safety and control violations. @@ -144,7 +158,7 @@ async def evaluate( on the server; SDKs reconstruct and emit those events separately through the observability ingestion endpoint. """ - del client # Authentication is still required by dependency injection. + namespace_key = principal.namespace_key agent_result = await db.execute( select(Agent).where( diff --git a/server/src/agent_control_server/endpoints/policies.py b/server/src/agent_control_server/endpoints/policies.py index 7b8b2ef9..ddda7127 100644 --- a/server/src/agent_control_server/endpoints/policies.py +++ b/server/src/agent_control_server/endpoints/policies.py @@ -9,7 +9,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import require_admin_key +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import ConflictError, DatabaseError, NotFoundError from ..logging_utils import get_logger @@ -23,13 +23,14 @@ @router.put( "", - dependencies=[Depends(require_admin_key)], response_model=CreatePolicyResponse, summary="Create a new policy", response_description="Created policy ID", ) async def create_policy( - request: CreatePolicyRequest, db: AsyncSession = Depends(get_async_db) + request: CreatePolicyRequest, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_CREATE)), ) -> CreatePolicyResponse: """ Create a new empty policy with a unique name. @@ -48,8 +49,14 @@ async def create_policy( HTTPException 409: Policy with this name already exists HTTPException 500: Database error during creation """ + namespace_key = principal.namespace_key # Uniqueness check - existing = await db.execute(select(Policy.id).where(Policy.name == request.name)) + existing = await db.execute( + select(Policy.id).where( + Policy.namespace_key == namespace_key, + Policy.name == request.name, + ) + ) if existing.first() is not None: raise ConflictError( error_code=ErrorCode.POLICY_NAME_CONFLICT, @@ -59,7 +66,7 @@ async def create_policy( hint="Choose a different name or update the existing policy.", ) - policy = Policy(name=request.name) + policy = Policy(namespace_key=namespace_key, name=request.name) db.add(policy) try: await db.commit() @@ -80,13 +87,15 @@ async def create_policy( @router.post( "/{policy_id}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Add control to policy", response_description="Success confirmation", ) async def add_control_to_policy( - policy_id: int, control_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + control_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_UPDATE)), ) -> AssocResponse: """ Associate a control with a policy. @@ -106,8 +115,14 @@ async def add_control_to_policy( HTTPException 404: Policy or control not found HTTPException 500: Database error """ + namespace_key = principal.namespace_key # Find policy and control - pol_res = await db.execute(select(Policy).where(Policy.id == policy_id)) + pol_res = await db.execute( + select(Policy).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) policy = pol_res.scalars().first() if policy is None: raise NotFoundError( @@ -119,11 +134,17 @@ async def add_control_to_policy( ) control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key + ) # Add association using INSERT ... ON CONFLICT DO NOTHING for idempotency try: - await control_service.add_control_to_policy(policy_id=policy_id, control_id=control_id) + await control_service.add_control_to_policy( + policy_id=policy_id, + control_id=control_id, + namespace_key=namespace_key, + ) await db.commit() except Exception: await db.rollback() @@ -149,13 +170,15 @@ async def add_control_to_policy( @router.delete( "/{policy_id}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove control from policy", response_description="Success confirmation", ) async def remove_control_from_policy( - policy_id: int, control_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + control_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_UPDATE)), ) -> AssocResponse: """ Remove a control from a policy. @@ -175,7 +198,13 @@ async def remove_control_from_policy( HTTPException 404: Policy or control not found HTTPException 500: Database error """ - pol_res = await db.execute(select(Policy).where(Policy.id == policy_id)) + namespace_key = principal.namespace_key + pol_res = await db.execute( + select(Policy).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) policy = pol_res.scalars().first() if policy is None: raise NotFoundError( @@ -187,13 +216,16 @@ async def remove_control_from_policy( ) control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key + ) # Remove association (idempotent - deleting non-existent is no-op) try: await control_service.remove_control_from_policy( policy_id=policy_id, control_id=control_id, + namespace_key=namespace_key, ) await db.commit() except Exception: @@ -222,7 +254,9 @@ async def remove_control_from_policy( response_description="List of control IDs", ) async def list_policy_controls( - policy_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_READ)), ) -> GetPolicyControlsResponse: """ List all controls associated with a policy. @@ -237,7 +271,13 @@ async def list_policy_controls( Raises: HTTPException 404: Policy not found """ - pol_res = await db.execute(select(Policy.id).where(Policy.id == policy_id)) + namespace_key = principal.namespace_key + pol_res = await db.execute( + select(Policy.id).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) if pol_res.first() is None: raise NotFoundError( error_code=ErrorCode.POLICY_NOT_FOUND, @@ -247,5 +287,8 @@ async def list_policy_controls( hint="Verify the policy ID is correct and the policy has been created.", ) - control_ids = await ControlService(db).list_policy_control_ids(policy_id) + control_ids = await ControlService(db).list_policy_control_ids( + policy_id, + namespace_key=namespace_key, + ) return GetPolicyControlsResponse(control_ids=control_ids) diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index bc1bf04b..a1561e63 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -252,7 +252,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- # Register handler for FastAPI's RequestValidationError (Pydantic validation) app.add_exception_handler(RequestValidationError, validation_exception_handler) # type: ignore[arg-type] -# Register handler for standard HTTPException (legacy code, FastAPI internals) +# Register handler for standard HTTPException (older routes, FastAPI internals) app.add_exception_handler(HTTPException, http_exception_handler) # type: ignore[arg-type] # Register catch-all handler for unexpected exceptions @@ -261,16 +261,18 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- # API v1 prefix for all routes api_v1_prefix = f"{settings.api_prefix}/{settings.api_version}" -# Protected routes (require valid API key) +# API routers. Routers migrated to the auth framework mount the +# non-validating header extractor only so OpenAPI advertises X-API-Key; +# each endpoint's ``require_operation`` dependency owns authn + authz. app.include_router( agent_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( policy_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( # Endpoint dependencies handle auth; this advertises X-API-Key. @@ -281,11 +283,11 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- app.include_router( # The auth framework on each endpoint owns authentication and # authorization for control bindings, so this router is mounted - # without the legacy router-level gate. See ``auth_framework`` for + # without the router-level auth gate. See ``auth_framework`` for # the provider contract. ``get_api_key_from_header`` is a non- # validating extractor (``auto_error=False``); it is attached purely # so the generated OpenAPI spec advertises the X-API-Key requirement - # on these routes — without it, downstream SDK generators would treat + # on these routes - without it, downstream SDK generators would treat # the routes as unauthenticated. control_binding_router, prefix=api_v1_prefix, @@ -309,9 +311,10 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- app.include_router( evaluation_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) +# Evaluator discovery still uses the local credential dependency. app.include_router( evaluator_router, prefix=api_v1_prefix, @@ -324,7 +327,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- prefix=api_v1_prefix, ) -# System routes (config, login, logout) — no auth required +# System routes (config, login, logout) - no auth required app.include_router( system_router, prefix=settings.api_prefix, diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index 263120b7..41a62282 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -20,6 +20,7 @@ from ..errors import APIValidationError, NotFoundError from ..models import ( + DEFAULT_NAMESPACE_KEY, Control, ControlBinding, ControlVersion, @@ -96,9 +97,15 @@ class ControlService: def __init__(self, db: AsyncSession) -> None: self._db = db - def create_control(self, *, name: str, data: dict[str, Any]) -> Control: + def create_control( + self, + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, + name: str, + data: dict[str, Any], + ) -> Control: """Create a new pending control row.""" - control = Control(name=name, data=data) + control = Control(namespace_key=namespace_key, name=name, data=data) self._db.add(control) return control @@ -128,10 +135,13 @@ async def get_control_or_404( self, control_id: int, *, + namespace_key: str | None = None, for_update: bool = False, ) -> Control: """Load any control row, including soft-deleted controls.""" stmt = select(Control).where(Control.id == control_id) + if namespace_key is not None: + stmt = stmt.where(Control.namespace_key == namespace_key) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) @@ -180,10 +190,15 @@ async def active_control_name_exists( self, name: str, *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, exclude_control_id: int | None = None, ) -> bool: """Return whether an active control already uses the provided name.""" - stmt = select(Control.id).where(Control.name == name, Control.deleted_at.is_(None)) + stmt = select(Control.id).where( + Control.namespace_key == namespace_key, + Control.name == name, + Control.deleted_at.is_(None), + ) if exclude_control_id is not None: stmt = stmt.where(Control.id != exclude_control_id) result = await self._db.execute(stmt) @@ -216,11 +231,12 @@ async def list_versions( self, control_id: int, *, + namespace_key: str, cursor: int | None, limit: int, ) -> ControlVersionPage: """Return control versions newest-first with cursor pagination.""" - await self.get_control_or_404(control_id) + await self.get_control_or_404(control_id, namespace_key=namespace_key) total_result = await self._db.execute( select(func.count()) @@ -255,9 +271,11 @@ async def list_versions( next_cursor=next_cursor, ) - async def get_version_or_404(self, control_id: int, version_num: int) -> ControlVersion: + async def get_version_or_404( + self, control_id: int, version_num: int, *, namespace_key: str + ) -> ControlVersion: """Load a specific version row for a control.""" - await self.get_control_or_404(control_id) + await self.get_control_or_404(control_id, namespace_key=namespace_key) result = await self._db.execute( select(ControlVersion).where( @@ -303,12 +321,17 @@ async def list_controls_for_policy( result = await self._db.execute(stmt) return list(result.scalars().unique().all()) - async def list_policy_control_ids(self, policy_id: int) -> list[int]: + async def list_policy_control_ids(self, policy_id: int, *, namespace_key: str) -> list[int]: """Return active control IDs directly associated with a policy.""" result = await self._db.execute( select(policy_controls.c.control_id) .join(Control, Control.id == policy_controls.c.control_id) - .where(policy_controls.c.policy_id == policy_id, Control.deleted_at.is_(None)) + .where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.policy_id == policy_id, + Control.namespace_key == namespace_key, + Control.deleted_at.is_(None), + ) .order_by(policy_controls.c.control_id) ) return [cast(int, row[0]) for row in result.all()] @@ -396,6 +419,7 @@ async def list_runtime_controls_for_agent( async def list_controls_page( self, *, + namespace_key: str, cursor: int | None, limit: int, name: str | None, @@ -407,7 +431,11 @@ async def list_controls_page( tag: str | None, ) -> ControlListPage: """Return paginated active controls for the browse endpoint.""" - query = select(Control).where(Control.deleted_at.is_(None)).order_by(Control.id.desc()) + query = ( + select(Control) + .where(Control.namespace_key == namespace_key, Control.deleted_at.is_(None)) + .order_by(Control.id.desc()) + ) query = self._apply_control_list_filters( query, name=name, @@ -424,7 +452,11 @@ async def list_controls_page( result = await self._db.execute(query.limit(limit + 1)) controls = list(result.scalars().all()) - total_query = select(func.count()).select_from(Control).where(Control.deleted_at.is_(None)) + total_query = ( + select(func.count()) + .select_from(Control) + .where(Control.namespace_key == namespace_key, Control.deleted_at.is_(None)) + ) total_query = self._apply_control_list_filters( total_query, name=name, @@ -453,7 +485,9 @@ async def list_controls_page( next_cursor=next_cursor, ) - async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, ControlUsage]: + async def list_control_usage( + self, control_ids: Sequence[int], *, namespace_key: str + ) -> dict[int, ControlUsage]: """Return representative agent usage and usage counts for the provided controls.""" if not control_ids: return {} @@ -465,8 +499,16 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont agent_policies.c.agent_name, ) .select_from(policy_controls) - .join(agent_policies, policy_controls.c.policy_id == agent_policies.c.policy_id) - .where(policy_controls.c.control_id.in_(control_ids)) + .join( + agent_policies, + (policy_controls.c.policy_id == agent_policies.c.policy_id) + & (policy_controls.c.namespace_key == agent_policies.c.namespace_key), + ) + .where( + policy_controls.c.namespace_key == namespace_key, + agent_policies.c.namespace_key == namespace_key, + policy_controls.c.control_id.in_(control_ids), + ) ) direct_agents_query = ( select( @@ -474,7 +516,10 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont agent_controls.c.agent_name, ) .select_from(agent_controls) - .where(agent_controls.c.control_id.in_(control_ids)) + .where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id.in_(control_ids), + ) ) agents_result = await self._db.execute(union_all(policy_agents_query, direct_agents_query)) for control_id, agent_name in agents_result.all(): @@ -491,6 +536,8 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont async def list_active_control_counts_by_agent( self, agent_names: Sequence[str], + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, ) -> dict[str, int]: """Return active control counts keyed by agent name.""" if not agent_names: @@ -503,15 +550,24 @@ async def list_active_control_counts_by_agent( ) .select_from( agent_policies.join( - policy_controls, agent_policies.c.policy_id == policy_controls.c.policy_id + policy_controls, + (agent_policies.c.policy_id == policy_controls.c.policy_id) + & (agent_policies.c.namespace_key == policy_controls.c.namespace_key), ) ) - .where(agent_policies.c.agent_name.in_(agent_names)) + .where( + agent_policies.c.namespace_key == namespace_key, + policy_controls.c.namespace_key == namespace_key, + agent_policies.c.agent_name.in_(agent_names), + ) ) direct_associations = select( agent_controls.c.agent_name.label("agent_name"), agent_controls.c.control_id.label("control_id"), - ).where(agent_controls.c.agent_name.in_(agent_names)) + ).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.agent_name.in_(agent_names), + ) all_associations = union_all(policy_associations, direct_associations).subquery() result = await self._db.execute( @@ -521,6 +577,7 @@ async def list_active_control_counts_by_agent( ) .join(Control, all_associations.c.control_id == Control.id) .where( + Control.namespace_key == namespace_key, Control.deleted_at.is_(None), or_( Control.data["enabled"].astext == "true", @@ -531,19 +588,28 @@ async def list_active_control_counts_by_agent( ) return {cast(str, row[0]): cast(int, row[1]) for row in result.all()} - async def add_control_to_policy(self, *, policy_id: int, control_id: int) -> None: + async def add_control_to_policy( + self, *, policy_id: int, control_id: int, namespace_key: str + ) -> None: """Create a policy-control association if it does not already exist.""" await self._db.execute( pg_insert(policy_controls) - .values(policy_id=policy_id, control_id=control_id) + .values( + namespace_key=namespace_key, + policy_id=policy_id, + control_id=control_id, + ) .on_conflict_do_nothing() ) - async def remove_control_from_policy(self, *, policy_id: int, control_id: int) -> None: + async def remove_control_from_policy( + self, *, policy_id: int, control_id: int, namespace_key: str + ) -> None: """Remove a policy-control association if it exists.""" await self._db.execute( delete(policy_controls).where( - (policy_controls.c.policy_id == policy_id) + (policy_controls.c.namespace_key == namespace_key) + & (policy_controls.c.policy_id == policy_id) & (policy_controls.c.control_id == control_id) ) ) @@ -613,16 +679,24 @@ async def remove_control_from_agent( control_still_active=policy_inheritance_result.first() is not None, ) - async def list_control_associations(self, control_id: int) -> ControlAssociations: + async def list_control_associations( + self, control_id: int, *, namespace_key: str + ) -> ControlAssociations: """Return all policy and direct agent associations for a control.""" policy_assoc_query = select( policy_controls.c.policy_id.label("policy_id"), literal(None, type_=String).label("agent_name"), - ).where(policy_controls.c.control_id == control_id) + ).where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.control_id == control_id, + ) agent_assoc_query = select( literal(None, type_=Integer).label("policy_id"), agent_controls.c.agent_name.label("agent_name"), - ).where(agent_controls.c.control_id == control_id) + ).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id == control_id, + ) assoc_result = await self._db.execute(union_all(policy_assoc_query, agent_assoc_query)) policy_ids: set[int] = set() @@ -638,16 +712,26 @@ async def list_control_associations(self, control_id: int) -> ControlAssociation agent_names=sorted(agent_names), ) - async def remove_all_control_associations(self, control_id: int) -> ControlAssociations: + async def remove_all_control_associations( + self, control_id: int, *, namespace_key: str + ) -> ControlAssociations: """Remove all policy and direct agent associations for a control.""" - associations = await self.list_control_associations(control_id) + associations = await self.list_control_associations( + control_id, namespace_key=namespace_key + ) if associations.policy_ids: await self._db.execute( - delete(policy_controls).where(policy_controls.c.control_id == control_id) + delete(policy_controls).where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.control_id == control_id, + ) ) if associations.agent_names: await self._db.execute( - delete(agent_controls).where(agent_controls.c.control_id == control_id) + delete(agent_controls).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id == control_id, + ) ) return associations diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 96c4aad8..2d39bfa3 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -20,6 +20,7 @@ AccessLevel, HeaderAuthProvider, HttpUpstreamAuthProvider, + NoAuthProvider, ) from agent_control_server.auth_framework.providers.header import ( DEFAULT_OPERATION_ACCESS, @@ -64,6 +65,35 @@ def test_default_operation_access_covers_every_operation(): assert not missing, f"Operations missing default access mapping: {missing}" +# --------------------------------------------------------------------------- +# NoAuthProvider +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_no_auth_provider_allows_any_operation(): + provider = NoAuthProvider(default_namespace_key="ns-local") + + principal = await provider.authorize( + _build_request(), + Operation.CONTROLS_DELETE, + ) + + assert principal == Principal(namespace_key="ns-local") + + +@pytest.mark.asyncio +async def test_no_auth_provider_grants_runtime_exchange_scope(): + provider = NoAuthProvider() + + principal = await provider.authorize( + _build_request(), + Operation.RUNTIME_TOKEN_EXCHANGE, + ) + + assert principal.scopes == (Operation.RUNTIME_USE.value,) + + # --------------------------------------------------------------------------- # HeaderAuthProvider # --------------------------------------------------------------------------- @@ -101,7 +131,7 @@ async def test_header_provider_public_returns_default_namespace(): @pytest.mark.asyncio -async def test_header_provider_authenticated_calls_legacy_validator(): +async def test_header_provider_authenticated_calls_local_validator(): provider = HeaderAuthProvider() expected_client = MagicMock(is_admin=False, key_id="abc12345") @@ -945,6 +975,70 @@ def test_runtime_ttl_loader_accepts_max(monkeypatch): ) +def test_build_default_provider_accepts_none_mode(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "none") + + assert isinstance(auth_config._build_default_provider(), NoAuthProvider) + + +def test_resolve_runtime_mode_defaults_to_api_key_without_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + assert auth_config._resolve_runtime_mode() == "api_key" + + +def test_resolve_runtime_mode_defaults_to_jwt_with_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + assert auth_config._resolve_runtime_mode() == "jwt" + + +def test_configure_runtime_none_installs_no_auth_provider(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "none") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), NoAuthProvider) + assert auth_config.runtime_auth_config() is None + + +def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "api_key") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), HeaderAuthProvider) + assert auth_config.runtime_auth_config() is None + + +def test_configure_runtime_jwt_requires_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "jwt") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + with pytest.raises(RuntimeError, match="requires AGENT_CONTROL_RUNTIME_TOKEN_SECRET"): + auth_config.configure_auth_from_env() + + def test_configure_then_reconfigure_clears_runtime_override(monkeypatch): """Reconfiguring without a runtime secret must drop the override.""" from agent_control_server.auth_framework import config as auth_config diff --git a/server/tests/test_controls_additional.py b/server/tests/test_controls_additional.py index b4922b9d..dfbb15f5 100644 --- a/server/tests/test_controls_additional.py +++ b/server/tests/test_controls_additional.py @@ -8,19 +8,19 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from agent_control_evaluators import RegexEvaluatorConfig +from agent_control_models import ConditionNode from fastapi.testclient import TestClient from sqlalchemy import text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from agent_control_models import ConditionNode +from agent_control_server.auth_framework import Principal from agent_control_server.db import get_async_db -from agent_control_server.models import Control - -from agent_control_evaluators import RegexEvaluatorConfig from agent_control_server.endpoints import controls as controls_module from agent_control_server.main import app +from agent_control_server.models import DEFAULT_NAMESPACE_KEY, Control from .conftest import engine from .utils import VALID_CONTROL_PAYLOAD @@ -1106,7 +1106,12 @@ def model_dump(self, *args: object, **kwargs: object) -> dict[str, object]: request = SimpleNamespace(data=DummyData(payload)) # When: updating the control data with a non-Pydantic selector - response = await controls_module.set_control_data(control.id, request, async_db) + response = await controls_module.set_control_data( + control.id, + request, + async_db, + principal=Principal(namespace_key=DEFAULT_NAMESPACE_KEY), + ) # Then: the update succeeds and uses the original selector serialization assert response.success is True diff --git a/server/tests/test_controls_auth.py b/server/tests/test_controls_auth.py index 1a2af21f..c0f17754 100644 --- a/server/tests/test_controls_auth.py +++ b/server/tests/test_controls_auth.py @@ -4,14 +4,13 @@ import uuid -import pytest from fastapi.testclient import TestClient -from agent_control_server.config import auth_settings +from agent_control_server.auth_framework import set_authorizer +from agent_control_server.auth_framework.providers import NoAuthProvider from .utils import VALID_CONTROL_PAYLOAD - _CONTROLS_URL = "/api/v1/controls" _TEMPLATES_URL = "/api/v1/control-templates" @@ -199,18 +198,19 @@ def test_non_admin_cannot_delete_control( assert resp.status_code == 403, resp.text -def test_non_admin_cannot_validate_control_data( +def test_non_admin_can_validate_control_data( non_admin_client: TestClient, ) -> None: - """``/controls/validate`` requires ``CONTROLS_CREATE``.""" + """``/controls/validate`` requires ``CONTROLS_READ``.""" # When: a non-admin attempts to validate a draft payload resp = non_admin_client.post( f"{_CONTROLS_URL}/validate", json={"data": VALID_CONTROL_PAYLOAD}, ) - # Then: validation is admin-only - assert resp.status_code == 403, resp.text + # Then: validation is allowed for authenticated non-admin callers + assert resp.status_code == 200, resp.text + assert resp.json()["success"] is True def test_non_admin_cannot_render_template(non_admin_client: TestClient) -> None: @@ -283,21 +283,16 @@ def test_unauthenticated_cannot_render_template( # --------------------------------------------------------------------------- -# No-auth deployment mode: api_key_enabled=False bypasses every gate. +# No-auth deployment mode: explicit provider bypasses every gate. # --------------------------------------------------------------------------- def test_no_auth_mode_allows_writes_without_credentials( unauthenticated_client: TestClient, - monkeypatch: pytest.MonkeyPatch, ) -> None: - """When ``api_key_enabled`` is False, the ``HeaderAuthProvider`` - short-circuits to a non-admin ``Principal`` for every operation, - including admin-level writes. This pins the "no auth" deployment - path so a future refactor can't silently start enforcing. - """ - # Given: api_key_enabled is False (single-tenant OSS dev mode) - monkeypatch.setattr(auth_settings, "api_key_enabled", False) + """Explicit no-auth provider allows every operation without credentials.""" + # Given: the request-auth framework is in no-auth mode + set_authorizer(NoAuthProvider()) # When: an unauthenticated client creates a control resp = unauthenticated_client.put( @@ -311,4 +306,3 @@ def test_no_auth_mode_allows_writes_without_credentials( # Then: the create succeeds because auth is disabled at the provider assert resp.status_code == 200, resp.text assert "control_id" in resp.json() - diff --git a/server/tests/test_principal_namespace_flow.py b/server/tests/test_principal_namespace_flow.py new file mode 100644 index 00000000..40ecd216 --- /dev/null +++ b/server/tests/test_principal_namespace_flow.py @@ -0,0 +1,141 @@ +"""HTTP-level coverage for principal-derived namespace scoping.""" + +from __future__ import annotations + +import uuid +from typing import Any + +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from agent_control_server.auth_framework import ( + Operation, + Principal, + set_authorizer, +) + +from .utils import VALID_CONTROL_PAYLOAD + + +class HeaderNamespaceAuthorizer: + """Test authorizer that maps a request header to ``Principal.namespace_key``.""" + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del context + scopes = ( + (Operation.RUNTIME_USE.value,) + if operation is Operation.RUNTIME_TOKEN_EXCHANGE + else () + ) + return Principal( + namespace_key=request.headers.get("X-Test-Namespace", "default"), + is_admin=True, + scopes=scopes, + ) + + +def _client(app: FastAPI, namespace_key: str) -> TestClient: + return TestClient( + app, + raise_server_exceptions=True, + headers={"X-Test-Namespace": namespace_key}, + ) + + +def _agent_payload(agent_name: str) -> dict[str, Any]: + return { + "agent": { + "agent_name": agent_name, + "agent_description": "test agent", + "agent_version": "1.0", + }, + "steps": [], + } + + +def _evaluation_payload(agent_name: str) -> dict[str, Any]: + return { + "agent_name": agent_name, + "step": { + "type": "llm", + "name": "test-step", + "input": "x marks the spot", + "context": {}, + }, + "stage": "pre", + "target_type": "env", + "target_id": "prod", + } + + +def test_principal_namespace_scopes_management_and_runtime(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + + register_a = ns_a.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)) + register_b = ns_b.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)) + assert register_a.status_code == 200, register_a.text + assert register_b.status_code == 200, register_b.text + + create_control = ns_a.put( + "/api/v1/controls", + json={ + "name": f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + assert create_control.status_code == 200, create_control.text + control_id = int(create_control.json()["control_id"]) + + policy = ns_a.put( + "/api/v1/policies", + json={"name": f"policy-{uuid.uuid4().hex[:12]}"}, + ) + assert policy.status_code == 200, policy.text + policy_id = int(policy.json()["policy_id"]) + attach_to_policy = ns_a.post(f"/api/v1/policies/{policy_id}/controls/{control_id}") + assert attach_to_policy.status_code == 200, attach_to_policy.text + + binding = ns_a.put( + "/api/v1/control-bindings", + json={ + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + "enabled": True, + }, + ) + assert binding.status_code == 200, binding.text + + assert ns_b.get(f"/api/v1/controls/{control_id}").status_code == 404 + assert ns_b.get(f"/api/v1/policies/{policy_id}/controls").status_code == 404 + assert ns_b.get("/api/v1/control-bindings").json()["bindings"] == [] + + eval_a = ns_a.post("/api/v1/evaluation", json=_evaluation_payload(agent_name)) + assert eval_a.status_code == 200, eval_a.text + assert eval_a.json()["is_safe"] is False + assert eval_a.json()["matches"][0]["control_id"] == control_id + + eval_b = ns_b.post("/api/v1/evaluation", json=_evaluation_payload(agent_name)) + assert eval_b.status_code == 200, eval_b.text + assert eval_b.json()["is_safe"] is True + + +def test_duplicate_control_names_allowed_across_principal_namespaces(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + control_name = f"control-{uuid.uuid4().hex[:12]}" + payload = {"name": control_name, "data": VALID_CONTROL_PAYLOAD} + + assert ns_a.put("/api/v1/controls", json=payload).status_code == 200 + assert ns_b.put("/api/v1/controls", json=payload).status_code == 200 diff --git a/server/tests/test_target_merged_contract.py b/server/tests/test_target_merged_contract.py index 295a85e2..62891ba5 100644 --- a/server/tests/test_target_merged_contract.py +++ b/server/tests/test_target_merged_contract.py @@ -232,9 +232,9 @@ def test_target_binding_de_duplicated_against_direct_attachment( async def _insert_agent_in_namespace(async_db, *, name: str, namespace_key: str) -> None: """Insert an Agent row directly so the test can simulate a foreign namespace. - The endpoint's ``get_namespace_key`` returns the default namespace; this - helper sidesteps the resolver to seed an agent that the request-time - code path should not be able to reach. + The default test authorizer returns the default namespace; this helper + sidesteps the authorizer to seed an agent that the request-time code + path should not be able to reach. """ from agent_control_server.models import Agent From 9251f40f7f6a2fef875a492ee68d8176fac7d9a5 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 7 May 2026 20:20:00 +0530 Subject: [PATCH 02/12] chore(sdk-ts): regenerate client docs --- .../src/generated/funcs/agents-get-evaluator.ts | 3 +-- sdks/typescript/src/generated/funcs/agents-get.ts | 3 +-- .../typescript/src/generated/funcs/agents-init.ts | 1 + .../src/generated/funcs/agents-list-controls.ts | 2 +- .../src/generated/funcs/agents-list-evaluators.ts | 3 +-- .../typescript/src/generated/funcs/agents-list.ts | 2 +- .../src/generated/funcs/agents-update.ts | 1 + .../generated/funcs/control-bindings-create.ts | 6 +----- .../src/generated/funcs/control-bindings-list.ts | 3 +-- sdks/typescript/src/generated/sdk/agents.ts | 15 +++++++-------- .../src/generated/sdk/control-bindings.ts | 9 ++------- 11 files changed, 18 insertions(+), 30 deletions(-) diff --git a/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts b/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts index acb364eb..ceca1ec0 100644 --- a/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts +++ b/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts @@ -37,8 +37,7 @@ import { Result } from "../types/fp.js"; * agent_name: Agent identifier * evaluator_name: Name of the evaluator * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * EvaluatorSchemaItem with schema details diff --git a/sdks/typescript/src/generated/funcs/agents-get.ts b/sdks/typescript/src/generated/funcs/agents-get.ts index 9724edbf..142f3062 100644 --- a/sdks/typescript/src/generated/funcs/agents-get.ts +++ b/sdks/typescript/src/generated/funcs/agents-get.ts @@ -38,8 +38,7 @@ import { Result } from "../types/fp.js"; * Args: * agent_name: Agent identifier * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * GetAgentResponse with agent metadata and step list diff --git a/sdks/typescript/src/generated/funcs/agents-init.ts b/sdks/typescript/src/generated/funcs/agents-init.ts index 9d63358d..7150b2a4 100644 --- a/sdks/typescript/src/generated/funcs/agents-init.ts +++ b/sdks/typescript/src/generated/funcs/agents-init.ts @@ -51,6 +51,7 @@ import { Result } from "../types/fp.js"; * Args: * request: Agent metadata and step schemas * db: Database session (injected) + * principal: Authorized request principal * * Returns: * InitAgentResponse with created flag and the effective controls diff --git a/sdks/typescript/src/generated/funcs/agents-list-controls.ts b/sdks/typescript/src/generated/funcs/agents-list-controls.ts index 661c5509..d1e5b27d 100644 --- a/sdks/typescript/src/generated/funcs/agents-list-controls.ts +++ b/sdks/typescript/src/generated/funcs/agents-list-controls.ts @@ -53,7 +53,7 @@ import { Result } from "../types/fp.js"; * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * namespace_key: Namespace scoping for the resolution (injected) + * principal: Authorized request principal * * Returns: * AgentControlsResponse with controls matching the requested state filters diff --git a/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts b/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts index c4d8a4b2..4217e752 100644 --- a/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts +++ b/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts @@ -42,8 +42,7 @@ import { Result } from "../types/fp.js"; * cursor: Optional cursor for pagination (name of last evaluator from previous page) * limit: Pagination limit (default 20, max 100) * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * ListEvaluatorsResponse with evaluator schemas and pagination diff --git a/sdks/typescript/src/generated/funcs/agents-list.ts b/sdks/typescript/src/generated/funcs/agents-list.ts index fda7574d..f887d0b5 100644 --- a/sdks/typescript/src/generated/funcs/agents-list.ts +++ b/sdks/typescript/src/generated/funcs/agents-list.ts @@ -42,7 +42,7 @@ import { Result } from "../types/fp.js"; * limit: Pagination limit (default 20, max 100) * name: Optional name filter (case-insensitive partial match) * db: Database session (injected) - * namespace_key: Resolved namespace for the request + * principal: Authorized request principal * * Returns: * ListAgentsResponse with agent summaries and pagination info diff --git a/sdks/typescript/src/generated/funcs/agents-update.ts b/sdks/typescript/src/generated/funcs/agents-update.ts index e82644cf..aff9d827 100644 --- a/sdks/typescript/src/generated/funcs/agents-update.ts +++ b/sdks/typescript/src/generated/funcs/agents-update.ts @@ -40,6 +40,7 @@ import { Result } from "../types/fp.js"; * agent_name: Agent identifier * request: Lists of step/evaluator identifiers to remove * db: Database session (injected) + * principal: Authorized request principal * * Returns: * PatchAgentResponse with lists of actually removed items diff --git a/sdks/typescript/src/generated/funcs/control-bindings-create.ts b/sdks/typescript/src/generated/funcs/control-bindings-create.ts index 8412487e..71dee5a0 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-create.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-create.ts @@ -33,11 +33,7 @@ import { Result } from "../types/fp.js"; * Attach a control to an opaque external target. * * Each binding row is scoped to the request namespace as resolved by - * ``get_namespace_key``. The auth chain still runs via - * ``require_operation`` for authentication and authorization, but the - * storage namespace is taken from the same resolver the rest of the - * server uses so binding writes and runtime reads stay in lockstep - * until auth-derived namespace resolution lands across every endpoint. + * the active authorizer. */ export function controlBindingsCreate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-list.ts b/sdks/typescript/src/generated/funcs/control-bindings-list.ts index 5e7e87c3..5c90c7c2 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-list.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-list.ts @@ -35,8 +35,7 @@ import { Result } from "../types/fp.js"; * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by ``get_namespace_key`` so this - * listing stays in lockstep with the rest of the server's reads. + * storage namespace is resolved by the active authorizer. */ export function controlBindingsList( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/sdk/agents.ts b/sdks/typescript/src/generated/sdk/agents.ts index a22f4209..0a70e128 100644 --- a/sdks/typescript/src/generated/sdk/agents.ts +++ b/sdks/typescript/src/generated/sdk/agents.ts @@ -39,7 +39,7 @@ export class Agents extends ClientSDK { * limit: Pagination limit (default 20, max 100) * name: Optional name filter (case-insensitive partial match) * db: Database session (injected) - * namespace_key: Resolved namespace for the request + * principal: Authorized request principal * * Returns: * ListAgentsResponse with agent summaries and pagination info @@ -80,6 +80,7 @@ export class Agents extends ClientSDK { * Args: * request: Agent metadata and step schemas * db: Database session (injected) + * principal: Authorized request principal * * Returns: * InitAgentResponse with created flag and the effective controls @@ -106,8 +107,7 @@ export class Agents extends ClientSDK { * Args: * agent_name: Agent identifier * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * GetAgentResponse with agent metadata and step list @@ -140,6 +140,7 @@ export class Agents extends ClientSDK { * agent_name: Agent identifier * request: Lists of step/evaluator identifiers to remove * db: Database session (injected) + * principal: Authorized request principal * * Returns: * PatchAgentResponse with lists of actually removed items @@ -185,7 +186,7 @@ export class Agents extends ClientSDK { * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * namespace_key: Namespace scoping for the resolution (injected) + * principal: Authorized request principal * * Returns: * AgentControlsResponse with controls matching the requested state filters @@ -256,8 +257,7 @@ export class Agents extends ClientSDK { * cursor: Optional cursor for pagination (name of last evaluator from previous page) * limit: Pagination limit (default 20, max 100) * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * ListEvaluatorsResponse with evaluator schemas and pagination @@ -287,8 +287,7 @@ export class Agents extends ClientSDK { * agent_name: Agent identifier * evaluator_name: Name of the evaluator * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * EvaluatorSchemaItem with schema details diff --git a/sdks/typescript/src/generated/sdk/control-bindings.ts b/sdks/typescript/src/generated/sdk/control-bindings.ts index 5101ce74..dc6f20d3 100644 --- a/sdks/typescript/src/generated/sdk/control-bindings.ts +++ b/sdks/typescript/src/generated/sdk/control-bindings.ts @@ -23,8 +23,7 @@ export class ControlBindings extends ClientSDK { * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by ``get_namespace_key`` so this - * listing stays in lockstep with the rest of the server's reads. + * storage namespace is resolved by the active authorizer. */ async list( request?: @@ -46,11 +45,7 @@ export class ControlBindings extends ClientSDK { * Attach a control to an opaque external target. * * Each binding row is scoped to the request namespace as resolved by - * ``get_namespace_key``. The auth chain still runs via - * ``require_operation`` for authentication and authorization, but the - * storage namespace is taken from the same resolver the rest of the - * server uses so binding writes and runtime reads stay in lockstep - * until auth-derived namespace resolution lands across every endpoint. + * the active authorizer. */ async create( request: models.CreateControlBindingRequest, From 3e0d2abd96b08d278376863f6e35e703b96a6766 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 7 May 2026 23:07:04 +0530 Subject: [PATCH 03/12] fix(server): address runtime auth review feedback --- .../funcs/auth-runtime-token-exchange.ts | 9 ++--- .../funcs/control-bindings-create.ts | 4 +- .../funcs/control-bindings-delete.ts | 2 +- .../generated/funcs/control-bindings-get.ts | 5 +-- .../generated/funcs/control-bindings-list.ts | 2 +- .../funcs/control-bindings-update.ts | 2 +- sdks/typescript/src/generated/sdk/auth.ts | 9 ++--- .../src/generated/sdk/control-bindings.ts | 15 ++++--- .../auth_framework/core.py | 2 - .../auth_framework/providers/header.py | 2 - .../agent_control_server/endpoints/auth.py | 19 +++++---- .../endpoints/control_bindings.py | 15 ++++--- .../endpoints/controls.py | 6 +-- server/src/agent_control_server/namespace.py | 23 ----------- .../agent_control_server/services/controls.py | 23 ++++++----- server/tests/test_auth_framework.py | 24 +++++++++++ server/tests/test_controls_auth.py | 12 +++--- .../test_runtime_token_exchange_endpoint.py | 36 ++++++++++++++++- server/tests/test_services_controls.py | 40 ++++++++++++++----- 19 files changed, 146 insertions(+), 104 deletions(-) delete mode 100644 server/src/agent_control_server/namespace.py diff --git a/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts b/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts index 176693e3..7e8679c8 100644 --- a/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts +++ b/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts @@ -32,11 +32,10 @@ import { Result } from "../types/fp.js"; * @remarks * Mint a short-lived runtime token for the requested target. * - * The caller's credential is authenticated and authorized by the - * installed default authorizer; the resulting :class:`Principal` - * supplies the actor identity and (when the upstream surfaces it) - * the grant scopes and expiry. This endpoint then mints a local HS256 - * token whose lifetime cannot outlive the upstream grant. + * The caller's credential is authenticated and authorized before the + * resolved principal supplies the actor identity, grant scopes, and + * expiry. This endpoint then mints a local HS256 token whose lifetime + * cannot outlive the grant. * * Runtime auth must be enabled via * ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/sdks/typescript/src/generated/funcs/control-bindings-create.ts b/sdks/typescript/src/generated/funcs/control-bindings-create.ts index 71dee5a0..faf99923 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-create.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-create.ts @@ -32,8 +32,8 @@ import { Result } from "../types/fp.js"; * @remarks * Attach a control to an opaque external target. * - * Each binding row is scoped to the request namespace as resolved by - * the active authorizer. + * Each binding row is scoped to the namespace associated with the + * authenticated request. */ export function controlBindingsCreate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-delete.ts b/sdks/typescript/src/generated/funcs/control-bindings-delete.ts index 9e4d1293..9872a9b4 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-delete.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-delete.ts @@ -36,7 +36,7 @@ import { Result } from "../types/fp.js"; * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``POST /by-key:delete`` for - * target-scoped detach that forwards the target to the authorizer. + * target-scoped detach that includes the target in the request context. */ export function controlBindingsDelete( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-get.ts b/sdks/typescript/src/generated/funcs/control-bindings-get.ts index dafb7c7c..88b4e419 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-get.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-get.ts @@ -34,12 +34,11 @@ import { Result } from "../types/fp.js"; * Read a single control binding by surrogate ID. * * Authorization is namespace-wide: the binding's target identifiers - * are not forwarded to the upstream because they are only discoverable - * after the row is loaded, and ``require_operation`` is single-pass. + * are not available until after the row is loaded. * Callers whose authorization model requires per-target permissions * should use the natural-key endpoints (``PUT /by-key``, * ``POST /by-key:delete``) and the target-filtered list endpoint, all - * of which forward ``(target_type, target_id)`` to the authorizer. + * of which include ``(target_type, target_id)`` in the request context. */ export function controlBindingsGet( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-list.ts b/sdks/typescript/src/generated/funcs/control-bindings-list.ts index 5c90c7c2..a87ca89f 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-list.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-list.ts @@ -35,7 +35,7 @@ import { Result } from "../types/fp.js"; * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by the active authorizer. + * storage namespace is resolved from the authenticated request. */ export function controlBindingsList( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-update.ts b/sdks/typescript/src/generated/funcs/control-bindings-update.ts index b3faf800..b94520a2 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-update.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-update.ts @@ -36,7 +36,7 @@ import { Result } from "../types/fp.js"; * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``PUT /by-key`` for target-scoped - * upserts that forward the target to the authorizer. + * upserts that include the target in the request context. */ export function controlBindingsUpdate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/sdk/auth.ts b/sdks/typescript/src/generated/sdk/auth.ts index cf6de9ba..2d0cf74e 100644 --- a/sdks/typescript/src/generated/sdk/auth.ts +++ b/sdks/typescript/src/generated/sdk/auth.ts @@ -14,11 +14,10 @@ export class Auth extends ClientSDK { * @remarks * Mint a short-lived runtime token for the requested target. * - * The caller's credential is authenticated and authorized by the - * installed default authorizer; the resulting :class:`Principal` - * supplies the actor identity and (when the upstream surfaces it) - * the grant scopes and expiry. This endpoint then mints a local HS256 - * token whose lifetime cannot outlive the upstream grant. + * The caller's credential is authenticated and authorized before the + * resolved principal supplies the actor identity, grant scopes, and + * expiry. This endpoint then mints a local HS256 token whose lifetime + * cannot outlive the grant. * * Runtime auth must be enabled via * ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/sdks/typescript/src/generated/sdk/control-bindings.ts b/sdks/typescript/src/generated/sdk/control-bindings.ts index dc6f20d3..5a5bcf2b 100644 --- a/sdks/typescript/src/generated/sdk/control-bindings.ts +++ b/sdks/typescript/src/generated/sdk/control-bindings.ts @@ -23,7 +23,7 @@ export class ControlBindings extends ClientSDK { * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by the active authorizer. + * storage namespace is resolved from the authenticated request. */ async list( request?: @@ -44,8 +44,8 @@ export class ControlBindings extends ClientSDK { * @remarks * Attach a control to an opaque external target. * - * Each binding row is scoped to the request namespace as resolved by - * the active authorizer. + * Each binding row is scoped to the namespace associated with the + * authenticated request. */ async create( request: models.CreateControlBindingRequest, @@ -104,7 +104,7 @@ export class ControlBindings extends ClientSDK { * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``POST /by-key:delete`` for - * target-scoped detach that forwards the target to the authorizer. + * target-scoped detach that includes the target in the request context. */ async delete( request: @@ -125,12 +125,11 @@ export class ControlBindings extends ClientSDK { * Read a single control binding by surrogate ID. * * Authorization is namespace-wide: the binding's target identifiers - * are not forwarded to the upstream because they are only discoverable - * after the row is loaded, and ``require_operation`` is single-pass. + * are not available until after the row is loaded. * Callers whose authorization model requires per-target permissions * should use the natural-key endpoints (``PUT /by-key``, * ``POST /by-key:delete``) and the target-filtered list endpoint, all - * of which forward ``(target_type, target_id)`` to the authorizer. + * of which include ``(target_type, target_id)`` in the request context. */ async get( request: @@ -153,7 +152,7 @@ export class ControlBindings extends ClientSDK { * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``PUT /by-key`` for target-scoped - * upserts that forward the target to the authorizer. + * upserts that include the target in the request context. */ async update( request: diff --git a/server/src/agent_control_server/auth_framework/core.py b/server/src/agent_control_server/auth_framework/core.py index e0ea6da7..058169de 100644 --- a/server/src/agent_control_server/auth_framework/core.py +++ b/server/src/agent_control_server/auth_framework/core.py @@ -52,11 +52,9 @@ class Operation(StrEnum): POLICIES_READ = "policies.read" POLICIES_CREATE = "policies.create" POLICIES_UPDATE = "policies.update" - POLICIES_DELETE = "policies.delete" AGENTS_READ = "agents.read" AGENTS_CREATE = "agents.create" AGENTS_UPDATE = "agents.update" - AGENTS_DELETE = "agents.delete" RUNTIME_USE = "runtime.use" diff --git a/server/src/agent_control_server/auth_framework/providers/header.py b/server/src/agent_control_server/auth_framework/providers/header.py index 228ec443..16760768 100644 --- a/server/src/agent_control_server/auth_framework/providers/header.py +++ b/server/src/agent_control_server/auth_framework/providers/header.py @@ -45,11 +45,9 @@ class AccessLevel(Enum): Operation.POLICIES_READ: AccessLevel.AUTHENTICATED, Operation.POLICIES_CREATE: AccessLevel.ADMIN, Operation.POLICIES_UPDATE: AccessLevel.ADMIN, - Operation.POLICIES_DELETE: AccessLevel.ADMIN, Operation.AGENTS_READ: AccessLevel.AUTHENTICATED, Operation.AGENTS_CREATE: AccessLevel.AUTHENTICATED, Operation.AGENTS_UPDATE: AccessLevel.ADMIN, - Operation.AGENTS_DELETE: AccessLevel.ADMIN, Operation.RUNTIME_TOKEN_EXCHANGE: AccessLevel.AUTHENTICATED, Operation.RUNTIME_USE: AccessLevel.AUTHENTICATED, } diff --git a/server/src/agent_control_server/endpoints/auth.py b/server/src/agent_control_server/endpoints/auth.py index f80cd2fa..b1ade969 100644 --- a/server/src/agent_control_server/endpoints/auth.py +++ b/server/src/agent_control_server/endpoints/auth.py @@ -2,13 +2,13 @@ The runtime auth flow is two-phase: this endpoint is phase one. The caller presents a long-lived credential plus ``(target_type, -target_id)``; the default authorizer authenticates the credential and -authorizes the implied -:data:`Operation.RUNTIME_TOKEN_EXCHANGE`. On success, this endpoint +target_id)``; the configured authorization provider authenticates the +credential and authorizes the implied +``runtime.token_exchange`` operation. On success, this endpoint mints a short-lived local runtime token bound to the supplied target and returns it. Subsequent target-bearing runtime calls present the returned token, which is verified locally by -:class:`LocalJwtVerifyProvider`. +the runtime JWT provider. """ from __future__ import annotations @@ -56,7 +56,7 @@ class RuntimeTokenExchangeResponse(BaseModel): async def _exchange_context(request: Request) -> dict[str, Any]: - """Surface target identifiers to the authorizer's context. + """Surface target identifiers to the authorization context. Reads the request body once. FastAPI caches the parsed body, so the endpoint's own Pydantic body model still binds normally. @@ -89,11 +89,10 @@ async def runtime_token_exchange( ) -> RuntimeTokenExchangeResponse: """Mint a short-lived runtime token for the requested target. - The caller's credential is authenticated and authorized by the - installed default authorizer; the resulting :class:`Principal` - supplies the actor identity and (when the upstream surfaces it) - the grant scopes and expiry. This endpoint then mints a local HS256 - token whose lifetime cannot outlive the upstream grant. + The caller's credential is authenticated and authorized before the + resolved principal supplies the actor identity, grant scopes, and + expiry. This endpoint then mints a local HS256 token whose lifetime + cannot outlive the grant. Runtime auth must be enabled via ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/server/src/agent_control_server/endpoints/control_bindings.py b/server/src/agent_control_server/endpoints/control_bindings.py index d2fe4b44..87386723 100644 --- a/server/src/agent_control_server/endpoints/control_bindings.py +++ b/server/src/agent_control_server/endpoints/control_bindings.py @@ -102,8 +102,8 @@ async def create_control_binding( ) -> CreateControlBindingResponse: """Attach a control to an opaque external target. - Each binding row is scoped to the request namespace as resolved by - the active authorizer. + Each binding row is scoped to the namespace associated with the + authenticated request. """ service = ControlBindingsService(db) binding = await service.create_binding( @@ -153,7 +153,7 @@ async def list_control_bindings( cursor-based pagination. Bindings are ordered by ID descending (newest first). The cursor is opaque to clients: pass back the ``next_cursor`` value verbatim to fetch the following page. The - storage namespace is resolved by the active authorizer. + storage namespace is resolved from the authenticated request. """ parsed_cursor: int | None if cursor is None: @@ -201,12 +201,11 @@ async def get_control_binding( """Read a single control binding by surrogate ID. Authorization is namespace-wide: the binding's target identifiers - are not forwarded to the upstream because they are only discoverable - after the row is loaded, and ``require_operation`` is single-pass. + are not available until after the row is loaded. Callers whose authorization model requires per-target permissions should use the natural-key endpoints (``PUT /by-key``, ``POST /by-key:delete``) and the target-filtered list endpoint, all - of which forward ``(target_type, target_id)`` to the authorizer. + of which include ``(target_type, target_id)`` in the request context. """ service = ControlBindingsService(db) binding = await service.get_binding_or_404( @@ -232,7 +231,7 @@ async def patch_control_binding( See the GET-by-id docstring for the authorization scope: this route is namespace-wide because the target identifiers are not available before the binding is loaded. Use ``PUT /by-key`` for target-scoped - upserts that forward the target to the authorizer. + upserts that include the target in the request context. """ service = ControlBindingsService(db) binding = await service.set_enabled( @@ -260,7 +259,7 @@ async def delete_control_binding( See the GET-by-id docstring for the authorization scope: this route is namespace-wide because the target identifiers are not available before the binding is loaded. Use ``POST /by-key:delete`` for - target-scoped detach that forwards the target to the authorizer. + target-scoped detach that includes the target in the request context. """ service = ControlBindingsService(db) await service.delete_binding(namespace_key=principal.namespace_key, binding_id=binding_id) diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index 5b01593c..00d2b710 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -787,12 +787,12 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) -# Authorized as CONTROLS_READ: validate exercises the materialization -# path but does not mutate stored control data. +# Authorized as CONTROLS_CREATE: validate exercises the same materialization +# path as create/update authoring flows, even though it does not save. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> ValidateControlDataResponse: """ Validate control configuration data without saving it. diff --git a/server/src/agent_control_server/namespace.py b/server/src/agent_control_server/namespace.py deleted file mode 100644 index 30e30be5..00000000 --- a/server/src/agent_control_server/namespace.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Namespace resolution for request-scoped scoping. - -V1 always resolves to the default namespace. The function exists as a -single seam so a future change can switch every namespace-scoped -endpoint to a real per-request resolver without touching each call -site. Overriding the dependency in V1 is not supported: only this -binding/evaluation layer reads it; controls, agents, and policies still -write under the default namespace, so an override here would create -inconsistent rows. Future work will thread a single resolver through -every write path together. -""" - -from __future__ import annotations - -from .models import DEFAULT_NAMESPACE_KEY - - -def get_namespace_key() -> str: - """Return the namespace_key for the current request. - - V1 returns ``DEFAULT_NAMESPACE_KEY`` unconditionally. - """ - return DEFAULT_NAMESPACE_KEY diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index 41a62282..e3a5fd26 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -20,7 +20,6 @@ from ..errors import APIValidationError, NotFoundError from ..models import ( - DEFAULT_NAMESPACE_KEY, Control, ControlBinding, ControlVersion, @@ -100,7 +99,7 @@ def __init__(self, db: AsyncSession) -> None: def create_control( self, *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, name: str, data: dict[str, Any], ) -> Control: @@ -161,17 +160,19 @@ async def get_active_control_or_404( control_id: int, *, for_update: bool = False, - namespace_key: str | None = None, + namespace_key: str, ) -> Control: """Load an active control row or raise CONTROL_NOT_FOUND. - When ``namespace_key`` is supplied, the lookup is scoped to that - namespace; a control that exists only in another namespace - surfaces as 404 (non-disclosing). + The lookup is scoped to the supplied namespace; a control that + exists only in another namespace surfaces as 404 + (non-disclosing). """ - stmt = select(Control).where(Control.id == control_id, Control.deleted_at.is_(None)) - if namespace_key is not None: - stmt = stmt.where(Control.namespace_key == namespace_key) + stmt = select(Control).where( + Control.id == control_id, + Control.namespace_key == namespace_key, + Control.deleted_at.is_(None), + ) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) @@ -190,7 +191,7 @@ async def active_control_name_exists( self, name: str, *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, exclude_control_id: int | None = None, ) -> bool: """Return whether an active control already uses the provided name.""" @@ -537,7 +538,7 @@ async def list_active_control_counts_by_agent( self, agent_names: Sequence[str], *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, ) -> dict[str, int]: """Return active control counts keyed by agent name.""" if not agent_names: diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 2d39bfa3..799b2d52 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -20,6 +20,7 @@ AccessLevel, HeaderAuthProvider, HttpUpstreamAuthProvider, + LocalJwtVerifyProvider, NoAuthProvider, ) from agent_control_server.auth_framework.providers.header import ( @@ -1029,6 +1030,29 @@ def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): assert auth_config.runtime_auth_config() is None +@pytest.mark.asyncio +async def test_configure_http_upstream_management_with_jwt_runtime(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "jwt") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + try: + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.CONTROLS_READ), HttpUpstreamAuthProvider) + assert isinstance(get_authorizer(Operation.RUNTIME_USE), LocalJwtVerifyProvider) + runtime_config = auth_config.runtime_auth_config() + assert runtime_config is not None + assert runtime_config.secret == _TEST_SECRET + finally: + await auth_config.teardown_auth() + + def test_configure_runtime_jwt_requires_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config diff --git a/server/tests/test_controls_auth.py b/server/tests/test_controls_auth.py index c0f17754..04f44ca4 100644 --- a/server/tests/test_controls_auth.py +++ b/server/tests/test_controls_auth.py @@ -4,10 +4,9 @@ import uuid -from fastapi.testclient import TestClient - from agent_control_server.auth_framework import set_authorizer from agent_control_server.auth_framework.providers import NoAuthProvider +from fastapi.testclient import TestClient from .utils import VALID_CONTROL_PAYLOAD @@ -198,19 +197,18 @@ def test_non_admin_cannot_delete_control( assert resp.status_code == 403, resp.text -def test_non_admin_can_validate_control_data( +def test_non_admin_cannot_validate_control_data( non_admin_client: TestClient, ) -> None: - """``/controls/validate`` requires ``CONTROLS_READ``.""" + """``/controls/validate`` requires ``CONTROLS_CREATE``.""" # When: a non-admin attempts to validate a draft payload resp = non_admin_client.post( f"{_CONTROLS_URL}/validate", json={"data": VALID_CONTROL_PAYLOAD}, ) - # Then: validation is allowed for authenticated non-admin callers - assert resp.status_code == 200, resp.text - assert resp.json()["success"] is True + # Then: validation is admin-only + assert resp.status_code == 403, resp.text def test_non_admin_cannot_render_template(non_admin_client: TestClient) -> None: diff --git a/server/tests/test_runtime_token_exchange_endpoint.py b/server/tests/test_runtime_token_exchange_endpoint.py index 8d333a5c..1b1edae2 100644 --- a/server/tests/test_runtime_token_exchange_endpoint.py +++ b/server/tests/test_runtime_token_exchange_endpoint.py @@ -11,8 +11,6 @@ from datetime import UTC, datetime, timedelta import pytest -from fastapi.testclient import TestClient - from agent_control_server.auth_framework import Operation, Principal from agent_control_server.auth_framework.config import ( RuntimeAuthConfig, @@ -25,6 +23,7 @@ from agent_control_server.auth_framework.providers import ( LocalJwtVerifyProvider, ) +from fastapi.testclient import TestClient _TEST_SECRET = "test-runtime-secret-12345678901234567890" @@ -180,6 +179,39 @@ async def test_exchange_then_verify_full_round_trip(client: TestClient, runtime_ assert principal.caller_id == "actor-rt" +def test_evaluation_rejects_runtime_jwt_for_wrong_target( + client: TestClient, + runtime_config_enabled, +): + """A runtime JWT minted for one target cannot be used for another target.""" + stub = _StubExchangeAuthorizer(actor_id="actor-rt", scopes=("runtime.use",)) + clear_authorizers() + set_authorizer(stub) + set_authorizer(LocalJwtVerifyProvider(secret=_TEST_SECRET), operation=Operation.RUNTIME_USE) + + exchange = client.post( + "/api/v1/auth/runtime-token-exchange", + json={"target_type": "log_stream", "target_id": "ls-allowed"}, + ) + assert exchange.status_code == 200, exchange.text + token = exchange.json()["token"] + + response = client.post( + "/api/v1/evaluation", + headers={"Authorization": f"Bearer {token}"}, + json={ + "agent_name": "agent", + "step": {"type": "llm", "name": "step", "input": "hello"}, + "stage": "pre", + "target_type": "log_stream", + "target_id": "ls-other", + }, + ) + + assert response.status_code == 403, response.text + assert response.json()["detail"] == "Runtime token target_id does not match the request." + + def test_exchange_endpoint_502_when_upstream_grant_already_expired( client: TestClient, runtime_config_enabled, diff --git a/server/tests/test_services_controls.py b/server/tests/test_services_controls.py index b858c527..3815f26b 100644 --- a/server/tests/test_services_controls.py +++ b/server/tests/test_services_controls.py @@ -8,10 +8,6 @@ import pytest from agent_control_models.errors import ErrorCode -from sqlalchemy import insert, select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Session - from agent_control_server.errors import APIValidationError from agent_control_server.models import ( DEFAULT_NAMESPACE_KEY, @@ -27,6 +23,9 @@ from agent_control_server.services.controls import ( ControlService, ) +from sqlalchemy import insert, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session from .conftest import AsyncSessionTest, engine from .utils import VALID_CONTROL_PAYLOAD @@ -70,7 +69,11 @@ async def _create_versioned_control( async with AsyncSessionTest() as session: service = ControlService(session) - control = service.create_control(name=control_name, data=control_data) + control = service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, + name=control_name, + data=control_data, + ) await service.create_version( control, event_type="created", @@ -143,6 +146,7 @@ async def test_create_control_transaction_rollback_does_not_persist_control_or_v async with AsyncSessionTest() as session: service = ControlService(session) control = service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, name=control_name, data=deepcopy(VALID_CONTROL_PAYLOAD), ) @@ -167,7 +171,10 @@ async def test_replace_control_data_transaction_rollback_preserves_prior_state() async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) updated_data = deepcopy(control.data) updated_data["description"] = "Should not persist" service.replace_control_data(control, data=updated_data) @@ -194,7 +201,10 @@ async def test_patch_mutation_transaction_rollback_preserves_prior_state() -> No async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) service.rename_control(control, name=f"{control_name}-renamed") service.set_control_enabled(control, enabled=False) await service.create_version( @@ -221,7 +231,10 @@ async def test_delete_control_transaction_rollback_preserves_active_state() -> N async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) service.mark_control_deleted(control, deleted_at=dt.datetime.now(dt.UTC)) await service.create_version( control, @@ -511,7 +524,10 @@ async def test_list_active_control_counts_by_agent_deduplicates_and_filters_inac await async_db.commit() # When: counting active controls for the agent - counts = await ControlService(async_db).list_active_control_counts_by_agent([agent.name]) + counts = await ControlService(async_db).list_active_control_counts_by_agent( + [agent.name], + namespace_key=DEFAULT_NAMESPACE_KEY, + ) # Then: active controls are deduplicated and inactive controls are excluded assert counts == {agent.name: 2} @@ -572,6 +588,7 @@ async def test_create_version_allocates_sequential_numbers_under_concurrent_muta async with AsyncSessionTest() as setup_session: setup_service = ControlService(setup_session) control = setup_service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, name=f"control-{uuid.uuid4()}", data=deepcopy(VALID_CONTROL_PAYLOAD), ) @@ -592,7 +609,10 @@ async def mutate_and_version(description: str) -> None: async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) updated_data = deepcopy(control.data) updated_data["description"] = description service.replace_control_data(control, data=updated_data) From d770ae17c39c7ae2bb7577bfd5c780f4ce180329 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 8 May 2026 16:43:39 +0530 Subject: [PATCH 04/12] feat(server): operator-configurable extra forwarded headers on HttpUpstream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The default forward set (X-API-Key, Authorization, Cookie) only covers credential headers Agent Control itself reads. Deployments whose upstream authenticates against a different header name (e.g., a deployer-specific API-key header) had no way to surface that credential through HttpUpstreamAuthProvider — the inbound header reached AC but never crossed the upstream call. Add an extra_forward_headers config field on HttpUpstreamConfig (defaulting to the empty tuple) that operators populate via the new AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS env var (comma- separated). The provider's _forward_headers iterates over the union of the default set and the extras, deduplicating case-insensitively so a duplicate name (cross-set or within extras) does not produce two copies on the wire. Tests: - forwards a configured extra header alongside defaults - default forward set unchanged when extras are empty - extras dedupe against defaults case-insensitively - _parse_extra_forward_headers parametric: None / empty / single / multiple / whitespace / empty-entries / case-folded duplicates - configure_auth_from_env threads the parsed tuple onto the provider Lint clean, typecheck clean, full server suite (747) green. --- .../auth_framework/config.py | 29 +++++ .../auth_framework/providers/http_upstream.py | 20 ++- .../endpoints/controls.py | 3 +- server/tests/test_auth_framework.py | 115 ++++++++++++++++++ 4 files changed, 163 insertions(+), 4 deletions(-) diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index c8f428dc..8c39a2ec 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -46,6 +46,7 @@ _UPSTREAM_TIMEOUT_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_TIMEOUT_SECONDS" _UPSTREAM_TOKEN_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN" _UPSTREAM_TOKEN_HEADER_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER" +_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS" # Runtime flow. _RUNTIME_MODE_ENV = "AGENT_CONTROL_RUNTIME_AUTH_MODE" @@ -196,6 +197,9 @@ def _build_default_provider() -> RequestAuthorizer: timeout = float(os.environ.get(_UPSTREAM_TIMEOUT_ENV, "5.0")) token = os.environ.get(_UPSTREAM_TOKEN_ENV) token_header = os.environ.get(_UPSTREAM_TOKEN_HEADER_ENV, "X-Agent-Control-Service-Token") + extra_forward_headers = _parse_extra_forward_headers( + os.environ.get(_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV) + ) _logger.info("Default auth provider: http_upstream url=%s", url) return HttpUpstreamAuthProvider( HttpUpstreamConfig( @@ -203,6 +207,7 @@ def _build_default_provider() -> RequestAuthorizer: timeout_seconds=timeout, service_token=token, service_token_header=token_header, + extra_forward_headers=extra_forward_headers, ) ) raise RuntimeError( @@ -210,6 +215,30 @@ def _build_default_provider() -> RequestAuthorizer: ) +def _parse_extra_forward_headers(raw: str | None) -> tuple[str, ...]: + """Parse a comma-separated header list into a deduplicated tuple. + + Empty / unset env var returns an empty tuple. Whitespace around each + name is stripped. Empty entries (e.g. ``"X-A,,X-B"``) are dropped. + Order is preserved; duplicates (case-insensitive) are dropped after + the first occurrence. + """ + if not raw or not raw.strip(): + return () + seen: set[str] = set() + result: list[str] = [] + for raw_name in raw.split(","): + name = raw_name.strip() + if not name: + continue + lower = name.lower() + if lower in seen: + continue + seen.add(lower) + result.append(name) + return tuple(result) + + def _resolve_runtime_mode() -> str: raw = os.environ.get(_RUNTIME_MODE_ENV) if raw is None or not raw.strip(): diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index 8d5c850c..78ed9ae2 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -60,7 +60,7 @@ _logger = get_logger(__name__) -_FORWARDED_HEADERS = ("X-API-Key", "Authorization", "Cookie") +_DEFAULT_FORWARDED_HEADERS = ("X-API-Key", "Authorization", "Cookie") class _UpstreamGrant(BaseModel): @@ -136,6 +136,17 @@ class HttpUpstreamConfig: service_token_header: str = "X-Agent-Control-Service-Token" + extra_forward_headers: tuple[str, ...] = () + """Additional inbound request headers to forward to the upstream + on top of the default ``(X-API-Key, Authorization, Cookie)`` set. + + Use this when the upstream authenticates via a header the provider + does not forward by default (e.g., a deployer-specific API-key + header). Header lookups against the inbound request are + case-insensitive; an empty or absent inbound header is silently + dropped. Names duplicating the default set or each other (after + case-folding) are deduplicated.""" + class HttpUpstreamAuthProvider(RequestAuthorizer): """Delegates authorization to an upstream HTTP service.""" @@ -190,7 +201,12 @@ async def authorize( def _forward_headers(self, request: Request) -> dict[str, str]: headers: dict[str, str] = {} - for name in _FORWARDED_HEADERS: + seen: set[str] = set() + for name in (*_DEFAULT_FORWARDED_HEADERS, *self._config.extra_forward_headers): + lower = name.lower() + if lower in seen: + continue + seen.add(lower) value = request.headers.get(name) if value is not None: headers[name] = value diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index 00d2b710..b4fa8d0b 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -787,8 +787,7 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) -# Authorized as CONTROLS_CREATE: validate exercises the same materialization -# path as create/update authoring flows, even though it does not save. +# Validation uses the authoring path, so require create access. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 799b2d52..dc3a1787 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -261,6 +261,75 @@ def factory(request: httpx.Request) -> httpx.Response: assert captured["headers"]["x-custom-token"] == "shh" +@pytest.mark.asyncio +async def test_http_upstream_forwards_extra_headers(): + # Given: a provider configured with an extra header in its forward list + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream( + factory, + config_overrides={"extra_forward_headers": ("X-Deployer-Auth",)}, + ) + + # When: the inbound request carries the extra header + inbound = _build_request(headers={"X-Deployer-Auth": "k_abc", "X-API-Key": "k1"}) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: both the default and the extra header reach the upstream + assert captured["headers"]["x-deployer-auth"] == "k_abc" + assert captured["headers"]["x-api-key"] == "k1" + + +@pytest.mark.asyncio +async def test_http_upstream_default_forward_set_unchanged(): + # Given: a provider with no extra_forward_headers + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream(factory) + + # When: the inbound carries an unlisted header alongside a default one + inbound = _build_request( + headers={"X-API-Key": "k1", "X-Deployer-Auth": "should-not-forward"} + ) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: only the default-set header reaches the upstream + assert captured["headers"].get("x-api-key") == "k1" + assert "x-deployer-auth" not in captured["headers"] + + +@pytest.mark.asyncio +async def test_http_upstream_extra_forward_dedupes_against_defaults(): + # Given: extra list duplicates a default header (different case) + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream( + factory, + config_overrides={"extra_forward_headers": ("x-api-key", "Authorization")}, + ) + + # When: inbound has both + inbound = _build_request(headers={"X-API-Key": "k1", "Authorization": "Bearer t"}) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: each header appears exactly once on the upstream request + forwarded = captured["headers"] + assert sum(1 for k in forwarded if k.lower() == "x-api-key") == 1 + assert sum(1 for k in forwarded if k.lower() == "authorization") == 1 + + @pytest.mark.asyncio @pytest.mark.parametrize( "status, expected", @@ -1053,6 +1122,52 @@ async def test_configure_http_upstream_management_with_jwt_runtime(monkeypatch): await auth_config.teardown_auth() +@pytest.mark.parametrize( + "raw, expected", + [ + (None, ()), + ("", ()), + (" ", ()), + ("X-One", ("X-One",)), + ("X-One,X-Two", ("X-One", "X-Two")), + (" X-One , X-Two ", ("X-One", "X-Two")), + ("X-One,,X-Two", ("X-One", "X-Two")), + ("X-One,x-one,X-One", ("X-One",)), + ("X-A,X-B,x-a,X-C,X-b", ("X-A", "X-B", "X-C")), + ], +) +def test_parse_extra_forward_headers(raw, expected): + from agent_control_server.auth_framework.config import _parse_extra_forward_headers + + assert _parse_extra_forward_headers(raw) == expected + + +@pytest.mark.asyncio +async def test_configure_http_upstream_extra_forward_headers_env(monkeypatch): + """Setting the env var threads extra_forward_headers into the provider.""" + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.setenv( + "AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS", + "X-Deployer-Auth, X-Deployer-Trace", + ) + + try: + auth_config.configure_auth_from_env() + provider = get_authorizer(Operation.CONTROLS_READ) + assert isinstance(provider, HttpUpstreamAuthProvider) + assert provider._config.extra_forward_headers == ( + "X-Deployer-Auth", + "X-Deployer-Trace", + ) + finally: + await auth_config.teardown_auth() + + def test_configure_runtime_jwt_requires_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config From cbb098b4ebabb70a40656223ce2f77a2b13281e6 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 8 May 2026 21:36:45 +0530 Subject: [PATCH 05/12] fix(server): preserve default runtime auth fallback --- .../auth_framework/config.py | 37 +++++++++------- server/tests/test_auth_framework.py | 44 +++++++++++++++++-- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 8c39a2ec..595c3117 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -16,8 +16,8 @@ :class:`NoAuthProvider`, ``api_key`` uses :class:`HeaderAuthProvider`, and ``jwt`` uses :class:`LocalJwtVerifyProvider`. When the mode is unset, startup - preserves historical behavior by selecting ``jwt`` if - ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set, otherwise ``api_key``. + selects ``jwt`` if ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set; + otherwise runtime falls through to the default authorizer. The ``runtime.token_exchange`` operation continues to flow through the default authorizer because the exchange itself is shaped like a management call (forward credential, get grant). @@ -96,10 +96,11 @@ def configure_auth_from_env() -> None: Runtime flow: - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=none``: :class:`NoAuthProvider`. - - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key`` (default when no runtime - token secret is configured): :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key``: :class:`HeaderAuthProvider`. - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=jwt`` (default when a runtime token secret is configured): :class:`LocalJwtVerifyProvider`. + - unset mode without a runtime token secret: fall through to the default + authorizer. Clears any previously-installed default and operation overrides before installing fresh ones, so reconfiguration cannot leave @@ -121,20 +122,26 @@ def configure_auth_from_env() -> None: set_authorizer(default) _active_providers.append(default) - runtime_provider = _build_runtime_provider(runtime_mode, _runtime_auth_config) - set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) - _active_providers.append(runtime_provider) - if runtime_mode == "jwt": + if runtime_mode == "default": _logger.info( - "Runtime auth provider: jwt override installed for %s", + "Runtime auth provider: default authorizer handles %s", Operation.RUNTIME_USE.value, ) else: - _logger.info( - "Runtime auth provider: %s override installed for %s", - runtime_mode, - Operation.RUNTIME_USE.value, - ) + runtime_provider = _build_runtime_provider(runtime_mode, _runtime_auth_config) + set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) + _active_providers.append(runtime_provider) + if runtime_mode == "jwt": + _logger.info( + "Runtime auth provider: jwt override installed for %s", + Operation.RUNTIME_USE.value, + ) + else: + _logger.info( + "Runtime auth provider: %s override installed for %s", + runtime_mode, + Operation.RUNTIME_USE.value, + ) async def teardown_auth() -> None: @@ -242,7 +249,7 @@ def _parse_extra_forward_headers(raw: str | None) -> tuple[str, ...]: def _resolve_runtime_mode() -> str: raw = os.environ.get(_RUNTIME_MODE_ENV) if raw is None or not raw.strip(): - return "jwt" if os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) else "api_key" + return "jwt" if os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) else "default" mode = raw.strip().lower() if mode in {"none", "no_auth"}: diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index dc3a1787..20c58aed 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -7,7 +7,6 @@ import httpx import pytest - from agent_control_server.auth_framework.core import ( Operation, Principal, @@ -700,7 +699,6 @@ def test_runtime_token_rejects_naive_upstream_expires_at(): def test_runtime_token_rejects_management_token_passed_to_runtime_verify(): """A token without ``domain=runtime`` must be rejected by runtime verify.""" import jwt - from agent_control_server.auth_framework.runtime_token import ( RuntimeTokenError, verify_runtime_token, @@ -1053,13 +1051,13 @@ def test_build_default_provider_accepts_none_mode(monkeypatch): assert isinstance(auth_config._build_default_provider(), NoAuthProvider) -def test_resolve_runtime_mode_defaults_to_api_key_without_secret(monkeypatch): +def test_resolve_runtime_mode_defaults_to_default_without_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) - assert auth_config._resolve_runtime_mode() == "api_key" + assert auth_config._resolve_runtime_mode() == "default" def test_resolve_runtime_mode_defaults_to_jwt_with_secret(monkeypatch): @@ -1099,6 +1097,44 @@ def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): assert auth_config.runtime_auth_config() is None +def test_configure_runtime_unset_preserves_no_auth_default(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "none") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), NoAuthProvider) + assert auth_config.runtime_auth_config() is None + + +@pytest.mark.asyncio +async def test_configure_runtime_unset_preserves_http_upstream_default(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + try: + auth_config.configure_auth_from_env() + + default_provider = get_authorizer(Operation.CONTROLS_READ) + runtime_provider = get_authorizer(Operation.RUNTIME_USE) + assert isinstance(default_provider, HttpUpstreamAuthProvider) + assert runtime_provider is default_provider + assert auth_config.runtime_auth_config() is None + finally: + await auth_config.teardown_auth() + + @pytest.mark.asyncio async def test_configure_http_upstream_management_with_jwt_runtime(monkeypatch): from agent_control_server.auth_framework import config as auth_config From 5008ab37a0dcb8a17142cc888b412fb64ff5981a Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 14:39:01 +0530 Subject: [PATCH 06/12] fix(server): harden auth scoping --- docs/README.md | 1 + docs/auth.md | 148 ++++++++++++++++++ models/src/agent_control_models/server.py | 3 +- .../agent_control_server/endpoints/agents.py | 84 +++++++++- .../agent_control_server/endpoints/auth.py | 17 +- .../endpoints/controls.py | 44 +++++- server/tests/test_principal_namespace_flow.py | 33 +++- server/tests/test_target_merged_contract.py | 96 +++++++++++- 8 files changed, 402 insertions(+), 24 deletions(-) create mode 100644 docs/auth.md diff --git a/docs/README.md b/docs/README.md index 9b7cb757..e53dcf13 100644 --- a/docs/README.md +++ b/docs/README.md @@ -10,6 +10,7 @@ This repository keeps documentation concise. The full documentation lives on the - [Controls](https://docs.agentcontrol.dev/concepts/controls) — Define and configure control rules - [Reference](https://docs.agentcontrol.dev/core/reference) — SDK and server API reference - [Configuration](https://docs.agentcontrol.dev/core/configuration) — Environment variables, auth, and database settings +- [Server auth contract](auth.md) - Pluggable auth modes, HTTP upstream contract, and runtime JWT claims - [UI Quickstart](https://docs.agentcontrol.dev/core/ui-quickstart) — Run the dashboard and manage controls visually ## Examples diff --git a/docs/auth.md b/docs/auth.md new file mode 100644 index 00000000..5002faa8 --- /dev/null +++ b/docs/auth.md @@ -0,0 +1,148 @@ +# Server Auth Contract + +Agent Control keeps authentication and authorization provider-neutral. The server asks a configured provider whether a request may perform an operation, then scopes all data access with the returned `Principal`. + +## Operations + +Operations are stable strings. Deployers map them to their own permission model. + +```text +controls.read +controls.create +controls.update +controls.delete +policies.read +policies.create +policies.update +agents.read +agents.create +agents.update +control_bindings.read +control_bindings.write +runtime.token_exchange +runtime.use +``` + +## Principal + +Providers return a generic principal. Agent Control treats `namespace_key`, `caller_id`, `target_type`, and `target_id` as opaque strings. + +```json +{ + "namespace_key": "tenant-a", + "is_admin": false, + "caller_id": "user-or-key-id", + "target_type": "session", + "target_id": "target-123", + "scopes": ["runtime.use"], + "expires_at": "2026-05-11T15:00:00Z" +} +``` + +`namespace_key` is the tenancy boundary. Server queries filter by it, and namespace-aware foreign keys prevent cross-namespace references. + +## Auth Modes + +Management auth is selected by `AGENT_CONTROL_AUTH_MODE`. + +| Mode | Meaning | +| --- | --- | +| `none` | No credentials required. Intended for local development only. | +| `api_key` | Validate caller credentials locally with `AGENT_CONTROL_API_KEYS`. This is the default. `header` is accepted as a backwards-compatible alias. | +| `http_upstream` | POST each management authorization decision to `AGENT_CONTROL_AUTH_UPSTREAM_URL`. | + +Runtime auth is selected by `AGENT_CONTROL_RUNTIME_AUTH_MODE`. + +| Mode | Meaning | +| --- | --- | +| unset | Use `jwt` when `AGENT_CONTROL_RUNTIME_TOKEN_SECRET` is set. Otherwise runtime requests fall through to management auth. | +| `none` | No runtime credentials required. Intended for local development only. | +| `api_key` | Validate runtime requests with the same local API-key mechanism. | +| `jwt` | Require target-bound runtime tokens minted by `/api/v1/auth/runtime-token-exchange`. | + +Common combinations: + +| Management | Runtime | Use case | +| --- | --- | --- | +| `api_key` | unset | Existing standalone deployments. | +| `api_key` | `jwt` | Local management keys with short-lived target-bound runtime tokens. | +| `http_upstream` | `jwt` | External identity or authorization service for management, local token verify for high-volume runtime calls. | +| `none` | `none` | Single-process local development. Do not use in production. | + +## HTTP Upstream Contract + +When `AGENT_CONTROL_AUTH_MODE=http_upstream`, the server sends: + +```http +POST {AGENT_CONTROL_AUTH_UPSTREAM_URL} +``` + +```json +{ + "operation": "control_bindings.write", + "context": { + "target_type": "session", + "target_id": "target-123" + } +} +``` + +The provider forwards inbound `X-API-Key`, `Authorization`, and `Cookie` headers. Add deployer-specific header names with `AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS`, for example: + +```text +AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS=Vendor-API-Key,X-Workspace-Id +``` + +If `AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN` is set, it is forwarded on `AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER` or `X-Agent-Control-Service-Token` by default. + +A successful upstream response is: + +```json +{ + "namespace_key": "tenant-a", + "is_admin": false, + "caller_id": "user-or-key-id", + "target_type": "session", + "target_id": "target-123", + "scopes": ["runtime.use"], + "expires_at": "2026-05-11T15:00:00Z" +} +``` + +Only `namespace_key` is always required. `target_type` and `target_id` must be returned together when present. `expires_at` must include timezone information. + +Status handling: + +| Upstream status | Agent Control result | +| --- | --- | +| `200` | Parse the principal grant. | +| `401` | Authentication error. | +| `403` | Forbidden error. | +| `404` | Not found error. | +| `429` | `503` with a rate-limit detail and `Retry-After` hint when present. | +| Other statuses or malformed JSON | Fail closed with `503` or `502`. | + +## Runtime JWT Claims + +`/api/v1/auth/runtime-token-exchange` is a management-style request. The configured management provider authorizes `runtime.token_exchange` for the requested target. Agent Control then mints its own HS256 JWT with `AGENT_CONTROL_RUNTIME_TOKEN_SECRET`. + +The token payload contains: + +```json +{ + "iss": "agent-control/server", + "domain": "runtime", + "namespace_key": "tenant-a", + "actor_id": "user-or-key-id", + "target_type": "session", + "target_id": "target-123", + "scopes": ["runtime.use"], + "iat": 1778509800, + "exp": 1778510100, + "jti": "opaque-token-id" +} +``` + +Verification requires the expected issuer, `domain="runtime"`, a valid signature, an unexpired `exp`, and `runtime.use` in `scopes`. The token is accepted only for requests whose `target_type` and `target_id` match the bound target. + +The expiry is the earlier of `AGENT_CONTROL_RUNTIME_TOKEN_TTL_SECONDS` and the upstream grant's `expires_at` when supplied. Runtime token TTLs are capped at 86400 seconds. diff --git a/models/src/agent_control_models/server.py b/models/src/agent_control_models/server.py index 9b890b91..3529a5d4 100644 --- a/models/src/agent_control_models/server.py +++ b/models/src/agent_control_models/server.py @@ -640,7 +640,7 @@ class CreateControlBindingRequest(BaseModel): target_type: ControlBindingTargetField = Field( ..., - description="Opaque attachment kind (caller-defined; e.g. 'env', 'log_stream').", + description="Opaque attachment kind (caller-defined; e.g. 'environment', 'session').", ) target_id: ControlBindingTargetField = Field( ..., description="Opaque external identifier within the target_type." @@ -760,4 +760,3 @@ class DeleteControlBindingByKeyResponse(BaseModel): ), ) - diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index ac099911..57ca1ebc 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -29,20 +29,21 @@ SetPolicyResponse, StepKey, ) -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends, Query, Request from jsonschema_rs import ValidationError as JSONSchemaValidationError from pydantic import BaseModel, ValidationError from sqlalchemy import delete, func, select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession -from ..auth_framework import Operation, Principal, require_operation +from ..auth_framework import Operation, Principal, get_authorizer, require_operation from ..db import get_async_db from ..errors import ( APIValidationError, BadRequestError, ConflictError, DatabaseError, + ForbiddenError, NotFoundError, ) from ..logging_utils import get_logger @@ -85,6 +86,81 @@ type StepKeyTuple = tuple[str, str] +def _complete_target_context( + target_type: object | None, + target_id: object | None, +) -> dict[str, str] | None: + """Return target context only when both halves are present strings.""" + if not isinstance(target_type, str) or not isinstance(target_id, str): + return None + if not target_type or not target_id: + return None + return {"target_type": target_type, "target_id": target_id} + + +async def _init_agent_target_context(request: Request) -> dict[str, str] | None: + """Extract optional target context from an ``initAgent`` body.""" + try: + body = await request.json() + except Exception: # noqa: BLE001 malformed JSON, defer to endpoint validation + return None + if not isinstance(body, dict): + return None + return _complete_target_context(body.get("target_type"), body.get("target_id")) + + +def _agent_controls_target_context(request: Request) -> dict[str, str] | None: + """Extract optional target context from ``GET /agents/{name}/controls``.""" + return _complete_target_context( + request.query_params.get("target_type"), + request.query_params.get("target_id"), + ) + + +async def _authorize_target_read_if_present( + request: Request, + context: dict[str, str] | None, +) -> Principal | None: + """Require target read authorization before returning target-merged controls.""" + if context is None: + return None + return await get_authorizer(Operation.CONTROL_BINDINGS_READ).authorize( + request, + Operation.CONTROL_BINDINGS_READ, + context, + ) + + +async def _init_agent_target_principal(request: Request) -> Principal | None: + return await _authorize_target_read_if_present( + request, + await _init_agent_target_context(request), + ) + + +async def _agent_controls_target_principal(request: Request) -> Principal | None: + return await _authorize_target_read_if_present( + request, + _agent_controls_target_context(request), + ) + + +def _ensure_target_principal_matches_namespace( + principal: Principal, + target_principal: Principal | None, +) -> None: + """Fail closed if the target authorization resolves to a different namespace.""" + if target_principal is None: + return + if target_principal.namespace_key == principal.namespace_key: + return + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="Target authorization resolved to a different namespace.", + hint="Ensure the credential is scoped to the requested target and namespace.", + ) + + # ============================================================================= # List Agents Models # ============================================================================= @@ -445,6 +521,7 @@ async def init_agent( request: InitAgentRequest, db: AsyncSession = Depends(get_async_db), principal: Principal = Depends(require_operation(Operation.AGENTS_CREATE)), + target_principal: Principal | None = Depends(_init_agent_target_principal), ) -> InitAgentResponse: """ Register a new agent or update an existing agent's steps and metadata. @@ -474,6 +551,7 @@ async def init_agent( InitAgentResponse with created flag and the effective controls """ namespace_key = principal.namespace_key + _ensure_target_principal_matches_namespace(principal, target_principal) # Check for evaluator name collisions with built-in evaluators builtin_names = _get_builtin_evaluator_names() @@ -1493,6 +1571,7 @@ async def list_agent_controls( ), db: AsyncSession = Depends(get_async_db), principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), + target_principal: Principal | None = Depends(_agent_controls_target_principal), ) -> AgentControlsResponse: """ List protection controls effective for an agent. @@ -1527,6 +1606,7 @@ async def list_agent_controls( HTTPException 404: Agent not found """ namespace_key = principal.namespace_key + _ensure_target_principal_matches_namespace(principal, target_principal) if (target_type is None) != (target_id is None): raise BadRequestError( diff --git a/server/src/agent_control_server/endpoints/auth.py b/server/src/agent_control_server/endpoints/auth.py index b1ade969..7125b64d 100644 --- a/server/src/agent_control_server/endpoints/auth.py +++ b/server/src/agent_control_server/endpoints/auth.py @@ -28,8 +28,10 @@ mint_runtime_token, ) from ..errors import APIError, BadRequestError +from ..logging_utils import get_logger router = APIRouter(prefix="/auth", tags=["auth"]) +_logger = get_logger(__name__) class RuntimeTokenExchangeRequest(BaseModel): @@ -38,7 +40,7 @@ class RuntimeTokenExchangeRequest(BaseModel): model_config = ConfigDict(extra="forbid") target_type: str = Field( - ..., description="Opaque target kind (e.g., ``log_stream``).", min_length=1 + ..., description="Opaque target kind (e.g., ``session``).", min_length=1 ) target_id: str = Field(..., description="Opaque target identifier.", min_length=1) @@ -175,6 +177,19 @@ async def runtime_token_exchange( hint="Check the runtime token configuration.", ) from exc + _logger.info( + "Runtime token exchanged", + extra={ + "namespace_key": claims.namespace_key, + "actor_id": claims.actor_id, + "target_type": claims.target_type, + "target_id": claims.target_id, + "scopes": list(claims.scopes), + "expires_at": claims.expires_at.isoformat(), + "jti": claims.jti, + }, + ) + return RuntimeTokenExchangeResponse( token=token, expires_at=claims.expires_at, diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index b4fa8d0b..6e6441e9 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -195,12 +195,17 @@ async def _render_and_validate_template_input( template_input: TemplateControlInput, *, db: AsyncSession, + namespace_key: str, enabled: bool = True, ) -> ControlDefinition: """Render a template-backed input and validate evaluator config.""" rendered = render_template_control_input(template_input, enabled=enabled) try: - await _validate_control_definition(rendered.control, db) + await _validate_control_definition( + rendered.control, + db, + namespace_key=namespace_key, + ) except APIValidationError as exc: raise remap_template_api_error( exc, @@ -214,6 +219,7 @@ async def _materialize_control_input( control_input: ControlDefinition | TemplateControlInput, *, db: AsyncSession, + namespace_key: str, current_payload: object | None = None, control_id: int | None = None, ) -> ControlDefinition | UnrenderedTemplateControl: @@ -226,6 +232,7 @@ async def _materialize_control_input( return await _render_and_validate_template_input( control_input, db=db, + namespace_key=namespace_key, enabled=enabled, ) @@ -244,6 +251,7 @@ async def _materialize_control_input( return await _render_and_validate_template_input( control_input, db=db, + namespace_key=namespace_key, enabled=enabled, ) @@ -262,12 +270,19 @@ async def _materialize_control_input( raise RuntimeError("control_id is required for template-backed raw updates") raise _template_backed_raw_update_conflict(control_id) - await _validate_control_definition(control_input, db) + await _validate_control_definition( + control_input, + db, + namespace_key=namespace_key, + ) return control_input async def _validate_control_definition( - control_def: ControlDefinition, db: AsyncSession + control_def: ControlDefinition, + db: AsyncSession, + *, + namespace_key: str, ) -> None: """Validate evaluator config for definitions referencing known global evaluators. @@ -296,7 +311,10 @@ async def _validate_control_definition( agent_data = agent_data_by_name.get(agent_namespace) if agent_data is None: agent_result = await db.execute( - select(Agent).where(Agent.name == agent_namespace) + select(Agent).where( + Agent.name == agent_namespace, + Agent.namespace_key == namespace_key, + ) ) agent = agent_result.scalars().first() if agent is None: @@ -447,7 +465,7 @@ async def _validate_control_definition( async def render_control_template( request: RenderControlTemplateRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> RenderControlTemplateResponse: """Render a template-backed control without persisting it.""" control_def = await _render_and_validate_template_input( @@ -456,6 +474,7 @@ async def render_control_template( template_values=request.template_values, ), db=db, + namespace_key=principal.namespace_key, enabled=True, ) return RenderControlTemplateResponse(control=control_def) @@ -504,7 +523,11 @@ async def create_control( hint="Choose a different name or update the existing control.", ) - control_def = await _materialize_control_input(request.data, db=db) + control_def = await _materialize_control_input( + request.data, + db=db, + namespace_key=namespace_key, + ) control_data = _serialize_control_data(control_def) control = control_service.create_control( @@ -751,6 +774,7 @@ async def set_control_data( control_def = await _materialize_control_input( request.data, db=db, + namespace_key=principal.namespace_key, current_payload=control.data, control_id=control_id, ) @@ -791,7 +815,7 @@ async def set_control_data( async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> ValidateControlDataResponse: """ Validate control configuration data without saving it. @@ -805,7 +829,11 @@ async def validate_control_data( """ # Validate mirrors create: complete template values trigger a full render, # incomplete values validate structure only (matching unrendered create). - await _materialize_control_input(request.data, db=db) + await _materialize_control_input( + request.data, + db=db, + namespace_key=principal.namespace_key, + ) return ValidateControlDataResponse(success=True) diff --git a/server/tests/test_principal_namespace_flow.py b/server/tests/test_principal_namespace_flow.py index 40ecd216..14d2d874 100644 --- a/server/tests/test_principal_namespace_flow.py +++ b/server/tests/test_principal_namespace_flow.py @@ -3,16 +3,16 @@ from __future__ import annotations import uuid +from copy import deepcopy from typing import Any -from fastapi import FastAPI, Request -from fastapi.testclient import TestClient - from agent_control_server.auth_framework import ( Operation, Principal, set_authorizer, ) +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient from .utils import VALID_CONTROL_PAYLOAD @@ -139,3 +139,30 @@ def test_duplicate_control_names_allowed_across_principal_namespaces(app: FastAP assert ns_a.put("/api/v1/controls", json=payload).status_code == 200 assert ns_b.put("/api/v1/controls", json=payload).status_code == 200 + + +def test_agent_scoped_evaluator_validation_uses_principal_namespace(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + + register_b = ns_b.post( + "/api/v1/agents/initAgent", + json={ + **_agent_payload(agent_name), + "evaluators": [{"name": "custom", "config_schema": {"type": "object"}}], + }, + ) + assert register_b.status_code == 200, register_b.text + + control_data = deepcopy(VALID_CONTROL_PAYLOAD) + control_data["condition"]["evaluator"] = { + "name": f"{agent_name}:custom", + "config": {}, + } + + resp = ns_a.post("/api/v1/controls/validate", json={"data": control_data}) + assert resp.status_code == 404, resp.text + assert resp.json()["detail"] == f"Agent '{agent_name}' not found" diff --git a/server/tests/test_target_merged_contract.py b/server/tests/test_target_merged_contract.py index 62891ba5..6bc4ab0f 100644 --- a/server/tests/test_target_merged_contract.py +++ b/server/tests/test_target_merged_contract.py @@ -18,11 +18,37 @@ from copy import deepcopy from typing import Any +import pytest +from agent_control_server.auth_framework import Operation, Principal, set_authorizer +from fastapi import Request from fastapi.testclient import TestClient from .utils import VALID_CONTROL_PAYLOAD, canonicalize_control_payload +class RecordingAuthorizer: + """Authorizer that records operation/context pairs for endpoint contract tests.""" + + def __init__(self, *, target_namespace_key: str = "default") -> None: + self.calls: list[tuple[Operation, dict[str, Any] | None]] = [] + self.target_namespace_key = target_namespace_key + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request + self.calls.append((operation, context)) + namespace_key = ( + self.target_namespace_key + if operation is Operation.CONTROL_BINDINGS_READ and context is not None + else "default" + ) + return Principal(namespace_key=namespace_key, is_admin=True) + + def _agent_payload( agent_name: str, *, @@ -115,7 +141,7 @@ def _list_effective_via_get( # --------------------------------------------------------------------------- -def test_initAgent_with_target_merges_direct_and_target_controls( +def test_init_agent_with_target_merges_direct_and_target_controls( client: TestClient, ) -> None: agent_name = f"agent-{uuid.uuid4().hex[:12]}" @@ -134,7 +160,7 @@ def test_initAgent_with_target_merges_direct_and_target_controls( assert returned_ids == {direct_id, target_id_ctrl} -def test_initAgent_newly_created_with_target_picks_up_pre_existing_bindings( +def test_init_agent_newly_created_with_target_picks_up_pre_existing_bindings( client: TestClient, ) -> None: """Bindings can pre-exist the agent row. @@ -154,7 +180,7 @@ def test_initAgent_newly_created_with_target_picks_up_pre_existing_bindings( assert returned_ids == [pre_existing] -def test_initAgent_partial_target_pair_rejected(client: TestClient) -> None: +def test_init_agent_partial_target_pair_rejected(client: TestClient) -> None: agent_name = f"agent-{uuid.uuid4().hex[:12]}" payload = _agent_payload(agent_name) payload["target_type"] = "env" # target_id omitted @@ -162,12 +188,28 @@ def test_initAgent_partial_target_pair_rejected(client: TestClient) -> None: assert resp.status_code == 422 +def test_init_agent_with_target_requires_target_read_authorization( + client: TestClient, +) -> None: + authorizer = RecordingAuthorizer() + set_authorizer(authorizer) + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + + body = _register_agent(client, agent_name, target_type="env", target_id="prod") + + assert body["created"] is True + assert ( + Operation.CONTROL_BINDINGS_READ, + {"target_type": "env", "target_id": "prod"}, + ) in authorizer.calls + + # --------------------------------------------------------------------------- # GET /agents/{name}/controls contract. # --------------------------------------------------------------------------- -def test_get_agent_controls_with_target_matches_initAgent_response( +def test_get_agent_controls_with_target_matches_init_agent_response( client: TestClient, ) -> None: agent_name = f"agent-{uuid.uuid4().hex[:12]}" @@ -200,6 +242,45 @@ def test_get_agent_controls_partial_target_pair_returns_400( assert resp.status_code == 400 +def test_get_agent_controls_with_target_requires_target_read_authorization( + client: TestClient, +) -> None: + authorizer = RecordingAuthorizer() + set_authorizer(authorizer) + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + _register_agent(client, agent_name) + authorizer.calls.clear() + + ids = _list_effective_via_get( + client, + agent_name, + target_type="env", + target_id="prod", + ) + + assert ids == [] + assert (Operation.AGENTS_READ, None) in authorizer.calls + assert ( + Operation.CONTROL_BINDINGS_READ, + {"target_type": "env", "target_id": "prod"}, + ) in authorizer.calls + + +def test_get_agent_controls_rejects_target_namespace_mismatch( + client: TestClient, +) -> None: + set_authorizer(RecordingAuthorizer(target_namespace_key="other-ns")) + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + _register_agent(client, agent_name) + + resp = client.get( + f"/api/v1/agents/{agent_name}/controls", + params={"target_type": "env", "target_id": "prod"}, + ) + + assert resp.status_code == 403, resp.text + + def test_get_agent_controls_no_target_omits_target_bindings( client: TestClient, ) -> None: @@ -243,11 +324,10 @@ async def _insert_agent_in_namespace(async_db, *, name: str, namespace_key: str) await async_db.commit() -import pytest # noqa: E402 (kept local; the rest of the file is sync) - - @pytest.mark.asyncio -async def test_get_agent_controls_cross_namespace_returns_404(client: TestClient, async_db) -> None: +async def test_get_agent_controls_cross_namespace_returns_404( + client: TestClient, async_db +) -> None: """Agent existing only in another namespace must not surface here. The merged-resolver contract is namespace-scoped end-to-end; if the From d9ea19ee9b1aa53428eee400c8e67de393c26bc0 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 15:20:43 +0530 Subject: [PATCH 07/12] docs(server): clarify upstream auth failure mapping --- docs/auth.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/auth.md b/docs/auth.md index 5002faa8..7aafd2ad 100644 --- a/docs/auth.md +++ b/docs/auth.md @@ -120,7 +120,8 @@ Status handling: | `403` | Forbidden error. | | `404` | Not found error. | | `429` | `503` with a rate-limit detail and `Retry-After` hint when present. | -| Other statuses or malformed JSON | Fail closed with `503` or `502`. | +| Other statuses or upstream network errors | Fail closed with `503`. | +| Malformed `200` principal response | Fail closed with `502`. | ## Runtime JWT Claims From 9561df45d905c7d7e0bb7e4f8be9030eab59a1a6 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 15:50:43 +0530 Subject: [PATCH 08/12] docs(server): explain target principal authorization --- .../agent_control_server/endpoints/agents.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index 57ca1ebc..1b380026 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -121,7 +121,20 @@ async def _authorize_target_read_if_present( request: Request, context: dict[str, str] | None, ) -> Principal | None: - """Require target read authorization before returning target-merged controls.""" + """Require target read authorization before returning target-merged controls. + + Agent endpoints that accept optional target context have two separate + authorization decisions: + + - the endpoint operation itself (for example, ``agents.create``), whose + result is exposed to the route as ``principal``; + - the target binding read (``control_bindings.read``), whose result is + exposed as ``target_principal``. + + Keeping the results separate lets the route verify that the caller's + namespace and the target's resolved namespace agree before merging + target-bound controls into the response. + """ if context is None: return None return await get_authorizer(Operation.CONTROL_BINDINGS_READ).authorize( @@ -545,7 +558,8 @@ async def init_agent( Args: request: Agent metadata and step schemas db: Database session (injected) - principal: Authorized request principal + principal: Authorized request principal for the agent create operation + target_principal: Optional principal from the target binding read check Returns: InitAgentResponse with created flag and the effective controls @@ -1596,7 +1610,8 @@ async def list_agent_controls( target_type: Optional opaque target kind (paired with target_id) target_id: Optional opaque target identifier (paired with target_type) db: Database session (injected) - principal: Authorized request principal + principal: Authorized request principal for the agent read operation + target_principal: Optional principal from the target binding read check Returns: AgentControlsResponse with controls matching the requested state filters From 31a2d02ef07a32aaa11df73d41a3d660ad1bba7f Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 15:59:55 +0530 Subject: [PATCH 09/12] chore(sdk-ts): refresh generated client docs --- sdks/typescript/src/generated/funcs/agents-init.ts | 3 ++- sdks/typescript/src/generated/funcs/agents-list-controls.ts | 3 ++- .../src/generated/models/create-control-binding-request.ts | 2 +- .../src/generated/models/runtime-token-exchange-request.ts | 2 +- sdks/typescript/src/generated/sdk/agents.ts | 6 ++++-- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sdks/typescript/src/generated/funcs/agents-init.ts b/sdks/typescript/src/generated/funcs/agents-init.ts index 7150b2a4..d1136c2f 100644 --- a/sdks/typescript/src/generated/funcs/agents-init.ts +++ b/sdks/typescript/src/generated/funcs/agents-init.ts @@ -51,7 +51,8 @@ import { Result } from "../types/fp.js"; * Args: * request: Agent metadata and step schemas * db: Database session (injected) - * principal: Authorized request principal + * principal: Authorized request principal for the agent create operation + * target_principal: Optional principal from the target binding read check * * Returns: * InitAgentResponse with created flag and the effective controls diff --git a/sdks/typescript/src/generated/funcs/agents-list-controls.ts b/sdks/typescript/src/generated/funcs/agents-list-controls.ts index d1e5b27d..619a45d6 100644 --- a/sdks/typescript/src/generated/funcs/agents-list-controls.ts +++ b/sdks/typescript/src/generated/funcs/agents-list-controls.ts @@ -53,7 +53,8 @@ import { Result } from "../types/fp.js"; * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * principal: Authorized request principal + * principal: Authorized request principal for the agent read operation + * target_principal: Optional principal from the target binding read check * * Returns: * AgentControlsResponse with controls matching the requested state filters diff --git a/sdks/typescript/src/generated/models/create-control-binding-request.ts b/sdks/typescript/src/generated/models/create-control-binding-request.ts index ace9f49b..f4e0c940 100644 --- a/sdks/typescript/src/generated/models/create-control-binding-request.ts +++ b/sdks/typescript/src/generated/models/create-control-binding-request.ts @@ -22,7 +22,7 @@ export type CreateControlBindingRequest = { */ targetId: string; /** - * Opaque attachment kind (caller-defined; e.g. 'env', 'log_stream'). + * Opaque attachment kind (caller-defined; e.g. 'environment', 'session'). */ targetType: string; }; diff --git a/sdks/typescript/src/generated/models/runtime-token-exchange-request.ts b/sdks/typescript/src/generated/models/runtime-token-exchange-request.ts index 65e02bda..e20ed22e 100644 --- a/sdks/typescript/src/generated/models/runtime-token-exchange-request.ts +++ b/sdks/typescript/src/generated/models/runtime-token-exchange-request.ts @@ -14,7 +14,7 @@ export type RuntimeTokenExchangeRequest = { */ targetId: string; /** - * Opaque target kind (e.g., ``log_stream``). + * Opaque target kind (e.g., ``session``). */ targetType: string; }; diff --git a/sdks/typescript/src/generated/sdk/agents.ts b/sdks/typescript/src/generated/sdk/agents.ts index 0a70e128..bed5b41f 100644 --- a/sdks/typescript/src/generated/sdk/agents.ts +++ b/sdks/typescript/src/generated/sdk/agents.ts @@ -80,7 +80,8 @@ export class Agents extends ClientSDK { * Args: * request: Agent metadata and step schemas * db: Database session (injected) - * principal: Authorized request principal + * principal: Authorized request principal for the agent create operation + * target_principal: Optional principal from the target binding read check * * Returns: * InitAgentResponse with created flag and the effective controls @@ -186,7 +187,8 @@ export class Agents extends ClientSDK { * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * principal: Authorized request principal + * principal: Authorized request principal for the agent read operation + * target_principal: Optional principal from the target binding read check * * Returns: * AgentControlsResponse with controls matching the requested state filters From 4b778e3c94bf3f7e616f64b28d23fe55b8967e9c Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Tue, 12 May 2026 13:44:38 +0530 Subject: [PATCH 10/12] fix(server): sanitize jsonvalue openapi variants --- server/src/agent_control_server/main.py | 15 +++++++++++-- server/tests/test_main_lifespan.py | 28 ++++++++++++++++++------- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index a1561e63..364778ef 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -334,6 +334,16 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- ) +JSON_VALUE_SCHEMA_NAMES = ( + "JSONValue", + "JSONValue-Input", + "JSONValue-Output", + "JsonValue", + "JsonValue-Input", + "JsonValue-Output", +) + + # Override OpenAPI to avoid recursive JSONValue schema issues in TS generators. def custom_openapi() -> dict[str, Any]: if app.openapi_schema: @@ -347,8 +357,9 @@ def custom_openapi() -> dict[str, Any]: ) schemas = openapi_schema.get("components", {}).get("schemas", {}) - if "JSONValue" in schemas: - schemas["JSONValue"] = {"description": "Any JSON value"} + for schema_name in JSON_VALUE_SCHEMA_NAMES: + if schema_name in schemas: + schemas[schema_name] = {"description": "Any JSON value"} app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/server/tests/test_main_lifespan.py b/server/tests/test_main_lifespan.py index 5a557743..e6e6f595 100644 --- a/server/tests/test_main_lifespan.py +++ b/server/tests/test_main_lifespan.py @@ -1,5 +1,8 @@ from __future__ import annotations +from fastapi import FastAPI +from fastapi.testclient import TestClient + from agent_control_server import main as main_module from agent_control_server.config import observability_settings, settings from agent_control_server.main import lifespan @@ -8,8 +11,6 @@ register_control_event_sink_factory, unregister_control_event_sink_factory, ) -from fastapi import FastAPI -from fastapi.testclient import TestClient def test_lifespan_initializes_observability_when_enabled(monkeypatch) -> None: @@ -156,11 +157,22 @@ def test_lifespan_skips_observability_when_disabled(monkeypatch) -> None: assert not hasattr(app.state, "event_ingestor") -def test_custom_openapi_replaces_jsonvalue(monkeypatch) -> None: - # Given: a custom openapi generator that includes JSONValue +def test_custom_openapi_replaces_jsonvalue_variants(monkeypatch) -> None: + # Given: a custom openapi generator that includes Pydantic JSONValue schemas + json_value_schema_names = ( + "JSONValue", + "JSONValue-Input", + "JSONValue-Output", + "JsonValue", + "JsonValue-Input", + "JsonValue-Output", + ) + def fake_get_openapi(*, title, version, description, routes): return { - "components": {"schemas": {"JSONValue": {"type": "object"}}}, + "components": { + "schemas": {name: {"type": "object"} for name in json_value_schema_names} + }, "info": {"title": title, "version": version, "description": description}, "paths": {}, } @@ -171,8 +183,10 @@ def fake_get_openapi(*, title, version, description, routes): # When: generating openapi schema = main_module.app.openapi() - # Then: JSONValue is replaced with safe description - assert schema["components"]["schemas"]["JSONValue"]["description"] == "Any JSON value" + # Then: JSONValue schemas are replaced with a non-recursive schema + schemas = schema["components"]["schemas"] + for schema_name in json_value_schema_names: + assert schemas[schema_name] == {"description": "Any JSON value"} def test_custom_openapi_is_cached(monkeypatch) -> None: From d0dd12c16154e41db7b20a987edae0228df50e9c Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Tue, 12 May 2026 15:18:35 +0530 Subject: [PATCH 11/12] fix(server): address runtime auth review feedback --- sdks/typescript/src/generated/models/index.ts | 4 - .../src/generated/models/json-value-input.ts | 40 ---------- .../src/generated/models/json-value-input1.ts | 47 ------------ .../src/generated/models/json-value-output.ts | 44 ----------- .../generated/models/json-value-output1.ts | 48 ------------ sdks/typescript/src/generated/models/step.ts | 30 +++----- .../models/template-definition-input.ts | 14 ++-- .../models/template-definition-output.ts | 11 ++- .../auth_framework/config.py | 6 +- .../auth_framework/providers/local_jwt.py | 38 ++++------ .../endpoints/evaluation.py | 4 +- .../agent_control_server/services/controls.py | 9 ++- server/tests/test_auth_framework.py | 37 +++++++++- server/tests/test_principal_namespace_flow.py | 74 +++++++++++++++++++ .../test_runtime_token_exchange_endpoint.py | 43 ++++++++++- 15 files changed, 201 insertions(+), 248 deletions(-) delete mode 100644 sdks/typescript/src/generated/models/json-value-input.ts delete mode 100644 sdks/typescript/src/generated/models/json-value-input1.ts delete mode 100644 sdks/typescript/src/generated/models/json-value-output.ts delete mode 100644 sdks/typescript/src/generated/models/json-value-output1.ts diff --git a/sdks/typescript/src/generated/models/index.ts b/sdks/typescript/src/generated/models/index.ts index a31abbbe..595a9501 100644 --- a/sdks/typescript/src/generated/models/index.ts +++ b/sdks/typescript/src/generated/models/index.ts @@ -63,10 +63,6 @@ export * from "./init-agent-evaluator-removal.js"; export * from "./init-agent-overwrite-changes.js"; export * from "./init-agent-request.js"; export * from "./init-agent-response.js"; -export * from "./json-value-input.js"; -export * from "./json-value-input1.js"; -export * from "./json-value-output.js"; -export * from "./json-value-output1.js"; export * from "./list-agents-response.js"; export * from "./list-control-bindings-response.js"; export * from "./list-control-versions-response.js"; diff --git a/sdks/typescript/src/generated/models/json-value-input.ts b/sdks/typescript/src/generated/models/json-value-input.ts deleted file mode 100644 index 4f448073..00000000 --- a/sdks/typescript/src/generated/models/json-value-input.ts +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. - */ - -import * as z from "zod/v4-mini"; -import { smartUnion } from "../types/smart-union.js"; - -export type JSONValueInput = - | string - | number - | number - | boolean - | Array - | { [k: string]: JSONValueInput | null }; - -/** @internal */ -export type JSONValueInput$Outbound = - | string - | number - | number - | boolean - | Array - | { [k: string]: JSONValueInput$Outbound | null }; - -/** @internal */ -export const JSONValueInput$outboundSchema: z.ZodMiniType< - JSONValueInput$Outbound, - JSONValueInput -> = smartUnion([ - z.string(), - z.int(), - z.number(), - z.boolean(), - z.array(z.nullable(z.lazy(() => JSONValueInput$outboundSchema))), - z.record(z.string(), z.nullable(z.lazy(() => JSONValueInput$outboundSchema))), -]); - -export function jsonValueInputToJSON(jsonValueInput: JSONValueInput): string { - return JSON.stringify(JSONValueInput$outboundSchema.parse(jsonValueInput)); -} diff --git a/sdks/typescript/src/generated/models/json-value-input1.ts b/sdks/typescript/src/generated/models/json-value-input1.ts deleted file mode 100644 index b613f2e4..00000000 --- a/sdks/typescript/src/generated/models/json-value-input1.ts +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. - */ - -import * as z from "zod/v4-mini"; -import { smartUnion } from "../types/smart-union.js"; -import { - JSONValueInput, - JSONValueInput$Outbound, - JSONValueInput$outboundSchema, -} from "./json-value-input.js"; - -export type JsonValueInput1 = - | string - | number - | number - | boolean - | Array - | { [k: string]: JSONValueInput | null }; - -/** @internal */ -export type JsonValueInput1$Outbound = - | string - | number - | number - | boolean - | Array - | { [k: string]: JSONValueInput$Outbound | null }; - -/** @internal */ -export const JsonValueInput1$outboundSchema: z.ZodMiniType< - JsonValueInput1$Outbound, - JsonValueInput1 -> = smartUnion([ - z.string(), - z.int(), - z.number(), - z.boolean(), - z.array(z.nullable(JSONValueInput$outboundSchema)), - z.record(z.string(), z.nullable(z.lazy(() => JSONValueInput$outboundSchema))), -]); - -export function jsonValueInput1ToJSON( - jsonValueInput1: JsonValueInput1, -): string { - return JSON.stringify(JsonValueInput1$outboundSchema.parse(jsonValueInput1)); -} diff --git a/sdks/typescript/src/generated/models/json-value-output.ts b/sdks/typescript/src/generated/models/json-value-output.ts deleted file mode 100644 index f50e2790..00000000 --- a/sdks/typescript/src/generated/models/json-value-output.ts +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. - */ - -import * as z from "zod/v4-mini"; -import { safeParse } from "../lib/schemas.js"; -import { Result as SafeParseResult } from "../types/fp.js"; -import * as types from "../types/primitives.js"; -import { smartUnion } from "../types/smart-union.js"; -import { SDKValidationError } from "./errors/sdk-validation-error.js"; - -export type JSONValueOutput = - | string - | number - | number - | boolean - | Array - | { [k: string]: JSONValueOutput | null }; - -/** @internal */ -export const JSONValueOutput$inboundSchema: z.ZodMiniType< - JSONValueOutput, - unknown -> = smartUnion([ - types.string(), - types.number(), - types.number(), - types.boolean(), - z.array(types.nullable(z.lazy(() => JSONValueOutput$inboundSchema))), - z.record( - z.string(), - types.nullable(z.lazy(() => JSONValueOutput$inboundSchema)), - ), -]); - -export function jsonValueOutputFromJSON( - jsonString: string, -): SafeParseResult { - return safeParse( - jsonString, - (x) => JSONValueOutput$inboundSchema.parse(JSON.parse(x)), - `Failed to parse 'JSONValueOutput' from JSON`, - ); -} diff --git a/sdks/typescript/src/generated/models/json-value-output1.ts b/sdks/typescript/src/generated/models/json-value-output1.ts deleted file mode 100644 index 877520c3..00000000 --- a/sdks/typescript/src/generated/models/json-value-output1.ts +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. - */ - -import * as z from "zod/v4-mini"; -import { safeParse } from "../lib/schemas.js"; -import { Result as SafeParseResult } from "../types/fp.js"; -import * as types from "../types/primitives.js"; -import { smartUnion } from "../types/smart-union.js"; -import { SDKValidationError } from "./errors/sdk-validation-error.js"; -import { - JSONValueOutput, - JSONValueOutput$inboundSchema, -} from "./json-value-output.js"; - -export type JsonValueOutput1 = - | string - | number - | number - | boolean - | Array - | { [k: string]: JSONValueOutput | null }; - -/** @internal */ -export const JsonValueOutput1$inboundSchema: z.ZodMiniType< - JsonValueOutput1, - unknown -> = smartUnion([ - types.string(), - types.number(), - types.number(), - types.boolean(), - z.array(types.nullable(JSONValueOutput$inboundSchema)), - z.record( - z.string(), - types.nullable(z.lazy(() => JSONValueOutput$inboundSchema)), - ), -]); - -export function jsonValueOutput1FromJSON( - jsonString: string, -): SafeParseResult { - return safeParse( - jsonString, - (x) => JsonValueOutput1$inboundSchema.parse(JSON.parse(x)), - `Failed to parse 'JsonValueOutput1' from JSON`, - ); -} diff --git a/sdks/typescript/src/generated/models/step.ts b/sdks/typescript/src/generated/models/step.ts index 8c3d4468..132cf9c9 100644 --- a/sdks/typescript/src/generated/models/step.ts +++ b/sdks/typescript/src/generated/models/step.ts @@ -3,11 +3,6 @@ */ import * as z from "zod/v4-mini"; -import { - JSONValueInput, - JSONValueInput$Outbound, - JSONValueInput$outboundSchema, -} from "./json-value-input.js"; /** * Runtime payload for an agent step invocation. @@ -16,8 +11,11 @@ export type Step = { /** * Optional context (conversation history, metadata, etc.) */ - context?: { [k: string]: JSONValueInput | null } | null | undefined; - input: JSONValueInput | null; + context?: { [k: string]: any } | null | undefined; + /** + * Any JSON value + */ + input: any; /** * Step name (tool name or model/chain id) */ @@ -25,7 +23,7 @@ export type Step = { /** * Output content for this step (None for pre-checks) */ - output?: JSONValueInput | null | undefined; + output?: any | null | undefined; /** * Step type (e.g., 'tool', 'llm') */ @@ -34,24 +32,20 @@ export type Step = { /** @internal */ export type Step$Outbound = { - context?: { [k: string]: JSONValueInput$Outbound | null } | null | undefined; - input: JSONValueInput$Outbound | null; + context?: { [k: string]: any } | null | undefined; + input: any; name: string; - output?: JSONValueInput$Outbound | null | undefined; + output?: any | null | undefined; type: string; }; /** @internal */ export const Step$outboundSchema: z.ZodMiniType = z.object( { - context: z.optional( - z.nullable( - z.record(z.string(), z.nullable(JSONValueInput$outboundSchema)), - ), - ), - input: z.nullable(JSONValueInput$outboundSchema), + context: z.optional(z.nullable(z.record(z.string(), z.any()))), + input: z.any(), name: z.string(), - output: z.optional(z.nullable(JSONValueInput$outboundSchema)), + output: z.optional(z.nullable(z.any())), type: z.string(), }, ); diff --git a/sdks/typescript/src/generated/models/template-definition-input.ts b/sdks/typescript/src/generated/models/template-definition-input.ts index 61e40755..e27f379e 100644 --- a/sdks/typescript/src/generated/models/template-definition-input.ts +++ b/sdks/typescript/src/generated/models/template-definition-input.ts @@ -4,11 +4,6 @@ import * as z from "zod/v4-mini"; import { remap as remap$ } from "../lib/primitives.js"; -import { - JsonValueInput1, - JsonValueInput1$Outbound, - JsonValueInput1$outboundSchema, -} from "./json-value-input1.js"; import { TemplateParameterDefinition, TemplateParameterDefinition$Outbound, @@ -19,7 +14,10 @@ import { * Reusable template with typed parameters and a JSON definition template. */ export type TemplateDefinitionInput = { - definitionTemplate: JsonValueInput1 | null; + /** + * Any JSON value + */ + definitionTemplate: any; /** * Metadata describing the template itself */ @@ -32,7 +30,7 @@ export type TemplateDefinitionInput = { /** @internal */ export type TemplateDefinitionInput$Outbound = { - definition_template: JsonValueInput1$Outbound | null; + definition_template: any; description?: string | null | undefined; parameters?: | { [k: string]: TemplateParameterDefinition$Outbound } @@ -45,7 +43,7 @@ export const TemplateDefinitionInput$outboundSchema: z.ZodMiniType< TemplateDefinitionInput > = z.pipe( z.object({ - definitionTemplate: z.nullable(JsonValueInput1$outboundSchema), + definitionTemplate: z.any(), description: z.optional(z.nullable(z.string())), parameters: z.optional( z.record(z.string(), TemplateParameterDefinition$outboundSchema), diff --git a/sdks/typescript/src/generated/models/template-definition-output.ts b/sdks/typescript/src/generated/models/template-definition-output.ts index b246dd7d..15cc9140 100644 --- a/sdks/typescript/src/generated/models/template-definition-output.ts +++ b/sdks/typescript/src/generated/models/template-definition-output.ts @@ -8,10 +8,6 @@ import { safeParse } from "../lib/schemas.js"; import { Result as SafeParseResult } from "../types/fp.js"; import * as types from "../types/primitives.js"; import { SDKValidationError } from "./errors/sdk-validation-error.js"; -import { - JsonValueOutput1, - JsonValueOutput1$inboundSchema, -} from "./json-value-output1.js"; import { TemplateParameterDefinition, TemplateParameterDefinition$inboundSchema, @@ -21,7 +17,10 @@ import { * Reusable template with typed parameters and a JSON definition template. */ export type TemplateDefinitionOutput = { - definitionTemplate: JsonValueOutput1 | null; + /** + * Any JSON value + */ + definitionTemplate: any; /** * Metadata describing the template itself */ @@ -38,7 +37,7 @@ export const TemplateDefinitionOutput$inboundSchema: z.ZodMiniType< unknown > = z.pipe( z.object({ - definition_template: types.nullable(JsonValueOutput1$inboundSchema), + definition_template: z.any(), description: z.optional(z.nullable(types.string())), parameters: types.optional( z.record(z.string(), TemplateParameterDefinition$inboundSchema), diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 595c3117..73852248 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -218,7 +218,8 @@ def _build_default_provider() -> RequestAuthorizer: ) ) raise RuntimeError( - f"Unknown {_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'http_upstream'." + f"Unknown {_MODE_ENV}={mode!r}; expected 'none', 'api_key', 'header', " + "or 'http_upstream'." ) @@ -259,7 +260,8 @@ def _resolve_runtime_mode() -> str: if mode == "jwt": return mode raise RuntimeError( - f"Unknown {_RUNTIME_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'jwt'." + f"Unknown {_RUNTIME_MODE_ENV}={mode!r}; expected 'none', 'api_key', " + "'header', or 'jwt'." ) diff --git a/server/src/agent_control_server/auth_framework/providers/local_jwt.py b/server/src/agent_control_server/auth_framework/providers/local_jwt.py index 8620d3b6..3f39e6fd 100644 --- a/server/src/agent_control_server/auth_framework/providers/local_jwt.py +++ b/server/src/agent_control_server/auth_framework/providers/local_jwt.py @@ -4,9 +4,8 @@ ``Authorization`` header, verifies the signature against the runtime secret, checks the token's scope covers the requested operation, and returns a :class:`Principal` carrying the bound target. When a -``context_builder`` on the dependency surfaces ``target_type`` / -``target_id``, the provider also enforces that they match the token's -binding - runtime endpoints get the request-target check for free. +``context_builder`` on the dependency must surface matching +``target_type`` / ``target_id`` values for target-bound tokens. """ from __future__ import annotations @@ -55,25 +54,20 @@ async def authorize( hint="Request a token with the required scope.", ) - if context is not None: - requested_target_type = context.get("target_type") - requested_target_id = context.get("target_id") - if requested_target_type is not None and requested_target_type != claims.target_type: - raise ForbiddenError( - error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, - detail=( - "Runtime token target_type does not match the request." - ), - hint="Re-exchange a token bound to the request target.", - ) - if requested_target_id is not None and requested_target_id != claims.target_id: - raise ForbiddenError( - error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, - detail=( - "Runtime token target_id does not match the request." - ), - hint="Re-exchange a token bound to the request target.", - ) + requested_target_type = context.get("target_type") if context is not None else None + requested_target_id = context.get("target_id") if context is not None else None + if requested_target_type != claims.target_type: + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="Runtime token target_type does not match the request.", + hint="Re-exchange a token bound to the request target.", + ) + if requested_target_id != claims.target_id: + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="Runtime token target_id does not match the request.", + hint="Re-exchange a token bound to the request target.", + ) return Principal( namespace_key=claims.namespace_key, diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index 437af8b5..30779c5c 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -1,5 +1,6 @@ """Evaluation analysis endpoints.""" +import json from dataclasses import dataclass from agent_control_engine.core import ControlEngine @@ -121,7 +122,8 @@ async def _evaluation_context(request: Request) -> dict[str, object]: """Surface target identifiers to the runtime authorizer.""" try: body = await request.json() - except Exception: # noqa: BLE001 malformed JSON, defer to endpoint validation + except (json.JSONDecodeError, UnicodeDecodeError): + _logger.debug("Unable to decode evaluation request body for auth context") return {} if not isinstance(body, dict): return {} diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index e3a5fd26..6c015310 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -134,13 +134,14 @@ async def get_control_or_404( self, control_id: int, *, - namespace_key: str | None = None, + namespace_key: str, for_update: bool = False, ) -> Control: """Load any control row, including soft-deleted controls.""" - stmt = select(Control).where(Control.id == control_id) - if namespace_key is not None: - stmt = stmt.where(Control.namespace_key == namespace_key) + stmt = select(Control).where( + Control.id == control_id, + Control.namespace_key == namespace_key, + ) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 20c58aed..06f1be89 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -742,7 +742,11 @@ async def test_local_jwt_provider_returns_target_bound_principal(): provider = LocalJwtVerifyProvider(secret=_TEST_SECRET) request = _build_request(headers={"Authorization": f"Bearer {token}"}) - principal = await provider.authorize(request, Operation.RUNTIME_USE) + principal = await provider.authorize( + request, + Operation.RUNTIME_USE, + context={"target_type": "log_stream", "target_id": "ls-42"}, + ) assert principal.target_type == "log_stream" assert principal.target_id == "ls-42" @@ -815,10 +819,39 @@ async def test_local_jwt_provider_carries_token_namespace_to_principal(): provider = LocalJwtVerifyProvider(secret=_TEST_SECRET) request = _build_request(headers={"Authorization": f"Bearer {token}"}) - principal = await provider.authorize(request, Operation.RUNTIME_USE) + principal = await provider.authorize( + request, + Operation.RUNTIME_USE, + context={"target_type": "log_stream", "target_id": "ls"}, + ) assert principal.namespace_key == "org-7" +@pytest.mark.asyncio +async def test_local_jwt_provider_rejects_missing_target_context(): + """A target-bound runtime token requires matching request target context.""" + from agent_control_server.auth_framework.providers import LocalJwtVerifyProvider + from agent_control_server.auth_framework.runtime_token import ( + mint_runtime_token, + ) + from agent_control_server.errors import ForbiddenError + + token, _ = mint_runtime_token( + namespace_key="default", + actor_id="a", + target_type="log_stream", + target_id="bound-target", + scopes=("runtime.use",), + secret=_TEST_SECRET, + ttl_seconds=60, + ) + provider = LocalJwtVerifyProvider(secret=_TEST_SECRET) + request = _build_request(headers={"Authorization": f"Bearer {token}"}) + + with pytest.raises(ForbiddenError, match="target_type does not match"): + await provider.authorize(request, Operation.RUNTIME_USE) + + @pytest.mark.asyncio async def test_local_jwt_provider_enforces_target_context_match(): """When the dependency surfaces a target context, the provider enforces it.""" diff --git a/server/tests/test_principal_namespace_flow.py b/server/tests/test_principal_namespace_flow.py index 14d2d874..0ca1bca8 100644 --- a/server/tests/test_principal_namespace_flow.py +++ b/server/tests/test_principal_namespace_flow.py @@ -129,6 +129,80 @@ def test_principal_namespace_scopes_management_and_runtime(app: FastAPI) -> None assert eval_b.json()["is_safe"] is True +def test_principal_namespace_scopes_cross_namespace_writes(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + + assert ns_a.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)).status_code == 200 + assert ns_b.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)).status_code == 200 + + create_control = ns_a.put( + "/api/v1/controls", + json={ + "name": f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + assert create_control.status_code == 200, create_control.text + control_id = int(create_control.json()["control_id"]) + + policy = ns_a.put( + "/api/v1/policies", + json={"name": f"policy-{uuid.uuid4().hex[:12]}"}, + ) + assert policy.status_code == 200, policy.text + policy_id = int(policy.json()["policy_id"]) + + binding = ns_a.put( + "/api/v1/control-bindings/by-key", + json={ + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + "enabled": True, + }, + ) + assert binding.status_code == 200, binding.text + + assert ns_b.patch(f"/api/v1/controls/{control_id}", json={"enabled": False}).status_code == 404 + assert ( + ns_b.put( + f"/api/v1/controls/{control_id}/data", + json={"data": VALID_CONTROL_PAYLOAD}, + ).status_code + == 404 + ) + assert ( + ns_b.put( + "/api/v1/control-bindings/by-key", + json={ + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + "enabled": False, + }, + ).status_code + == 404 + ) + delete_binding = ns_b.post( + "/api/v1/control-bindings/by-key:delete", + json={ + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + }, + ) + assert delete_binding.status_code == 200, delete_binding.text + assert delete_binding.json()["deleted"] is False + assert ns_a.get("/api/v1/control-bindings").json()["bindings"] + + assert ns_b.post(f"/api/v1/agents/{agent_name}/policies/{policy_id}").status_code == 404 + assert ns_b.post(f"/api/v1/agents/{agent_name}/controls/{control_id}").status_code == 404 + + def test_duplicate_control_names_allowed_across_principal_namespaces(app: FastAPI) -> None: set_authorizer(HeaderNamespaceAuthorizer()) diff --git a/server/tests/test_runtime_token_exchange_endpoint.py b/server/tests/test_runtime_token_exchange_endpoint.py index 1b1edae2..0863c9a0 100644 --- a/server/tests/test_runtime_token_exchange_endpoint.py +++ b/server/tests/test_runtime_token_exchange_endpoint.py @@ -172,7 +172,11 @@ async def test_exchange_then_verify_full_round_trip(client: TestClient, runtime_ verify_provider = LocalJwtVerifyProvider(secret=_TEST_SECRET) request = MagicMock() request.headers = {"Authorization": f"Bearer {token}"} - principal = await verify_provider.authorize(request, Operation.RUNTIME_USE) + principal = await verify_provider.authorize( + request, + Operation.RUNTIME_USE, + context={"target_type": "log_stream", "target_id": "ls-99"}, + ) assert principal.target_type == "log_stream" assert principal.target_id == "ls-99" @@ -212,6 +216,37 @@ def test_evaluation_rejects_runtime_jwt_for_wrong_target( assert response.json()["detail"] == "Runtime token target_id does not match the request." +def test_evaluation_rejects_runtime_jwt_without_bound_target_context( + client: TestClient, + runtime_config_enabled, +): + """A target-bound runtime JWT must not authorize a target-less evaluation.""" + stub = _StubExchangeAuthorizer(actor_id="actor-rt", scopes=("runtime.use",)) + clear_authorizers() + set_authorizer(stub) + set_authorizer(LocalJwtVerifyProvider(secret=_TEST_SECRET), operation=Operation.RUNTIME_USE) + + exchange = client.post( + "/api/v1/auth/runtime-token-exchange", + json={"target_type": "log_stream", "target_id": "ls-allowed"}, + ) + assert exchange.status_code == 200, exchange.text + token = exchange.json()["token"] + + response = client.post( + "/api/v1/evaluation", + headers={"Authorization": f"Bearer {token}"}, + json={ + "agent_name": "agent", + "step": {"type": "llm", "name": "step", "input": "hello"}, + "stage": "pre", + }, + ) + + assert response.status_code == 403, response.text + assert response.json()["detail"] == "Runtime token target_type does not match the request." + + def test_exchange_endpoint_502_when_upstream_grant_already_expired( client: TestClient, runtime_config_enabled, @@ -316,7 +351,11 @@ async def authorize(self, request, operation, context=None): verify_provider = LocalJwtVerifyProvider(secret=_TEST_SECRET) req = MagicMock() req.headers = {"Authorization": f"Bearer {token}"} - principal = await verify_provider.authorize(req, Operation.RUNTIME_USE) + principal = await verify_provider.authorize( + req, + Operation.RUNTIME_USE, + context={"target_type": "log_stream", "target_id": "ls-org-a"}, + ) assert principal.namespace_key == "org-A" assert principal.target_id == "ls-org-a" From aa0bccef8375892a9a65f1dc567aa984aac94cc3 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Tue, 12 May 2026 15:48:33 +0530 Subject: [PATCH 12/12] fix(server): route evaluator and observability auth through framework --- docs/auth.md | 3 + .../auth_framework/core.py | 6 +- .../auth_framework/providers/header.py | 3 + .../endpoints/evaluators.py | 5 +- .../endpoints/observability.py | 35 ++++++-- server/src/agent_control_server/main.py | 7 +- server/tests/test_auth.py | 78 ++++++++++++----- server/tests/test_observability_endpoints.py | 83 ++++++++++++++++--- 8 files changed, 170 insertions(+), 50 deletions(-) diff --git a/docs/auth.md b/docs/auth.md index 7aafd2ad..9d2f6efd 100644 --- a/docs/auth.md +++ b/docs/auth.md @@ -17,6 +17,9 @@ policies.update agents.read agents.create agents.update +evaluators.read +observability.read +observability.write control_bindings.read control_bindings.write runtime.token_exchange diff --git a/server/src/agent_control_server/auth_framework/core.py b/server/src/agent_control_server/auth_framework/core.py index 058169de..011c62de 100644 --- a/server/src/agent_control_server/auth_framework/core.py +++ b/server/src/agent_control_server/auth_framework/core.py @@ -55,6 +55,9 @@ class Operation(StrEnum): AGENTS_READ = "agents.read" AGENTS_CREATE = "agents.create" AGENTS_UPDATE = "agents.update" + EVALUATORS_READ = "evaluators.read" + OBSERVABILITY_READ = "observability.read" + OBSERVABILITY_WRITE = "observability.write" RUNTIME_USE = "runtime.use" @@ -109,8 +112,7 @@ async def authorize( request: Request, operation: Operation, context: dict[str, Any] | None = None, - ) -> Principal: - ... + ) -> Principal: ... _default_authorizer: RequestAuthorizer | None = None diff --git a/server/src/agent_control_server/auth_framework/providers/header.py b/server/src/agent_control_server/auth_framework/providers/header.py index 16760768..2d917d91 100644 --- a/server/src/agent_control_server/auth_framework/providers/header.py +++ b/server/src/agent_control_server/auth_framework/providers/header.py @@ -48,6 +48,9 @@ class AccessLevel(Enum): Operation.AGENTS_READ: AccessLevel.AUTHENTICATED, Operation.AGENTS_CREATE: AccessLevel.AUTHENTICATED, Operation.AGENTS_UPDATE: AccessLevel.ADMIN, + Operation.EVALUATORS_READ: AccessLevel.AUTHENTICATED, + Operation.OBSERVABILITY_READ: AccessLevel.AUTHENTICATED, + Operation.OBSERVABILITY_WRITE: AccessLevel.AUTHENTICATED, Operation.RUNTIME_TOKEN_EXCHANGE: AccessLevel.AUTHENTICATED, Operation.RUNTIME_USE: AccessLevel.AUTHENTICATED, } diff --git a/server/src/agent_control_server/endpoints/evaluators.py b/server/src/agent_control_server/endpoints/evaluators.py index a9cdaa2a..6bbeddfc 100644 --- a/server/src/agent_control_server/endpoints/evaluators.py +++ b/server/src/agent_control_server/endpoints/evaluators.py @@ -3,9 +3,11 @@ from typing import Any from agent_control_engine import list_evaluators -from fastapi import APIRouter +from fastapi import APIRouter, Depends from pydantic import BaseModel, Field +from ..auth_framework import Operation, require_operation + router = APIRouter(prefix="/evaluators", tags=["evaluators"]) @@ -25,6 +27,7 @@ class EvaluatorInfo(BaseModel): response_model=dict[str, EvaluatorInfo], summary="List available evaluators", response_description="Dictionary of evaluator name to evaluator info", + dependencies=[Depends(require_operation(Operation.EVALUATORS_READ))], ) async def get_evaluators() -> dict[str, EvaluatorInfo]: """List all available evaluators. diff --git a/server/src/agent_control_server/endpoints/observability.py b/server/src/agent_control_server/endpoints/observability.py index 5de90c0a..3296ca1c 100644 --- a/server/src/agent_control_server/endpoints/observability.py +++ b/server/src/agent_control_server/endpoints/observability.py @@ -5,7 +5,7 @@ 2. Event queries (POST /events/query) - Query raw events by trace_id, etc. 3. Stats (GET /stats) - Aggregated statistics for dashboards -All endpoints require API key authentication. +All endpoints declare operation-based auth dependencies. Dependencies are stored on app.state during server lifespan (see main.py): - app.state.event_ingestor: EventIngestor @@ -27,7 +27,7 @@ ) from fastapi import APIRouter, Depends, Request -from ..auth import require_api_key +from ..auth_framework import Operation, require_operation from ..observability.ingest.base import EventIngestor from ..observability.store.base import ( EventStore, @@ -42,7 +42,6 @@ router = APIRouter( prefix="/observability", tags=["observability"], - dependencies=[Depends(require_api_key)], ) @@ -72,7 +71,12 @@ def get_event_store(request: Request) -> EventStore: # ============================================================================= -@router.post("/events", status_code=202, response_model=BatchEventsResponse) +@router.post( + "/events", + status_code=202, + response_model=BatchEventsResponse, + dependencies=[Depends(require_operation(Operation.OBSERVABILITY_WRITE))], +) async def ingest_events( request: BatchEventsRequest, ingestor: EventIngestor = Depends(get_event_ingestor), @@ -121,7 +125,11 @@ async def ingest_events( # ============================================================================= -@router.post("/events/query", response_model=EventQueryResponse) +@router.post( + "/events/query", + response_model=EventQueryResponse, + dependencies=[Depends(require_operation(Operation.OBSERVABILITY_READ))], +) async def query_events( request: EventQueryRequest, store: EventStore = Depends(get_event_store), @@ -158,7 +166,11 @@ async def query_events( # ============================================================================= -@router.get("/stats", response_model=StatsResponse) +@router.get( + "/stats", + response_model=StatsResponse, + dependencies=[Depends(require_operation(Operation.OBSERVABILITY_READ))], +) async def get_stats( agent_name: str, time_range: TimeRange = "5m", @@ -207,7 +219,11 @@ async def get_stats( ) -@router.get("/stats/controls/{control_id}", response_model=ControlStatsResponse) +@router.get( + "/stats/controls/{control_id}", + response_model=ControlStatsResponse, + dependencies=[Depends(require_operation(Operation.OBSERVABILITY_READ))], +) async def get_control_stats( control_id: int, agent_name: str, @@ -266,7 +282,10 @@ async def get_control_stats( # ============================================================================= -@router.get("/status") +@router.get( + "/status", + dependencies=[Depends(require_operation(Operation.OBSERVABILITY_READ))], +) async def get_status(request: Request) -> dict: """ Get observability system status. diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index 364778ef..005923aa 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -17,7 +17,7 @@ from starlette_exporter import PrometheusMiddleware, handle_metrics from . import __version__ as server_version -from .auth import get_api_key_from_header, require_api_key +from .auth import get_api_key_from_header from .config import observability_settings, settings from .db import AsyncSessionLocal from .endpoints.agents import router as agent_router @@ -314,17 +314,16 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- dependencies=[Depends(get_api_key_from_header)], ) -# Evaluator discovery still uses the local credential dependency. app.include_router( evaluator_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) -# Observability routes (already has auth dependency in router) app.include_router( observability_router, prefix=api_v1_prefix, + dependencies=[Depends(get_api_key_from_header)], ) # System routes (config, login, logout) - no auth required diff --git a/server/tests/test_auth.py b/server/tests/test_auth.py index 44f2de27..fba5088c 100644 --- a/server/tests/test_auth.py +++ b/server/tests/test_auth.py @@ -1,16 +1,36 @@ """Tests for API key authentication.""" import uuid +from typing import Any import pytest +from fastapi import Request from fastapi.testclient import TestClient from agent_control_server import __version__ as server_version +from agent_control_server.auth_framework import Operation, Principal, set_authorizer from agent_control_server.config import auth_settings from .utils import VALID_CONTROL_PAYLOAD +class _RecordingAuthorizer: + """Test authorizer that records the operation requested by a route.""" + + def __init__(self) -> None: + self.calls: list[tuple[Operation, dict[str, Any] | None]] = [] + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request + self.calls.append((operation, context)) + return Principal(namespace_key="default") + + class TestHealthEndpoint: """Health endpoint should always be accessible without authentication.""" @@ -40,9 +60,7 @@ class TestProtectedEndpoints: def test_missing_api_key_returns_401(self, unauthenticated_client: TestClient) -> None: """Given no API key, when requesting protected endpoint, then returns 401.""" # When: - response = unauthenticated_client.get( - "/api/v1/agents/00000000-0000-0000-0000-000000000000" - ) + response = unauthenticated_client.get("/api/v1/agents/00000000-0000-0000-0000-000000000000") # Then: assert response.status_code == 401 @@ -111,6 +129,20 @@ def test_missing_key_returns_401_on_evaluators( # Then: assert response.status_code == 401 + def test_evaluators_use_auth_framework_provider(self, app: object) -> None: + """Given a custom authorizer, when listing evaluators, then route uses it.""" + # Given: + authorizer = _RecordingAuthorizer() + set_authorizer(authorizer) + client = TestClient(app, raise_server_exceptions=True) + + # When: + response = client.get("/api/v1/evaluators") + + # Then: + assert response.status_code == 200 + assert authorizer.calls == [(Operation.EVALUATORS_READ, None)] + class TestAuthDisabled: """When auth is disabled, all requests should succeed.""" @@ -120,21 +152,15 @@ def disable_auth(self, monkeypatch: pytest.MonkeyPatch) -> None: """Disable auth for tests in this class.""" monkeypatch.setattr(auth_settings, "api_key_enabled", False) - def test_no_key_allowed_when_disabled( - self, unauthenticated_client: TestClient - ) -> None: + def test_no_key_allowed_when_disabled(self, unauthenticated_client: TestClient) -> None: """Given auth disabled, when requesting without API key, then request succeeds.""" # When: - response = unauthenticated_client.get( - "/api/v1/agents/00000000-0000-0000-0000-000000000000" - ) + response = unauthenticated_client.get("/api/v1/agents/00000000-0000-0000-0000-000000000000") # Then: (404 for non-existent resource, but NOT 401) assert response.status_code == 404 - def test_evaluators_accessible_when_disabled( - self, unauthenticated_client: TestClient - ) -> None: + def test_evaluators_accessible_when_disabled(self, unauthenticated_client: TestClient) -> None: """Given auth disabled, when listing evaluators without API key, then returns 200.""" # When: response = unauthenticated_client.get("/api/v1/evaluators") @@ -264,9 +290,7 @@ def test_admin_key_allowed_on_representative_mutations(self, admin_client: TestC init_response = admin_client.post("/api/v1/agents/initAgent", json=init_payload) assert init_response.status_code == 200 - set_policy_response = admin_client.post( - f"/api/v1/agents/{agent_name}/policy/{policy_id}" - ) + set_policy_response = admin_client.post(f"/api/v1/agents/{agent_name}/policy/{policy_id}") assert set_policy_response.status_code == 200 @@ -344,9 +368,7 @@ def setup_no_keys(self, monkeypatch: pytest.MonkeyPatch) -> None: def test_misconfigured_returns_500(self, unauthenticated_client: TestClient) -> None: """Given auth enabled but no keys configured, when requesting, then returns 500.""" # When: - response = unauthenticated_client.get( - "/api/v1/agents/00000000-0000-0000-0000-000000000000" - ) + response = unauthenticated_client.get("/api/v1/agents/00000000-0000-0000-0000-000000000000") # Then: assert response.status_code == 500 @@ -360,6 +382,7 @@ class TestOptionalApiKey: def _make_optional_app(self) -> TestClient: from fastapi import Depends, FastAPI + from agent_control_server.auth import optional_api_key app = FastAPI() @@ -374,7 +397,9 @@ def maybe_auth(client=Depends(optional_api_key)) -> dict[str, object]: return TestClient(app) - def test_optional_api_key_auth_disabled_returns_none(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_optional_api_key_auth_disabled_returns_none( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Given: auth disabled monkeypatch.setattr(auth_settings, "api_key_enabled", False) @@ -386,7 +411,9 @@ def test_optional_api_key_auth_disabled_returns_none(self, monkeypatch: pytest.M assert response.status_code == 200 assert response.json()["auth"] is False - def test_optional_api_key_missing_header_returns_none(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_optional_api_key_missing_header_returns_none( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Given: auth enabled with configured keys monkeypatch.setattr(auth_settings, "api_key_enabled", True) monkeypatch.setattr(auth_settings, "api_keys", "user-key") @@ -402,7 +429,9 @@ def test_optional_api_key_missing_header_returns_none(self, monkeypatch: pytest. assert response.status_code == 200 assert response.json()["auth"] is False - def test_optional_api_key_invalid_header_returns_none(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_optional_api_key_invalid_header_returns_none( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Given: auth enabled with configured keys monkeypatch.setattr(auth_settings, "api_key_enabled", True) monkeypatch.setattr(auth_settings, "api_keys", "user-key") @@ -418,7 +447,9 @@ def test_optional_api_key_invalid_header_returns_none(self, monkeypatch: pytest. assert response.status_code == 200 assert response.json()["auth"] is False - def test_optional_api_key_admin_header_sets_admin(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_optional_api_key_admin_header_sets_admin( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Given: auth enabled with admin key monkeypatch.setattr(auth_settings, "api_key_enabled", True) monkeypatch.setattr(auth_settings, "api_keys", "user-key") @@ -449,6 +480,7 @@ def test_require_admin_key_rejects_non_admin( # When: requiring admin key on an endpoint from fastapi import Depends, FastAPI + from agent_control_server.auth import require_admin_key local_app = FastAPI() @@ -483,6 +515,7 @@ def test_authenticated_client_key_id_masks_short_key(self) -> None: def test_get_api_key_from_header_extracts_value(self) -> None: # Given: a route that returns raw API key header from fastapi import Depends, FastAPI + from agent_control_server.auth import get_api_key_from_header app = FastAPI() @@ -503,6 +536,7 @@ def raw_key(key: str | None = Depends(get_api_key_from_header)) -> dict[str, str def test_get_api_key_from_header_allows_missing(self) -> None: # Given: a route that returns raw API key header from fastapi import Depends, FastAPI + from agent_control_server.auth import get_api_key_from_header app = FastAPI() diff --git a/server/tests/test_observability_endpoints.py b/server/tests/test_observability_endpoints.py index 476cf00c..97fc0f7c 100644 --- a/server/tests/test_observability_endpoints.py +++ b/server/tests/test_observability_endpoints.py @@ -2,17 +2,20 @@ import json from datetime import datetime, timedelta, timezone +from typing import Any from uuid import UUID, uuid4 import pytest -from fastapi.testclient import TestClient -from sqlalchemy import text - from agent_control_models import ( BatchEventsRequest, ControlExecutionEvent, EventQueryRequest, ) +from fastapi import Request +from fastapi.testclient import TestClient +from sqlalchemy import text + +from agent_control_server.auth_framework import Operation, Principal, set_authorizer from agent_control_server.main import app from agent_control_server.observability.ingest.base import IngestResult @@ -42,6 +45,64 @@ def create_test_event( ) +class _RecordingAuthorizer: + """Test authorizer that records the operation requested by a route.""" + + def __init__(self) -> None: + self.calls: list[tuple[Operation, dict[str, Any] | None]] = [] + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request + self.calls.append((operation, context)) + return Principal(namespace_key="default") + + +class TestObservabilityAuthFramework: + """Tests observability routes declare operation-based authorization.""" + + def test_status_uses_read_operation(self, app: object) -> None: + """Given a custom authorizer, when getting status, then read is authorized.""" + # Given: + authorizer = _RecordingAuthorizer() + set_authorizer(authorizer) + client = TestClient(app, raise_server_exceptions=True) + + # When: + response = client.get("/api/v1/observability/status") + + # Then: + assert response.status_code == 200 + assert authorizer.calls == [(Operation.OBSERVABILITY_READ, None)] + + def test_ingest_events_uses_write_operation( + self, + app: object, + setup_observability: object, + ) -> None: + """Given a custom authorizer, when ingesting events, then write is authorized.""" + # Given: + _ = setup_observability + authorizer = _RecordingAuthorizer() + set_authorizer(authorizer) + client = TestClient(app, raise_server_exceptions=True) + request = BatchEventsRequest(events=[create_test_event()]) + + # When: + response = client.post( + "/api/v1/observability/events", + json=request.model_dump(mode="json"), + ) + + # Then: + assert response.status_code == 202 + assert authorizer.calls == [(Operation.OBSERVABILITY_WRITE, None)] + + class TestEventIngestion: """Tests for POST /events endpoint.""" @@ -155,7 +216,7 @@ def test_event_with_all_fields(self): event = ControlExecutionEvent( trace_id="a" * 32, span_id="b" * 16, - agent_name="test-agent", + agent_name="test-agent", control_id=1, control_name="test-control", check_stage="post", @@ -441,9 +502,7 @@ async def test_timeseries_aggregates_events_per_bucket( total_exec = sum(b["execution_count"] for b in buckets_with_events) total_match = sum(b["match_count"] for b in buckets_with_events) total_non_match = sum(b["non_match_count"] for b in buckets_with_events) - total_observe = sum( - b["action_counts"].get("observe", 0) for b in buckets_with_events - ) + total_observe = sum(b["action_counts"].get("observe", 0) for b in buckets_with_events) total_deny = sum(b["action_counts"].get("deny", 0) for b in buckets_with_events) assert total_exec == 3 @@ -453,9 +512,7 @@ async def test_timeseries_aggregates_events_per_bucket( assert total_deny == 1 @pytest.mark.asyncio - async def test_timeseries_empty_buckets_included( - self, client: TestClient, setup_observability - ): + async def test_timeseries_empty_buckets_included(self, client: TestClient, setup_observability): """Empty buckets are included with zero counts.""" store = setup_observability agent_name = f"agent-{uuid4().hex[:12]}" @@ -594,9 +651,7 @@ async def test_control_stats_with_timeseries(self, client: TestClient, setup_obs assert data["stats"]["execution_count"] == 2 # Sum timeseries buckets should equal total - total_from_buckets = sum( - b["execution_count"] for b in data["stats"]["timeseries"] - ) + total_from_buckets = sum(b["execution_count"] for b in data["stats"]["timeseries"]) assert total_from_buckets == 2 @pytest.mark.asyncio @@ -797,6 +852,7 @@ class TestObservabilityIngestStatus: def test_ingest_events_partial_status(self, client: TestClient, setup_observability): """Test partial status when some events are dropped.""" + # Given: a stub ingestor that drops some events class StubIngestor: async def ingest(self, events): @@ -826,6 +882,7 @@ async def ingest(self, events): def test_ingest_events_failed_status(self, client: TestClient, setup_observability): """Test failed status when all events are dropped.""" + # Given: a stub ingestor that drops all events class StubIngestor: async def ingest(self, events):