diff --git a/sdk/agentserver/.gitignore b/sdk/agentserver/.gitignore index 89f79044a692..9eee79d740a8 100644 --- a/sdk/agentserver/.gitignore +++ b/sdk/agentserver/.gitignore @@ -2,3 +2,4 @@ specs/ .specify/ .github/ +.vscode/ diff --git a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md index ab460125bd8c..c188328c5657 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md @@ -4,6 +4,42 @@ ### Features Added +- **Durable long-running agents** — New `@durable_task` decorator and supporting types for building crash-resilient, long-running agents that survive container crashes, OOM kills, and redeployments. Key capabilities: + - **Lifecycle automation** — `.run()` and `.start()` automatically start, resume, or recover tasks based on their current state in the task store. + - **Entry mode awareness** — `ctx.entry_mode` tells the function whether it was entered `"fresh"`, `"resumed"` from suspension, or `"recovered"` from a crash. + - **Suspend & resume** — `ctx.suspend(output=..., reason=...)` pauses execution for multi-turn agent patterns (e.g., waiting for user input). + - **TaskResult wrapper** — `run()` and `result()` return `TaskResult[Output]` with `.is_completed` / `.is_suspended` properties, making suspension a normal return value instead of an exception. + - **Streaming** — `ctx.stream(chunk)` emits incremental output; consumers iterate with `async for chunk in task_run`. + - **Cancellation & timeout** — Cooperative cancel via `ctx.cancel` event, configurable `timeout`, and `terminate()` for forced shutdown. + - **RetryPolicy** — Configurable retry with factory presets: `.exponential_backoff()`, `.fixed_delay()`, `.linear_backoff()`, `.no_retry()`. + - **Source auto-stamping** — The framework automatically stamps every task with provenance metadata: `type` (`agentserver.durable_task`), `name` (the decorator `name` option — the stable identity anchor), and `server_version` (the `x-platform-server` header value). Source is framework-owned and not user-overridable. A reserved tag `_durable_task_name` is also auto-stamped for LIST API filtering by function name. + - **Callable factories** — `tags`, `title`, and `description` accept `Callable[[Input, task_id], T]` for dynamic metadata computed at task creation time. + - **TaskMetadata** — Dict-like mutable progress metadata (`ctx.metadata["key"] = value`) with debounced auto-flush to the task store. Supports `[]`, `in`, `for`, `len`, `del`, plus convenience methods `.increment()` and `.append()`. + - **Handle operations** — `TaskRun.metadata` for progress snapshot reads, `TaskRun.delete()` for task cleanup, `TaskRun.refresh()` for re-fetching state from the store, `TaskRun.lease_expiry_count` for monitoring ownership churn. + - **TaskContext.description** — `ctx.description` exposes the task description string within the running function. + - **Configurable shutdown grace** — `DurableTaskManager(shutdown_grace_seconds=25.0)` controls how long the manager waits for tasks to checkpoint before force-expiring leases during shutdown. + - **Task listing** — `my_task.list(status=...)` returns all tasks for a specific durable task function, automatically scoped by function name (via tag) and source type. Supports `status` and `session_id` filters. +- **Steerable durable tasks** — New `steerable=True` parameter on `@durable_task` enables mid-flight steering where new inputs can be queued while a task is still running. Key capabilities: + - **Input queue** — `start()` on an in-progress steerable task queues the new input and returns a `TaskRun` handle immediately, instead of raising `TaskConflictError`. + - **Cancel signal** — `ctx.cancel` is automatically set when new inputs arrive, giving the function a cooperative signal to short-circuit. + - **Automatic drain** — The framework drains the queue after the function suspends or completes, re-entering with the next queued input using `entry_mode="resumed"` and `was_steered=True`. + - **Superseded results** — Previous generation's `TaskRun.result()` resolves with `status="superseded"` and `is_superseded=True`. + - **Context enrichment** — `ctx.was_steered`, `ctx.previous_input`, `ctx.pending_inputs`, and `ctx.generation` provide full steering context. + - **Queue limits** — `max_pending` (default 10) prevents unbounded queue growth; raises `SteeringQueueFull` when exceeded. + - **Crash recovery** — `drain_in_progress` flag in persisted state enables recovery from mid-drain crashes. + - **Distributed steering** — Lease renewal loop polls for pending inputs from other processes and sets `ctx.cancel` accordingly. + - **Etag-aware completion** — Steerable tasks use optimistic concurrency on completion to detect concurrent steering. + +### Breaking Changes + +- **`source` parameter removed** — The `source` keyword argument has been removed from `@durable_task()`, `.run()`, `.start()`, and `.options()`. Source provenance is now auto-stamped by the framework and cannot be overridden by developers. Use `tags` for custom metadata. + +### Bugs Fixed + +- **Local provider payload merge** — Fixed `_local_provider.py` to use strict shallow merge per Protocol Spec §11: root-level keys are now always replaced, not recursively merged. Previously nested dicts were merged with `dict.update()`, which was more forgiving than the real Task Storage API. +- **Task recovery routing** — `_find_resume_callback()` now matches by `source.name` (the auto-stamped function name) first, then falls back to title prefix match. Previously relied only on fragile title prefix heuristic. + +### Other Changes - Added `_platform_headers` module with cross-cutting protocol header name constants (`x-request-id`, `x-platform-server`, `x-agent-session-id`, `x-platform-error-source`, `x-platform-error-detail`, and others). Protocol packages now import shared header name strings from core instead of maintaining their own copies. - Added `TraceContextMiddleware` — a lightweight pure-ASGI middleware that propagates W3C trace context (`traceparent`, `tracestate`) and baggage from incoming HTTP requests. Any spans created by downstream frameworks (e.g. MAF / agent-framework) are automatically children of the caller's trace without additional framework spans. - Added `enable_sensitive_data` parameter to `configure_observability()` to control whether prompts, tool arguments, and results are recorded in telemetry. Respects `OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT` environment variable. diff --git a/sdk/agentserver/azure-ai-agentserver-core/README.md b/sdk/agentserver/azure-ai-agentserver-core/README.md index add29e0bb57b..bc72ac7400f0 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/README.md +++ b/sdk/agentserver/azure-ai-agentserver-core/README.md @@ -113,6 +113,54 @@ export APPLICATIONINSIGHTS_CONNECTION_STRING="InstrumentationKey=..." python my_agent.py ``` +### Durable long-running agents + +The `@durable_task` decorator builds crash-resilient agents that survive container restarts, OOM kills, and redeployments. Task state is persisted to a task store, enabling automatic recovery and multi-turn suspend/resume patterns. + +```python +from datetime import timedelta +from azure.ai.agentserver.core.durable import durable_task, TaskContext, RetryPolicy + +@durable_task( + timeout=timedelta(minutes=30), + retry=RetryPolicy.exponential_backoff(max_attempts=3), + tags={"priority": "high"}, +) +async def process_document(ctx: TaskContext[dict]) -> dict: + ctx.metadata["phase"] = "processing" + result = await analyze(ctx.input["document_url"]) + ctx.metadata["phase"] = "complete" + return {"summary": result} +``` + +**Start and await a task:** + +```python +result = await process_document.run(task_id="doc-42", input={"document_url": "..."}) +print(result.output) # {"summary": "..."} +``` + +**Multi-turn suspend/resume (e.g., conversational agents):** + +```python +@durable_task() +async def chat_session(ctx: TaskContext[dict]) -> dict: + message = ctx.input["message"] + history = ctx.metadata.get("history", []) + reply = await generate_reply(message, history) + history.append({"role": "user", "content": message}) + history.append({"role": "assistant", "content": reply}) + ctx.metadata["history"] = history + return await ctx.suspend(output={"reply": reply}) + +# Each call resumes the same session: +result = await chat_session.run(task_id="session-1", input={"message": "Hello"}) +print(result.output) # {"reply": "Hi! How can I help?"} +print(result.is_suspended) # True +``` + +See the [Developer Guide](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/agentserver/azure-ai-agentserver-core/docs/durable-task-developer-guide.md) for the full API reference. + ## Troubleshooting ### Logging @@ -130,6 +178,7 @@ To report an issue with the client library, or request additional features, plea ## Next steps - Install [`azure-ai-agentserver-invocations`](https://pypi.org/project/azure-ai-agentserver-invocations/) to add the invocation protocol endpoints. +- Read the [Durable Task Developer Guide](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/agentserver/azure-ai-agentserver-core/docs/durable-task-developer-guide.md) for crash-resilient long-running agents. - See the [container image spec](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/agentserver) for the full hosted agent contract. ## Contributing diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py index d360a00966a8..9e034a69d087 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py @@ -21,6 +21,7 @@ trace_stream, ) """ + __path__ = __import__("pkgutil").extend_path(__path__, __name__) from ._base import AgentServerHost diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py index 84a7ccd06c24..2fb8d4e45588 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py @@ -244,6 +244,26 @@ async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF protocols, ) + # --- Durable task manager auto-initialization --- + durable_manager = None + try: + from .durable._manager import ( # pylint: disable=import-outside-toplevel + DurableTaskManager, + set_task_manager, + ) + + durable_manager = DurableTaskManager( + config=cfg, + shutdown_event=asyncio.Event(), + ) + set_task_manager(durable_manager) + await durable_manager.startup() + logger.info("DurableTaskManager initialized automatically") + except ImportError: + pass # durable module not available + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to initialize DurableTaskManager", exc_info=True) + yield # --- SHUTDOWN: runs once when the server is stopping --- @@ -251,6 +271,20 @@ async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF "AgentServerHost shutting down (graceful timeout=%ss)", self._graceful_shutdown_timeout, ) + + # Shutdown durable task manager + if durable_manager is not None: + try: + await durable_manager.shutdown() + from .durable._manager import ( # pylint: disable=import-outside-toplevel + set_task_manager as _clear_manager, + ) + + _clear_manager(None) + logger.info("DurableTaskManager shut down") + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Error shutting down DurableTaskManager", exc_info=True) + if self._graceful_shutdown_timeout == 0: logger.info("Graceful shutdown drain period disabled (timeout=0)") else: diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_config.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_config.py index f1d8dd3db86f..22113f413bc9 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_config.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_config.py @@ -129,7 +129,8 @@ def from_env(cls) -> Self: session_id=os.environ.get(_ENV_FOUNDRY_AGENT_SESSION_ID, ""), port=resolve_port(None), appinsights_connection_string=os.environ.get( - _ENV_APPLICATIONINSIGHTS_CONNECTION_STRING, ""), + _ENV_APPLICATIONINSIGHTS_CONNECTION_STRING, "" + ), otlp_endpoint=os.environ.get(_ENV_OTEL_EXPORTER_OTLP_ENDPOINT, ""), sse_keepalive_interval=resolve_sse_keepalive_interval(None), ws_ping_interval=resolve_ws_ping_interval(), @@ -168,9 +169,7 @@ def _require_int(name: str, value: object) -> int: :raises ValueError: If *value* is not an integer. """ if isinstance(value, bool) or not isinstance(value, int): - raise ValueError( - f"Invalid value for {name}: {value!r} (expected an integer)" - ) + raise ValueError(f"Invalid value for {name}: {value!r} (expected an integer)") return value @@ -186,9 +185,7 @@ def _validate_port(value: int, source: str) -> int: :raises ValueError: If the port is outside 1-65535. """ if not 1 <= value <= 65535: - raise ValueError( - f"Invalid value for {source}: {value} (expected 1-65535)" - ) + raise ValueError(f"Invalid value for {source}: {value} (expected 1-65535)") return value @@ -249,9 +246,7 @@ def resolve_appinsights_connection_string( """ if connection_string is not None: return connection_string - return os.environ.get( - _ENV_APPLICATIONINSIGHTS_CONNECTION_STRING - ) + return os.environ.get(_ENV_APPLICATIONINSIGHTS_CONNECTION_STRING) def resolve_log_level(level: Optional[str]) -> str: diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_errors.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_errors.py index c5b1c9e01efe..9268e24df81c 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_errors.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_errors.py @@ -58,6 +58,4 @@ def create_error_response( body["type"] = error_type if details is not None: body["details"] = details - return JSONResponse( - {"error": body}, status_code=status_code, headers=headers - ) + return JSONResponse({"error": body}, status_code=status_code, headers=headers) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py index 4fb3fe78a9cd..63b0d320a771 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py @@ -76,7 +76,9 @@ def _get_trace_id(headers: list[tuple[bytes, bytes]] | None = None) -> str | Non :rtype: str | None """ try: - from opentelemetry import trace as _trace # pylint: disable=import-outside-toplevel + from opentelemetry import ( + trace as _trace, + ) # pylint: disable=import-outside-toplevel span = _trace.get_current_span() ctx = span.get_span_context() @@ -147,7 +149,10 @@ async def _send_wrapper(message: MutableMapping[str, Any]) -> None: elapsed_ms = (time.monotonic() - start) * 1000 logger.warning( "Inbound %s %s failed with status 500 in %.1fms%s", - method, path, elapsed_ms, extra_str, + method, + path, + elapsed_ms, + extra_str, ) raise @@ -156,10 +161,18 @@ async def _send_wrapper(message: MutableMapping[str, Any]) -> None: if status_code is not None and status_code >= 400: logger.warning( "Inbound %s %s completed with status %d in %.1fms%s", - method, path, status_code, elapsed_ms, extra_str, + method, + path, + status_code, + elapsed_ms, + extra_str, ) else: logger.info( "Inbound %s %s completed with status %s in %.1fms%s", - method, path, status_code, elapsed_ms, extra_str, + method, + path, + status_code, + elapsed_ms, + extra_str, ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/__init__.py new file mode 100644 index 000000000000..53b1935f68c6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/__init__.py @@ -0,0 +1,89 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Durable task subsystem for crash-resilient long-running agents. + +Provides the :func:`durable_task` decorator and supporting types for +building Azure AI Hosted Agents that survive container crashes, +OOM kills, and redeployments. + +Key features: + +- **Lifecycle automation** — ``.run()`` and ``.start()`` automatically + start, resume, or recover tasks based on their current state. +- **Entry mode** — ``ctx.entry_mode`` tells the function whether it was + entered fresh, resumed from suspension, or recovered from a crash. +- **RetryPolicy** — configurable retry with exponential, fixed, or linear + backoff (see :class:`RetryPolicy` presets). +- **Streaming** — emit incremental output via ``ctx.stream()`` and consume + with ``async for chunk in task_run``. +- **Source tracking** — attach immutable provenance metadata at task + creation time via the ``source`` parameter. + +Public API:: + + from azure.ai.agentserver.core.durable import ( + durable_task, + DurableTask, + RetryPolicy, + TaskContext, + TaskMetadata, + TaskResult, + TaskRun, + Suspended, + TaskStatus, + TaskFailed, + TaskSuspended, + TaskCancelled, + TaskNotFound, + TaskConflictError, + TaskTerminated, + EntryMode, + TaskInfo, + ) +""" + +from ._context import EntryMode, TaskContext +from ._decorator import DurableTask, DurableTaskOptions, durable_task +from ._exceptions import ( + EtagConflict, + SteeringQueueFull, + TaskCancelled, + TaskConflictError, + TaskFailed, + TaskNotFound, + TaskSuspended, + TaskTerminated, +) +from ._metadata import TaskMetadata +from ._models import TaskInfo, TaskStatus +from ._result import TaskResult +from ._retry import RetryPolicy +from ._run import Suspended, TaskRun +from ._stream import QueueStreamHandler, StreamHandler, StreamHandlerFactory + +__all__ = [ + "durable_task", + "DurableTask", + "DurableTaskOptions", + "QueueStreamHandler", + "RetryPolicy", + "StreamHandler", + "StreamHandlerFactory", + "TaskContext", + "TaskMetadata", + "TaskResult", + "TaskRun", + "Suspended", + "TaskStatus", + "TaskFailed", + "TaskSuspended", + "TaskCancelled", + "TaskNotFound", + "TaskConflictError", + "TaskTerminated", + "EtagConflict", + "SteeringQueueFull", + "EntryMode", + "TaskInfo", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_client.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_client.py new file mode 100644 index 000000000000..53eccbcabea9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_client.py @@ -0,0 +1,241 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Hosted durable task provider — HTTP client for the Foundry Task Storage API. + +Communicates with ``{FOUNDRY_PROJECT_ENDPOINT}/tasks`` using +``httpx.AsyncClient``. Bearer tokens are obtained lazily from +``DefaultAzureCredential`` when running in a hosted environment. +""" + +from __future__ import annotations + +import logging +from typing import Any + +import httpx + +from ._exceptions import TaskNotFound +from ._models import ( + TaskCreateRequest, + TaskInfo, + TaskPatchRequest, + TaskStatus, +) + +logger = logging.getLogger("azure.ai.agentserver.durable") + +_AUTH_SCOPE = "https://ai.azure.com/.default" +_API_VERSION = "v1" + + +class HostedDurableTaskProvider: + """HTTP-backed provider for the Foundry Task Storage API. + + :param project_endpoint: The ``FOUNDRY_PROJECT_ENDPOINT`` base URL. + :type project_endpoint: str + :param credential: An ``azure.identity.aio.DefaultAzureCredential`` + instance, or any token credential supporting ``get_token(scope)``. + :type credential: Any + """ + + def __init__(self, project_endpoint: str, credential: Any) -> None: + self._base_url = f"{project_endpoint.rstrip('/')}/tasks" + self._credential = credential + self._client = httpx.AsyncClient(timeout=30.0) + + async def _get_headers(self) -> dict[str, str]: + token = await self._credential.get_token(_AUTH_SCOPE) + return { + "Authorization": f"Bearer {token.token}", + "Content-Type": "application/json", + } + + async def create(self, request: TaskCreateRequest) -> TaskInfo: + """Create a new task via POST /tasks. + + :param request: Task creation parameters. + :type request: TaskCreateRequest + :return: The created task record. + :rtype: TaskInfo + """ + headers = await self._get_headers() + params: dict[str, str] = {"api-version": _API_VERSION} + if request.lease_owner is not None: + params["lease_owner"] = request.lease_owner + if request.lease_instance_id is not None: + params["lease_instance_id"] = request.lease_instance_id + if request.lease_duration_seconds is not None: + params["lease_duration_seconds"] = str(request.lease_duration_seconds) + + body: dict[str, Any] = { + "agent_name": request.agent_name, + "session_id": request.session_id, + } + if request.id is not None: + body["id"] = request.id + if request.status != "pending": + body["status"] = request.status + if request.title is not None: + body["title"] = request.title + if request.description is not None: + body["description"] = request.description + if request.payload is not None: + body["payload"] = request.payload + if request.tags is not None: + body["tags"] = request.tags + if request.source is not None: + body["source"] = request.source + + response = await self._client.post( + self._base_url, json=body, headers=headers, params=params + ) + response.raise_for_status() + return TaskInfo.from_dict(response.json()) + + async def get(self, task_id: str) -> TaskInfo | None: + """Get a task by ID via GET /internal/tasks/{id}. + + :param task_id: The task identifier. + :type task_id: str + :return: The task record, or ``None`` if not found. + :rtype: TaskInfo | None + """ + headers = await self._get_headers() + response = await self._client.get( + f"{self._base_url}/{task_id}", + headers=headers, + params={"api-version": _API_VERSION}, + ) + if response.status_code == 404: + return None + response.raise_for_status() + return TaskInfo.from_dict(response.json()) + + async def update(self, task_id: str, patch: TaskPatchRequest) -> TaskInfo: + """Update a task via PATCH /internal/tasks/{id}. + + :param task_id: The task identifier. + :type task_id: str + :param patch: Fields to update. + :type patch: TaskPatchRequest + :return: The updated task record. + :rtype: TaskInfo + :raises TaskNotFound: If the task does not exist. + """ + headers = await self._get_headers() + params: dict[str, str] = {"api-version": _API_VERSION} + if patch.lease_owner is not None: + params["lease_owner"] = patch.lease_owner + if patch.lease_instance_id is not None: + params["lease_instance_id"] = patch.lease_instance_id + if patch.lease_duration_seconds is not None: + params["lease_duration_seconds"] = str(patch.lease_duration_seconds) + + body: dict[str, Any] = {} + if patch.status is not None: + body["status"] = patch.status + if patch.payload is not None: + body["payload"] = patch.payload + if patch.tags is not None: + body["tags"] = patch.tags + if patch.error is not None: + body["error"] = patch.error + if patch.suspension_reason is not None: + body["suspension_reason"] = patch.suspension_reason + + if patch.if_match is not None: + headers["If-Match"] = f'"{patch.if_match}"' + + response = await self._client.patch( + f"{self._base_url}/{task_id}", + json=body, + headers=headers, + params=params, + ) + if response.status_code == 404: + raise TaskNotFound(task_id) + response.raise_for_status() + return TaskInfo.from_dict(response.json()) + + async def delete( + self, + task_id: str, + *, + force: bool = False, + cascade: bool = False, + ) -> None: + """Delete a task via DELETE /internal/tasks/{id}. + + :param task_id: The task identifier. + :type task_id: str + :keyword force: Release active lease before deleting. + :paramtype force: bool + :keyword cascade: Delete dependent tasks. + :paramtype cascade: bool + """ + headers = await self._get_headers() + params: dict[str, str] = {"api-version": _API_VERSION} + if force: + params["force"] = "true" + if cascade: + params["cascade"] = "true" + + response = await self._client.delete( + f"{self._base_url}/{task_id}", + headers=headers, + params=params, + ) + if response.status_code == 404: + raise TaskNotFound(task_id) + response.raise_for_status() + + async def list( + self, + *, + agent_name: str, + session_id: str, + status: TaskStatus | None = None, + lease_owner: str | None = None, + tag: dict[str, str] | None = None, + ) -> list[TaskInfo]: + """List tasks via GET /internal/tasks. + + :keyword agent_name: Filter by agent name. + :paramtype agent_name: str + :keyword session_id: Filter by session ID. + :paramtype session_id: str + :keyword status: Filter by task status. + :paramtype status: TaskStatus | None + :keyword lease_owner: Filter by lease owner. + :paramtype lease_owner: str | None + :keyword tag: Filter by tag key-value pairs. + :paramtype tag: dict[str, str] | None + :return: Matching task records. + :rtype: list[TaskInfo] + """ + headers = await self._get_headers() + params: dict[str, str] = { + "api-version": _API_VERSION, + "agent_name": agent_name, + "session_id": session_id, + } + if status is not None: + params["status"] = status + if lease_owner is not None: + params["lease_owner"] = lease_owner + if tag: + for key, value in tag.items(): + params[f"tag.{key}"] = value + + response = await self._client.get( + self._base_url, headers=headers, params=params + ) + response.raise_for_status() + data = response.json() + items: list[dict[str, Any]] = data.get("data", data.get("items", [])) + return [TaskInfo.from_dict(item) for item in items] + + async def close(self) -> None: + """Close the underlying HTTP client.""" + await self._client.aclose() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_context.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_context.py new file mode 100644 index 000000000000..3d357d429d3b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_context.py @@ -0,0 +1,184 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""TaskContext — the single parameter to a durable task function. + +Provides identity, typed input, mutable metadata, cancellation signals, +and the ``suspend()`` method for pausing execution. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +from typing import Any, Generic, Literal, Sequence, TypeVar + +from ._metadata import TaskMetadata +from ._stream import StreamHandler + +Input = TypeVar("Input") +Output = TypeVar("Output") + +EntryMode = Literal["fresh", "resumed", "recovered"] +"""Why the durable function was entered. + +- ``"fresh"`` — First execution. Task was just created or started from pending. +- ``"resumed"`` — Re-entered after suspension. On developer-initiated resume + (via ``.run()``), ``ctx.input`` contains the new input. On platform-initiated + resume (via ``/tasks/{task_id}/resume``), ``ctx.input`` contains the task's + persisted input. Also used when a steering input drains from the queue — + check ``ctx.was_steered`` to distinguish steering re-entry from normal resume. +- ``"recovered"`` — Re-entered after stale task detection. The previous execution + crashed or timed out. ``ctx.input`` contains the task's persisted input. + If a steerable task crashed mid-drain, ``ctx.was_steered`` will be ``True`` + and steering context (``previous_input``, ``generation``) is meaningful. +""" + + +class _Suspended: + """Internal sentinel for suspended tasks. See ``Suspended`` in ``_run.py``.""" + + __slots__ = ("reason", "output") + + def __init__( + self, + reason: str | None = None, + output: Any | None = None, + ) -> None: + self.reason = reason + self.output = output + + +class TaskContext(Generic[Input]): # pylint: disable=too-many-instance-attributes + """The single parameter to a durable task function. + + Provides access to the task's identity, typed input, mutable metadata + for progress tracking, cancellation signals, and the ability to + suspend execution. + + :param task_id: Unique task identifier. + :type task_id: str + :param title: Human-readable task title. + :type title: str + :param description: Optional task description. + :type description: str | None + :param session_id: Session scope identifier. + :type session_id: str + :param agent_name: Agent name from config. + :type agent_name: str + :param tags: Merged decorator + call-site tags. + :type tags: dict[str, str] + :param input: Typed, validated input value. + :type input: Input + :param metadata: Mutable progress metadata. + :type metadata: TaskMetadata + :param run_attempt: Framework retry attempt counter. + :type run_attempt: int + :param lease_generation: Lease re-acquisition counter. + :type lease_generation: int + :param cancel: Request-level cancellation event. + :type cancel: asyncio.Event + :param shutdown: Container-level shutdown event. + :type shutdown: asyncio.Event + """ + + __slots__ = ( + "task_id", + "title", + "description", + "session_id", + "agent_name", + "tags", + "input", + "metadata", + "run_attempt", + "lease_generation", + "cancel", + "shutdown", + "_suspend_callback", + "_stream_handler", + "entry_mode", + "was_steered", + "previous_input", + "pending_inputs", + "generation", + ) + + def __init__( + self, + *, + task_id: str, + title: str, + description: str | None = None, + session_id: str, + agent_name: str, + tags: dict[str, str], + input: Input, # noqa: A002 — mirrors the spec naming + metadata: TaskMetadata, + run_attempt: int = 0, + lease_generation: int = 0, + cancel: asyncio.Event | None = None, + shutdown: asyncio.Event | None = None, + stream_handler: StreamHandler | None = None, + entry_mode: EntryMode = "fresh", + was_steered: bool = False, + previous_input: Input | None = None, + pending_inputs: Sequence[Any] | None = None, + generation: int = 0, + ) -> None: + self.task_id = task_id + self.title = title + self.description = description + self.session_id = session_id + self.agent_name = agent_name + self.tags = tags + self.input = input + self.metadata = metadata + self.run_attempt = run_attempt + self.lease_generation = lease_generation + self.cancel = cancel or asyncio.Event() + self.shutdown = shutdown or asyncio.Event() + self._suspend_callback: Any = None + self._stream_handler: StreamHandler | None = stream_handler + self.entry_mode: EntryMode = entry_mode + self.was_steered: bool = was_steered + self.previous_input: Input | None = previous_input + self.pending_inputs: Sequence[Any] = ( + pending_inputs if pending_inputs is not None else () + ) + self.generation: int = generation + + async def suspend( + self, + *, + reason: str | None = None, + output: Any | None = None, + ) -> Any: + """Suspend the task, releasing the lease and persisting state. + + Must be used as ``return await ctx.suspend(...)``. The framework + interprets the returned sentinel to transition the task to + ``suspended`` status. + + :keyword reason: Human-readable suspension reason. + :paramtype reason: str | None + :keyword output: Optional output snapshot for observers. + :paramtype output: Any | None + :return: A ``Suspended`` sentinel that the framework interprets. + :rtype: Suspended + """ + from ._run import Suspended # pylint: disable=import-outside-toplevel + + return Suspended(reason=reason, output=output) + + async def stream(self, item: Any) -> None: + """Emit a streaming item to observers iterating this task's output. + + When a :class:`~azure.ai.agentserver.core.durable.StreamHandler` + is configured, the item is routed through ``handler.put(item)``. + Otherwise the call is a no-op. + + :param item: The value to stream. + :type item: Any + """ + if self._stream_handler is not None: + await self._stream_handler.put(item) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_decorator.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_decorator.py new file mode 100644 index 000000000000..f9ecafc4ef93 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_decorator.py @@ -0,0 +1,1035 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""``@durable_task`` decorator — turns an async function into a crash-resilient +unit of work with automatic task lifecycle management. + +Usage:: + + from azure.ai.agentserver.core.durable import durable_task, TaskContext + + @durable_task + async def my_task(ctx: TaskContext[MyInput]) -> MyOutput: + ... + + result = await my_task.run(task_id="t1", input=MyInput(...)) +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +import inspect +import logging as _logging +from collections.abc import Awaitable, Callable +from datetime import timedelta +from typing import ( + TYPE_CHECKING, + Any, + Generic, + TypeVar, + get_args, + get_type_hints, + overload, +) + +import re + +from ._context import TaskContext +from ._result import TaskResult +from ._retry import RetryPolicy +from ._run import TaskRun +from ._stream import StreamHandler, StreamHandlerFactory + +if TYPE_CHECKING: + from ._models import TaskStatus + +Input = TypeVar("Input") +Output = TypeVar("Output") +F = TypeVar("F", bound=Callable[..., Any]) + +_VALID_TASK_ID_RE = re.compile(r"^[a-zA-Z0-9\-_.:]+$") +_MAX_TASK_ID_LENGTH = 256 + +#: Prefix for framework-reserved tags. Developer tags with this prefix are +#: silently stripped to prevent collisions with auto-stamped tags. +_RESERVED_TAG_PREFIX = "_durable_task_" + +_logger = _logging.getLogger("azure.ai.agentserver.durable") + +# Global registry of durable task descriptors for recovery purposes. +# Populated at import time when @durable_task decorates a function. +_REGISTERED_DESCRIPTORS: list[tuple[str, Callable[..., Any], "DurableTaskOptions"]] = [] + + +def _strip_reserved_tags(tags: dict[str, str]) -> dict[str, str]: + """Remove framework-reserved tags from developer-provided tags. + + Tags prefixed with ``_durable_task_`` are reserved for framework use. + If a developer provides them, they are silently dropped with a warning. + + :param tags: Developer-provided tags. + :type tags: dict[str, str] + :return: Tags with reserved keys removed. + :rtype: dict[str, str] + """ + reserved = [k for k in tags if k.startswith(_RESERVED_TAG_PREFIX)] + if reserved: + _logger.warning( + "Ignoring reserved tag(s) %s — tags prefixed with %r are " + "framework-owned and cannot be overridden", + reserved, + _RESERVED_TAG_PREFIX, + ) + return {k: v for k, v in tags.items() if not k.startswith(_RESERVED_TAG_PREFIX)} + return tags + + +def _validate_task_id(task_id: str) -> None: + if not task_id or len(task_id) > _MAX_TASK_ID_LENGTH: + raise ValueError( + f"task_id must be 1-{_MAX_TASK_ID_LENGTH} characters, " + f"got {len(task_id)}" + ) + if not _VALID_TASK_ID_RE.match(task_id): + raise ValueError( + f"task_id contains invalid characters: {task_id!r}. " + f"Allowed: [a-zA-Z0-9\\-_.:] " + ) + + +def _extract_generic_args( + fn: Callable[..., Any], +) -> tuple[type[Any], type[Any]]: + """Extract Input and Output types from a durable task function signature. + + The function must accept a single ``TaskContext[Input]`` parameter + and return ``Output``. + + :param fn: The async function to inspect. + :type fn: Callable[..., Any] + :returns: ``(InputType, OutputType)`` tuple. + :rtype: tuple[type[Any], type[Any]] + :raises TypeError: If the signature doesn't match expectations. + """ + hints = get_type_hints(fn) + params = list(inspect.signature(fn).parameters.values()) + + # Find the TaskContext parameter + ctx_param = None + for p in params: + hint = hints.get(p.name) + if hint is not None: + origin = getattr(hint, "__origin__", None) + if origin is TaskContext: + ctx_param = p + break + + if ctx_param is None: + raise TypeError( + f"Durable task function {fn.__qualname__!r} must accept a " + f"TaskContext[Input] parameter" + ) + + ctx_hint = hints[ctx_param.name] + args = get_args(ctx_hint) + input_type: type[Any] = args[0] if args else Any + + return_hint = hints.get("return", Any) + # Unwrap Optional, Awaitable, etc. + output_type: type[Any] = return_hint if return_hint is not None else type(None) + + return input_type, output_type + + +def _serialize_input(value: Any) -> Any: + """Serialize an input value for storage in the task payload. + + :param value: The input value to serialize. + :type value: Any + :return: The serialized form of the input. + :rtype: Any + """ + # Pydantic model + if hasattr(value, "model_dump"): + return value.model_dump() + # Plain JSON-serializable + return value + + +def _deserialize_input(value: Any, input_type: type[Any]) -> Any: + """Deserialize an input value from the task payload. + + :param value: The serialized input value. + :type value: Any + :param input_type: The expected type to deserialize into. + :type input_type: type[Any] + :return: The deserialized input value. + :rtype: Any + """ + if value is None: + return None + # Pydantic model + if hasattr(input_type, "model_validate"): + return input_type.model_validate(value) + # dict-constructable class + if ( + isinstance(value, dict) + and callable(input_type) + and input_type not in (dict, str, int, float, bool, list) + ): + try: + return input_type(**value) + except TypeError: + pass + return value + + +def _is_stale(task_updated_at: str, timeout: float) -> bool: + """Check if an in_progress task is stale based on its updated_at timestamp. + + :param task_updated_at: ISO 8601 timestamp of the task's last update. + :type task_updated_at: str + :param timeout: Seconds after which the task is considered stale. + :type timeout: float + :returns: True if the task is stale. + :rtype: bool + """ + if not task_updated_at: + return False + from datetime import datetime, timezone # pylint: disable=import-outside-toplevel + + updated = datetime.fromisoformat(task_updated_at) + now = datetime.now(timezone.utc) + if updated.tzinfo is None: + updated = updated.replace(tzinfo=timezone.utc) + return (now - updated).total_seconds() > timeout + + +class DurableTaskOptions: # pylint: disable=too-many-instance-attributes + """Options for a durable task. + + :param name: **Stable identity anchor.** Used for recovery routing and + source stamping. If you rename the Python function later, existing + in-flight tasks are still recovered correctly because the framework + matches on this name. + :type name: str + :param title: Human-readable title template. + :type title: str | Callable[[Any, str], str] | None + :param tags: Default tags (static dict or callable factory). + :type tags: dict[str, str] | Callable[[Any, str], dict[str, str]] + :param description: Task description (static string or callable factory). + :type description: str | Callable[[Any, str], str] | None + :param timeout: Execution timeout. + :type timeout: timedelta | None + :param lease_duration_seconds: Lease TTL. + :type lease_duration_seconds: int + :param store_input: Whether to persist input on the task record. + :type store_input: bool + :param ephemeral: Whether to delete on terminal exit. + :type ephemeral: bool + :param stream_handler_factory: Optional factory callable that receives a + ``task_id`` and returns a :class:`StreamHandler`. When set, crash- + recovery and resume paths use this factory instead of defaulting to + :class:`QueueStreamHandler`. + :type stream_handler_factory: Callable[[str], StreamHandler] | None + """ + + __slots__ = ( + "name", + "title", + "tags", + "description", + "timeout", + "lease_duration_seconds", + "store_input", + "ephemeral", + "retry", + "steerable", + "max_pending", + "stream_handler_factory", + ) + + def __init__( + self, + name: str, + title: str | Callable[[Any, str], str] | None = None, + tags: dict[str, str] | Callable[[Any, str], dict[str, str]] | None = None, + description: str | Callable[[Any, str], str | None] | None = None, + timeout: timedelta | None = None, + lease_duration_seconds: int = 60, + store_input: bool = True, + ephemeral: bool = True, + retry: RetryPolicy | None = None, + steerable: bool = False, + max_pending: int = 10, + stream_handler_factory: StreamHandlerFactory | None = None, + ) -> None: + self.name = name + self.title = title + self.tags = tags if tags is not None else {} + self.description = description + self.timeout = timeout + self.lease_duration_seconds = lease_duration_seconds + self.store_input = store_input + self.ephemeral = ephemeral + self.retry = retry + self.steerable = steerable + self.max_pending = max_pending + self.stream_handler_factory = stream_handler_factory + + def __repr__(self) -> str: + return ( + f"DurableTaskOptions(name={self.name!r}, lease_duration_seconds={self.lease_duration_seconds}, " + f"store_input={self.store_input}, ephemeral={self.ephemeral}, retry={self.retry!r}, " + f"timeout={self.timeout!r}, " + f"steerable={self.steerable}, max_pending={self.max_pending})" + ) + + +class DurableTask(Generic[Input, Output]): + """A decorated durable task function. Not callable directly. + + Use :meth:`run` (invoke-and-wait), :meth:`start` (fire-and-forget), + or :meth:`options` (per-call overrides). + + :param fn: The decorated async function. + :param opts: Frozen task options. + :param input_type: Extracted input type. + :param output_type: Extracted output type. + """ + + __slots__ = ("_fn", "_opts", "_input_type", "_output_type", "name") + + def __init__( + self, + fn: Callable[[TaskContext[Input]], Awaitable[Output]], + opts: DurableTaskOptions, + input_type: type[Input], + output_type: type[Output], + ) -> None: + self._fn = fn + self._opts = opts + self._input_type = input_type + self._output_type = output_type + self.name = opts.name + # Register for recovery — manager picks these up at startup + _REGISTERED_DESCRIPTORS.append((opts.name, fn, opts)) + + def _resolve_title(self, input_val: Input, task_id: str) -> str: + if callable(self._opts.title): + return self._opts.title(input_val, task_id) + if isinstance(self._opts.title, str): + return self._opts.title + return f"{self.name}:{task_id[:8]}" + + def _resolve_tags(self, input_val: Input, task_id: str) -> dict[str, str]: + """Resolve decorator-level tags (static dict or callable factory). + + Reserved tags (prefixed with ``_durable_task_``) are stripped to + prevent developer code from colliding with framework-stamped tags. + + :param input_val: The task input value. + :type input_val: Input + :param task_id: The task identifier. + :type task_id: str + :return: Resolved tags dictionary. + :rtype: dict[str, str] + """ + tags = self._opts.tags + if callable(tags): + result = tags(input_val, task_id) + if not isinstance(result, dict): + raise TypeError( + f"tags callable must return dict[str, str], " + f"got {type(result).__name__}" + ) + return _strip_reserved_tags(result) + return _strip_reserved_tags(dict(tags) if tags else {}) + + def _resolve_description(self, input_val: Input, task_id: str) -> str | None: + """Resolve decorator-level description (static or callable). + + :param input_val: The task input value. + :type input_val: Input + :param task_id: The task identifier. + :type task_id: str + :return: Resolved description string or None. + :rtype: str | None + """ + desc = self._opts.description + if callable(desc): + result = desc(input_val, task_id) + if result is not None and not isinstance(result, str): + raise TypeError( + f"description callable must return str or None, " + f"got {type(result).__name__}" + ) + return result + return desc + + def _merge_tags( + self, input_val: Input, task_id: str, call_tags: dict[str, str] | None + ) -> dict[str, str]: + merged = self._resolve_tags(input_val, task_id) + if call_tags: + merged.update(_strip_reserved_tags(call_tags)) + return merged + + async def run( + self, + *, + task_id: str, + input: Input, # noqa: A002 + session_id: str | None = None, + title: str | None = None, + tags: dict[str, str] | None = None, + retry: RetryPolicy | None = None, + stale_timeout: float = 300.0, + stream_handler: StreamHandler | None = None, + ) -> TaskResult[Output]: + """Run a lifecycle-aware durable task and return the result. + + Automatically starts, resumes, or recovers the task based on its + current state: + + - No task / pending → create and start (``entry_mode="fresh"``) + - Suspended → resume with new input (``entry_mode="resumed"``) + - In-progress (stale) → recover (``entry_mode="recovered"``) + - In-progress (not stale) → raise :class:`TaskConflictError` + - Completed → raise :class:`TaskConflictError` + + :keyword task_id: Unique task identifier. + :paramtype task_id: str + :keyword input: Typed input value. + :paramtype input: Input + :keyword session_id: Session scope override. + :paramtype session_id: str | None + :keyword title: Title override. + :paramtype title: str | None + :keyword tags: Per-call tag overrides. + :paramtype tags: dict[str, str] | None + :keyword retry: Retry policy override. Overrides decorator-level retry. + :paramtype retry: ~azure.ai.agentserver.core.durable.RetryPolicy | None + :keyword stale_timeout: Seconds before an in-progress task is considered + stale and eligible for recovery. Default 300 (5 minutes). + :paramtype stale_timeout: float + :keyword stream_handler: Custom stream handler for pluggable streaming. + If ``None``, a default :class:`QueueStreamHandler` is used. + :paramtype stream_handler: ~azure.ai.agentserver.core.durable.StreamHandler | None + :return: The task result wrapper with output, status, and suspension info. + :rtype: ~azure.ai.agentserver.core.durable.TaskResult[Output] + :raises TaskFailed: On unhandled exception. + :raises ~azure.ai.agentserver.core.durable.TaskConflictError: If the + task is already in-progress or completed. + """ + _validate_task_id(task_id) + handle = await self._lifecycle_start( + task_id=task_id, + input=input, + session_id=session_id, + title=title, + tags=tags, + retry=retry, + stale_timeout=stale_timeout, + stream_handler=stream_handler, + ) + return await handle.result() + + async def start( + self, + *, + task_id: str, + input: Input, # noqa: A002 + session_id: str | None = None, + title: str | None = None, + tags: dict[str, str] | None = None, + retry: RetryPolicy | None = None, + stale_timeout: float = 300.0, + stream_handler: StreamHandler | None = None, + ) -> TaskRun[Output]: + """Start a lifecycle-aware durable task and return a handle. + + Follows the same lifecycle rules as :meth:`run` but returns + immediately with a :class:`TaskRun` handle instead of blocking. + + :keyword task_id: Unique task identifier. + :paramtype task_id: str + :keyword input: Typed input value. + :paramtype input: Input + :keyword session_id: Session scope override. + :paramtype session_id: str | None + :keyword title: Title override. + :paramtype title: str | None + :keyword tags: Per-call tag overrides. + :paramtype tags: dict[str, str] | None + :keyword retry: Retry policy override. Overrides decorator-level retry. + :paramtype retry: ~azure.ai.agentserver.core.durable.RetryPolicy | None + :keyword stale_timeout: Seconds before an in-progress task is considered + stale and eligible for recovery. Default 300 (5 minutes). + :paramtype stale_timeout: float + :keyword stream_handler: Custom stream handler for pluggable streaming. + If ``None``, a default :class:`QueueStreamHandler` is used. + :paramtype stream_handler: ~azure.ai.agentserver.core.durable.StreamHandler | None + :return: A handle to the running task. + :rtype: TaskRun[Output] + :raises ~azure.ai.agentserver.core.durable.TaskConflictError: If the + task is already in-progress or completed. + """ + _validate_task_id(task_id) + return await self._lifecycle_start( + task_id=task_id, + input=input, + session_id=session_id, + title=title, + tags=tags, + retry=retry, + stale_timeout=stale_timeout, + stream_handler=stream_handler, + ) + + async def get(self, task_id: str) -> Any: + """Return the full persisted task information. + + Works for any task state — running, suspended, completed, etc. + Returns whatever is persisted. Returns ``None`` if no task exists. + + :param task_id: The task identifier. + :type task_id: str + :return: Task info or ``None`` if no task exists. + :rtype: ~azure.ai.agentserver.core.durable.TaskInfo | None + """ + from ._manager import ( # pylint: disable=import-outside-toplevel + get_task_manager, + ) + + manager = get_task_manager() + return await manager.provider.get(task_id) + + def get_active_run(self, task_id: str) -> TaskRun[Output] | None: + """Return a TaskRun handle for an active (in-progress) task. + + Enables late-join consumers to iterate a running task's stream + without being the original caller of ``start()``/``run()``. + Returns ``None`` if the task is not currently active in this process. + + :param task_id: The task identifier. + :type task_id: str + :return: A TaskRun bound to the active task's stream handler, + or ``None`` if not active. + :rtype: TaskRun[Output] | None + + Example:: + + # In another coroutine or request handler: + run = my_task.get_active_run("task-123") + if run is not None: + async for chunk in run: + print(chunk, end="") + """ + from ._manager import ( # pylint: disable=import-outside-toplevel + get_task_manager, + ) + + manager = get_task_manager() + return manager.get_active_run(task_id) + + async def list( + self, + *, + session_id: str | None = None, + status: TaskStatus | None = None, + ) -> list[Any]: + """List tasks created by this durable task function. + + Automatically scoped to this function's ``name`` via the + ``_durable_task_name`` tag (server-side) and ``source.type`` + (client-side). Only returns tasks created by this framework. + + :keyword session_id: Session scope override. Defaults to the + manager's configured session ID. + :paramtype session_id: str | None + :keyword status: Filter by task status (e.g., ``"in_progress"``, + ``"suspended"``, ``"completed"``). + :paramtype status: ~azure.ai.agentserver.core.durable.TaskStatus | None + :return: Matching task records. + :rtype: list[~azure.ai.agentserver.core.durable.TaskInfo] + + Example:: + + tasks = await my_task.list(status="suspended") + for t in tasks: + print(t.id, t.status) + """ + from ._manager import ( # pylint: disable=import-outside-toplevel + get_task_manager, + ) + + manager = get_task_manager() + return await manager.list_tasks( + fn_name=self.name, + session_id=session_id, + status=status, + ) + + async def _append_steering_input( # pylint: disable=protected-access + self, + manager: Any, + *, + task_id: str, + input_val: Any, + existing: Any, + ) -> None: + """Append a steering input to the task's pending queue.""" + from ._exceptions import ( # pylint: disable=import-outside-toplevel + SteeringQueueFull, + ) + from ._models import ( # pylint: disable=import-outside-toplevel + TaskPatchRequest, + ) + + max_retries = 5 + serialized = _serialize_input(input_val) + + for _attempt in range(max_retries): + task_info = ( + existing if _attempt == 0 else await manager.provider.get(task_id) + ) + if task_info is None: + raise RuntimeError( + f"Task {task_id!r} disappeared during steering append" + ) + + payload = dict(task_info.payload) if task_info.payload else {} + steering = dict(payload.get("_steering", {})) + pending: list[Any] = list(steering.get("pending_inputs", [])) + + if len(pending) >= self._opts.max_pending: + raise SteeringQueueFull(task_id, self._opts.max_pending) + + pending.append(serialized) + steering["pending_inputs"] = pending + steering["cancel_requested"] = True + if "generation" not in steering: + steering["generation"] = 0 + payload["_steering"] = steering + + etag = getattr(task_info, "etag", None) or None + try: + await manager.provider.update( + task_id, + TaskPatchRequest(payload=payload, if_match=etag), + ) + # Signal the running task's cancel event so it can short-circuit + active = manager._active_tasks.get( + task_id + ) # pylint: disable=protected-access # noqa: SLF001 + if active and hasattr(active, "context") and active.context is not None: + active.context.cancel.set() + return + except ValueError: + # Local provider etag conflict — retry + continue + + raise RuntimeError( + f"Failed to append steering input after {max_retries} retries" + ) + + def _create_steering_ack_run( + self, + manager: Any, + task_id: str, + future: Any, + ) -> TaskRun[Output]: + """Create a TaskRun for a queued steering input.""" + return TaskRun( + task_id=task_id, + provider=manager.provider, + result_future=future, + ) + + async def _lifecycle_start( + self, + *, + task_id: str, + input: Input, # noqa: A002 + session_id: str | None, + title: str | None, + tags: dict[str, str] | None, + retry: RetryPolicy | None, + stale_timeout: float, + stream_handler: StreamHandler | None = None, + ) -> TaskRun[Output]: + """Resolve lifecycle state and start/resume/recover accordingly. + + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword input: Typed input value. + :paramtype input: Input + :keyword session_id: Session scope override. + :paramtype session_id: str | None + :keyword title: Title override. + :paramtype title: str | None + :keyword tags: Per-call tag overrides. + :paramtype tags: dict[str, str] | None + :keyword retry: Retry policy override. + :paramtype retry: RetryPolicy | None + :keyword stale_timeout: Stale timeout in seconds. + :paramtype stale_timeout: float + :keyword stream_handler: Custom stream handler. Defaults to + :class:`QueueStreamHandler` when ``None``. + :paramtype stream_handler: StreamHandler | None + :return: A handle to the running task. + :rtype: TaskRun[Output] + """ + from ._exceptions import ( # pylint: disable=import-outside-toplevel + TaskConflictError, + ) + from ._manager import ( # pylint: disable=import-outside-toplevel + get_task_manager, + ) + + manager = get_task_manager() + existing = await manager.provider.get(task_id) + + resolved_retry = retry or self._opts.retry + + if existing is None or existing.status == "pending": + # Fresh start + if existing is not None and existing.status == "pending": + # Pending task exists — patch to in_progress and execute + return await manager._start_existing_task( # pylint: disable=protected-access + fn=self._fn, + fn_name=self.name, + task_info=existing, + entry_mode="fresh", + input_val=input, + input_type=self._input_type, + opts=self._opts, + retry=resolved_retry, + stream_handler=stream_handler, + ) + # No task exists — create new + return await manager.create_and_start( + fn=self._fn, + fn_name=self.name, + task_id=task_id, + input_val=input, + input_type=self._input_type, + session_id=session_id, + title=title or self._resolve_title(input, task_id), + tags=self._merge_tags(input, task_id, tags), + description=self._resolve_description(input, task_id), + opts=self._opts, + retry=resolved_retry, + entry_mode="fresh", + stream_handler=stream_handler, + ) + + if existing.status == "suspended": + # Resume — patch input onto task, then start + serialized = _serialize_input(input) + from ._models import ( # pylint: disable=import-outside-toplevel + TaskPatchRequest, + ) + + await manager.provider.update( + task_id, + TaskPatchRequest(payload={"input": serialized}), + ) + # Re-fetch after input patch + updated_info = await manager.provider.get(task_id) + if updated_info is None: + raise RuntimeError(f"Task {task_id!r} disappeared after input patch") + return ( + await manager._start_existing_task( # pylint: disable=protected-access + fn=self._fn, + fn_name=self.name, + task_info=updated_info, + entry_mode="resumed", + input_val=input, + input_type=self._input_type, + opts=self._opts, + retry=resolved_retry, + stream_handler=stream_handler, + ) + ) + + if existing.status == "in_progress": + if _is_stale(existing.updated_at, stale_timeout): + # Stale — check for steering recovery state first + if self._opts.steerable and existing.payload: + steering = existing.payload.get("_steering", {}) + if steering.get("drain_in_progress") or steering.get( + "pending_inputs" + ): + # Stale with steering state — recover via steered path + return await manager._start_existing_task( # pylint: disable=protected-access + fn=self._fn, + fn_name=self.name, + task_info=existing, + entry_mode="recovered", + input_val=input, + input_type=self._input_type, + opts=self._opts, + retry=resolved_retry, + stream_handler=stream_handler, + ) + # Normal stale recovery + return await manager._start_existing_task( # pylint: disable=protected-access + fn=self._fn, + fn_name=self.name, + task_info=existing, + entry_mode="recovered", + input_val=input, + input_type=self._input_type, + opts=self._opts, + retry=resolved_retry, + stream_handler=stream_handler, + ) + if self._opts.steerable: + # Steering path: append input to queue, signal cancel, return ack + ack_future = manager._register_steering_future( + task_id + ) # pylint: disable=protected-access + await self._append_steering_input( + manager, + task_id=task_id, + input_val=input, + existing=existing, + ) + # Set cancel on in-memory context if task runs in this process + active = manager._active_tasks.get( + task_id + ) # pylint: disable=protected-access + if active: + active.context.cancel.set() + return self._create_steering_ack_run(manager, task_id, ack_future) + raise TaskConflictError(task_id, "in_progress") + + # completed (or any other terminal status) + raise TaskConflictError(task_id, existing.status) + + def options( + self, + *, + title: str | Callable[[Any, str], str] | None = None, + tags: dict[str, str] | Callable[[Any, str], dict[str, str]] | None = None, + description: str | Callable[[Any, str], str | None] | None = None, + timeout: timedelta | None = None, + lease_duration_seconds: int | None = None, + store_input: bool | None = None, + ephemeral: bool | None = None, + retry: RetryPolicy | None = None, + steerable: bool | None = None, + max_pending: int | None = None, + ) -> DurableTask[Input, Output]: + """Return a new DurableTask with merged options. + + The original is unchanged. + + :keyword timeout: Execution timeout override. + :paramtype timeout: timedelta | None + :keyword ephemeral: Whether to delete task on terminal exit. + :paramtype ephemeral: bool | None + :keyword tags: Tag overrides. + :paramtype tags: dict[str, str] | Callable[[Any, str], dict[str, str]] | None + :keyword store_input: Whether to persist input. + :paramtype store_input: bool | None + :keyword retry: Retry policy override. + :paramtype retry: RetryPolicy | None + :keyword title: Title override. + :paramtype title: str | Callable[[Any, str], str] | None + :keyword description: Description override. + :paramtype description: str | Callable[[Any, str], str | None] | None + :keyword lease_duration_seconds: Lease TTL override. + :paramtype lease_duration_seconds: int | None + :keyword steerable: Whether this task accepts steering inputs. + :paramtype steerable: bool | None + :keyword max_pending: Maximum queued steering inputs. + :paramtype max_pending: int | None + :return: A new DurableTask with overridden options. + :rtype: DurableTask[Input, Output] + """ + # For tags: if both old and new are dicts, merge them. + # Mixing callable and dict is not supported — use one or the other. + resolved_tags: dict[str, str] | Callable[[Any, str], dict[str, str]] | None + if tags is not None: + if callable(tags) != callable(self._opts.tags) and self._opts.tags: + raise TypeError( + "Cannot mix callable and dict tags in options(). " + "Pass a callable to replace a callable, or a dict to merge with a dict." + ) + if callable(tags): + resolved_tags = tags + else: + existing = self._opts.tags if isinstance(self._opts.tags, dict) else {} + resolved_tags = _strip_reserved_tags({**existing, **(tags or {})}) + else: + resolved_tags = self._opts.tags + + new_opts = DurableTaskOptions( + name=self._opts.name, + title=title if title is not None else self._opts.title, + tags=resolved_tags, + description=( + description if description is not None else self._opts.description + ), + timeout=timeout if timeout is not None else self._opts.timeout, + lease_duration_seconds=( + lease_duration_seconds + if lease_duration_seconds is not None + else self._opts.lease_duration_seconds + ), + store_input=( + store_input if store_input is not None else self._opts.store_input + ), + ephemeral=(ephemeral if ephemeral is not None else self._opts.ephemeral), + retry=retry if retry is not None else self._opts.retry, + steerable=(steerable if steerable is not None else self._opts.steerable), + max_pending=( + max_pending if max_pending is not None else self._opts.max_pending + ), + ) + return DurableTask( + fn=self._fn, + opts=new_opts, + input_type=self._input_type, + output_type=self._output_type, + ) + + +@overload +def durable_task( + fn: Callable[[TaskContext[Input]], Awaitable[Output]], +) -> DurableTask[Input, Output]: ... + + +@overload +def durable_task( + *, + name: str | None = ..., + title: str | Callable[[Any, str], str] | None = ..., + tags: dict[str, str] | Callable[[Any, str], dict[str, str]] | None = ..., + description: str | Callable[[Any, str], str | None] | None = ..., + timeout: timedelta | None = ..., + lease_duration_seconds: int = ..., + store_input: bool = ..., + ephemeral: bool = ..., + retry: RetryPolicy | None = ..., + steerable: bool = ..., + max_pending: int = ..., + stream_handler_factory: StreamHandlerFactory | None = ..., +) -> Callable[ + [Callable[[TaskContext[Input]], Awaitable[Output]]], + DurableTask[Input, Output], +]: ... + + +def durable_task( + fn: Callable[..., Any] | None = None, + *, + name: str | None = None, + title: str | Callable[[Any, str], str] | None = None, + tags: dict[str, str] | Callable[[Any, str], dict[str, str]] | None = None, + description: str | Callable[[Any, str], str | None] | None = None, + timeout: timedelta | None = None, + lease_duration_seconds: int = 60, + store_input: bool = True, + ephemeral: bool = True, + retry: RetryPolicy | None = None, + steerable: bool = False, + max_pending: int = 10, + stream_handler_factory: StreamHandlerFactory | None = None, +) -> Any: + """Turn an async function into a crash-resilient durable task. + + Can be used with or without arguments:: + + @durable_task + async def my_task(ctx: TaskContext[MyInput]) -> MyOutput: ... + + @durable_task(name="custom-name", ephemeral=False) + async def my_task(ctx: TaskContext[MyInput]) -> MyOutput: ... + + :param fn: The async function to decorate (when used without parens). + :type fn: Callable[..., Any] | None + :keyword name: **Stable identity anchor.** Used for recovery routing and + source stamping. Defaults to ``fn.__qualname__``. Always provide an + explicit name for production tasks — if you rename the function later, + existing in-flight tasks are still recovered correctly because the + framework matches on this name, not the Python function name. + :keyword title: Human-readable title (string or callable). + :keyword tags: Default tags (static dict or callable factory receiving + ``(input, task_id)``). Merged with per-call ``tags=`` overrides. + :keyword description: Task description (static string or callable factory + receiving ``(input, task_id)``). + :keyword timeout: Execution timeout. When elapsed, ``ctx.cancel`` is set + cooperatively. If the function does not exit, the lease eventually + expires and the task is recovered. + :keyword lease_duration_seconds: Lease TTL (default 60). + :keyword store_input: Whether to persist input on the task record. + :keyword ephemeral: Delete task on terminal exit (default True). + :keyword retry: Default retry policy for this task. + :keyword steerable: Whether this task accepts steering inputs. When True, + calling ``start()`` on an ``in_progress`` task queues the input and + signals cancel instead of raising ``TaskConflictError``. Default False. + :keyword max_pending: Maximum number of queued steering inputs. Default 10. + :keyword stream_handler_factory: Optional factory callable that receives a + ``task_id`` and returns a :class:`StreamHandler`. When set, crash-recovery + and resume paths use this factory instead of defaulting to + :class:`QueueStreamHandler`. Call-site ``stream_handler=`` overrides the + factory for that specific call. + :return: A ``DurableTask[Input, Output]`` wrapper. + :rtype: Any + """ + + def _wrap( + func: Callable[..., Any], + ) -> DurableTask[Any, Any]: + if not asyncio.iscoroutinefunction(func): + raise TypeError( + f"@durable_task requires an async function, " + f"got {func.__qualname__!r}" + ) + + if lease_duration_seconds < 1: + raise ValueError( + f"lease_duration_seconds must be >= 1, got {lease_duration_seconds}" + ) + + if max_pending < 1: + raise ValueError(f"max_pending must be >= 1, got {max_pending}") + + input_type, output_type = _extract_generic_args(func) + + # Preserve callable tags as-is (stripped at resolve time); strip static dicts now + resolved_tags = ( + tags if callable(tags) else _strip_reserved_tags(dict(tags) if tags else {}) + ) + + opts = DurableTaskOptions( + name=name or func.__qualname__, + title=title, + tags=resolved_tags, + description=description, + timeout=timeout, + lease_duration_seconds=lease_duration_seconds, + store_input=store_input, + ephemeral=ephemeral, + retry=retry, + steerable=steerable, + max_pending=max_pending, + stream_handler_factory=stream_handler_factory, + ) + + task = DurableTask( + fn=func, + opts=opts, + input_type=input_type, + output_type=output_type, + ) + return task + + if fn is not None: + return _wrap(fn) + return _wrap diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_exceptions.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_exceptions.py new file mode 100644 index 000000000000..45a6b75ae7bf --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_exceptions.py @@ -0,0 +1,162 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Exception types for the durable task subsystem.""" + +from typing import Any + + +class TaskFailed(Exception): + """Raised when a durable task function raises an unhandled exception. + + :param task_id: The identifier of the failed task. + :type task_id: str + :param error: Structured error details captured from the exception. + :type error: dict[str, Any] + """ + + def __init__(self, task_id: str, error: dict[str, Any]) -> None: + self.task_id = task_id + self.error = error + message = error.get("message", "Task failed") + super().__init__(f"Task {task_id!r} failed: {message}") + + +class TaskSuspended(Exception): + """Raised when awaiting the result of a suspended task. + + :param task_id: The identifier of the suspended task. + :type task_id: str + :param reason: Human-readable suspension reason, if provided. + :type reason: str | None + :param output: Optional output snapshot set at suspension time. + :type output: Any | None + """ + + def __init__( + self, + task_id: str, + reason: str | None = None, + output: Any | None = None, + ) -> None: + self.task_id = task_id + self.reason = reason + self.output = output + suffix = f": {reason}" if reason else "" + super().__init__(f"Task {task_id!r} is suspended{suffix}") + + +class TaskCancelled(Exception): + """Raised when a durable task is cancelled. + + Inherits from :class:`Exception` rather than :class:`asyncio.CancelledError` + to prevent unintentional suppression by generic ``CancelledError`` handlers + in the asyncio event loop. + + :param task_id: The identifier of the cancelled task. + :type task_id: str + """ + + def __init__(self, task_id: str) -> None: + self.task_id = task_id + super().__init__(f"Task {task_id!r} was cancelled") + + +class TaskNotFound(Exception): + """Raised when a task ID is not found in the store. + + :param task_id: The identifier that was not found. + :type task_id: str + """ + + def __init__(self, task_id: str) -> None: + self.task_id = task_id + super().__init__(f"Task {task_id!r} not found") + + +class TaskTerminated(Exception): + """Raised when a task is forcefully terminated via ``handle.terminate()``. + + Unlike :class:`TaskCancelled`, terminated tasks go through the failure + path and do NOT stay ``in_progress`` for recovery. + + :param task_id: The identifier of the terminated task. + :type task_id: str + :param reason: Optional human-readable termination reason. + :type reason: str | None + """ + + __slots__ = ("task_id", "reason") + + def __init__(self, task_id: str, reason: str | None = None) -> None: + self.task_id = task_id + self.reason = reason + suffix = f": {reason}" if reason else "" + super().__init__(f"Task {task_id!r} was terminated{suffix}") + + +class TaskConflictError(RuntimeError): + """Raised when a task lifecycle conflict cannot be resolved. + + Raised by ``.run()`` or ``.start()`` when the task is already + ``in_progress`` (non-stale) or ``completed``. The lifecycle is + deterministic: create if none, start if pending, resume if suspended, + throw if in-progress or completed. + + :param task_id: The conflicting task's ID. + :type task_id: str + :param current_status: The task's current status. + :type current_status: str + """ + + __slots__ = ("task_id", "current_status") + + def __init__( + self, + task_id: str, + current_status: str, + ) -> None: + self.task_id = task_id + self.current_status = current_status + super().__init__(f"Task '{task_id}' is already {current_status}") + + +class EtagConflict(RuntimeError): + """Raised when an optimistic concurrency (etag) check fails. + + The task record was modified between read and write. Callers should + retry the operation with the updated etag. + + :param task_id: The task ID where the conflict occurred. + :type task_id: str + :param message: Optional detail message. + :type message: str | None + """ + + __slots__ = ("task_id",) + + def __init__(self, task_id: str, message: str | None = None) -> None: + self.task_id = task_id + msg = message or f"Etag conflict on task '{task_id}'" + super().__init__(msg) + + +class SteeringQueueFull(RuntimeError): + """Raised when the steering pending-input queue is at capacity. + + The caller should retry later or increase ``max_pending``. + + :param task_id: The task whose queue is full. + :type task_id: str + :param max_pending: The configured queue capacity. + :type max_pending: int + """ + + __slots__ = ("task_id", "max_pending") + + def __init__(self, task_id: str, max_pending: int) -> None: + self.task_id = task_id + self.max_pending = max_pending + super().__init__( + f"Steering queue full for task '{task_id}' " f"(max_pending={max_pending})" + ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_lease.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_lease.py new file mode 100644 index 000000000000..cb5f186d3e5d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_lease.py @@ -0,0 +1,155 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Lease identity derivation and renewal loop for durable tasks. + +Provides utility functions for constructing stable lease owner strings, +generating ephemeral instance IDs, and running the background lease +renewal loop. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +import logging +import os +import time +import uuid +from collections.abc import Awaitable, Callable + +from ._models import TaskPatchRequest +from ._provider import DurableTaskProvider + +logger = logging.getLogger("azure.ai.agentserver.durable") + + +def derive_lease_owner(session_id: str) -> str: + """Derive a stable lease owner string from the session ID. + + The owner is stable across process restarts within the same session, + enabling dual-identity lease reclamation. + + :param session_id: The agent session identifier. + :type session_id: str + :return: A lease owner string in the format ``"session:{session_id}"``. + :rtype: str + """ + return f"session:{session_id}" + + +def generate_instance_id() -> str: + """Generate an ephemeral lease instance ID unique to this process. + + Combines the PID and a timestamp to ensure uniqueness even after + rapid restarts. + + :return: A unique instance identifier. + :rtype: str + """ + return f"worker-{os.getpid()}-{uuid.uuid4().hex[:8]}-{int(time.time())}" + + +async def lease_renewal_loop( + provider: DurableTaskProvider, + task_id: str, + *, + lease_owner: str, + lease_instance_id: str, + lease_duration_seconds: int, + cancel_event: asyncio.Event, + on_failure_count: int = 3, + on_cancel_callback: asyncio.Event | None = None, + steering_poll_callback: Callable[[], Awaitable[None]] | None = None, +) -> None: + """Run a background lease renewal loop at half the lease duration. + + Renews the lease by PATCHing the task with the same owner/instance. + On ``on_failure_count`` consecutive failures, signals the optional + ``on_cancel_callback`` event to give the task function a chance to + checkpoint. + + The loop exits when ``cancel_event`` is set or the task is cancelled. + + :param provider: The storage provider. + :type provider: DurableTaskProvider + :param task_id: The task to renew. + :type task_id: str + :keyword lease_owner: The stable lease owner. + :paramtype lease_owner: str + :keyword lease_instance_id: The ephemeral instance ID. + :paramtype lease_instance_id: str + :keyword lease_duration_seconds: The lease TTL in seconds. + :paramtype lease_duration_seconds: int + :keyword cancel_event: Event that stops the loop when set. + :paramtype cancel_event: asyncio.Event + :keyword on_failure_count: Consecutive failures before signalling cancel. + :paramtype on_failure_count: int + :keyword on_cancel_callback: Event to signal on repeated renewal failure. + :paramtype on_cancel_callback: asyncio.Event | None + :keyword steering_poll_callback: Async callback invoked each renewal to poll + for steering inputs. Called after successful lease renewal. + :paramtype steering_poll_callback: Callable[[], Awaitable[None]] | None + """ + interval = max(1, lease_duration_seconds // 2) + consecutive_failures = 0 + + while not cancel_event.is_set(): + try: + await asyncio.wait_for( + _wait_for_event(cancel_event), + timeout=interval, + ) + # cancel_event was set — exit the loop + break + except asyncio.TimeoutError: + pass + + try: + await provider.update( + task_id, + TaskPatchRequest( + lease_owner=lease_owner, + lease_instance_id=lease_instance_id, + lease_duration_seconds=lease_duration_seconds, + ), + ) + consecutive_failures = 0 + logger.debug("Lease renewed for task %s", task_id) + + # Poll for steering inputs after successful renewal + if steering_poll_callback is not None: + try: + await steering_poll_callback() + except Exception: # pylint: disable=broad-exception-caught + logger.debug( + "Steering poll failed for task %s", task_id, exc_info=True + ) + except Exception: # pylint: disable=broad-exception-caught + consecutive_failures += 1 + logger.warning( + "Lease renewal failed for task %s (attempt %d/%d)", + task_id, + consecutive_failures, + on_failure_count, + exc_info=True, + ) + if ( + consecutive_failures >= on_failure_count + and on_cancel_callback is not None + ): + logger.error( + "Lease renewal failed %d times for task %s — signalling cancellation", + on_failure_count, + task_id, + ) + on_cancel_callback.set() + break + + +async def _wait_for_event(event: asyncio.Event) -> None: + """Await an asyncio event. Used with ``wait_for`` for interruptible sleep. + + :param event: The asyncio event to wait for. + :type event: asyncio.Event + """ + await event.wait() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_local_provider.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_local_provider.py new file mode 100644 index 000000000000..da187a518398 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_local_provider.py @@ -0,0 +1,377 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Local filesystem-backed durable task provider. + +Stores tasks as JSON files under ``$HOME/.durable-tasks/{agent_name}/{session_id}/`` +for local development with full lifecycle parity. +""" + +from __future__ import annotations + +import datetime +import hashlib +import json +import logging +import os +from pathlib import Path +from typing import Any + +from ._exceptions import TaskNotFound +from ._models import ( + LeaseInfo, + TaskCreateRequest, + TaskInfo, + TaskPatchRequest, + TaskStatus, +) + +logger = logging.getLogger("azure.ai.agentserver.durable") + + +def _now_iso() -> str: + return datetime.datetime.now(datetime.timezone.utc).isoformat() + + +def _generate_etag(data: dict[str, Any]) -> str: + raw = json.dumps(data, sort_keys=True) + return f"local-{hashlib.sha256(raw.encode()).hexdigest()[:16]}" + + +def _is_lease_expired(lease: LeaseInfo | None) -> bool: + if lease is None: + return True + try: + expires = datetime.datetime.fromisoformat(lease.expires_at) + now = datetime.datetime.now(datetime.timezone.utc) + return now >= expires + except (ValueError, TypeError): + return True + + +class LocalFileDurableTaskProvider: + """Filesystem-backed provider for local development. + + Tasks are stored as individual JSON files. Lease expiry is simulated + by checking timestamps on read. + + :param base_dir: Root directory for task storage. + Defaults to ``$HOME/.durable-tasks``. + :type base_dir: Path | None + """ + + def __init__(self, base_dir: Path | None = None) -> None: + self._base_dir = base_dir or Path.home() / ".durable-tasks" + + def _task_dir(self, agent_name: str, session_id: str) -> Path: + return self._base_dir / agent_name / session_id + + def _task_path(self, agent_name: str, session_id: str, task_id: str) -> Path: + return self._task_dir(agent_name, session_id) / f"{task_id}.json" + + def _find_task_path(self, task_id: str) -> Path | None: + """Search all agent/session dirs for a task file. + + :param task_id: The task identifier. + :type task_id: str + :return: The path to the task file, or None. + :rtype: ~pathlib.Path | None + """ + if not self._base_dir.exists(): + return None + for agent_dir in self._base_dir.iterdir(): + if not agent_dir.is_dir(): + continue + for session_dir in agent_dir.iterdir(): + if not session_dir.is_dir(): + continue + path = session_dir / f"{task_id}.json" + if path.exists(): + return path + return None + + def _read_task(self, path: Path) -> TaskInfo | None: + if not path.exists(): + return None + try: + data = json.loads(path.read_text(encoding="utf-8")) + return TaskInfo.from_dict(data) + except (json.JSONDecodeError, KeyError): + logger.warning("Corrupt task file: %s", path) + return None + + def _write_task(self, task: TaskInfo) -> None: + path = self._task_path(task.agent_name, task.session_id, task.id) + path.parent.mkdir(parents=True, exist_ok=True) + data = task.to_dict() + data["etag"] = _generate_etag(data) + task.etag = data["etag"] + path.write_text(json.dumps(data, indent=2), encoding="utf-8") + + async def create(self, request: TaskCreateRequest) -> TaskInfo: + """Create a new task as a JSON file. + + :param request: Task creation parameters. + :type request: TaskCreateRequest + :return: The created task record. + :rtype: TaskInfo + """ + now = _now_iso() + task_id = request.id or f"task-{os.urandom(8).hex()}" + + lease: LeaseInfo | None = None + started_at: str | None = None + status: TaskStatus = request.status + + if ( + request.lease_owner + and request.lease_instance_id + and request.lease_duration_seconds + ): + expires_at = ( + datetime.datetime.now(datetime.timezone.utc) + + datetime.timedelta(seconds=request.lease_duration_seconds) + ).isoformat() + lease = LeaseInfo( + owner=request.lease_owner, + instance_id=request.lease_instance_id, + generation=0, + expires_at=expires_at, + expiry_count=0, + ) + if status == "in_progress": + started_at = now + + task = TaskInfo( + id=task_id, + agent_name=request.agent_name, + session_id=request.session_id, + status=status, + title=request.title, + description=request.description, + lease=lease, + payload=request.payload, + tags=request.tags, + source=request.source, + created_at=now, + updated_at=now, + started_at=started_at, + ) + self._write_task(task) + logger.debug("Created local task %s", task_id) + return task + + async def get(self, task_id: str) -> TaskInfo | None: + """Get a task by ID from the filesystem. + + :param task_id: The task identifier. + :type task_id: str + :return: The task record, or ``None`` if not found. + :rtype: TaskInfo | None + """ + path = self._find_task_path(task_id) + if path is None: + return None + return self._read_task(path) + + async def update( + self, task_id: str, patch: TaskPatchRequest + ) -> TaskInfo: # pylint: disable=too-many-branches,too-many-statements + """Update a task via PATCH semantics. + + :param task_id: The task identifier. + :type task_id: str + :param patch: Fields to update. + :type patch: TaskPatchRequest + :return: The updated task record. + :rtype: TaskInfo + :raises TaskNotFound: If the task does not exist. + """ + path = self._find_task_path(task_id) + if path is None: + raise TaskNotFound(task_id) + + task = self._read_task(path) + if task is None: + raise TaskNotFound(task_id) + + # ETag check + if patch.if_match is not None and patch.if_match != task.etag: + raise ValueError( + f"ETag mismatch: expected {patch.if_match!r}, " f"got {task.etag!r}" + ) + + now = _now_iso() + + if patch.status is not None: + old_status = task.status # noqa: F841 # pylint: disable=unused-variable + task.status = patch.status + + if patch.status == "in_progress" and task.started_at is None: + task.started_at = now + if patch.status == "completed": + task.completed_at = now + if patch.status == "suspended": + task.suspension_reason = patch.suspension_reason + + # Lease handling on status transitions + if patch.status in ("completed", "suspended"): + task.lease = None + elif ( + patch.status == "in_progress" + and patch.lease_owner + and patch.lease_instance_id + ): + duration = patch.lease_duration_seconds or 60 + expires_at = ( + datetime.datetime.now(datetime.timezone.utc) + + datetime.timedelta(seconds=duration) + ).isoformat() + old_gen = task.lease.generation if task.lease else -1 + new_gen = ( + old_gen + 1 + if patch.lease_instance_id + != (task.lease.instance_id if task.lease else "") + else max(old_gen, 0) + ) + task.lease = LeaseInfo( + owner=patch.lease_owner, + instance_id=patch.lease_instance_id, + generation=new_gen, + expires_at=expires_at, + expiry_count=task.lease.expiry_count if task.lease else 0, + ) + + # Lease renewal (no status change) + if patch.status is None and patch.lease_owner and patch.lease_instance_id: + duration = patch.lease_duration_seconds or 60 + expires_at = ( + datetime.datetime.now(datetime.timezone.utc) + + datetime.timedelta(seconds=duration) + ).isoformat() + if task.lease and patch.lease_instance_id != task.lease.instance_id: + # Reclaim with new instance + task.lease = LeaseInfo( + owner=patch.lease_owner, + instance_id=patch.lease_instance_id, + generation=task.lease.generation + 1, + expires_at=expires_at, + expiry_count=task.lease.expiry_count, + ) + elif task.lease: + # Simple renewal + task.lease = LeaseInfo( + owner=task.lease.owner, + instance_id=task.lease.instance_id, + generation=task.lease.generation, + expires_at=expires_at, + expiry_count=task.lease.expiry_count, + ) + else: + task.lease = LeaseInfo( + owner=patch.lease_owner, + instance_id=patch.lease_instance_id, + generation=0, + expires_at=expires_at, + ) + + # Force-expire: lease_duration_seconds=0 + if patch.lease_duration_seconds == 0 and task.lease: + task.lease = LeaseInfo( + owner=task.lease.owner, + instance_id=task.lease.instance_id, + generation=task.lease.generation, + expires_at=_now_iso(), + expiry_count=task.lease.expiry_count, + ) + + # Payload shallow-merge (spec §11: root-level additive, values replaced) + if patch.payload is not None: + if task.payload is None: + task.payload = {} + for key, value in patch.payload.items(): + task.payload[key] = value + + # Tags null-as-delete merge + if patch.tags is not None: + if task.tags is None: + task.tags = {} + for key, value in patch.tags.items(): + if value is None: + task.tags.pop(key, None) + else: + task.tags[key] = value + + if patch.error is not None: + task.error = patch.error + + task.updated_at = now + self._write_task(task) + return task + + async def delete( + self, + task_id: str, + *, + force: bool = False, # pylint: disable=unused-argument + cascade: bool = False, # pylint: disable=unused-argument + ) -> None: + """Delete a task JSON file. + + :param task_id: The task identifier. + :type task_id: str + :keyword force: Release active lease before deleting. + :paramtype force: bool + :keyword cascade: Delete dependent tasks (no-op for local). + :paramtype cascade: bool + """ + path = self._find_task_path(task_id) + if path is None: + raise TaskNotFound(task_id) + path.unlink(missing_ok=True) + logger.debug("Deleted local task %s", task_id) + + async def list( + self, + *, + agent_name: str, + session_id: str, + status: TaskStatus | None = None, + lease_owner: str | None = None, + tag: dict[str, str] | None = None, + ) -> list[TaskInfo]: + """List tasks from the filesystem. + + :keyword agent_name: Filter by agent name. + :paramtype agent_name: str + :keyword session_id: Filter by session ID. + :paramtype session_id: str + :keyword status: Filter by task status. + :paramtype status: TaskStatus | None + :keyword lease_owner: Filter by lease owner. + :paramtype lease_owner: str | None + :keyword tag: Filter by tags (AND semantics — all must match). + :paramtype tag: dict[str, str] | None + :return: Matching task records. + :rtype: list[TaskInfo] + """ + task_dir = self._task_dir(agent_name, session_id) + if not task_dir.exists(): + return [] + + results: list[TaskInfo] = [] + for path in task_dir.glob("*.json"): + task = self._read_task(path) + if task is None: + continue + if status is not None and task.status != status: + continue + if lease_owner is not None: + if task.lease is None or task.lease.owner != lease_owner: + continue + if tag is not None: + task_tags = task.tags or {} + if not all(task_tags.get(k) == v for k, v in tag.items()): + continue + results.append(task) + return results diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_manager.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_manager.py new file mode 100644 index 000000000000..f3e373e1f378 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_manager.py @@ -0,0 +1,1740 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""DurableTaskManager — lifecycle orchestration for durable tasks. + +Manages task creation, lease acquisition, execution, recovery, and +shutdown. One instance per ``AgentServerHost``, accessed via the +module-level ``get_task_manager()`` function. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +import logging +import traceback +from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import Any, TypeVar + +from .._config import AgentConfig +from ._context import EntryMode, TaskContext +from ._decorator import DurableTaskOptions, _deserialize_input, _serialize_input +from ._exceptions import TaskFailed, TaskNotFound +from ._lease import derive_lease_owner, generate_instance_id, lease_renewal_loop +from ._metadata import TaskMetadata +from ._models import TaskCreateRequest, TaskInfo, TaskPatchRequest, TaskStatus +from ._provider import DurableTaskProvider +from ._result import TaskResult +from ._retry import RetryPolicy +from ._run import Suspended, TaskRun +from ._stream import QueueStreamHandler, StreamHandler +from .._version import VERSION as _CORE_VERSION +from .._server_version import build_server_version as _build_server_version + +logger = logging.getLogger("azure.ai.agentserver.durable") + +#: Auto-stamped source type for all tasks created by this framework. +_SOURCE_TYPE = "agentserver.durable_task" + +#: Reserved tag key for task name filtering via the LIST API. +_TAG_TASK_NAME = "_durable_task_name" + +#: Pre-computed server version segment for source stamps. +_SOURCE_SERVER_VERSION = _build_server_version( + "azure-ai-agentserver-core", _CORE_VERSION +) + +Input = TypeVar("Input") +Output = TypeVar("Output") + +# Module-level manager singleton +_manager: DurableTaskManager | None = None + + +def get_task_manager() -> DurableTaskManager: + """Return the active DurableTaskManager singleton. + + :raises RuntimeError: If no manager has been initialized. + :return: The active manager. + :rtype: DurableTaskManager + """ + if _manager is None: + raise RuntimeError( + "DurableTaskManager not initialized. Ensure durable tasks " + "are enabled on the AgentServerHost." # pylint: disable=implicit-str-concat + ) + return _manager + + +def set_task_manager(manager: DurableTaskManager | None) -> None: + """Set the module-level DurableTaskManager singleton. + + Called by ``AgentServerHost`` during startup/shutdown. + + :param manager: The manager to set, or ``None`` to clear. + :type manager: DurableTaskManager | None + """ + global _manager # pylint: disable=global-statement + _manager = manager + + +class _ActiveTask: # pylint: disable=too-many-instance-attributes + """In-memory tracking for a running task.""" + + __slots__ = ( + "task_id", + "fn_name", + "context", + "execution_task", + "renewal_task", + "renewal_cancel", + "result_future", + "terminate_event", + "fn", + "input_type", + "opts", + "retry", + ) + + def __init__( + self, + task_id: str, + fn_name: str, + context: TaskContext[Any], + execution_task: asyncio.Task[Any], + renewal_task: asyncio.Task[None] | None, + renewal_cancel: asyncio.Event, + result_future: asyncio.Future[Any], + terminate_event: asyncio.Event | None = None, + fn: Callable[..., Awaitable[Any]] | None = None, + input_type: type[Any] | None = None, + opts: DurableTaskOptions | None = None, + retry: RetryPolicy | None = None, + ) -> None: + self.task_id = task_id + self.fn_name = fn_name + self.context = context + self.execution_task = execution_task + self.renewal_task = renewal_task + self.renewal_cancel = renewal_cancel + self.result_future = result_future + self.terminate_event = terminate_event or asyncio.Event() + self.fn = fn + self.input_type = input_type + self.opts = opts + self.retry = retry + + +class DurableTaskManager: + """Lifecycle orchestrator for durable tasks. + + Manages provider selection, task creation, lease management, + execution dispatch, crash recovery, and graceful shutdown. + + :param config: Resolved agent configuration. + :type config: AgentConfig + :param provider: Optional explicit provider (for testing). + :type provider: DurableTaskProvider | None + :param shutdown_event: Shared shutdown event from the host. + :type shutdown_event: asyncio.Event | None + :param shutdown_grace_seconds: Seconds to wait for tasks to checkpoint + before force-expiring leases during shutdown. Defaults to 25.0. + :type shutdown_grace_seconds: float + """ + + def __init__( + self, + config: AgentConfig, + *, + provider: DurableTaskProvider | None = None, + shutdown_event: asyncio.Event | None = None, + shutdown_grace_seconds: float = 25.0, + ) -> None: + self._config = config + self._provider = provider or self._create_provider(config) + self._active_tasks: dict[str, _ActiveTask] = {} + self._resume_callbacks: dict[str, Callable[..., Any]] = {} + self._resume_opts: dict[str, DurableTaskOptions] = {} + self._lease_owner = derive_lease_owner(config.session_id or "local") + self._instance_id = generate_instance_id() + self._shutdown_event = shutdown_event or asyncio.Event() + self._shutdown_grace_seconds = shutdown_grace_seconds + self._active_generation_future: dict[str, asyncio.Future[Any]] = {} + self._pending_steering_futures: dict[str, list[asyncio.Future[Any]]] = {} + + @staticmethod + def _build_source(fn_name: str) -> dict[str, str]: + """Build the framework-owned source stamp for a task. + + The ``fn_name`` is the developer-provided ``name`` from the decorator + (or ``fn.__qualname__`` when omitted). It serves as the **stable + identity anchor** — recovery routing matches ``source.name`` against + registered callbacks to dispatch recovered tasks back to the correct + function. + + :param fn_name: The task name (from ``@durable_task(name=...)``). + :type fn_name: str + :return: Source metadata dict. + :rtype: dict[str, str] + """ + return { + "type": _SOURCE_TYPE, + "name": fn_name, + "server_version": _SOURCE_SERVER_VERSION, + } + + @staticmethod + def _create_provider(config: AgentConfig) -> DurableTaskProvider: + """Auto-select provider based on hosting environment. + + The Task Storage API is not yet generally available. To avoid + failures in hosted environments, the local file-based provider + is used by default even when ``FOUNDRY_HOSTING_ENVIRONMENT`` is + set. Set the ``FOUNDRY_TASK_API_ENABLED=1`` environment variable + to opt in to the HTTP-backed provider for testing once the APIs + are lit up. + + :param config: The agent configuration. + :type config: AgentConfig + :return: The storage provider instance. + :rtype: DurableTaskProvider + """ + import os # pylint: disable=import-outside-toplevel + + task_api_enabled = os.environ.get("FOUNDRY_TASK_API_ENABLED", "").strip() + + if config.is_hosted and task_api_enabled in ("1", "true", "yes"): + from ._client import ( # pylint: disable=import-outside-toplevel + HostedDurableTaskProvider, + ) + + try: + from azure.identity.aio import ( # type: ignore[import-untyped] + DefaultAzureCredential, + ) + except ImportError as exc: + raise ImportError( + "azure-identity is required for hosted mode. " + "Install with: pip install azure-ai-agentserver-core[hosted]" + ) from exc + + logger.info( + "Task Storage API enabled via FOUNDRY_TASK_API_ENABLED; " # pylint: disable=implicit-str-concat + "using HostedDurableTaskProvider" + ) + return HostedDurableTaskProvider( + project_endpoint=config.project_endpoint, + credential=DefaultAzureCredential(), + ) + + if config.is_hosted and not task_api_enabled: + logger.info( + "Hosted environment detected but Task Storage API not yet enabled. " + "Using local file provider. Set FOUNDRY_TASK_API_ENABLED=1 to use " + "the HTTP-backed provider when the APIs are available." + ) + + from ._local_provider import ( # pylint: disable=import-outside-toplevel + LocalFileDurableTaskProvider, + ) + + return LocalFileDurableTaskProvider(base_dir=Path.home() / ".durable-tasks") + + @property + def provider(self) -> DurableTaskProvider: + """The storage provider. + + :return: The active provider. + :rtype: DurableTaskProvider + """ + return self._provider + + def register_resume_callback( + self, + fn_name: str, + fn: Callable[..., Any], + opts: DurableTaskOptions | None = None, + ) -> None: + """Register a function as a resume callback. + + :param fn_name: The durable task function name. + :type fn_name: str + :param fn: The async function to call on resume. + :type fn: Callable[..., Any] + :param opts: The task options (for stream_handler_factory etc.). + :type opts: DurableTaskOptions | None + """ + self._resume_callbacks[fn_name] = fn + if opts is not None: + self._resume_opts[fn_name] = opts + + async def list_tasks( + self, + *, + fn_name: str, + session_id: str | None = None, + status: TaskStatus | None = None, + ) -> list[TaskInfo]: + """List tasks scoped to a specific durable task function. + + Uses server-side filtering (``agent_name``, ``session_id``, + ``_durable_task_name`` tag, ``status``) and client-side filtering + (``source.type``) to return only tasks created by this framework + for the given function. + + :keyword fn_name: The task function name (stable identity anchor). + :paramtype fn_name: str + :keyword session_id: Session scope override. Defaults to config. + :paramtype session_id: str | None + :keyword status: Filter by task status. + :paramtype status: ~azure.ai.agentserver.core.durable.TaskStatus | None + :return: Matching task records. + :rtype: list[TaskInfo] + """ + resolved_session = session_id or self._config.session_id or "local" + agent_name = self._config.agent_name or "default" + + # Server-side filters: agent_name, session_id, tag, status + results = await self._provider.list( + agent_name=agent_name, + session_id=resolved_session, + status=status, + tag={_TAG_TASK_NAME: fn_name}, + ) + + # Client-side filter: source.type (until source_type server filter exists) + return [ + task + for task in results + if task.source and task.source.get("type") == _SOURCE_TYPE + ] + + def _register_steering_future(self, task_id: str) -> asyncio.Future[Any]: + """Create and register a future for a queued steering input. + + Must be called BEFORE ``_append_steering_input()`` to avoid a race + where the drain pops the queue before the future exists. + + :param task_id: The task identifier. + :type task_id: str + :return: The registered future. + :rtype: asyncio.Future[Any] + """ + loop = asyncio.get_event_loop() + future: asyncio.Future[Any] = loop.create_future() + if task_id not in self._pending_steering_futures: + self._pending_steering_futures[task_id] = [] + self._pending_steering_futures[task_id].append(future) + return future + + async def startup(self) -> None: + """Initialize the manager and recover stale tasks. + + Called by ``AgentServerHost`` during lifespan startup. + """ + logger.info( + "DurableTaskManager starting (owner=%s, instance=%s, hosted=%s)", + self._lease_owner, + self._instance_id, + self._config.is_hosted, + ) + # Pick up descriptors registered at import time (for recovery) + from ._decorator import ( # pylint: disable=import-outside-toplevel + _REGISTERED_DESCRIPTORS, + ) + + for fn_name, fn, opts in _REGISTERED_DESCRIPTORS: + self._resume_callbacks[fn_name] = fn + self._resume_opts[fn_name] = opts + + await self._recover_stale_tasks() + + async def shutdown(self) -> None: + """Signal shutdown on all active tasks and force-expire leases. + + Called by ``AgentServerHost`` during lifespan shutdown. + """ + logger.info("DurableTaskManager shutting down") + self._shutdown_event.set() + + # Signal shutdown on all active contexts + for active in self._active_tasks.values(): + active.context.shutdown.set() + + # Wait for tasks to checkpoint before force-expiring leases + if self._active_tasks: + await asyncio.sleep(self._shutdown_grace_seconds) + + # Force-expire all leases + for active in list(self._active_tasks.values()): + try: + await self._provider.update( + active.task_id, + TaskPatchRequest( + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=0, + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to force-expire lease for task %s", + active.task_id, + exc_info=True, + ) + + # Cancel all renewal and execution tasks + for active in self._active_tasks.values(): + active.renewal_cancel.set() + if active.renewal_task and not active.renewal_task.done(): + active.renewal_task.cancel() + if not active.execution_task.done(): + active.execution_task.cancel() + + self._active_tasks.clear() + set_task_manager(None) + + async def create_and_run( + self, + *, + fn: Callable[..., Awaitable[Any]], + fn_name: str, + task_id: str, + input_val: Any, + input_type: type[Any], + session_id: str | None, + title: str, + tags: dict[str, str], + opts: DurableTaskOptions, + retry: RetryPolicy | None = None, + entry_mode: EntryMode = "fresh", + ) -> Any: + """Create a task, run the function, and return the result. + + :keyword fn: The async function to execute. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword fn_name: The registered function name. + :paramtype fn_name: str + :keyword task_id: Unique task identifier. + :paramtype task_id: str + :keyword input_val: The input value. + :paramtype input_val: Any + :keyword input_type: The input type. + :paramtype input_type: type[Any] + :keyword session_id: Session scope. + :paramtype session_id: str | None + :keyword tags: Task tags. + :paramtype tags: dict[str, str] + :keyword opts: Task options. + :paramtype opts: DurableTaskOptions + :keyword entry_mode: Entry mode. + :paramtype entry_mode: EntryMode + :keyword retry: Retry policy. + :paramtype retry: RetryPolicy | None + :keyword title: Human-readable title. + :paramtype title: str + :returns: The function's return value. + :rtype: Any + :raises TaskFailed: On unhandled exception. + :raises TaskSuspended: If the function suspends. + """ + handle = await self.create_and_start( + fn=fn, + fn_name=fn_name, + task_id=task_id, + input_val=input_val, + input_type=input_type, + session_id=session_id, + title=title, + tags=tags, + opts=opts, + retry=retry, + entry_mode=entry_mode, + ) + return await handle.result() + + async def create_and_start( # pylint: disable=too-many-locals + self, + *, + fn: Callable[..., Awaitable[Any]], + fn_name: str, + task_id: str, + input_val: Any, + input_type: type[Any], # pylint: disable=unused-argument + session_id: str | None, + title: str, + tags: dict[str, str], + description: str | None = None, + opts: DurableTaskOptions, + retry: RetryPolicy | None = None, + entry_mode: EntryMode = "fresh", + stream_handler: StreamHandler | None = None, + ) -> TaskRun[Any]: + """Create a task, start the function, and return a handle. + + Source provenance is auto-stamped by the framework using + ``fn_name`` and the core SDK version. + + :keyword fn: The async task function. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword fn_name: Function name for logging. + :paramtype fn_name: str + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword input_val: The task input value. + :paramtype input_val: Any + :keyword input_type: Type for deserializing input. + :paramtype input_type: type[Any] + :keyword session_id: Session scope identifier. + :paramtype session_id: str | None + :keyword title: Human-readable task title. + :paramtype title: str + :keyword tags: Merged decorator + call-site tags. + :paramtype tags: dict[str, str] + :keyword description: Optional task description. + :paramtype description: str | None + :keyword opts: Task options. + :paramtype opts: DurableTaskOptions + :keyword retry: Retry policy. + :paramtype retry: RetryPolicy | None + :keyword entry_mode: Why this execution is starting. + :paramtype entry_mode: EntryMode + :keyword stream_handler: Custom stream handler. If ``None``, + a default :class:`QueueStreamHandler` is created. + :paramtype stream_handler: StreamHandler | None + :return: A ``TaskRun`` handle. + :rtype: TaskRun + """ + resolved_session = session_id or self._config.session_id or "local" + agent_name = self._config.agent_name or "default" + + # Build payload + payload: dict[str, Any] = {} + if opts.store_input: + payload["input"] = _serialize_input(input_val) + payload["metadata"] = {} + + # Auto-stamp source provenance (framework-owned, not user-overridable) + source = self._build_source(fn_name) + + # Auto-stamp task name tag for LIST filtering + if tags is None: + tags = {} + tags[_TAG_TASK_NAME] = fn_name + + # Create task with lease + task_info = await self._provider.create( + TaskCreateRequest( + id=task_id, + agent_name=agent_name, + session_id=resolved_session, + status="in_progress", + title=title, + description=description, + payload=payload, + tags=tags or None, + source=source, + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=opts.lease_duration_seconds, + ) + ) + + logger.info("Created durable task %s (%s)", task_id, fn_name) + + # Register resume callback + self._resume_callbacks[fn_name] = fn + self._resume_opts[fn_name] = opts + + # Build context + cancel_event = asyncio.Event() + # Resolve handler: call-site > factory > default + if stream_handler is not None: + handler = stream_handler + elif opts.stream_handler_factory is not None: + handler = opts.stream_handler_factory(task_id) + else: + handler = QueueStreamHandler() + metadata = TaskMetadata( + flush_callback=self._make_metadata_flush(task_id), + flush_interval=5.0, + ) + + lease_gen = task_info.lease.generation if task_info.lease else 0 + + ctx: TaskContext[Any] = TaskContext( + task_id=task_id, + title=title, + description=description, + session_id=resolved_session, + agent_name=agent_name, + tags=tags, + input=input_val, + metadata=metadata, + run_attempt=0, + lease_generation=lease_gen, + cancel=cancel_event, + shutdown=self._shutdown_event, + stream_handler=handler, + entry_mode=entry_mode, + generation=0, + ) + loop = asyncio.get_event_loop() + result_future: asyncio.Future[Any] = loop.create_future() + + # Start lease renewal + renewal_cancel = asyncio.Event() + + # Build steering poll callback for steerable tasks + steering_poll_cb_cs: Callable[[], Awaitable[None]] | None = None + if opts.steerable: + + async def _steering_poll_cs() -> None: + active = self._active_tasks.get(task_id) + if active is None or active.context.cancel.is_set(): + return + info = await self._provider.get(task_id) + if info is None or not info.payload: + return + st = info.payload.get("_steering", {}) + if st.get("pending_inputs"): + active.context.cancel.set() + + steering_poll_cb_cs = _steering_poll_cs + + renewal_task = asyncio.create_task( + lease_renewal_loop( + self._provider, + task_id, + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=opts.lease_duration_seconds, + cancel_event=renewal_cancel, + on_cancel_callback=cancel_event, + steering_poll_callback=steering_poll_cb_cs, + ) + ) + + # Start execution + terminate_event = asyncio.Event() + terminate_reason_ref: list[str | None] = [None] + execution_task = asyncio.create_task( + self._execute_task( + fn=fn, + ctx=ctx, + task_id=task_id, + opts=opts, + result_future=result_future, + renewal_cancel=renewal_cancel, + retry=retry, + terminate_event=terminate_event, + terminate_reason_ref=terminate_reason_ref, + ) + ) + + # Track active task + active = _ActiveTask( + task_id=task_id, + fn_name=fn_name, + context=ctx, + execution_task=execution_task, + renewal_task=renewal_task, + renewal_cancel=renewal_cancel, + result_future=result_future, + terminate_event=terminate_event, + fn=fn, + input_type=input_type, + opts=opts, + retry=retry, + ) + self._active_tasks[task_id] = active + + # Start metadata auto-flush + metadata.start_auto_flush() + + return TaskRun( + task_id=task_id, + provider=self._provider, + result_future=result_future, + metadata=metadata, + cancel_event=cancel_event, + stream_handler=handler, + terminate_event=terminate_event, + execution_task=execution_task, + terminate_reason_ref=terminate_reason_ref, + ) + + async def handle_resume(self, task_id: str) -> None: + """Resume a suspended task. + + :param task_id: The task to resume. + :type task_id: str + :raises TaskNotFound: If the task doesn't exist. + :raises ValueError: If the task is not suspended or no callback. + """ + task_info = await self._provider.get(task_id) + if task_info is None: + raise TaskNotFound(task_id) + + if task_info.status != "suspended": + raise ValueError( + f"Task {task_id!r} is {task_info.status!r}, not 'suspended'" + ) + + # Find the resume callback by scanning registered names + fn = self._find_resume_callback(task_info) + if fn is None: + raise ValueError(f"No resume callback registered for task {task_id!r}") + + await self._start_existing_task( + fn=fn, + fn_name=task_info.agent_name, + task_info=task_info, + entry_mode="resumed", + ) + + logger.info("Resumed task %s", task_id) + + def get_active_run(self, task_id: str) -> TaskRun[Any] | None: + """Return a TaskRun handle for an active (in-progress) task. + + Enables late-join consumers to get a handle to a running task's + stream without being the original caller of ``start()``/``run()``. + Returns ``None`` if the task is not currently active in this process. + + :param task_id: The task identifier. + :type task_id: str + :return: A TaskRun bound to the active task's stream handler, + or ``None`` if not active. + :rtype: TaskRun[Any] | None + """ + active = self._active_tasks.get(task_id) + if active is None: + return None + return TaskRun( + task_id=task_id, + provider=self._provider, + result_future=active.result_future, + metadata=active.context.metadata, + cancel_event=active.context.cancel, + stream_handler=active.context._stream_handler, # pylint: disable=protected-access + terminate_event=active.terminate_event, + execution_task=active.execution_task, + ) + + async def _start_existing_task( # pylint: disable=too-many-locals,too-many-statements + self, + *, + fn: Callable[..., Awaitable[Any]], + fn_name: str, + task_info: TaskInfo, + entry_mode: EntryMode, + input_val: Any | None = None, + input_type: type[Any] | None = None, + opts: DurableTaskOptions | None = None, + retry: RetryPolicy | None = None, + stream_handler: StreamHandler | None = None, + ) -> TaskRun[Any]: + """Transition an existing task to in_progress and execute it. + + Used by lifecycle-aware ``.run()``/``.start()`` for suspended, + pending, and stale in_progress tasks. + + :keyword fn: The durable task function. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword fn_name: Function name for logging. + :paramtype fn_name: str + :keyword task_info: The current task record. + :paramtype task_info: TaskInfo + :keyword entry_mode: Why this execution is starting. + :paramtype entry_mode: EntryMode + :keyword input_val: New input (overrides persisted input). + :paramtype input_val: Any | None + :keyword input_type: Type for deserializing persisted input. + :paramtype input_type: type[Any] | None + :keyword opts: Task options (uses defaults if not provided). + :paramtype opts: DurableTaskOptions | None + :keyword retry: Retry policy. + :paramtype retry: RetryPolicy | None + :keyword stream_handler: Custom stream handler. If ``None``, falls + back to ``opts.stream_handler_factory`` or :class:`QueueStreamHandler`. + :paramtype stream_handler: StreamHandler | None + :return: A TaskRun handle. + :rtype: TaskRun[Any] + """ + task_id = task_info.id + resolved_opts = opts or DurableTaskOptions(name=fn_name, ephemeral=False) + lease_duration = resolved_opts.lease_duration_seconds + + # Transition to in_progress with new lease + await self._provider.update( + task_id, + TaskPatchRequest( + status="in_progress", + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=lease_duration, + ), + ) + + # Re-fetch updated task + updated_info: TaskInfo | None = await self._provider.get(task_id) + if updated_info is None: + raise TaskNotFound(task_id) + task_info = updated_info + + # Resolve input: prefer caller-provided, fall back to persisted + if input_val is not None: + resolved_input = input_val + elif task_info.payload and "input" in task_info.payload: + raw_input = task_info.payload["input"] + if input_type is not None: + resolved_input = _deserialize_input(raw_input, input_type) + else: + resolved_input = raw_input + else: + resolved_input = None + + # Build context for execution + cancel_event = asyncio.Event() + # Resolve handler: call-site > factory > default + if stream_handler is not None: + handler = stream_handler + elif resolved_opts.stream_handler_factory is not None: + handler = resolved_opts.stream_handler_factory(task_id) + else: + handler = QueueStreamHandler() + existing_metadata = ( + task_info.payload.get("metadata", {}) if task_info.payload else {} + ) + metadata = TaskMetadata( + initial=existing_metadata, + flush_callback=self._make_metadata_flush(task_id), + flush_interval=5.0, + ) + + lease_gen = task_info.lease.generation if task_info.lease else 0 + + # Extract steering context from payload + steering = (task_info.payload or {}).get("_steering", {}) + # Detect steering context from payload (covers recovered-mid-drain) + was_steered = bool( + steering.get("drain_in_progress") + or steering.get("pending_inputs") + or steering.get("generation", 0) > 0 + ) + + # For steerable recovery with drain_in_progress, use active_input + if ( + entry_mode == "recovered" + and steering.get("drain_in_progress") + and "active_input" in steering + ): + raw_active = steering["active_input"] + if input_type is not None: + resolved_input = _deserialize_input(raw_active, input_type) + else: + resolved_input = raw_active + + prev_input_raw = steering.get("previous_input") + previous_input = None + if prev_input_raw is not None and input_type is not None: + previous_input = _deserialize_input(prev_input_raw, input_type) + elif prev_input_raw is not None: + previous_input = prev_input_raw + pending_snapshot = tuple(steering.get("pending_inputs", ())) + generation = steering.get("generation", 0) + + # Pre-set cancel if cancel_requested is True (steering short-circuit) + if steering.get("cancel_requested"): + cancel_event.set() + + ctx: TaskContext[Any] = TaskContext( + task_id=task_id, + title=task_info.title or "", + description=task_info.description, + session_id=task_info.session_id, + agent_name=task_info.agent_name, + tags=task_info.tags or {}, + input=resolved_input, + metadata=metadata, + run_attempt=0, + lease_generation=lease_gen, + cancel=cancel_event, + shutdown=self._shutdown_event, + stream_handler=handler, + entry_mode=entry_mode, + was_steered=was_steered, + pending_inputs=pending_snapshot, + generation=generation, + ) + + loop = asyncio.get_event_loop() + result_future: asyncio.Future[Any] = loop.create_future() + + renewal_cancel = asyncio.Event() + + # Build steering poll callback for steerable tasks + steering_poll_cb: Callable[[], Awaitable[None]] | None = None + if resolved_opts.steerable: + + async def _steering_poll() -> None: + """Poll provider for new steering inputs and signal cancel.""" + active = self._active_tasks.get(task_id) + if active is None or active.context.cancel.is_set(): + return + info = await self._provider.get(task_id) + if info is None or not info.payload: + return + st = info.payload.get("_steering", {}) + if st.get("pending_inputs"): + active.context.cancel.set() + + steering_poll_cb = _steering_poll + + renewal_task = asyncio.create_task( + lease_renewal_loop( + self._provider, + task_id, + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=lease_duration, + cancel_event=renewal_cancel, + on_cancel_callback=cancel_event, + steering_poll_callback=steering_poll_cb, + ) + ) + + terminate_event = asyncio.Event() + terminate_reason_ref: list[str | None] = [None] + execution_task = asyncio.create_task( + self._execute_task( + fn=fn, + ctx=ctx, + task_id=task_id, + opts=resolved_opts, + result_future=result_future, + renewal_cancel=renewal_cancel, + retry=retry, + terminate_event=terminate_event, + terminate_reason_ref=terminate_reason_ref, + ) + ) + + active = _ActiveTask( + task_id=task_id, + fn_name=fn_name, + context=ctx, + execution_task=execution_task, + renewal_task=renewal_task, + renewal_cancel=renewal_cancel, + result_future=result_future, + terminate_event=terminate_event, + fn=fn, + input_type=input_type, + opts=resolved_opts, + retry=retry, + ) + self._active_tasks[task_id] = active + metadata.start_auto_flush() + + return TaskRun( + task_id=task_id, + provider=self._provider, + result_future=result_future, + metadata=metadata, + cancel_event=cancel_event, + stream_handler=handler, + terminate_event=terminate_event, + execution_task=execution_task, + terminate_reason_ref=terminate_reason_ref, + lease_expiry_count=task_info.lease.expiry_count if task_info.lease else 0, + ) + + async def _timeout_watchdog( + self, + timeout_seconds: float, + cancel_event: asyncio.Event, + ) -> None: + """Background watchdog that enforces execution timeout. + + After *timeout_seconds*, sets *cancel_event* (cooperative). + The function is expected to check ``ctx.cancel`` and exit + gracefully. If it doesn't, the lease will eventually expire + and the task will be recovered. + + :param timeout_seconds: Seconds before cooperative cancel. + :type timeout_seconds: float + :param cancel_event: Event to set for cooperative cancel. + :type cancel_event: asyncio.Event + """ + await asyncio.sleep(timeout_seconds) + cancel_event.set() + logger.info( + "Timeout watchdog fired cooperative cancel after %.1fs", timeout_seconds + ) + + async def _execute_task( + self, + *, + fn: Callable[..., Awaitable[Any]], + ctx: TaskContext[Any], + task_id: str, + opts: DurableTaskOptions, + result_future: asyncio.Future[Any], + renewal_cancel: asyncio.Event, + retry: RetryPolicy | None = None, + terminate_event: asyncio.Event | None = None, + terminate_reason_ref: list[str | None] | None = None, + ) -> None: + """Run the task function and handle completion/failure/suspend. + + When a ``RetryPolicy`` is provided, failed attempts are retried + with the configured delay and backoff. Suspend and cancellation + always exit immediately — they are not retried. + + :keyword fn: The async task function. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword ctx: The task context. + :paramtype ctx: TaskContext[Any] + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword opts: The task options. + :paramtype opts: DurableTaskOptions + :keyword result_future: Future to resolve with the result. + :paramtype result_future: asyncio.Future[Any] + :keyword renewal_cancel: Event to cancel lease renewal. + :paramtype renewal_cancel: asyncio.Event + :keyword retry: Optional retry policy. + :paramtype retry: RetryPolicy | None + :keyword terminate_event: Optional terminate event. + :paramtype terminate_event: asyncio.Event | None + :keyword terminate_reason_ref: Mutable ref for terminate reason. + :paramtype terminate_reason_ref: list[str | None] | None + """ + resolved_terminate = terminate_event or asyncio.Event() + + # Start timeout watchdog if configured + watchdog_task: asyncio.Task[None] | None = None + if opts.timeout is not None: + watchdog_task = asyncio.create_task( + self._timeout_watchdog( + timeout_seconds=opts.timeout.total_seconds(), + cancel_event=ctx.cancel, + ) + ) + + attempt = 0 # pylint: disable=unused-variable + try: + await self._execute_task_loop( + fn=fn, + ctx=ctx, + task_id=task_id, + opts=opts, + result_future=result_future, + renewal_cancel=renewal_cancel, + retry=retry, + terminate_event=resolved_terminate, + terminate_reason_ref=terminate_reason_ref, + ) + finally: + if watchdog_task is not None and not watchdog_task.done(): + watchdog_task.cancel() + try: + await watchdog_task + except asyncio.CancelledError: + pass + + async def _execute_task_loop( # pylint: disable=too-many-statements,too-many-branches,too-many-nested-blocks + self, + *, + fn: Callable[..., Awaitable[Any]], + ctx: TaskContext[Any], + task_id: str, + opts: DurableTaskOptions, + result_future: asyncio.Future[Any], + renewal_cancel: asyncio.Event, + retry: RetryPolicy | None = None, + terminate_event: asyncio.Event | None = None, + terminate_reason_ref: list[str | None] | None = None, + ) -> None: + """Inner execution loop — separated from watchdog management. + + :keyword fn: The async task function. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword ctx: The task context. + :paramtype ctx: TaskContext[Any] + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword opts: The task options. + :paramtype opts: DurableTaskOptions + :keyword result_future: Future to resolve with the result. + :paramtype result_future: asyncio.Future[Any] + :keyword renewal_cancel: Event to cancel lease renewal. + :paramtype renewal_cancel: asyncio.Event + :keyword retry: Optional retry policy. + :paramtype retry: RetryPolicy | None + :keyword terminate_event: Optional terminate event. + :paramtype terminate_event: asyncio.Event | None + :keyword terminate_reason_ref: Mutable ref for terminate reason. + :paramtype terminate_reason_ref: list[str | None] | None + """ + resolved_terminate = terminate_event or asyncio.Event() + reason_ref = ( + terminate_reason_ref if terminate_reason_ref is not None else [None] + ) + attempt = 0 + # Mutable ref: steering drain may swap the active result_future + current_result_future = result_future + while True: + ctx.run_attempt = attempt + try: + result = await fn(ctx) + + if isinstance(result, Suspended): + # STEERING: check for pending inputs BEFORE persisting suspend + if opts.steerable: + new_ctx = await self._try_drain_steering( + task_id=task_id, + ctx=ctx, + opts=opts, + result_future=current_result_future, + ) + if new_ctx is not None: + # Drain found pending input — loop with new context + ctx = new_ctx + attempt = 0 + # Update result future to the new generation's future + active = self._active_tasks.get(task_id) + if ( + active + and active.result_future is not current_result_future + ): + current_result_future = active.result_future + continue + + # No pending steering — normal suspend flow + renewal_cancel.set() + await ctx.metadata.stop_auto_flush() + await self._handle_suspend( + task_id=task_id, + reason=result.reason, + output=result.output, + metadata=ctx.metadata, + opts=opts, + ) + if not current_result_future.done(): + current_result_future.set_result( + TaskResult( + task_id=task_id, + output=result.output, + status="suspended", + suspension_reason=result.reason, + ) + ) + else: + # Guard: task functions must return raw output, not TaskResult + if isinstance(result, TaskResult): + raise TypeError( + "Task function returned TaskResult directly. " + "Return raw output instead — the framework wraps " + "it in TaskResult automatically." + ) + + # STEERING: check for pending before completing + if opts.steerable: + new_ctx = await self._try_drain_steering( + task_id=task_id, + ctx=ctx, + opts=opts, + result_future=current_result_future, + partial_output=result, + ) + if new_ctx is not None: + ctx = new_ctx + attempt = 0 + active = self._active_tasks.get(task_id) + if ( + active + and active.result_future is not current_result_future + ): + current_result_future = active.result_future + continue + + # Success flow + renewal_cancel.set() + await ctx.metadata.stop_auto_flush() + completed = await self._handle_success( + task_id=task_id, + result=result, + metadata=ctx.metadata, + opts=opts, + ) + if not completed: + # Etag conflict on steerable completion — re-drain + renewal_cancel = asyncio.Event() # reset for next iteration + new_ctx = await self._try_drain_steering( + task_id=task_id, + ctx=ctx, + opts=opts, + result_future=current_result_future, + partial_output=result, + ) + if new_ctx is not None: + ctx = new_ctx + attempt = 0 + active = self._active_tasks.get(task_id) + if ( + active + and active.result_future is not current_result_future + ): + current_result_future = active.result_future + continue + # No pending found despite conflict — complete anyway + if not current_result_future.done(): + current_result_future.set_result( + TaskResult( + task_id=task_id, + output=result, + status="completed", + ) + ) + + break # exit retry loop on success or suspend + + except asyncio.CancelledError: + renewal_cancel.set() + await ctx.metadata.stop_auto_flush() + if resolved_terminate.is_set(): + # Forced termination (timeout or explicit terminate()) + from ._exceptions import ( # pylint: disable=import-outside-toplevel + TaskTerminated, + ) + + await self._handle_failure( + task_id=task_id, + exc=TaskTerminated(task_id, reason=reason_ref[0]), + metadata=ctx.metadata, + opts=opts, + ) + if not current_result_future.done(): + current_result_future.set_exception( + TaskTerminated(task_id, reason=reason_ref[0]) + ) + else: + # Cooperative cancellation (suspend or caller cancel) + if not current_result_future.done(): + from ._exceptions import ( # pylint: disable=import-outside-toplevel + TaskCancelled, + ) + + current_result_future.set_exception(TaskCancelled(task_id)) + break # cancellation is never retried + + except Exception as exc: # pylint: disable=broad-exception-caught + if retry and retry.should_retry(attempt, exc): + delay = retry.compute_delay(attempt) + logger.warning( + "Task %s attempt %d failed (%s: %s), retrying in %.1fs", + task_id, + attempt, + type(exc).__name__, + exc, + delay, + ) + # Update error field so observers see intermediate failures + try: + await self._provider.update( + task_id, + TaskPatchRequest( + error={ + "type": type(exc).__name__, + "message": str(exc), + "attempt": attempt, + } + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.debug( + "Failed to update error field for retry", exc_info=True + ) + await asyncio.sleep(delay) + attempt += 1 + continue + + # Exhausted or non-retryable — terminal failure + renewal_cancel.set() + await ctx.metadata.stop_auto_flush() + + if retry and attempt > 0: + # Retries were attempted but exhausted + error_dict: dict[str, Any] = { + "type": "exhausted_retries", + "attempts": attempt + 1, + "last_error": str(exc), + "last_error_type": type(exc).__name__, + "traceback": traceback.format_exc(), + } + else: + error_dict = { + "type": type(exc).__name__, + "message": str(exc), + "traceback": traceback.format_exc(), + } + + await self._handle_failure( + task_id=task_id, + exc=exc, + metadata=ctx.metadata, + opts=opts, + ) + if not current_result_future.done(): + current_result_future.set_exception(TaskFailed(task_id, error_dict)) + break + + self._active_tasks.pop(task_id, None) + # Signal end of streaming via handler.close() + if ctx._stream_handler is not None: # pylint: disable=protected-access + try: + await ctx._stream_handler.close() # pylint: disable=protected-access + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Stream handler close() failed for task %s", + task_id, + exc_info=True, + ) + + async def _try_drain_steering( # pylint: disable=too-many-branches + self, + *, + task_id: str, + ctx: TaskContext[Any], + opts: DurableTaskOptions, + result_future: asyncio.Future[Any], + partial_output: Any | None = None, + ) -> TaskContext[Any] | None: + """Check for pending steering inputs and drain the next one. + + Called BEFORE persisting suspend/complete to avoid lease/status conflicts. + Returns a new ``TaskContext`` if a drain occurred, or ``None`` if no + pending inputs exist. + + :keyword task_id: The task identifier. + :keyword ctx: Current task context. + :keyword opts: Task options. + :keyword result_future: The current generation's result future. + :keyword partial_output: Output from the completed generation (for race recovery). + :return: New context for the drained generation, or None. + """ + task_info = await self._provider.get(task_id) + if task_info is None: + return None + + payload = dict(task_info.payload) if task_info.payload else {} + steering = dict(payload.get("_steering", {})) + pending: list[Any] = list(steering.get("pending_inputs", [])) + + if not pending: + return None + + # Pop the next input from the queue + next_input_raw = pending.pop(0) + previous_input_raw = steering.get("active_input") + + # Update steering state + steering["active_input"] = next_input_raw + if previous_input_raw is not None: + steering["previous_input"] = previous_input_raw + steering["pending_inputs"] = pending + old_generation = steering.get("generation", 0) + steering["generation"] = old_generation + 1 + steering["cancel_requested"] = len(pending) > 0 + steering["drain_in_progress"] = True + + # Save partial output if function completed (race recovery) + if partial_output is not None: + gen_results = dict(steering.get("generation_results", {})) + gen_results[str(old_generation)] = _serialize_input(partial_output) + steering["generation_results"] = gen_results + + payload["_steering"] = steering + + try: + etag = getattr(task_info, "etag", None) or None + await self._provider.update( + task_id, + TaskPatchRequest(payload=payload, if_match=etag), + ) + except ValueError: + # Etag conflict — re-read and retry once + logger.warning( + "Etag conflict during steering drain for %s, retrying", task_id + ) + return await self._try_drain_steering( + task_id=task_id, + ctx=ctx, + opts=opts, + result_future=result_future, + partial_output=partial_output, + ) + + # Pop and bind the next pending steering future (if any) + new_future: asyncio.Future[Any] | None = None + had_registered_future = False + steering_futures = self._pending_steering_futures.get(task_id, []) + if steering_futures: + new_future = steering_futures.pop(0) + had_registered_future = True + + # Resolve the superseded generation's future (only for external steer callers) + if had_registered_future and not result_future.done(): + result_future.set_result( + TaskResult(task_id=task_id, output=partial_output, status="superseded") + ) + + # Update active generation future + if new_future is not None: + self._active_generation_future[task_id] = new_future + + # Deserialize input + active_task = self._active_tasks.get(task_id) + input_type = active_task.input_type if active_task else None + if input_type is not None: + resolved_input = _deserialize_input(next_input_raw, input_type) + else: + resolved_input = next_input_raw + + # Deserialize previous input + previous_input = None + if previous_input_raw is not None and input_type is not None: + previous_input = _deserialize_input(previous_input_raw, input_type) + elif previous_input_raw is not None: + previous_input = previous_input_raw + + # Build new context, reusing metadata and shutdown event + cancel_event = asyncio.Event() + if steering["cancel_requested"]: + cancel_event.set() + + new_ctx: TaskContext[Any] = TaskContext( + task_id=task_id, + title=ctx.title, + description=ctx.description, + session_id=ctx.session_id, + agent_name=ctx.agent_name, + tags=ctx.tags, + input=resolved_input, + metadata=ctx.metadata, + run_attempt=0, + lease_generation=ctx.lease_generation, + cancel=cancel_event, + shutdown=ctx.shutdown, + stream_handler=ctx._stream_handler, # pylint: disable=protected-access + entry_mode="resumed", + was_steered=True, + previous_input=previous_input, + pending_inputs=tuple(pending), + generation=old_generation + 1, + ) + + # Update active task tracking + if active_task is not None: + active_task.context = new_ctx + if new_future is not None: + active_task.result_future = new_future + + # Clear drain_in_progress + steering["drain_in_progress"] = False + payload["_steering"] = steering + try: + await self._provider.update( + task_id, + TaskPatchRequest(payload=payload), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.debug("Failed to clear drain_in_progress for %s", task_id) + + logger.info( + "Steering drain: task %s generation %d → %d", + task_id, + old_generation, + old_generation + 1, + ) + return new_ctx + + async def _handle_success( + self, + *, + task_id: str, + result: Any, + metadata: TaskMetadata, + opts: DurableTaskOptions, + ) -> bool: + """Handle successful task completion. + + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword result: The task result value. + :paramtype result: Any + :keyword metadata: The task metadata. + :paramtype metadata: TaskMetadata + :keyword opts: The task options. + :paramtype opts: DurableTaskOptions + :return: True if completion succeeded, False if etag conflict + detected (steerable tasks only — caller should re-drain). + :rtype: bool + """ + if opts.ephemeral: + # Delete immediately — no intermediate PATCH + try: + await self._provider.delete(task_id, force=True) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to delete ephemeral task %s", task_id, exc_info=True + ) + else: + # PATCH to completed with output + payload_patch: dict[str, Any] = { + "metadata": metadata.to_dict(), + "output": _serialize_input(result), + } + + # For steerable tasks, use etag to detect concurrent steering + if opts.steerable: + try: + task_info = await self._provider.get(task_id) + etag = getattr(task_info, "etag", None) if task_info else None + await self._provider.update( + task_id, + TaskPatchRequest( + status="completed", + payload=payload_patch, + if_match=etag, + ), + ) + except ValueError: + # Etag conflict — another process may have steered + logger.info( + "Etag conflict completing task %s — re-checking for steers", + task_id, + ) + return False + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to complete task %s", task_id, exc_info=True) + else: + try: + await self._provider.update( + task_id, + TaskPatchRequest( + status="completed", + payload=payload_patch, + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to complete task %s", task_id, exc_info=True) + + logger.info("Task %s completed successfully", task_id) + return True + + async def _handle_failure( + self, + *, + task_id: str, + exc: Exception, + metadata: TaskMetadata, + opts: DurableTaskOptions, + ) -> None: + """Handle task failure. + + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword exc: The exception that caused the failure. + :paramtype exc: Exception + :keyword metadata: The task metadata. + :paramtype metadata: TaskMetadata + :keyword opts: The task options. + :paramtype opts: DurableTaskOptions + """ + error_dict = { + "type": type(exc).__name__, + "message": str(exc), + "traceback": traceback.format_exc(), + } + + if opts.ephemeral: + try: + await self._provider.delete(task_id, force=True) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to delete failed ephemeral task %s", + task_id, + exc_info=True, + ) + else: + try: + await self._provider.update( + task_id, + TaskPatchRequest( + status="completed", + error=error_dict, + payload={"metadata": metadata.to_dict()}, + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to record error for task %s", + task_id, + exc_info=True, + ) + + logger.error("Task %s failed: %s", task_id, exc) + + async def _handle_suspend( + self, + *, + task_id: str, + reason: str | None, + output: Any | None, + metadata: TaskMetadata, + opts: DurableTaskOptions, # pylint: disable=unused-argument + ) -> None: + """Handle task suspension. + + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword reason: Optional suspension reason. + :paramtype reason: str | None + :keyword output: Optional output snapshot. + :paramtype output: Any | None + :keyword metadata: The task metadata. + :paramtype metadata: TaskMetadata + :keyword opts: The task options. + :paramtype opts: DurableTaskOptions + """ + payload_patch: dict[str, Any] = { + "metadata": metadata.to_dict(), + } + if output is not None: + payload_patch["output"] = _serialize_input(output) + + try: + await self._provider.update( + task_id, + TaskPatchRequest( + status="suspended", + suspension_reason=reason, + payload=payload_patch, + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to suspend task %s", task_id, exc_info=True) + + logger.info("Task %s suspended: %s", task_id, reason) + + async def _recover_stale_tasks(self) -> None: + """Recover stale in-progress tasks from previous instances.""" + agent_name = self._config.agent_name or "default" + session_id = self._config.session_id or "local" + + try: + stale_tasks = await self._provider.list( + agent_name=agent_name, + session_id=session_id, + status="in_progress", + lease_owner=self._lease_owner, + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to query stale tasks for recovery", exc_info=True) + return + + for task_info in stale_tasks: + # Skip if we're already tracking this task + if task_info.id in self._active_tasks: + continue + + # Reclaim the lease with our new instance ID + try: + await self._provider.update( + task_info.id, + TaskPatchRequest( + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=60, + ), + ) + logger.info( + "Reclaimed stale task %s (generation will increment)", + task_info.id, + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to reclaim task %s", task_info.id, exc_info=True) + continue + + # Find resume callback and dispatch + fn = self._find_resume_callback(task_info) + if fn is not None: + try: + # Look up stored opts for stream_handler_factory etc. + fn_name = (task_info.source or {}).get("name", "") + opts = self._resume_opts.get(fn_name) + await self._start_existing_task( + fn=fn, + fn_name=task_info.agent_name, + task_info=task_info, + entry_mode="recovered", + opts=opts, + ) + logger.info("Recovered task %s is now active", task_info.id) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to resume recovered task %s", + task_info.id, + exc_info=True, + ) + + def _find_resume_callback(self, task_info: TaskInfo) -> Callable[..., Any] | None: + """Find a registered resume callback for a task. + + Matches by ``source.name`` (auto-stamped function name) first, + then falls back to title prefix match or single-callback default. + + :param task_info: The task record to match. + :type task_info: TaskInfo + :return: A matching resume callback, or None. + :rtype: Callable[..., Any] | None + """ + # Preferred: match by source.name (framework auto-stamped fn name) + if task_info.source and "name" in task_info.source: + source_name = task_info.source["name"] + if source_name in self._resume_callbacks: + return self._resume_callbacks[source_name] + + # Fallback: title prefix match + for name, fn in self._resume_callbacks.items(): + if task_info.title and task_info.title.startswith(name): + return fn + + # Last resort: single registered callback + if len(self._resume_callbacks) == 1: + return next(iter(self._resume_callbacks.values())) + return None + + def _make_metadata_flush( + self, task_id: str + ) -> Callable[[dict[str, Any]], Awaitable[None]]: + """Create a flush callback for metadata persistence. + + :param task_id: The task identifier. + :type task_id: str + :return: An async callback that flushes metadata. + :rtype: Callable[[dict[str, Any]], Awaitable[None]] + """ + + async def _flush(data: dict[str, Any]) -> None: + await self._provider.update( + task_id, + TaskPatchRequest(payload={"metadata": data}), + ) + + return _flush diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_metadata.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_metadata.py new file mode 100644 index 000000000000..885af44065cf --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_metadata.py @@ -0,0 +1,235 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Mutable progress metadata for durable tasks. + +Provides a dict-like interface with typed mutation methods and +debounced persistence to the task storage backend. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +import collections.abc +import logging +from collections.abc import Iterator +from typing import Any + +logger = logging.getLogger("azure.ai.agentserver.durable") + +# Sentinel to distinguish "not set" from None +_NOT_SET = object() + + +class TaskMetadata: + """Mutable progress dict persisted to the task record's payload. + + Changes are batched and flushed on a configurable interval, or + immediately on explicit :meth:`flush`, suspension, or completion. + + :param initial: Initial metadata values (from a recovered task). + :type initial: dict[str, Any] | None + :param flush_callback: Async callable that persists dirty metadata. + :type flush_callback: Callable[[dict[str, Any]], Awaitable[None]] | None + :param flush_interval: Seconds between automatic flushes (0 = disabled). + :type flush_interval: float + """ + + def __init__( + self, + initial: dict[str, Any] | None = None, + *, + flush_callback: Any = None, + flush_interval: float = 5.0, + ) -> None: + self._data: dict[str, Any] = dict(initial) if initial else {} + self._dirty = False + self._flush_callback = flush_callback + self._flush_interval = flush_interval + self._flush_task: asyncio.Task[None] | None = None + self._lock = asyncio.Lock() + + def set(self, key: str, value: Any) -> None: + """Set a key-value pair. + + :param key: Metadata key (must be a string). + :type key: str + :param value: Any JSON-serializable value. + :type value: Any + :raises TypeError: If key is not a string. + """ + if not isinstance(key, str): + raise TypeError(f"Metadata key must be a string, got {type(key).__name__}") + self._data[key] = value + self._mark_dirty() + + def get(self, key: str, default: Any = None) -> Any: + """Get a value by key. + + :param key: Metadata key. + :type key: str + :param default: Default value if key is absent. + :type default: Any + :return: The value, or *default*. + :rtype: Any + """ + return self._data.get(key, default) + + def increment(self, key: str, delta: int = 1) -> None: + """Atomically increment a numeric value. + + :param key: Metadata key. + :type key: str + :param delta: Amount to add (default 1). + :type delta: int + :raises TypeError: If the existing value is not numeric. + """ + if not isinstance(delta, (int, float)): + raise TypeError(f"Delta must be numeric, got {type(delta).__name__}") + current = self._data.get(key, 0) + if not isinstance(current, (int, float)): + raise TypeError( + f"Cannot increment non-numeric value at key {key!r}: " + f"{type(current).__name__}" + ) + self._data[key] = current + delta + self._mark_dirty() + + def append(self, key: str, value: Any) -> None: + """Append a value to a list. + + Creates the list if the key is absent. + + :param key: Metadata key. + :type key: str + :param value: Value to append. + :type value: Any + :raises TypeError: If the existing value is not a list. + """ + current = self._data.get(key, _NOT_SET) + if current is _NOT_SET: + self._data[key] = [value] + elif isinstance(current, list): + current.append(value) + else: + raise TypeError( + f"Cannot append to non-list value at key {key!r}: " + f"{type(current).__name__}" + ) + self._mark_dirty() + + def to_dict(self) -> dict[str, Any]: + """Return a snapshot of all metadata. + + :return: A shallow copy of the metadata dict. + :rtype: dict[str, Any] + """ + return dict(self._data) + + # -- Dict protocol (MutableMapping) ------------------------------------ + + def __setitem__(self, key: str, value: Any) -> None: + if not isinstance(key, str): + raise TypeError(f"Metadata key must be a string, got {type(key).__name__}") + self._data[key] = value + self._mark_dirty() + + def __getitem__(self, key: str) -> Any: + return self._data[key] + + def __delitem__(self, key: str) -> None: + del self._data[key] + self._mark_dirty() + + def __contains__(self, key: object) -> bool: + return key in self._data + + def __iter__(self) -> Iterator[str]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def keys(self) -> collections.abc.KeysView[str]: + """Return a view of metadata keys. + + :return: A view of the metadata keys. + :rtype: ~collections.abc.KeysView[str] + """ + return self._data.keys() + + def values(self) -> collections.abc.ValuesView[Any]: + """Return a view of metadata values. + + :return: A view of the metadata values. + :rtype: ~collections.abc.ValuesView[Any] + """ + return self._data.values() + + def items(self) -> collections.abc.ItemsView[str, Any]: + """Return a view of metadata key-value pairs. + + :return: A view of the metadata key-value pairs. + :rtype: ~collections.abc.ItemsView[str, Any] + """ + return self._data.items() + + async def flush(self) -> None: + """Force-flush pending metadata changes to the store. + + No-op if there are no pending changes or no flush callback. + """ + async with self._lock: + await self._do_flush() + + def start_auto_flush(self) -> None: + """Start the background auto-flush loop. + + Called by the framework when the task starts executing. Should + not be called by user code. + """ + if ( + self._flush_interval > 0 + and self._flush_callback is not None + and self._flush_task is None + ): + self._flush_task = asyncio.get_event_loop().create_task( + self._auto_flush_loop() + ) + + async def stop_auto_flush(self) -> None: + """Stop the auto-flush loop and perform a final flush.""" + if self._flush_task is not None: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + self._flush_task = None + # Final flush + async with self._lock: + await self._do_flush() + + def _mark_dirty(self) -> None: + self._dirty = True + + async def _do_flush(self) -> None: + if not self._dirty or self._flush_callback is None: + return + try: + await self._flush_callback(dict(self._data)) + self._dirty = False + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to flush metadata", exc_info=True) + + async def _auto_flush_loop(self) -> None: + """Periodically flush dirty metadata.""" + while True: + await asyncio.sleep(self._flush_interval) + async with self._lock: + await self._do_flush() + + +# Register as a virtual subclass so isinstance checks work +# without inheriting (preserves custom increment/append/flush). +collections.abc.MutableMapping.register(TaskMetadata) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_models.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_models.py new file mode 100644 index 000000000000..f4a28cbde7b0 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_models.py @@ -0,0 +1,380 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Internal data models for the durable task subsystem. + +These types represent wire-level task records and request/response shapes +used by providers. They are **not** part of the public API. +""" + +from __future__ import annotations + +from typing import Any, Literal + +TaskStatus = Literal["pending", "in_progress", "suspended", "completed"] +"""Valid task status values.""" + + +class LeaseInfo: + """Lease details on a task record. + + :param owner: Stable lease owner (e.g. ``"session:session_abc"``). + :type owner: str + :param instance_id: Ephemeral per-process instance identifier. + :type instance_id: str + :param generation: Fencing token — increments on re-acquisition. + :type generation: int + :param expires_at: ISO 8601 expiry timestamp. + :type expires_at: str + :param expiry_count: Number of times ownership changed via expiry. + :type expiry_count: int + """ + + __slots__ = ("owner", "instance_id", "generation", "expires_at", "expiry_count") + + def __init__( + self, + owner: str, + instance_id: str, + generation: int, + expires_at: str, + expiry_count: int = 0, + ) -> None: + self.owner = owner + self.instance_id = instance_id + self.generation = generation + self.expires_at = expires_at + self.expiry_count = expiry_count + + def __repr__(self) -> str: + return ( + f"LeaseInfo(owner={self.owner!r}, instance_id={self.instance_id!r}, " + f"generation={self.generation!r}, expires_at={self.expires_at!r}, " + f"expiry_count={self.expiry_count!r})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, LeaseInfo): + return NotImplemented + return ( + self.owner == other.owner + and self.instance_id == other.instance_id + and self.generation == other.generation + and self.expires_at == other.expires_at + and self.expiry_count == other.expiry_count + ) + + +class TaskInfo: # pylint: disable=too-many-instance-attributes + """Internal representation of a task record from the store. + + :param id: Unique task identifier. + :type id: str + :param agent_name: Agent scope. + :type agent_name: str + :param session_id: Session scope. + :type session_id: str + :param status: Current task status. + :type status: TaskStatus + :param title: Human-readable title. + :type title: str | None + :param description: Optional description. + :type description: str | None + :param lease: Active lease details, or ``None``. + :type lease: LeaseInfo | None + :param payload: Arbitrary JSON payload (input, metadata, output buckets). + :type payload: dict[str, Any] | None + :param tags: Key-value tags. + :type tags: dict[str, str] | None + :param error: Structured error details on failure. + :type error: dict[str, Any] | None + :param suspension_reason: Reason for suspension. + :type suspension_reason: str | None + :param etag: Optimistic concurrency token. + :type etag: str + :param created_at: ISO 8601 creation timestamp. + :type created_at: str + :param updated_at: ISO 8601 last-update timestamp. + :type updated_at: str + :param started_at: ISO 8601 timestamp of first ``in_progress`` transition. + :type started_at: str | None + :param completed_at: ISO 8601 timestamp of ``completed`` transition. + :type completed_at: str | None + """ + + __slots__ = ( + "id", + "agent_name", + "session_id", + "status", + "title", + "description", + "lease", + "payload", + "tags", + "error", + "suspension_reason", + "etag", + "created_at", + "updated_at", + "started_at", + "completed_at", + "source", + ) + + def __init__( + self, + id: str, # noqa: A002 + agent_name: str, + session_id: str, + status: TaskStatus, + title: str | None = None, + description: str | None = None, + lease: LeaseInfo | None = None, + payload: dict[str, Any] | None = None, + tags: dict[str, str] | None = None, + error: dict[str, Any] | None = None, + suspension_reason: str | None = None, + etag: str = "", + created_at: str = "", + updated_at: str = "", + started_at: str | None = None, + completed_at: str | None = None, + source: dict[str, Any] | None = None, + ) -> None: + self.id = id + self.agent_name = agent_name + self.session_id = session_id + self.status = status + self.title = title + self.description = description + self.lease = lease + self.payload = payload + self.tags = tags + self.error = error + self.suspension_reason = suspension_reason + self.etag = etag + self.created_at = created_at + self.updated_at = updated_at + self.started_at = started_at + self.completed_at = completed_at + self.source = source + + def __repr__(self) -> str: + return f"TaskInfo(id={self.id!r}, status={self.status!r}, agent_name={self.agent_name!r})" + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> TaskInfo: + """Construct a :class:`TaskInfo` from a JSON-decoded dict. + + :param data: Dictionary as returned by the Task Storage API. + :type data: dict[str, Any] + :return: A populated TaskInfo instance. + :rtype: TaskInfo + """ + lease_data = data.get("lease") + lease = ( + LeaseInfo( + owner=lease_data["owner"], + instance_id=lease_data["instance_id"], + generation=lease_data.get("generation", 0), + expires_at=lease_data.get("expires_at", ""), + expiry_count=lease_data.get("expiry_count", 0), + ) + if lease_data + else None + ) + return cls( + id=data["id"], + agent_name=data.get("agent_name", ""), + session_id=data.get("session_id", ""), + status=data.get("status", "pending"), + title=data.get("title"), + description=data.get("description"), + lease=lease, + payload=data.get("payload"), + tags=data.get("tags"), + error=data.get("error"), + suspension_reason=data.get("suspension_reason"), + etag=data.get("etag", ""), + created_at=data.get("created_at", ""), + updated_at=data.get("updated_at", ""), + started_at=data.get("started_at"), + completed_at=data.get("completed_at"), + source=data.get("source"), + ) + + def to_dict(self) -> dict[str, Any]: + """Serialize to a JSON-compatible dictionary. + + :return: Dictionary suitable for JSON serialization. + :rtype: dict[str, Any] + """ + result: dict[str, Any] = { + "object": "task", + "id": self.id, + "agent_name": self.agent_name, + "session_id": self.session_id, + "status": self.status, + } + if self.title is not None: + result["title"] = self.title + if self.description is not None: + result["description"] = self.description + if self.lease is not None: + result["lease"] = { + "owner": self.lease.owner, + "instance_id": self.lease.instance_id, + "generation": self.lease.generation, + "expires_at": self.lease.expires_at, + "expiry_count": self.lease.expiry_count, + } + else: + result["lease"] = None + if self.payload is not None: + result["payload"] = self.payload + if self.tags is not None: + result["tags"] = self.tags + if self.error is not None: + result["error"] = self.error + if self.suspension_reason is not None: + result["suspension_reason"] = self.suspension_reason + if self.source is not None: + result["source"] = self.source + result["etag"] = self.etag + result["created_at"] = self.created_at + result["updated_at"] = self.updated_at + result["started_at"] = self.started_at + result["completed_at"] = self.completed_at + return result + + +class TaskCreateRequest: # pylint: disable=too-many-instance-attributes + """Request body for creating a task. + + :param agent_name: Agent scope. + :type agent_name: str + :param session_id: Session scope. + :type session_id: str + :param status: Initial status (``"pending"`` or ``"in_progress"``). + :type status: TaskStatus + :param id: Optional client-supplied task ID. + :type id: str | None + :param title: Human-readable title. + :type title: str | None + :param description: Optional description. + :type description: str | None + :param payload: Initial payload (input bucket). + :type payload: dict[str, Any] | None + :param tags: Initial tags. + :type tags: dict[str, str] | None + :param lease_owner: Required when ``status`` is ``"in_progress"``. + :type lease_owner: str | None + :param lease_instance_id: Required when ``status`` is ``"in_progress"``. + :type lease_instance_id: str | None + :param lease_duration_seconds: Lease TTL. Required with lease params. + :type lease_duration_seconds: int | None + """ + + __slots__ = ( + "agent_name", + "session_id", + "status", + "id", + "title", + "description", + "payload", + "tags", + "source", + "lease_owner", + "lease_instance_id", + "lease_duration_seconds", + ) + + def __init__( + self, + agent_name: str, + session_id: str, + status: TaskStatus = "pending", + id: str | None = None, # noqa: A002 + title: str | None = None, + description: str | None = None, + payload: dict[str, Any] | None = None, + tags: dict[str, str] | None = None, + source: dict[str, Any] | None = None, + lease_owner: str | None = None, + lease_instance_id: str | None = None, + lease_duration_seconds: int | None = None, + ) -> None: + self.agent_name = agent_name + self.session_id = session_id + self.status = status + self.id = id + self.title = title + self.description = description + self.payload = payload + self.tags = tags + self.source = source + self.lease_owner = lease_owner + self.lease_instance_id = lease_instance_id + self.lease_duration_seconds = lease_duration_seconds + + +class TaskPatchRequest: + """Request body for patching a task. + + Only non-``None`` fields are included in the PATCH payload. + + :param status: New status. + :type status: TaskStatus | None + :param payload: Payload patch (shallow-merge semantics). + :type payload: dict[str, Any] | None + :param tags: Tags patch (null-as-delete merge). + :type tags: dict[str, str] | None + :param error: Structured error (on failure). + :type error: dict[str, Any] | None + :param suspension_reason: Reason for suspension. + :type suspension_reason: str | None + :param lease_owner: Lease owner for transitions. + :type lease_owner: str | None + :param lease_instance_id: Lease instance for transitions. + :type lease_instance_id: str | None + :param lease_duration_seconds: Lease TTL override. + :type lease_duration_seconds: int | None + :param if_match: ETag for optimistic concurrency. + :type if_match: str | None + """ + + __slots__ = ( + "status", + "payload", + "tags", + "error", + "suspension_reason", + "lease_owner", + "lease_instance_id", + "lease_duration_seconds", + "if_match", + ) + + def __init__( + self, + status: TaskStatus | None = None, + payload: dict[str, Any] | None = None, + tags: dict[str, str] | None = None, + error: dict[str, Any] | None = None, + suspension_reason: str | None = None, + lease_owner: str | None = None, + lease_instance_id: str | None = None, + lease_duration_seconds: int | None = None, + if_match: str | None = None, + ) -> None: + self.status = status + self.payload = payload + self.tags = tags + self.error = error + self.suspension_reason = suspension_reason + self.lease_owner = lease_owner + self.lease_instance_id = lease_instance_id + self.lease_duration_seconds = lease_duration_seconds + self.if_match = if_match diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_provider.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_provider.py new file mode 100644 index 000000000000..9fa2acaf326e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_provider.py @@ -0,0 +1,102 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Storage provider protocol for the durable task subsystem. + +Defines the structural typing contract that hosted and local providers +must satisfy. Uses :class:`typing.Protocol` (PEP 544) — implementations +do not need to inherit from this class. +""" + +from __future__ import annotations + +from typing import Protocol, runtime_checkable + +from ._models import TaskCreateRequest, TaskInfo, TaskPatchRequest, TaskStatus + + +@runtime_checkable +class DurableTaskProvider(Protocol): + """Async storage backend for durable tasks. + + Both :class:`HostedDurableTaskProvider` (HTTP → Task Storage API) and + :class:`LocalFileDurableTaskProvider` (filesystem) implement this + protocol. + """ + + async def create(self, request: TaskCreateRequest) -> TaskInfo: + """Create a new task. + + :param request: Task creation parameters. + :type request: TaskCreateRequest + :return: The created task record. + :rtype: TaskInfo + """ + ... + + async def get(self, task_id: str) -> TaskInfo | None: + """Get a single task by ID. + + :param task_id: The task identifier. + :type task_id: str + :return: The task record, or ``None`` if not found. + :rtype: TaskInfo | None + """ + ... + + async def update(self, task_id: str, patch: TaskPatchRequest) -> TaskInfo: + """Update a task via PATCH semantics. + + :param task_id: The task identifier. + :type task_id: str + :param patch: Fields to update. + :type patch: TaskPatchRequest + :return: The updated task record. + :rtype: TaskInfo + :raises TaskNotFound: If the task does not exist. + """ + ... + + async def delete( + self, + task_id: str, + *, + force: bool = False, + cascade: bool = False, + ) -> None: + """Delete a task. + + :param task_id: The task identifier. + :type task_id: str + :keyword force: Release active lease before deleting. + :paramtype force: bool + :keyword cascade: Delete dependent tasks. + :paramtype cascade: bool + """ + ... + + async def list( + self, + *, + agent_name: str, + session_id: str, + status: TaskStatus | None = None, + lease_owner: str | None = None, + tag: dict[str, str] | None = None, + ) -> list[TaskInfo]: + """List tasks with filters. + + :keyword agent_name: Filter by agent name. + :paramtype agent_name: str + :keyword session_id: Filter by session ID. + :paramtype session_id: str + :keyword status: Filter by task status. + :paramtype status: TaskStatus | None + :keyword lease_owner: Filter by lease owner. + :paramtype lease_owner: str | None + :keyword tag: Filter by tags (AND semantics — all must match). + :paramtype tag: dict[str, str] | None + :return: Matching task records. + :rtype: list[TaskInfo] + """ + ... diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_result.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_result.py new file mode 100644 index 000000000000..4130b2f0d9bd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_result.py @@ -0,0 +1,81 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""TaskResult wrapper for durable task completion and suspension outcomes.""" + +from __future__ import annotations + +from typing import Generic, Literal, TypeVar + +Output = TypeVar("Output") + + +class TaskResult(Generic[Output]): + """Result of a durable task execution. + + Wraps both completion and suspension outcomes. Failures, cancellation, + and termination are still raised as exceptions. + + :param task_id: The task identifier. + :type task_id: str + :param output: The task output value (typed for completion, optional for suspension). + :type output: Output | None + :param status: Whether the task completed, suspended, or was superseded. + :type status: ~typing.Literal["completed", "suspended", "superseded"] + :param suspension_reason: Human-readable suspension reason, if suspended. + :type suspension_reason: str | None + """ + + __slots__ = ("task_id", "output", "status", "suspension_reason") + + def __init__( + self, + *, + task_id: str, + output: Output | None = None, + status: Literal["completed", "suspended", "superseded"], + suspension_reason: str | None = None, + ) -> None: + self.task_id = task_id + self.output = output + self.status: Literal["completed", "suspended", "superseded"] = status + self.suspension_reason = suspension_reason + + @property + def is_completed(self) -> bool: + """Whether the task completed successfully. + + :return: True if the task completed. + :rtype: bool + """ + return self.status == "completed" + + @property + def is_suspended(self) -> bool: + """Whether the task was suspended. + + :return: True if the task is suspended. + :rtype: bool + """ + return self.status == "suspended" + + @property + def is_superseded(self) -> bool: + """Whether the generation was superseded by a steering input. + + :return: True if this generation was cancelled by a newer input. + :rtype: bool + """ + return self.status == "superseded" + + def __repr__(self) -> str: + output_repr = repr(self.output) + if len(output_repr) > 60: + output_repr = output_repr[:57] + "..." + parts = [ + f"TaskResult(task_id={self.task_id!r}, status={self.status!r}, output={output_repr}" + ] + if self.suspension_reason is not None: + parts.append(f", suspension_reason={self.suspension_reason!r}") + parts.append(")") + return "".join(parts) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_resume_route.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_resume_route.py new file mode 100644 index 000000000000..2af426376b3b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_resume_route.py @@ -0,0 +1,76 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""POST /tasks/resume — Starlette route for external task resume triggers. + +Returns an empty body with the appropriate status code: +- 202 Accepted: resume dispatched successfully +- 404 Not Found: task not found or not in a resumable state +- 409 Conflict: task is already in progress +""" + +from __future__ import annotations + +import json +import logging + +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Route + +logger = logging.getLogger("azure.ai.agentserver.durable") + + +async def _handle_resume_request( + request: Request, +) -> Response: # pylint: disable=too-many-return-statements + """Handle POST /tasks/resume. + + Expects a JSON body with ``{"task_id": "..."}`` and dispatches the + resume to the DurableTaskManager. + + :param request: The incoming HTTP request. + :type request: Request + :return: Empty-body response with status code. + :rtype: Response + """ + from ._manager import ( # pylint: disable=import-outside-toplevel + get_task_manager, + ) + + try: + body = await request.json() + except (json.JSONDecodeError, ValueError): + return Response(status_code=400) + + task_id = body.get("task_id") + if not task_id or not isinstance(task_id, str): + return Response(status_code=400) + + try: + manager = get_task_manager() + except RuntimeError: + return Response(status_code=503) + + try: + await manager.handle_resume(task_id) + logger.info("Resume accepted for task %s", task_id) + return Response(status_code=202) + + except Exception as exc: # pylint: disable=broad-exception-caught + msg = str(exc).lower() + if "not found" in msg: + return Response(status_code=404) + if "not 'suspended'" in msg or "already" in msg or "conflict" in msg: + return Response(status_code=409) + logger.error("Resume failed for task %s: %s", task_id, exc, exc_info=True) + return Response(status_code=500) + + +def create_resume_route() -> Route: + """Create the Starlette Route for POST /tasks/resume. + + :return: A Starlette Route to be added to the host. + :rtype: Route + """ + return Route("/tasks/resume", _handle_resume_request, methods=["POST"]) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_retry.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_retry.py new file mode 100644 index 000000000000..aa56b3eb8e26 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_retry.py @@ -0,0 +1,261 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""RetryPolicy — configurable retry behaviour for durable tasks. + +Aligned with industry conventions (Temporal, Azure Durable Functions, Celery). +Delay formula: ``min(initial_delay * backoff_coefficient ** attempt, max_delay)`` +With jitter: ``delay * uniform(0.75, 1.25)`` +""" + +from __future__ import annotations + +import random +from datetime import timedelta + + +class RetryPolicy: + """Retry configuration for durable tasks. + + :param initial_delay: Base delay between retries. + :type initial_delay: ~datetime.timedelta + :param backoff_coefficient: Multiplier applied per attempt. + :type backoff_coefficient: float + :param max_delay: Upper bound on computed delay. + :type max_delay: ~datetime.timedelta + :param max_attempts: Total attempts (including the first try). + :type max_attempts: int + :param retry_on: Exception types that trigger retry. ``None`` means all. + :type retry_on: tuple[type[Exception], ...] | None + :param jitter: Whether to add ±25% randomization to delays. + :type jitter: bool + + .. versionadded:: 2.1.0 + """ + + __slots__ = ( + "initial_delay", + "backoff_coefficient", + "max_delay", + "max_attempts", + "retry_on", + "jitter", + "_linear", + ) + + def __init__( + self, + *, + initial_delay: timedelta = timedelta(seconds=1), + backoff_coefficient: float = 2.0, + max_delay: timedelta = timedelta(seconds=60), + max_attempts: int = 3, + retry_on: tuple[type[Exception], ...] | None = None, + jitter: bool = True, + _linear: bool = False, + ) -> None: + if initial_delay.total_seconds() < 0: + raise ValueError(f"initial_delay must be >= 0, got {initial_delay}") + if max_attempts < 1 and not ( + max_attempts == 1 and initial_delay == timedelta(0) + ): + pass # allow no_retry preset + if backoff_coefficient < 1.0: + raise ValueError( + f"backoff_coefficient must be >= 1.0, got {backoff_coefficient}" + ) + if max_delay < initial_delay: + raise ValueError( + f"max_delay ({max_delay}) must be >= initial_delay ({initial_delay})" + ) + if max_attempts < 1: + raise ValueError(f"max_attempts must be >= 1, got {max_attempts}") + if retry_on is not None: + for exc_type in retry_on: + if not isinstance(exc_type, type) or not issubclass( + exc_type, Exception + ): + raise TypeError( + f"retry_on entries must be Exception subclasses, got {exc_type!r}" + ) + + self.initial_delay = initial_delay + self.backoff_coefficient = backoff_coefficient + self.max_delay = max_delay + self.max_attempts = max_attempts + self.retry_on = retry_on + self.jitter = jitter + self._linear = _linear + + def compute_delay(self, attempt: int) -> float: + """Return the delay in seconds for the given attempt (0-indexed). + + :param attempt: The 0-based attempt number that just failed. + :type attempt: int + :return: Delay in seconds before the next attempt. + :rtype: float + """ + base_seconds = self.initial_delay.total_seconds() + if self._linear: + # Linear: delay = initial_delay * (attempt + 1) + raw = base_seconds * (attempt + 1) + else: + # Exponential: delay = initial_delay * coefficient ^ attempt + raw = base_seconds * (self.backoff_coefficient**attempt) + + capped = min(raw, self.max_delay.total_seconds()) + + if self.jitter: + capped *= random.uniform(0.75, 1.25) + + return max(0.0, capped) + + def should_retry(self, attempt: int, error: Exception) -> bool: + """Return whether the task should be retried. + + :param attempt: The 0-based attempt number that just failed. + :type attempt: int + :param error: The exception that was raised. + :type error: Exception + :return: ``True`` if the task should be retried. + :rtype: bool + """ + # attempt is 0-indexed; max_attempts includes the first try + if attempt >= self.max_attempts - 1: + return False + if self.retry_on is None: + return True + return isinstance(error, self.retry_on) + + def __repr__(self) -> str: + return ( + f"RetryPolicy(initial_delay={self.initial_delay!r}, " + f"backoff_coefficient={self.backoff_coefficient}, " + f"max_delay={self.max_delay!r}, " + f"max_attempts={self.max_attempts}, " + f"retry_on={self.retry_on!r}, " + f"jitter={self.jitter})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RetryPolicy): + return NotImplemented + return ( + self.initial_delay == other.initial_delay + and self.backoff_coefficient == other.backoff_coefficient + and self.max_delay == other.max_delay + and self.max_attempts == other.max_attempts + and self.retry_on == other.retry_on + and self.jitter == other.jitter + and self._linear == other._linear + ) + + # ------------------------------------------------------------------ + # Convenience presets + # ------------------------------------------------------------------ + + @classmethod + def exponential_backoff( + cls, + *, + max_attempts: int = 3, + initial_delay: timedelta = timedelta(seconds=1), + max_delay: timedelta = timedelta(seconds=60), + jitter: bool = True, + ) -> RetryPolicy: + """Exponential backoff — the most common pattern. + + Delay doubles per attempt: 1 s → 2 s → 4 s → … capped at *max_delay*. + + :keyword max_attempts: Total attempts including the first try. + :paramtype max_attempts: int + :keyword initial_delay: Base delay. + :paramtype initial_delay: ~datetime.timedelta + :keyword max_delay: Upper bound. + :paramtype max_delay: ~datetime.timedelta + :keyword jitter: Add ±25% randomization. + :paramtype jitter: bool + :return: A configured ``RetryPolicy``. + :rtype: RetryPolicy + """ + return cls( + initial_delay=initial_delay, + backoff_coefficient=2.0, + max_delay=max_delay, + max_attempts=max_attempts, + jitter=jitter, + ) + + @classmethod + def fixed_delay( + cls, + *, + delay: timedelta = timedelta(seconds=5), + max_attempts: int = 3, + ) -> RetryPolicy: + """Fixed delay — constant interval between retries. + + Useful for rate-limited APIs where you want to wait a fixed + amount of time between each attempt. + + :keyword delay: Constant delay between retries. + :paramtype delay: ~datetime.timedelta + :keyword max_attempts: Total attempts including the first try. + :paramtype max_attempts: int + :return: A configured ``RetryPolicy``. + :rtype: RetryPolicy + """ + return cls( + initial_delay=delay, + backoff_coefficient=1.0, + max_delay=delay, + max_attempts=max_attempts, + jitter=False, + ) + + @classmethod + def linear_backoff( + cls, + *, + initial_delay: timedelta = timedelta(seconds=1), + max_delay: timedelta = timedelta(seconds=60), + max_attempts: int = 5, + ) -> RetryPolicy: + """Linear backoff — delay grows additively. + + Delay is ``initial_delay * (attempt + 1)``: 1 s → 2 s → 3 s → … + + :keyword initial_delay: Base delay unit. + :paramtype initial_delay: ~datetime.timedelta + :keyword max_delay: Upper bound. + :paramtype max_delay: ~datetime.timedelta + :keyword max_attempts: Total attempts including the first try. + :paramtype max_attempts: int + :return: A configured ``RetryPolicy``. + :rtype: RetryPolicy + """ + return cls( + initial_delay=initial_delay, + backoff_coefficient=1.0, + max_delay=max_delay, + max_attempts=max_attempts, + jitter=False, + _linear=True, + ) + + @classmethod + def no_retry(cls) -> RetryPolicy: + """No retry — the function runs once and fails on exception. + + Equivalent to not setting a retry policy at all. + + :return: A ``RetryPolicy`` that never retries. + :rtype: RetryPolicy + """ + return cls( + initial_delay=timedelta(0), + backoff_coefficient=1.0, + max_delay=timedelta(0), + max_attempts=1, + jitter=False, + ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_run.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_run.py new file mode 100644 index 000000000000..267f8a06f400 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_run.py @@ -0,0 +1,242 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""TaskRun handle and Suspended sentinel for the durable task subsystem.""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +from typing import Any, Generic, TypeVar + +from ._exceptions import ( + TaskNotFound, +) +from ._metadata import TaskMetadata +from ._models import TaskInfo, TaskStatus +from ._provider import DurableTaskProvider +from ._result import TaskResult +from ._stream import StreamHandler + +Output = TypeVar("Output") + + +class Suspended(Generic[Output]): + """Sentinel return value from :meth:`TaskContext.suspend`. + + Must be used as ``return await ctx.suspend(...)``. The framework + interprets this on function return to transition the task. + + :param reason: Human-readable suspension reason. + :type reason: str | None + :param output: Optional snapshot for observers. + :type output: Output | None + """ + + __slots__ = ("reason", "output") + + def __init__( + self, + reason: str | None = None, + output: Output | None = None, + ) -> None: + self.reason = reason + self.output = output + + def __repr__(self) -> str: + return f"Suspended(reason={self.reason!r})" + + +class TaskRun(Generic[Output]): # pylint: disable=too-many-instance-attributes + """Handle to a running or completed durable task. + + Returned by :meth:`DurableTask.start`. Provides external observation + and control of the task lifecycle. + + :param task_id: The task identifier. + :type task_id: str + :param provider: Storage provider for refresh/delete operations. + :type provider: DurableTaskProvider + :param result_future: Future that resolves with the task output. + :type result_future: asyncio.Future[Output] + :param metadata: The task's metadata instance. + :type metadata: TaskMetadata + :param cancel_event: Event to signal cancellation. + :type cancel_event: asyncio.Event + :param status: Initial task status. + :type status: TaskStatus + """ + + __slots__ = ( + "task_id", + "_provider", + "_result_future", + "_metadata", + "_cancel_event", + "_terminate_event", + "_terminate_reason_ref", + "_status", + "_stream_handler", + "_execution_task", + "_lease_expiry_count", + ) + + def __init__( + self, + task_id: str, + *, + provider: DurableTaskProvider, + result_future: asyncio.Future[TaskResult[Output]], + metadata: TaskMetadata | None = None, + cancel_event: asyncio.Event | None = None, + status: TaskStatus = "in_progress", + stream_handler: StreamHandler | None = None, + terminate_event: asyncio.Event | None = None, + execution_task: asyncio.Task[Any] | None = None, + terminate_reason_ref: list[str | None] | None = None, + lease_expiry_count: int = 0, + ) -> None: + self.task_id = task_id + self._provider = provider + self._result_future = result_future + self._metadata = metadata or TaskMetadata() + self._cancel_event = cancel_event or asyncio.Event() + self._terminate_event = terminate_event or asyncio.Event() + self._terminate_reason_ref: list[str | None] = ( + terminate_reason_ref if terminate_reason_ref is not None else [None] + ) + self._status = status + self._stream_handler: StreamHandler | None = stream_handler + self._execution_task: asyncio.Task[Any] | None = execution_task + self._lease_expiry_count = lease_expiry_count + + @property + def status(self) -> TaskStatus: + """Current task status (may be stale — call :meth:`refresh` to update). + + :return: The task status. + :rtype: TaskStatus + """ + return self._status + + @property + def metadata(self) -> TaskMetadata: + """The task's metadata. + + For in-process handles, this is the live metadata reference. For + remote observation, call :meth:`refresh` first. + + :return: The task metadata instance. + :rtype: TaskMetadata + """ + return self._metadata + + @property + def lease_expiry_count(self) -> int: + """Number of times the lease expired and ownership changed. + + Useful for dashboards to detect ownership churn. Call + :meth:`refresh` to get the latest value. + + :return: The lease expiry count. + :rtype: int + """ + return self._lease_expiry_count + + async def result(self) -> TaskResult[Output]: + """Await task completion and return the result. + + Returns a :class:`TaskResult` that wraps both completion and + suspension outcomes. Failures, cancellation, and termination are + still raised as exceptions. + + :return: The task result wrapper. + :rtype: TaskResult[Output] + :raises TaskFailed: If the function raised an exception. + :raises TaskCancelled: If the task was cancelled. + :raises TaskTerminated: If the task was terminated. + :raises TaskNotFound: If the task was deleted externally. + """ + return await self._result_future + + async def cancel(self) -> None: + """Signal cancellation to the running task. + + Sets the ``cancel`` event on the task context. The function + should check ``ctx.cancel.is_set()`` and exit cleanly. + """ + self._cancel_event.set() + + async def terminate(self, *, reason: str | None = None) -> None: + """Forcefully terminate the task. + + Unlike :meth:`cancel`, terminated tasks go through the failure path + and do NOT stay ``in_progress`` for recovery. + + :keyword reason: Optional human-readable termination reason. + :paramtype reason: str | None + """ + self._terminate_reason_ref[0] = reason + self._terminate_event.set() + self._cancel_event.set() + if self._execution_task is not None and not self._execution_task.done(): + self._execution_task.cancel() + + async def delete(self) -> None: + """Delete the task record from the store. + + :raises TaskNotFound: If the task does not exist. + """ + try: + await self._provider.delete(self.task_id, force=True) + except Exception as exc: + if "not found" in str(exc).lower(): + raise TaskNotFound(self.task_id) from exc + raise + + async def refresh(self) -> None: + """Re-fetch task state from the store. + + Updates :attr:`status` and :attr:`metadata` from the current + task record. + """ + task_info: TaskInfo | None = await self._provider.get(self.task_id) + if task_info is None: + raise TaskNotFound(self.task_id) + self._status = task_info.status + # Update lease expiry count + if task_info.lease is not None: + self._lease_expiry_count = task_info.lease.expiry_count + # Update metadata from payload + if task_info.payload and "metadata" in task_info.payload: + meta_data: dict[str, Any] = task_info.payload["metadata"] + for key, value in meta_data.items(): + self._metadata.set(key, value) + + def __aiter__(self) -> TaskRun[Output]: + """Return self as an async iterator over streamed items. + + Usage:: + + async for chunk in task_run: + print(chunk) + + :return: Self. + :rtype: TaskRun + """ + return self + + async def __anext__(self) -> Any: + """Yield the next streamed item, or raise ``StopAsyncIteration``. + + If no stream handler was provided, raises ``StopAsyncIteration`` + immediately (the task does not stream). When the stream is + closed, ``handler.get()`` raises ``StopAsyncIteration`` which + propagates naturally. + + :return: The next streamed item. + :rtype: Any + :raises StopAsyncIteration: When streaming ends. + """ + if self._stream_handler is None: + raise StopAsyncIteration + return await self._stream_handler.get() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_stream.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_stream.py new file mode 100644 index 000000000000..abf0867d6edc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/durable/_stream.py @@ -0,0 +1,112 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Pluggable stream handler protocol and default implementation. + +Provides :class:`StreamHandler` — a structural protocol that controls +how stream items are transported between the task function (producer +via ``ctx.stream()``) and consumers (via ``async for chunk in run``). + +The default :class:`QueueStreamHandler` wraps :class:`asyncio.Queue` +and preserves the existing in-memory, single-consumer behavior. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +from collections.abc import Callable +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class StreamHandler(Protocol): + """Protocol for pluggable stream transports. + + Implementations control how stream items move between the task + function (producer) and any number of consumers. The framework + calls :meth:`put` from ``ctx.stream()``, consumers call + :meth:`get` via ``async for chunk in run``, and the framework + calls :meth:`close` when the task finishes. + + All three methods are required. + """ + + async def put(self, item: Any) -> None: + """Accept a stream item from the task function. + + :param item: The value to stream. + :type item: Any + """ + ... + + async def get(self) -> Any: + """Return the next stream item, blocking until one is available. + + :return: The next streamed item. + :rtype: Any + :raises StopAsyncIteration: When the stream has been closed. + """ + ... + + async def close(self) -> None: + """Signal end-of-stream. + + After this call, :meth:`get` must raise + :class:`StopAsyncIteration`. Called by the framework when the + task finishes — both on success and on failure. + """ + ... + + +class QueueStreamHandler: + """Default stream handler wrapping :class:`asyncio.Queue`. + + Single-consumer, in-memory, unbounded. Preserves the exact + behavior of the previous raw-queue implementation. + + .. versionadded:: 2.1.0 + """ + + _SENTINEL: object = object() + """Internal sentinel placed in the queue by :meth:`close`.""" + + def __init__(self) -> None: + self._queue: asyncio.Queue[Any] = asyncio.Queue() + + async def put(self, item: Any) -> None: + """Enqueue a stream item. + + :param item: The value to stream. + :type item: Any + """ + await self._queue.put(item) + + async def get(self) -> Any: + """Dequeue the next stream item. + + Blocks until an item is available. Raises + :class:`StopAsyncIteration` when the stream has been closed. + + :return: The next streamed item. + :rtype: Any + :raises StopAsyncIteration: When the stream has been closed. + """ + item = await self._queue.get() + if item is self._SENTINEL: + raise StopAsyncIteration + return item + + async def close(self) -> None: + """Signal end-of-stream by placing the sentinel in the queue. + + Subsequent :meth:`get` calls will raise + :class:`StopAsyncIteration`. + """ + await self._queue.put(self._SENTINEL) + + +#: Type alias for a factory that creates a :class:`StreamHandler` from a +#: ``task_id``. Used on the decorator to ensure crash-recovery and resume +#: paths construct the correct handler instead of defaulting to +#: :class:`QueueStreamHandler`. +StreamHandlerFactory = Callable[[str], StreamHandler] diff --git a/sdk/agentserver/azure-ai-agentserver-core/docs/durable-task-developer-guide.md b/sdk/agentserver/azure-ai-agentserver-core/docs/durable-task-developer-guide.md new file mode 100644 index 000000000000..841cc684d0bd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/docs/durable-task-developer-guide.md @@ -0,0 +1,1468 @@ +# Durable Task Developer Guide + +> Developer guidance for building crash-resilient agents with `@durable_task` — the single decorator for turning async functions into units of work that survive container crashes, OOM kills, and redeployments. + +--- + +## Table of Contents + +- [Overview](#overview) + - [Why This Exists](#why-this-exists) + - [What You Get](#what-you-get) +- [Getting Started](#getting-started) +- [Lifecycle Automation](#lifecycle-automation) + - [State Diagram](#state-diagram) + - [Entry Mode Decision Table](#entry-mode-decision-table) + - [.run() vs .start() vs .get() vs .list() vs .get_active_run()](#run-vs-start-vs-get-vs-list-vs-get_active_run) +- [TaskContext](#taskcontext) + - [Properties Reference](#properties-reference) + - [Branching on Entry Mode](#branching-on-entry-mode) +- [Suspend & Resume](#suspend--resume) + - [Multi-Turn Conversations](#multi-turn-conversations) +- [Steering](#steering) + - [What Steering Solves](#what-steering-solves) + - [Generation Model](#generation-model) + - [Enabling Steering](#enabling-steering) + - [The Three-Phase Cancel Pattern](#the-three-phase-cancel-pattern) + - [Steering Flow Diagram](#steering-flow-diagram) + - [What Happens to Each Generation](#what-happens-to-each-generation) + - [Rapid-Fire Steering](#rapid-fire-steering) + - [Preserving Fidelity with External SDKs](#preserving-fidelity-with-external-sdks) + - [Steering Recovery](#steering-recovery) + - [Complete Steering Example](#complete-steering-example) +- [Streaming](#streaming) + - [Custom Stream Handlers](#custom-stream-handlers) + - [Late-Join Consumers](#late-join-consumers) +- [Persistence](#persistence) + - [Responsibility Matrix](#responsibility-matrix) + - [The Durable Boundary Rule](#the-durable-boundary-rule) +- [The Invocation Store Pattern](#the-invocation-store-pattern) +- [RetryPolicy](#retrypolicy) +- [Decorator Options](#decorator-options) +- [Error Handling](#error-handling) +- [Best Practices](#best-practices) +- [Common Mistakes](#common-mistakes) + +--- + +## Overview + +### Why This Exists + +Azure AI Foundry Hosted Agents run your code in platform-managed containers. +Those containers can be killed at any time — OOM kills, node preemptions, +rolling deployments, or unexpected crashes. Without durability, any in-flight +work is lost and the agent starts from scratch. + +Agent frameworks fall into two camps: + +| Category | Examples | What they need | +|----------|----------|----------------| +| **Externally stateful** — the framework owns durability | Temporal, Durable Functions, Orleans | Platform visibility: lifecycle tracking, lease-based liveness, status reporting on top of the framework's own durability | +| **Locally stateful** — the container holds state | LangGraph (SQLite checkpointer), Claude SDK tool loops, hand-written agents | A crash-safe entry point: lease-based liveness so the platform knows when to restart, plus run / resume / progress / suspend primitives the developer would otherwise hand-roll | + +`@durable_task` serves both camps. It is **not** a replacement for Temporal or +Durable Functions — it is the thin durable wrapper around the boundary between +the platform and your code. It does not make your function deterministic or +replayable. It turns `run(input) → output` into a unit of work that survives +a container crash, a deployment, or an idle-deactivation — with hooks for +progress, suspension, cancellation, and steering that compose with whatever +framework you use underneath. + +### What You Get + +Decorate your async function, and the framework guarantees it runs to completion +— even if the container restarts mid-execution. On recovery, your function is +re-invoked with the same input and last-saved metadata, so it can pick up where +it left off. + +**Your contract:** + +- Write a normal `async` function that takes a `TaskContext` +- Use `ctx.metadata` to record lightweight progress (e.g. current phase, step count) +- Check `ctx.entry_mode` if you need to distinguish fresh runs from recoveries +- Return a result, or `await ctx.suspend()` for multi-turn patterns + +**What you get:** + +- Automatic crash recovery — your function re-runs without any caller intervention +- Input and metadata persistence across restarts +- Retry with configurable backoff on failures +- Cooperative cancellation and timeout +- Streaming incremental output to observers +- Suspend/resume for multi-turn conversational agents +- Steering — submit a new input to a running task without cancel/wait/restart + +### What durable tasks are NOT + +- **Not a checkpoint/replay engine.** This is not Temporal or Durable Functions. + Your function is re-executed from the top on recovery, not replayed from a + deterministic log. If your function calls an LLM twice, it will call it again + on recovery. +- **Not a result store.** Task output and metadata exist only while the task is + alive. Once the task is deleted, they are gone. If you need results to outlive + the task, persist them in your own store (database, blob storage, etc.). +- **Not a stream log.** Streamed chunks are relayed to live observers in real + time but are not recorded. If a consumer connects after streaming ends, + the chunks are gone. +- **Not application-level persistence.** The framework manages *task lifecycle* + state (status, input, metadata, lease). Your application data — conversation + history, invocation results, user-facing state — is your responsibility. + See [Persistence](#persistence). +- **Not unbounded storage.** `ctx.metadata` is for small progress signals + (current phase, retry count, step index), not for accumulating large data. + The task payload has a 1 MB cap. Write large or growing data to your own store. + +--- + +## Getting Started + +A minimal durable task in 15 lines: + +```python +from azure.ai.agentserver.core.durable import durable_task, TaskContext + +@durable_task +async def greet(ctx: TaskContext[str]) -> str: + """A simple durable task that greets the user.""" + name = ctx.input + return f"Hello, {name}!" + +# Run it — lifecycle-aware: creates if new, recovers if stale +result = await greet.run(task_id="greet-alice", input="Alice") +print(result.output) # "Hello, Alice!" +``` + +That's it. The decorator transforms your function into a `DurableTask` with `.run()`, +`.start()`, `.get()`, and `.list()` methods. The function itself takes a single `TaskContext` +parameter. + +If the container crashes mid-execution, the framework automatically recovers the +task on restart — before any HTTP handlers go live. Your function is re-invoked +with `ctx.entry_mode = "recovered"` and the same input. No caller action is needed. + +If a caller calls `.run()` with a `task_id` that is already in progress, +the framework raises `TaskConflictError` — it does not create a duplicate. + +--- + +## Lifecycle Automation + +Every call to `.run()` or `.start()` follows the same state machine. You never +manually check task state or call resume — the framework does it for you. + +### State Diagram + +What the framework does when you call `.run()` or `.start()`: + +``` + .run() / .start() + │ + ▼ + ┌───── task exists? ─────┐ + │ │ + No Yes + │ │ + ▼ ▼ + ┌──────────┐ ┌──── status? ──────────────────────────┐ + │ Create │ │ │ │ │ + │ & Start │ pending suspended in_progress completed + └──────────┘ │ │ │ │ + │ ▼ ▼ ▼ ▼ + fresh fresh resumed stale? ephemeral? + │ │ + ┌─────┴─────┐ ┌───┴───┐ + Yes No Yes No + │ │ │ │ + ▼ ▼ ▼ ▼ + recovered steerable? fresh¹ TaskConflict + │ Error + ┌─────┴─────┐ + Yes No + │ │ + ▼ ▼ + Queue input TaskConflict + + cancel → Error + drain resumes + ("resumed", + was_steered) +``` + +¹ Ephemeral completed tasks were auto-deleted on completion, so they appear as +"no task exists" and a fresh task is created transparently. + +### Entry Mode Decision Table + +| Current State | Action | `ctx.entry_mode` | `ctx.was_steered` | +|---|---|---|---| +| No task exists | Create and start | `"fresh"` | `False` | +| `pending` | Start | `"fresh"` | `False` | +| `suspended` | Resume with new input | `"resumed"` | `False` | +| `in_progress` (stale) | Recover | `"recovered"` | `True` if steering state exists ¹ | +| `in_progress` (not stale, **steerable**) | Queue input, signal cancel → drain resumes | `"resumed"` | `True` | +| `in_progress` (not stale, not steerable) | **Raises `TaskConflictError`** | — | — | +| `completed` (ephemeral) | Task was auto-deleted → create fresh | `"fresh"` | `False` | +| `completed` (non-ephemeral) | **Raises `TaskConflictError`** | — | — | + +¹ When recovering a steerable task that crashed mid-drain, the initial recovery +enters with `"recovered"` and `was_steered=True`. The framework then drains +the pending queue, re-entering the function with `entry_mode="resumed"` and +`was_steered=True` for each queued input. See [Steering Recovery](#steering-recovery). + +A task is considered **stale** when its last update is older than `stale_timeout` +(default: 300 seconds). This means the previous execution likely crashed. + +### .run() vs .start() vs .get() vs .list() vs .get_active_run() + +| Method | Blocks? | Returns | Use When | +|--------|---------|---------|----------| +| `.run()` | Yes — awaits completion | `TaskResult[Output]` | You want the result inline | +| `.start()` | No — returns immediately | `TaskRun[Output]` | You want a handle for polling/streaming | +| `.get()` | No — reads from store | `TaskInfo \| None` | You want to query task state without executing | +| `.list()` | No — reads from store | `list[TaskInfo]` | You want all tasks for this function | +| `.get_active_run()` | No — in-memory lookup | `TaskRun[Output] \| None` | You want a stream handle for a task already running in this process | + +`.run()` and `.start()` follow the same lifecycle rules. The only difference is +whether you wait for the result or get a handle back. + +```python +# .start() returns immediately with a handle +task_run = await greet.start(task_id="greet-bob", input="Bob") + +# Use the handle to await the result later +result = await task_run.result() + +# Or stream incremental output (if the task uses ctx.stream()) +async for chunk in task_run: + print(chunk) +``` + +`.get()` does not execute the task. It reads whatever is persisted: + +```python +info = await greet.get("greet-bob") +if info is not None: + print(info.status) # "completed", "suspended", "in_progress", etc. + print(info.payload) # Contains input, metadata, output buckets +``` + +`.list()` returns all tasks created by this decorated function. It is automatically +scoped — each function only sees its own tasks: + +```python +# List all suspended tasks for this function +suspended = await greet.list(status="suspended") +for t in suspended: + print(t.id, t.status, t.created_at) + +# List all tasks (any status) +all_tasks = await greet.list() +``` + +> `.list()` is automatically scoped — each decorated function only sees tasks it +> created. The `name` option on `@durable_task` is the key that determines which +> tasks belong to this function. + +`.get_active_run()` is an in-memory lookup — it returns a `TaskRun` handle for +a task that is currently executing in this process. Unlike `.get()`, it does not +read from the store and only works for tasks active in the current process: + +```python +run = greet.get_active_run("greet-bob") +if run is not None: + async for chunk in run: + print(chunk, end="") +``` + +--- + +## TaskContext + +Every durable task function receives exactly one parameter: a `TaskContext[Input]` +where `Input` is your typed input type. + +### Properties Reference + +| Property | Type | Description | +|----------|------|-------------| +| `ctx.input` | `Input` | The typed input value passed to `.run()` / `.start()` | +| `ctx.entry_mode` | `EntryMode` | Why the function was entered: `"fresh"`, `"resumed"`, or `"recovered"` | +| `ctx.task_id` | `str` | The task's unique identifier | +| `ctx.session_id` | `str` | Session scope identifier | +| `ctx.metadata` | `TaskMetadata` | Mutable progress metadata (persisted automatically) | +| `ctx.agent_name` | `str` | Agent name from platform configuration | +| `ctx.lease_generation` | `int` | Lease generation counter (increments on recovery) | +| `ctx.cancel` | `asyncio.Event` | Set when cancellation is requested (including steering cancel) | +| `ctx.shutdown` | `asyncio.Event` | Set when the container is shutting down | +| `ctx.run_attempt` | `int` | Framework retry attempt counter (0-indexed) | +| `ctx.title` | `str` | Human-readable task title | +| `ctx.tags` | `dict[str, str]` | Merged decorator + call-site tags | +| `ctx.description` | `str \| None` | Task description (from decorator or call-site) | +| `ctx.generation` | `int` | Steering generation counter (0 for first run, increments on each steer) | +| `ctx.previous_input` | `Input \| None` | The superseded generation's input (set when steering state is present) | +| `ctx.pending_inputs` | `Sequence[Any]` | Read-only snapshot of queued steering inputs at function entry | +| `ctx.was_steered` | `bool` | `True` when this entry involves steering — the function is being re-entered with a new input from the steering queue. Always check this to detect steering; `entry_mode` will be `"resumed"` for normal steering drains or `"recovered"` for crash recovery of a mid-drain | + +### Branching on Entry Mode + +Use `ctx.entry_mode` to handle different execution scenarios: + +```python +from azure.ai.agentserver.core.durable import durable_task, TaskContext, EntryMode + +@durable_task(name="process_order") +async def process_order(ctx: TaskContext[dict]) -> dict: + order = ctx.input + + if ctx.entry_mode == "fresh": + # First time — validate and begin processing + ctx.metadata["step"] = "validating" + + elif ctx.entry_mode == "recovered": + # Crashed mid-execution — check what was already done + step = ctx.metadata.get("step", "validating") + if step == "charged": + # Payment already taken — skip to fulfillment + return await fulfill(order) + + elif ctx.entry_mode == "resumed": + # Resumed after suspension — ctx.input has new data + # For steerable tasks, check ctx.was_steered for steering context + if ctx.was_steered: + # This resume was triggered by steering — ctx.previous_input + # has the superseded generation's input + pass + + # ... do work ... + ctx.metadata["step"] = "charged" + return {"status": "completed", "order_id": order["id"]} +``` + +**`TaskMetadata`** is automatically persisted to the task store. Use it to track +progress so that recovered tasks can skip completed steps: + +```python +# Dict-style access (recommended) +ctx.metadata["progress"] = 50 # set a value +ctx.metadata["phase"] = "summarizing" # set another +progress = ctx.metadata["progress"] # read (raises KeyError if missing) +if "phase" in ctx.metadata: # containment check + print(f"Phase: {ctx.metadata['phase']}") +for key in ctx.metadata: # iterate keys + print(f"{key}: {ctx.metadata[key]}") + +# Convenience methods for special operations +ctx.metadata.increment("items_processed") # atomic increment +ctx.metadata.append("logs", "step 3 done") # append to list +progress = ctx.metadata.get("progress") # read with default (no KeyError) +snapshot = ctx.metadata.to_dict() # full snapshot copy +``` + +All mutations (including `[]` assignment and `del`) are automatically tracked +and flushed to the task store on a 5-second debounce interval. + +--- + +## Suspend & Resume + +Use `ctx.suspend()` to pause execution and release the task lease. The task +transitions to `suspended` status. A subsequent `.run()` or `.start()` call +resumes it with `entry_mode="resumed"` and new input. + +> **Critical**: Always use `return await ctx.suspend(...)`. Forgetting `return` +> or `await` silently breaks the suspension mechanism. + +```python +@durable_task(name="approval_flow") +async def approval_flow(ctx: TaskContext[dict]) -> dict: + request = ctx.input + + if ctx.entry_mode == "fresh": + # Submit for approval, then suspend + return await ctx.suspend(output={"status": "awaiting_approval", "request": request}) + + elif ctx.entry_mode == "resumed": + # Manager responded — ctx.input has the approval decision + decision = ctx.input + if decision.get("approved"): + return {"status": "approved", "approved_by": decision["manager"]} + return {"status": "rejected", "reason": decision.get("reason")} +``` + +The `output` parameter on `ctx.suspend()` is optional. It provides a snapshot +that observers can read while the task is suspended (via `.get()` or the +`TaskResult`'s `.output` attribute). + +### Multi-Turn Conversations + +The suspend/resume pattern is ideal for multi-turn agents where each turn is +one user ↔ agent interaction: + +```python +@durable_task(name="chat_session") +async def chat_session(ctx: TaskContext[dict]) -> dict: + message = ctx.input["message"] + + if ctx.entry_mode == "fresh": + history = [] + elif ctx.entry_mode == "resumed": + history = ctx.metadata.get("history", []) + + # Generate response (your LLM call, graph execution, etc.) + reply = await generate_reply(message, history) + + # Track conversation history in metadata + history.append({"role": "user", "content": message}) + history.append({"role": "assistant", "content": reply}) + ctx.metadata["history"] = history + + # Suspend — waiting for the next user message + return await ctx.suspend(output={"reply": reply}) +``` + +Each call to `.run(task_id=session_id, input={"message": "..."})` or +`.start(task_id=session_id, input={"message": "..."})` resumes the +same task with the new message. The framework handles the transition +from `suspended` to `in_progress` automatically. + +--- + +## Steering + +Steering extends the suspend/resume pattern for scenarios where a user sends +a new message while the agent is still processing the previous one. Without +steering, a `.start()` on an `in_progress` task raises `TaskConflictError` — +the caller must cancel, wait for the function to exit, and then start again. +With steering, the framework handles this automatically. + +### What Steering Solves + +Consider a chat UI. The user sends "Tell me about Python", then immediately +types "Actually, tell me about Rust" before the first reply finishes. Without +steering: + +1. The caller sees `TaskConflictError` on the second `.start()` +2. The caller must call `run.cancel()` and wait for the function to exit +3. Then call `.start()` again with the new input +4. Race conditions abound — what if another message arrives during step 2? + +With steering, the caller just calls `.start()` again. The framework queues +the new input, signals the running function to cancel, and re-enters the +function with the new input once the current generation exits. No manual +cancel/wait/restart dance. + +### Generation Model + +Each time the framework enters the durable function, it increments a +**generation** counter. This gives each invocation a stable identity: + +``` +Generation 0: fresh start with input A → entry_mode="fresh", was_steered=False +Generation 1: steered — input B replaced input A → entry_mode="resumed", was_steered=True +Generation 2: steered — input C (short-circuited) → entry_mode="resumed", was_steered=True +Generation 3: normal resume — user sends input D → entry_mode="resumed", was_steered=False +``` + +Generations are persisted in the task payload. Each `TaskRun` returned to a +caller is bound to a specific generation, so there is no ambiguity about which +invocation a caller is observing. + +### Enabling Steering + +Add `steerable=True` to the decorator: + +```python +@durable_task(name="chat_session", steerable=True) +async def chat_session(ctx: TaskContext[dict]) -> dict: + ... +``` + +| Decorator Option | Type | Default | Description | +|------------------|------|---------|-------------| +| `steerable` | `bool` | `False` | Enable steering support | +| `max_pending` | `int` | `10` | Maximum queued inputs. Excess raises `SteeringQueueFull` | + +When `steerable=False` (default), behavior is unchanged — `.start()` on an +`in_progress` task raises `TaskConflictError`. + +### The Three-Phase Cancel Pattern + +When a steering input arrives, the framework sets `ctx.cancel` on the running +function. But cancel can arrive at three different points. Your function must +handle all three: + +```python +@durable_task(name="agent_session", steerable=True) +async def agent_session(ctx: TaskContext[dict]) -> dict: + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + + invocation_store.save(invocation_id, {"status": "running"}) + + # ── Phase 1: Pre-entry cancel ─────────────────────────────── + # Cancel was already set before the function body runs. + # This happens in rapid-fire scenarios where multiple inputs + # queue up faster than the function can start. + if ctx.cancel.is_set(): + invocation_store.save(invocation_id, { + "status": "cancelled", "reason": "steered", + }) + return await ctx.suspend(reason="steered") + + # ── Phase 2: Mid-stream cancel ────────────────────────────── + # Check cancel between each chunk of work. This is where most + # steering cancels land in practice. + reply = "" + async for token in call_llm_streaming(message): + reply += token + if ctx.cancel.is_set(): + break # Stop producing — save what we have + + # ── Phase 3: Post-completion cancel ───────────────────────── + # Cancel arrived after the LLM finished but before we returned. + # The reply is complete, but it will be superseded by the next + # generation. Save the result so it is not lost. + was_steered = ctx.cancel.is_set() + + result = {"reply": reply, "partial": was_steered} + if was_steered: + invocation_store.save(invocation_id, { + "status": "superseded", "output": result, + }) + return await ctx.suspend(reason="steered") + + invocation_store.save(invocation_id, { + "status": "completed", "output": result, + }) + return await ctx.suspend(reason="awaiting_user_input", output=result) +``` + +**Key rule**: Always save your work before returning, even when cancelled. +The user's message was received and should be preserved (appended to +conversation history, written to your store, etc.). Only the *reply generation* +is interrupted — not the input recording. + +> **⚠️ Steerable tasks MUST suspend when steered — never return normally or +> raise.** When `ctx.cancel.is_set()` due to steering, always exit with +> `return await ctx.suspend(reason="steered")`. This keeps the task alive so +> the framework can drain the pending queue and resume with the next input. +> +> - **Normal return** → task completes → next `.start()` creates a fresh task +> → conversation continuity broken +> - **Raise exception** → task enters failure/retry path → wrong lifecycle +> - **Suspend** ✅ → task stays alive → framework resumes with next queued input + +### Steering Flow Diagram + +``` + Caller A: .start(input=A) Caller B: .start(input=B) + │ │ + ▼ │ + ┌──────────────┐ │ + │ Gen 0: fresh │ ◄── function starts │ + │ processing │ │ + │ ... │ ◄── ctx.cancel.set() ◄──────┤ input B queued + │ (checks │ │ + │ cancel) │ │ + │ break │ │ + └──────┬───────┘ │ + │ returns via suspend(reason="steered") + ▼ │ + ┌──────────────────┐ │ + │ Framework drains │ │ + │ pending queue │ │ + │ (pops input B) │ │ + └──────┬───────────┘ │ + │ │ + ▼ │ + ┌──────────────┐ │ + │ Gen 1:resumed│ ◄── function re-entered │ + │ was_steered │ ctx.previous_input = A │ + │ ctx.input = B│ │ + │ processing │ │ + │ ... │ │ + │ (completes) │ │ + └──────┬───────┘ │ + │ returns via suspend() │ + ▼ │ + Caller B's TaskRun Caller A's TaskRun + resolves with result resolved earlier with + "superseded" status +``` + +### What Happens to Each Generation + +| Scenario | Status Written to Store | `TaskRun` Resolution | +|----------|------------------------|----------------------| +| Pre-entry cancel (Phase 1) | `"cancelled"` — input preserved, no reply attempted | Superseded | +| Mid-stream cancel (Phase 2) | `"superseded"` — partial reply saved | Superseded | +| Post-completion cancel (Phase 3) | `"superseded"` — full reply saved | Superseded | +| Normal completion | `"completed"` — full reply | Completed | + +Superseded `TaskRun` handles resolve when the framework drains the queue and +starts the next generation. Callers polling these handles see the result of +their specific generation. + +### Rapid-Fire Steering + +When multiple inputs arrive in quick succession: + +``` +User types: "What is Python?" → "Actually, Rust" → "No wait, Go" +``` + +The framework queues all of them. Only the last one (Go) runs to completion: + +``` +Gen 0: "What is Python?" → cancel pre-set → Phase 1 short-circuit +Gen 1: "Actually, Rust" → cancel pre-set → Phase 1 short-circuit +Gen 2: "No wait, Go" → queue empty → full execution +``` + +**Important**: Each short-circuited generation still enters the function. +This is by design — it gives the developer a chance to: + +- Record the user's message in conversation history +- Write a `"cancelled"` status to the invocation store +- Perform any other bookkeeping + +The framework does NOT silently discard queued inputs. Every input gets a +function invocation, even if that invocation immediately short-circuits. + +### Preserving Fidelity with External SDKs + +When wrapping external LLM SDKs (Claude, Copilot, LangGraph), steering adds +a layer on top of the SDK's own interruption model. Be aware of how each SDK +handles cancellation: + +**Streaming SDKs (Claude, OpenAI)**: These use `async for token in stream`. +Breaking out of the loop is clean — the SDK handles connection cleanup. Check +`ctx.cancel.is_set()` between chunks: + +```python +async with client.messages.stream(...) as stream: + async for text in stream.text_stream: + reply += text + if ctx.cancel.is_set(): + break # SDK cleans up the stream +``` + +**Event-based SDKs (Copilot)**: These deliver results via callbacks. Use +`session.abort()` to stop event delivery, then let the handler drain: + +```python +session.on(handler) # Register callback +session.send(message) # Start generation (non-blocking) +# Wait for either completion or cancel: +done, _ = await asyncio.wait( + [completion_event.wait(), cancel_wait()], + return_when=asyncio.FIRST_COMPLETED, +) +if ctx.cancel.is_set(): + session.abort() # Stop further events +``` + +**Graph SDKs (LangGraph)**: These run a graph to completion. Use checkpoint +IDs to fork from a known state rather than replaying the full graph: + +```python +if ctx.was_steered and ctx.previous_input: + # Fork from the checkpoint before the superseded run + checkpoint_id = ctx.metadata.get("stable_checkpoint_id") + config = {"configurable": {"thread_id": ..., "checkpoint_id": checkpoint_id}} +``` + +### Steering Recovery + +If the container crashes while a steered task is processing: + +1. The task is `in_progress` with steering state in the payload +2. On container restart, the framework detects the stale task +3. If there are pending inputs in the queue, the framework recovers with + `entry_mode="recovered"` and `was_steered=True`, then drains the queue +4. If a drain was in progress when the crash occurred, the framework + resumes the drain from the persisted `active_input` +5. Each drained input re-enters the function with `entry_mode="resumed"` + and `was_steered=True` + +No data is lost — the pending queue and generation counter are persisted in +the task payload. + +> **How to detect steering**: Always use `ctx.was_steered` — never check +> `entry_mode` for steering. Steering re-entries arrive as `"resumed"` +> (because the task suspended and is being resumed with a new input). The +> `was_steered` flag tells you whether steering context (`previous_input`, +> `generation`, `pending_inputs`) is meaningful. + +### Complete Steering Example + +A full steerable chat session combining all patterns: + +```python +from azure.ai.agentserver.core.durable import TaskContext, durable_task +from my_app.store import FileStore + +invocation_store = FileStore("./invocations") +conversation_store = FileStore("./conversations") + + +@durable_task(name="chat_session", steerable=True) +async def chat_session(ctx: TaskContext[dict]) -> dict: + session_id = ctx.input["session_id"] + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + + # Mark invocation as running (inside the durable boundary) + invocation_store.save(invocation_id, {"status": "running"}) + + # Load conversation history from external store (not task metadata) + history = conversation_store.load(session_id) or [] + history.append({"role": "user", "content": message}) + + # ── Phase 1: Pre-entry cancel ─────────────────────────────── + if ctx.cancel.is_set(): + conversation_store.save(session_id, history) + invocation_store.save(invocation_id, { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + }) + return await ctx.suspend(reason="steered") + + # ── Phase 2: Stream response, checking cancel ─────────────── + reply = "" + was_aborted = False + async for token in call_llm_streaming(message, history): + reply += token + if ctx.cancel.is_set(): + was_aborted = True + break + + # ── Phase 3: Save result ──────────────────────────────────── + if reply: + history.append({"role": "assistant", "content": reply}) + conversation_store.save(session_id, history) + + output = {"reply": reply, "partial": was_aborted} + + if was_aborted or ctx.cancel.is_set(): + invocation_store.save(invocation_id, { + "status": "superseded", "output": output, + }) + return await ctx.suspend(reason="steered") + + # Normal completion — suspend awaiting next user message + invocation_store.save(invocation_id, { + "status": "completed", "output": output, + }) + return await ctx.suspend(reason="awaiting_user_input", output=output) +``` + +The HTTP layer remains unchanged — callers call `POST /invoke` and poll +`GET /invocations/{id}`. The steering happens transparently inside the +durable task boundary. + +--- + +## Streaming + +Use `ctx.stream()` to emit incremental output and `async for` on the `TaskRun` +handle to consume it: + +```python +@durable_task(name="generate_report") +async def generate_report(ctx: TaskContext[str]) -> str: + topic = ctx.input + chunks = [] + async for token in call_llm_streaming(topic): + await ctx.stream(token) # Emit to observers + chunks.append(token) + return "".join(chunks) + +# Consumer side +task_run = await generate_report.start(task_id="report-1", input="Q3 Results") +async for chunk in task_run: + print(chunk, end="") + +# After streaming completes, get the full result +final = await task_run.result() +``` + +`ctx.stream()` accepts any Python object — the framework simply passes it +through the stream handler with no serialization or transformation. + +> **Important**: The default `QueueStreamHandler` holds items in an in-memory +> `asyncio.Queue`. They are **not persisted** and are **lost on crash**. If the +> process restarts mid-stream, the recovered task starts from scratch. If you +> need durable incremental output, implement a custom `StreamHandler` or write +> to your own store inside the task function alongside `ctx.stream()`. + +### Custom Stream Handlers + +The streaming path is pluggable via the `StreamHandler` protocol. Implement +`put()`, `get()`, and `close()` to control how stream items are buffered, +transported, or persisted: + +```python +from azure.ai.agentserver.core.durable import StreamHandler + +class RedisStreamHandler: + """Example: fan-out streams via Redis.""" + + def __init__(self, redis_client, channel: str): + self._redis = redis_client + self._channel = channel + + async def put(self, item): + await self._redis.publish(self._channel, serialize(item)) + + async def get(self): + msg = await self._redis.subscribe_next(self._channel) + if msg is None: + raise StopAsyncIteration + return deserialize(msg) + + async def close(self): + await self._redis.publish(self._channel, "__CLOSED__") +``` + +Pass the handler at the call site — no decorator changes needed: + +```python +handler = RedisStreamHandler(redis, channel="report-1") +task_run = await generate_report.start( + task_id="report-1", + input="Q3 Results", + stream_handler=handler, +) +async for chunk in task_run: + print(chunk, end="") +``` + +#### stream_handler_factory (recommended for production) + +A call-site `stream_handler=` only covers the call that passed it. If the task +crashes and is recovered (or resumes from suspension), there is no caller — the +framework recovers automatically. Without a factory, recovery falls back to the +default `QueueStreamHandler`, silently losing your custom transport. + +Set `stream_handler_factory` on the decorator to ensure recovery and resume +paths always construct the correct handler: + +```python +@durable_task( + name="generate_report", + stream_handler_factory=lambda task_id: RedisStreamHandler( + redis, channel=f"stream:{task_id}" + ), +) +async def generate_report(ctx: TaskContext[str]) -> str: + ... +``` + +The factory receives the `task_id` so it can create a handler scoped to the +task. Resolution order: + +1. **Call-site `stream_handler=`** — highest priority (one-off override) +2. **Decorator `stream_handler_factory`** — used for fresh starts, recovery, + and resume when no call-site handler is provided +3. **Default `QueueStreamHandler`** — when neither is set + +**Key rules:** + +- `get()` must raise `StopAsyncIteration` after `close()` is called and all + buffered items are drained. This is Python's native iterator exhaustion signal. +- `close()` is always called by the framework when the task finishes — whether + it succeeds, fails, or is cancelled. +- If no `stream_handler` is provided, the framework uses `QueueStreamHandler` + (in-memory `asyncio.Queue`) as the default. +- The handler instance survives steering restarts — items streamed before and + after a steering cycle flow through the same handler. + +### Late-Join Consumers + +Any code in the same process can get a `TaskRun` handle for an active task +using `get_active_run()`, even if it wasn't the original caller of `start()`: + +```python +# In another coroutine or request handler: +run = generate_report.get_active_run("report-1") +if run is not None: + async for chunk in run: + print(chunk, end="") + result = await run.result() +``` + +Returns `None` if the task is not currently active in this process. + +--- + +## Persistence + +Understanding what is and isn't persisted is the most important concept in this +guide. + +### Responsibility Matrix + +| Data | Who Persists | Where | +|------|-------------|-------| +| Task status — `TaskStatus`: `"pending"`, `"in_progress"`, `"suspended"`, `"completed"` | **Framework** | Task store | +| Task input (the value passed to `.run()`/`.start()`) | **Framework** | Task store payload | +| Task metadata (`ctx.metadata`) | **Framework** | Task store payload | +| Task output (return value) | **Framework** | Task store payload | +| Task error (on failure) | **Framework** | Task store | +| Invocation results (what your API returns to callers) | **You** | Your store | +| Conversation history / checkpoints | **You** | Your store | +| Streaming items | **Nobody** (default) | In-memory; pluggable via `StreamHandler` | + +The task store powers lifecycle and recovery. **It is NOT your application +database.** You read from it via `.get()` to inspect task state, but you should +not depend on it as the persistence layer for your API responses. + +### The Durable Boundary Rule + +> **Everything that must survive a crash must happen inside the durable task function.** + +The durable task function is the crash-recovery boundary. If the process dies, +the framework automatically re-invokes your function on container restart. +Additionally, a subsequent `.run()` / `.start()` call with the same `task_id` +will detect the stale task and recover it. Any work done *outside* the function +(e.g., in an HTTP handler, in an `asyncio.create_task` callback) is lost. + +--- + +## The Invocation Store Pattern + +When building an HTTP API that fronts durable tasks (the 202 + poll pattern), +you need to persist invocation results so that clients can retrieve them. The +correct pattern: write results **inside** the durable task function. + +```python +# Your persistence layer (file store, Redis, database — your choice) +invocation_store = FileStore("./invocations") + +@durable_task(name="agent_session") +async def agent_session(ctx: TaskContext[dict]) -> dict: + invocation_id = ctx.input["invocation_id"] + message = ctx.input["message"] + + # Mark invocation as running (inside the durable boundary) + invocation_store.save(invocation_id, {"status": "running"}) + + # Do work + reply = await generate_reply(message) + result = {"status": "completed", "reply": reply} + + # Persist result (inside the durable boundary) + invocation_store.save(invocation_id, result) + + # Suspend — waiting for next turn + return await ctx.suspend(output=result) +``` + +The HTTP layer is minimal: + +```python +# POST /invoke — start or resume the task +async def invoke(request): + invocation_id = generate_id() + try: + await agent_session.start( + task_id=session_id, + input={"invocation_id": invocation_id, "message": message}, + ) + except TaskConflictError: + return JSONResponse({"error": "Task already running"}, status_code=409) + return JSONResponse({"invocation_id": invocation_id}, status_code=202) + +# GET /invocations/{id} — read from YOUR store, not the task store +async def get_invocation(request): + result = invocation_store.load(invocation_id) + if result is None: + return JSONResponse({"error": "Not found"}, status_code=404) + return JSONResponse(result) +``` + +Why this works: if the process crashes after `invocation_store.save(..., "running")` +but before the result write, the framework recovers the task, re-enters the function +with `entry_mode="recovered"`, and the result eventually gets written. The client +polls `GET /invocations/{id}` until it sees `"completed"`. + +--- + +## RetryPolicy + +Configure automatic retries on failure. Three presets cover most use cases: + +```python +from datetime import timedelta +from azure.ai.agentserver.core.durable import durable_task, RetryPolicy, TaskContext + +# Exponential backoff (default: 1s → 2s → 4s, 3 attempts) +@durable_task(name="call_api", retry=RetryPolicy.exponential_backoff()) +async def call_api(ctx: TaskContext[str]) -> dict: ... + +# Fixed delay (5s between each retry, 3 attempts) +@durable_task(name="poll_status", retry=RetryPolicy.fixed_delay(delay=timedelta(seconds=5))) +async def poll_status(ctx: TaskContext[str]) -> dict: ... + +# Linear backoff (1s → 2s → 3s → 4s → 5s, 5 attempts) +@durable_task(name="batch_job", retry=RetryPolicy.linear_backoff(max_attempts=5)) +async def batch_job(ctx: TaskContext[str]) -> dict: ... + +# No retry — fail immediately +@durable_task(name="one_shot", retry=RetryPolicy.no_retry()) +async def one_shot(ctx: TaskContext[str]) -> dict: ... +``` + +Customize any preset: + +```python +RetryPolicy.exponential_backoff( + max_attempts=5, + initial_delay=timedelta(seconds=2), + max_delay=timedelta(seconds=120), + jitter=True, # ±25% randomization (default) +) +``` + +Retry can also be set per-call, overriding the decorator: + +```python +result = await call_api.run( + task_id="api-1", + input="https://example.com", + retry=RetryPolicy.fixed_delay(max_attempts=10), +) +``` + +--- + +## Decorator Options + +The `@durable_task` decorator accepts these options (defined in `DurableTaskOptions`): + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `name` | `str` | Function `__qualname__` | **Stable task identity anchor.** Used for crash recovery routing and source stamping. If you ever rename your function, existing in-flight tasks are still recovered correctly because the framework matches on this name, not the Python function name. **Always provide an explicit name for production tasks.** | +| `retry` | `RetryPolicy \| None` | `None` | Retry policy on failure. See [RetryPolicy](#retrypolicy). | +| `ephemeral` | `bool` | `True` | Auto-delete task record on completion. | +| `tags` | `dict[str, str] \| Callable[[Any, str], dict[str, str]] \| None` | `{}` | Default tags (static or callable factory receiving `(input, task_id)`). | +| `title` | `str \| Callable[[Any, str], str] \| None` | `None` | Human-readable title or title factory. Defaults to `"{name}:{task_id[:8]}"` when not provided. | +| `description` | `str \| Callable[[Any, str], str \| None] \| None` | `None` | Task description (static or callable factory receiving `(input, task_id)`). | +| `store_input` | `bool` | `True` | Whether to persist input on the task record. | +| `timeout` | `timedelta \| None` | `None` | Execution timeout. When elapsed, `ctx.cancel` is set cooperatively. If the function does not exit, the lease eventually expires and the task is recovered. | +| `steerable` | `bool` | `False` | Enable steering. When True, `.start()` on an `in_progress` task queues the input instead of raising `TaskConflictError`. See [Steering](#steering). | +| `max_pending` | `int` | `10` | Maximum queued steering inputs. Excess raises `SteeringQueueFull`. | +| `stream_handler_factory` | `Callable[[str], StreamHandler] \| None` | `None` | Factory that receives `task_id` and returns a `StreamHandler`. Ensures crash-recovery and resume paths use the correct handler instead of defaulting to `QueueStreamHandler`. See [Custom Stream Handlers](#custom-stream-handlers). | + +> **Source provenance** is auto-stamped by the framework on every task. It records +> which function created the task and the SDK version. Source is not user-overridable; +> use `tags` for custom metadata. +> +> **Reserved tags**: The framework stamps internal tags (prefixed with `_durable_task_`) +> on every task for scoping and recovery. Any tags you provide with this prefix are +> silently stripped. Use unprefixed tag keys for your own metadata. + +```python +@durable_task( + name="analyze_document", + ephemeral=False, # Keep task record after completion + tags={"team": "platform", "model": "gpt-4o"}, + title="Document Analysis", +) +async def analyze_document(ctx: TaskContext[dict]) -> dict: ... +``` + +**`ephemeral`** controls what happens when a completed task's `.start()` / `.run()` +is called again: +- `ephemeral=True` (default): The completed task was auto-deleted, so a fresh task + is created. +- `ephemeral=False`: The completed task still exists, so `TaskConflictError` is raised. + +Use the `.options()` method for per-call overrides without modifying the decorator: + +```python +# Override tags for this specific call +result = await analyze_document.options( + tags={"model": "gpt-4o-mini"}, +).run(task_id="doc-1", input={"url": "..."}) +``` + +### Callable Factories (`tags`, `title`, `description`) + +When `tags`, `title`, or `description` is a callable, it receives `(input, task_id)` and +is invoked at **task creation time** — before the task function runs. + +```python +@durable_task( + tags=lambda input, task_id: {"user": input["user_id"], "run": task_id[:8]}, + description=lambda input, task_id: f"Processing {input['filename']}", +) +async def process_file(ctx: TaskContext[dict]) -> str: ... +``` + +**Error behaviour:** + +- If a factory **raises an exception**, it propagates directly to the `.run()` / `.start()` + caller. The task is never created. +- If a factory **returns the wrong type** (e.g., `tags` callable returns a list instead of + a dict), a `TypeError` is raised immediately at task creation time. +- Mixing a callable `tags` on the decorator with a static dict via `.options(tags={...})` + raises `TypeError` — use one style consistently. + +--- + +## Error Handling + +| Exception | Raised By | When | +|-----------|-----------|------| +| `TaskConflictError` | `.run()`, `.start()` | Task is `in_progress` (non-stale, non-steerable) or `completed` (non-ephemeral) | +| `TaskFailed` | `.run()`, `task_run.result()` | Unhandled exception in the task function | +| `TaskCancelled` | `.run()`, `task_run.result()` | Task was cancelled via `task_run.cancel()` | +| `TaskTerminated` | `.run()`, `task_run.result()` | Task was forcefully terminated (timeout or `task_run.terminate()`) | +| `TaskNotFound` | `task_run.refresh()`, `task_run.delete()` | Task record does not exist in the store | +| `SteeringQueueFull` | `.start()` | Steering queue has `max_pending` items. Caller should retry or back off | + +> **Note**: Suspension is no longer an exception. When a task suspends, `.run()` and +> `task_run.result()` return a `TaskResult` with `is_suspended == True`. Check +> `result.is_suspended` or `result.is_completed` to distinguish outcomes. + +Handle them in your application code: + +```python +from azure.ai.agentserver.core.durable import ( + TaskConflictError, + TaskFailed, + TaskTerminated, +) + +result = await my_task.run(task_id="t1", input="hello") + +if result.is_suspended: + # Task paused — result.output has the snapshot + print(f"Suspended: {result.output}") +elif result.is_completed: + print(f"Done: {result.output}") +``` + +Exceptions are raised for true error conditions: + +```python +try: + result = await my_task.run(task_id="t1", input="hello") +except TaskConflictError: + # Task already running or completed + info = await my_task.get("t1") + print(f"Task is {info.status}") +except TaskFailed as exc: + # Task function raised an exception + print(f"Failed: {exc.error}") +except TaskTerminated: + # Task was forcefully terminated (timeout or explicit terminate) + print("Task was terminated") +``` + +`TaskSuspended` is retained for backward compatibility but is no longer raised +by `.run()` or `task_run.result()`. Suspension is now a return value — check +`result.is_suspended` on the returned `TaskResult`. + +--- + +## Cancellation, Timeout, and Termination + +Durable tasks support three levels of stopping execution: + +### Cooperative Cancellation + +Set `ctx.cancel` to signal the task function to exit gracefully. The task +must check this event and respond: + +```python +run = await my_task.start(task_id="t1", input=data) +await run.cancel() # Sets ctx.cancel — task should check and exit +``` + +Inside the task function: + +```python +@durable_task +async def my_task(ctx: TaskContext[Input]) -> Output: + for item in items: + if ctx.cancel.is_set(): + return partial_result # Exit cleanly + await process(item) + return full_result +``` + +Cooperative cancel sets `ctx.cancel`. If the function checks this event and +**returns normally**, the task completes as a success — not as cancelled. The +function decides its own outcome. `TaskCancelled` is only raised when the +function does not handle the cancel and the asyncio task is cancelled. + +### Execution Timeout + +Set a `timeout` to automatically cancel tasks that run too long. When the +timeout elapses, `ctx.cancel` is set cooperatively — the same signal used +by `handle.cancel()` and steering. If the function does not exit, the lease +eventually expires and the task is recovered on the next heartbeat. + +```python +from datetime import timedelta + +@durable_task( + timeout=timedelta(minutes=5), +) +async def analyze(ctx: TaskContext[dict]) -> dict: + while not ctx.cancel.is_set(): + chunk = await process_next() + if chunk is None: + break + return {"status": "done"} +``` + +### Forced Termination + +`terminate()` immediately kills the task via the failure path. Unlike +cooperative cancel, terminated tasks are stored as failed and are **not** +eligible for recovery: + +```python +run = await my_task.start(task_id="t1", input=data) +await run.terminate(reason="User requested abort") + +try: + await run.result() +except TaskTerminated: + print("Task was terminated") +``` + +### Cancel vs Terminate Summary + +| Method | `ctx.cancel` set? | Hard cancel? | Outcome | Recoverable? | +|--------|-------------------|--------------|---------|--------------| +| `run.cancel()` | ✅ | ❌ | Success if function returns normally; `TaskCancelled` if unhandled | Yes (stays in_progress until function exits) | +| `run.terminate()` | ✅ | ✅ | `TaskTerminated` | No (goes to failed) | +| Timeout expired | ✅ then ✅ | After grace | `TaskTerminated` | No (goes to failed) | + +--- + +## Best Practices + +1. **Keep tasks idempotent for recovery.** When `entry_mode="recovered"`, the + function re-runs from the top. Use `ctx.metadata` to track completed steps + and skip them on re-entry. + +2. **Branch on `entry_mode`.** Always handle at least `"fresh"` and `"recovered"`. + For suspend/resume tasks, handle `"resumed"` as well. For steerable tasks, + check `ctx.was_steered` inside the `"resumed"` branch. + +3. **Persist results inside the durable boundary.** Any write that must survive + a crash belongs inside the task function, not in the HTTP handler or a + background `asyncio.create_task`. + +4. **Use `ephemeral=True` for one-shot tasks.** If the task doesn't need to be + queried after completion, let the framework auto-delete it. This keeps the + task store clean. + +5. **Keep task functions focused.** A task should do one logical unit of work. + Compose multiple tasks rather than building monolithic functions. + +6. **Check cancellation cooperatively.** Poll `ctx.cancel.is_set()` in long loops + and exit cleanly when set. For steerable tasks, this is what enables the + framework to drain the queue and start the next generation. + +7. **Use `ctx.metadata` for progress, not for large data.** Metadata is flushed + periodically to the task store. Keep values small and JSON-serializable. + The task payload has a 1 MB cap — write conversation history, results, and + growing data to your own store (database, blob, Redis). + +8. **Always preserve user input on cancel.** When `ctx.cancel.is_set()` in a + steerable task, save the user's message to your conversation store before + returning. The *reply* is interrupted, not the *input recording*. + +9. **Use the three-phase cancel pattern.** Check `ctx.cancel` at three points: + before the LLM call (Phase 1), between chunks (Phase 2), and after + completion (Phase 3). This covers all timing scenarios. + +10. **Store conversation history externally.** Don't put growing data in + `ctx.metadata`. Use an external store keyed by `session_id`. The task + metadata is for lightweight progress signals only. + +11. **Steerable tasks MUST suspend on cancel, not return normally or raise.** + When `ctx.cancel.is_set()` due to steering, always `return await + ctx.suspend(reason="steered")`. This keeps the task alive in `suspended` + state so the framework can drain the pending queue and resume with the + next input. If you return a normal value, the task completes — the next + `.start()` creates a fresh task, breaking conversation continuity. If you + raise an exception, the task enters the failure/retry path, which is also + wrong. Suspend is the only correct exit for a steered cancel. + +--- + +## Common Mistakes + +### ❌ Missing `return await` on suspend + +```python +# ❌ BAD — suspend() returns a sentinel, but it's discarded +async def my_task(ctx: TaskContext[str]) -> str: + await ctx.suspend(output="paused") + return "done" # This runs immediately — task never actually suspends + +# ✅ GOOD — return the sentinel so the framework sees it +async def my_task(ctx: TaskContext[str]) -> str: + return await ctx.suspend(output="paused") +``` + +### ❌ Persisting results outside the durable boundary + +```python +# ❌ BAD — if the process crashes, the result is never written +async def invoke(request): + task_run = await my_task.start(task_id="t1", input="hello") + asyncio.create_task(save_result_when_done(task_run)) # LOST ON CRASH + return JSONResponse({"id": "inv-1"}, status_code=202) + +# ✅ GOOD — write results inside the task function itself +@durable_task(name="my_task") +async def my_task(ctx: TaskContext[dict]) -> dict: + invocation_id = ctx.input["invocation_id"] + result = await do_work() + invocation_store.save(invocation_id, result) # DURABLE + return result +``` + +### ❌ Leaking task_id to API callers + +```python +# ❌ BAD — task_id is an internal lifecycle identifier +return JSONResponse({"task_id": task_id}, status_code=202) + +# ✅ GOOD — expose your own identifier (invocation_id, session_id, etc.) +return JSONResponse({"invocation_id": invocation_id}, status_code=202) +``` + +### ❌ Assuming streaming survives crashes + +```python +# ❌ BAD — default QueueStreamHandler is in-memory only +@durable_task(name="stream_report") +async def stream_report(ctx: TaskContext[str]) -> str: + for chunk in generate_chunks(): + await ctx.stream(chunk) # Lost if process crashes here + return "done" + +# ✅ GOOD — also persist to your store if durability matters +@durable_task(name="stream_report") +async def stream_report(ctx: TaskContext[str]) -> str: + for chunk in generate_chunks(): + await ctx.stream(chunk) + append_to_store(ctx.task_id, chunk) # Durable fallback + return "done" + +# ✅ ALSO GOOD — use a custom StreamHandler that persists +handler = DurableStreamHandler(store, ctx.task_id) +run = await stream_report.start( + task_id="r1", input="...", stream_handler=handler, +) +``` + +### ❌ Storing conversation history in task metadata + +```python +# ❌ BAD — metadata has a 1 MB cap and is not designed for growing data +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + history = ctx.metadata.get("history", []) + history.append({"role": "user", "content": ctx.input["message"]}) + reply = await call_llm(history) + history.append({"role": "assistant", "content": reply}) + ctx.metadata["history"] = history # GROWS UNBOUNDED + +# ✅ GOOD — use an external store, reference by session_id +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + session_id = ctx.input["session_id"] + history = conversation_store.load(session_id) or [] + history.append({"role": "user", "content": ctx.input["message"]}) + reply = await call_llm(history) + history.append({"role": "assistant", "content": reply}) + conversation_store.save(session_id, history) # EXTERNAL STORE +``` + +### ❌ Discarding input on steering cancel + +```python +# ❌ BAD — user's message is lost when cancel fires +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") # Message never saved! + +# ✅ GOOD — always save the user's message before returning +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + history = load_history(ctx.input["session_id"]) + history.append({"role": "user", "content": ctx.input["message"]}) + if ctx.cancel.is_set(): + save_history(ctx.input["session_id"], history) # PRESERVE INPUT + return await ctx.suspend(reason="steered") +``` + +### ❌ Skipping Phase 1 cancel check + +```python +# ❌ BAD — starts an expensive LLM call even when cancel is already set +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + # Missing Phase 1 check! + reply = "" + async for token in call_llm_streaming(ctx.input["message"]): + reply += token + if ctx.cancel.is_set(): + break + ... + +# ✅ GOOD — short-circuit before the LLM call +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + if ctx.cancel.is_set(): # Phase 1: pre-entry + save_input_and_return(ctx) + return await ctx.suspend(reason="steered") + reply = "" + async for token in call_llm_streaming(ctx.input["message"]): + reply += token + if ctx.cancel.is_set(): # Phase 2: mid-stream + break + ... +``` + +### ❌ Using `steerable=True` without `suspend()` + +Steerable tasks **must** suspend on every exit — both on normal completion +(awaiting next user input) and on steering cancel. If the function returns +normally, the task completes and the framework has nowhere to drain the +pending queue. If it raises, the task enters the failure/retry path. + +```python +# ❌ BAD — task completes, can't accept next turn or drain queue +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + reply = await call_llm(ctx.input["message"]) + return {"reply": reply} # Task completes → next .start() creates fresh task + +# ❌ BAD — raising on cancel enters the failure path +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + if ctx.cancel.is_set(): + raise RuntimeError("Cancelled") # Wrong! Enters retry/failure path + +# ✅ GOOD — always suspend: on cancel AND on normal completion +@durable_task(name="chat", steerable=True) +async def chat(ctx: TaskContext[dict]) -> dict: + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") # Keep alive for drain + reply = await call_llm(ctx.input["message"]) + return await ctx.suspend(reason="awaiting_user_input", output={"reply": reply}) +``` diff --git a/sdk/agentserver/azure-ai-agentserver-core/docs/durable-task-overview.md b/sdk/agentserver/azure-ai-agentserver-core/docs/durable-task-overview.md new file mode 100644 index 000000000000..cabd878d9777 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/docs/durable-task-overview.md @@ -0,0 +1,366 @@ +# Durable Tasks: Crash-Resilient Long-Running Agents + +## The Problem + +Agent workloads run for minutes to hours — multi-step reasoning, tool loops, batch processing, multi-turn conversations with human-in-the-loop pauses. The sandbox hosting that work can crash, be OOM-killed, redeployed, or idle-deactivated at any time — and most failure modes are unannounced. + +``` +┌─────────────────────────────────────────────────────────┐ +│ Agent starts a 45-minute research task... │ +│ │ +│ ██████████████░░░░░░░░░░░░░░░░ 35% complete │ +│ ▲ │ +│ │ 💥 Sandbox crash / OOM / redeploy │ +│ │ +│ Result: All progress lost. User must restart. │ +└─────────────────────────────────────────────────────────┘ +``` + +Without a contract for *what's running and where*, the platform can't restart the right thing, and the developer can't recover their work. + +--- + +## What We're Solving + +Most agent frameworks already provide durability for state *between* turns (LangGraph checkpointers, Semantic Kernel, Temporal, etc.). What **none** of them solve is the **entrypoint**: + +- Who calls the framework when the sandbox starts back up after a crash? +- Who knows there *was* a crash? +- Who tells the platform a unit of work is still in flight so the sandbox doesn't get killed? + +**That's the gap `@durable_task` closes.** It wraps a durable boundary around the developer's agent function — a unit of work the platform can see, lease, restart, and resume — so whatever framework is underneath has somewhere to plug in. + +--- + +## The Solution: One Decorator + +```python +from azure.ai.agentserver.core.durable import durable_task + +@durable_task(name="research-agent") +async def research(ctx, query: str) -> str: + # This function survives crashes. The framework handles everything. + results = await do_research(query) + await ctx.stream(f"Found {len(results)} sources") + report = await synthesize(results) + return report +``` + +The framework handles persistence, crash recovery, streaming, and lifecycle — your code focuses purely on agent logic. + +--- + +## Architecture + +### System Layers + +The system is structured in three layers, each specified independently: + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ Layer 1: Developer API │ +│ @durable_task decorator + TaskContext │ +│ (What agent developers write) │ +├─────────────────────────────────────────────────────────────────────┤ +│ Layer 2: Sandbox Runtime Contract │ +│ Startup recovery, lease management, graceful shutdown │ +│ (What the SDK manages automatically) │ +├─────────────────────────────────────────────────────────────────────┤ +│ Layer 3: Foundry Task Storage Protocol │ +│ CRUDL HTTP API + lease model │ +│ (Platform-managed service, always available) │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +### How the Parts Interact + +``` +┌───────────────────────────────────────────────────────────────────────┐ +│ Hosted Agent Sandbox │ +│ │ +│ ┌──────────────┐ ┌───────────────────────┐ ┌───────────────┐ │ +│ │ Protocol │ │ Durable Task │ │ Developer's │ │ +│ │ Layer │───▶│ Runtime │───▶│ Agent │ │ +│ │ (Invocations │ │ │ │ Function │ │ +│ │ / Responses)│ │ • Lease management │ │ │ │ +│ └──────────────┘ │ • State persistence │ │ @durable_task│ │ +│ │ • Crash detection │ └───────────────┘ │ +│ │ • Stream relay │ │ +│ └───────────┬───────────┘ │ +│ │ │ +└───────────────────────────────────┼───────────────────────────────────┘ + │ HTTP + ┌────────────▼────────────────┐ + │ Foundry Task Storage API │ + │ (Platform-managed) │ + │ │ + │ • CRUDL over task records │ + │ • Lease-based ownership │ + │ • Dual-identity model │ + │ • Optimistic concurrency │ + └──────────────────────────────┘ +``` + +### The Lease Model + +The lease is the core mechanism that makes crash recovery possible. It answers: **"Is this task still being actively worked on?"** + +``` +┌──────────────────────────────────────────────────────────────────┐ +│ Lease Lifecycle │ +│ │ +│ Sandbox A acquires lease Lease renewed every N sec │ +│ ┌─────┐ ┌─────┐ │ +│ │START│──── acquire ────▶ HELD ◀──────│RENEW│ │ +│ └─────┘ │ └─────┘ │ +│ │ │ +│ 💥 crash (no renewal) │ +│ │ │ +│ ▼ │ +│ EXPIRED ◀── lease_duration elapsed │ +│ │ │ +│ │ Sandbox B detects expired lease │ +│ ▼ │ +│ RE-ACQUIRED by Sandbox B │ +│ (generation counter++) │ +└──────────────────────────────────────────────────────────────────┘ +``` + +While the function runs, the framework renews the lease in the background. If the sandbox dies, renewals stop. The lease expires. The next sandbox instance detects the expired lease and re-enters the function — that's crash recovery. + +--- + +## The Lifecycle: Start → Crash → Recover + +``` + Sandbox A Sandbox B (after crash) + ────────── ──────────────────────── + + ① CREATE & START ④ DETECT & RECOVER + ┌──────────────────┐ ┌──────────────────────┐ + │ POST task to store│ │ On startup: query │ + │ Acquire lease │ │ tasks with expired │ + │ Persist input │ │ leases (dual-identity│ + │ Enter function │ │ match) │ + │ entry_mode="fresh"│ │ Re-acquire lease │ + └────────┬─────────┘ │ generation++ │ + │ └──────────┬───────────┘ + ▼ ▼ + ② EXECUTE ⑤ RE-ENTER FUNCTION + ┌──────────────────┐ ┌──────────────────────┐ + │ Run agent logic │ │ entry_mode="recovered"│ + │ ctx.stream(chunk) │ │ Same input + metadata │ + │ ctx.metadata[k]=v │ │ Developer branches on │ + │ Lease renewed │◀── heartbeat │ entry_mode to skip │ + │ automatically │ │ completed steps │ + └────────┬─────────┘ └──────────┬───────────┘ + │ ▼ + ▼ ⑥ COMPLETE + ③ CRASH 💥 ┌──────────────────────┐ + ┌──────────────────┐ │ return output │ + │ Sandbox dies │ │ Release lease │ + │ Lease stops │ │ Delete task (ephemeral│ + │ renewing │ │ =True) or mark done │ + │ Task remains in │ │ Notify consumers │ + │ storage (safe) │ └──────────────────────┘ + └──────────────────┘ +``` + +--- + +## Four State Buckets + +The framework splits task state into four clear buckets — not one opaque payload: + +``` +┌─ Task Record (live, in Foundry Task Storage) ─────────────────────┐ +│ │ +│ INPUT immutable, set at start │ +│ typed model passed to ctx.input │ +│ │ +│ METADATA mutable progress dict │ +│ ctx.metadata["phase"] = "reasoning" │ +│ debounced flush to store │ +│ │ +│ OUTPUT written on suspend (snapshot for observers) │ +│ on success: in-process delivery (ephemeral=True) │ +│ or persisted (ephemeral=False) │ +│ │ +│ ERROR structured failure detail │ +│ {type, message, traceback} │ +│ │ +├────────────────────────────────────────────────────────────────────┤ +│ STREAM NOT on the task record │ +│ Real-time relay: ctx.stream() → async for chunk │ +│ Pass-through, not persisted │ +└────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## Developer Experience + +### Before: Fragile Long-Running Agent + +```python +@app.invoke_handler +async def research(request): + query = (await request.json())["query"] + results = await search(query) # 10 minutes... + report = await synthesize(results) # 💥 crash = all lost + return JSONResponse({"report": report}) +``` + +### After: Crash-Resilient with One Decorator + +```python +@durable_task(name="research-agent") +async def research(ctx, query: str) -> str: + if ctx.entry_mode == "recovered": + # Framework tells you this is a recovery + step = ctx.metadata.get("step", "search") + if step == "synthesize": + return await synthesize_from_cache(ctx.task_id) + + results = await search(query) + ctx.metadata["step"] = "synthesize" + + report = await synthesize(results) + return report +``` + +### Starting and Consuming + +```python +# Fire-and-forget — returns immediately with a handle +run = await research.start(task_id="inv-123", input="quantum computing") + +# Stream incremental output (real-time, while function is running) +async for chunk in run: + print(chunk) + +# Get final result (blocks until terminal exit) +result = await run.result() +print(result.output) +``` + +--- + +## Key Patterns + +### 1. Multi-Turn Agents (Suspend & Resume) + +The function suspends when it needs external input, then resumes with new data: + +``` + User Framework Agent Function + │ │ │ + │── "plan trip" ────▶│── start ─────────────▶│ entry_mode="fresh" + │ │ │── do research... + │ │◀── ctx.suspend() ─────│ reason="need_approval" + │◀── "here's plan" ─│ │ + │ │ (task: suspended, lease released) + │ │ + │── "approved" ─────▶│── start (same ID) ───▶│ entry_mode="resumed" + │ │ │── execute plan... + │◀── "done!" ────────│◀── return result ─────│ +``` + +### 2. Streaming Progress + +``` + Consumer Framework Agent Function + │ │ │ + │── async for chunk ────▶│ │ + │ │◀── ctx.stream("10%") ─│ + │◀── "10%" ─────────────│ │ + │ │◀── ctx.stream("50%") ─│ + │◀── "50%" ─────────────│ │ + │ │◀── ctx.stream("done")─│ + │◀── "done" ────────────│ │ + │◀── StopAsyncIteration ─│◀── return output ────│ +``` + +Streams are **real-time relay** — not persisted. Late-joining consumers pick up from the next live chunk. For durable replay, the developer writes chunks to their own store. + +### 3. Steering (Mid-Flight Redirect) + +For chat agents where a new user message should redirect the agent without waiting for the current response to finish: + +``` + Time ──────────────────────────────────────────────────────────────▶ + + User sends "Tell me about Python" + │ + ▼ + [gen 0: running ████████── ctx.cancel fires ──│ suspend] + ▲ │ + │ ▼ + User sends "Actually, tell me about Rust" [gen 1: running ████ ✓] + ▲ + ctx.was_steered = True + ctx.previous_input = "...Python" +``` + +The framework atomically queues the new input, signals cancel, and re-enters with the next input. No manual orchestration needed. + +### 4. Crash Recovery (Transparent) + +``` + Time ────────────────────────────────────────────────────────────────▶ + + Sandbox A: [███ running ██████ 💥 dies] + │ + lease expires (~60s) + │ + Sandbox B: [startup ────┤── recover ──── running ──── ✓ ] + │ + entry_mode = "recovered" + Same input + metadata intact + generation++ (observable via ctx) +``` + +--- + +## What the Framework Manages + +| Concern | What the Framework Does | Developer Writes | +|---------|------------------------|-----------------| +| **Crash recovery** | Detects expired leases on startup, re-acquires, re-enters function | Branch on `ctx.entry_mode` for idempotency | +| **Lease renewal** | Background heartbeat keeps the lease alive while function runs | Nothing — automatic | +| **State persistence** | Input persisted at start, metadata flushed on debounce | `ctx.metadata["key"] = value` | +| **Streaming** | Real-time relay from `ctx.stream()` to handle iterators | `await ctx.stream(chunk)` | +| **Suspend/Resume** | PATCH status, release lease, re-enter on next `start()` | `return await ctx.suspend(...)` | +| **Cancellation** | Sets `ctx.cancel` event; waits for cooperative exit | Check `ctx.cancel.is_set()` at break points | +| **Graceful shutdown** | Sets `ctx.shutdown`; waits grace period; force-expires lease | Optionally checkpoint on `ctx.shutdown` | +| **Retry** | Re-enters function with `ctx.run_attempt++` on retryable errors | Configure `retry=RetryPolicy.exponential_backoff(...)` | +| **Source stamping** | Auto-stamps `type`, `name`, `server_version` on every task | Nothing — used for recovery routing | +| **Task listing** | Scoped by function name via auto-stamped tags | `my_task.list(status="suspended")` | + +--- + +## What This Is NOT + +| Not This | Use This Instead | +|----------|-----------------| +| A result store (outputs beyond task lifetime) | Developer-chosen store (DB, blob, conversation storage) | +| A stream log (durable replay of every chunk) | Write chunks to your own store as you emit | +| An audit trail (history of all executions) | Logs, telemetry, developer-owned event store | +| A workflow replay engine (deterministic replay, sub-step memoization) | Temporal, Durable Functions, or the framework underneath | + +The task store is for **lifecycle** — tracking that work is in flight, surviving crashes, and handing off across sandbox instances. Application data belongs in a store the developer chooses. + +--- + +## Summary + +```python +@durable_task(name="my-agent") +async def my_agent(ctx, input: str) -> str: + # Your agent logic here. + # Crashes, OOM kills, redeployments — all handled. + # The framework leases, persists, recovers, and re-enters. + ... +``` + +One decorator. Crash-resilient long-running agents. The platform can see, lease, restart, and resume your work — no matter what happens to the sandbox. diff --git a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml index 5e19c7a03b89..abfbe502595e 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml @@ -23,11 +23,18 @@ keywords = ["azure", "azure sdk", "agent", "agentserver", "core"] dependencies = [ "starlette>=0.45.0", "hypercorn>=0.17.0", + "httpx>=0.27.0", + "aiohttp>=3.9.0,<4.0.0a0", "opentelemetry-api>=1.40.0", "opentelemetry-sdk>=1.40.0", "microsoft-opentelemetry>=1.0.0", ] +[project.optional-dependencies] +hosted = [ + "azure-identity>=1.16.0", +] + [build-system] requires = ["setuptools>=69", "wheel"] build-backend = "setuptools.build_meta" @@ -60,5 +67,7 @@ verifytypes = false latestdependency = false pylint = true type_check_samples = false +# apistub crashes on Generic[Input, Output] classes (Python 3.10 inspect.getsource bug) +apistub = false [tool.uv.sources] \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/durable_retry/durable_retry.py b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_retry/durable_retry.py new file mode 100644 index 000000000000..ef469da9dfd2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_retry/durable_retry.py @@ -0,0 +1,117 @@ +"""Durable task with retry policies. + +Demonstrates using ``RetryPolicy`` presets to automatically retry tasks +that fail with transient errors. + +Usage:: + + pip install azure-ai-agentserver-core + + python durable_retry.py + +.. note:: + + This sample uses a **file-based** task store for simplicity. + In production, a proper persistence store **must** be used. +""" + +from __future__ import annotations + +import asyncio +import logging +from datetime import timedelta + +from azure.ai.agentserver.core import AgentServerHost +from azure.ai.agentserver.core.durable import RetryPolicy, durable_task +from azure.ai.agentserver.core.durable._context import TaskContext + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Track call count to simulate transient failures +_call_count = 0 + + +@durable_task( + name="flaky_task", + retry=RetryPolicy.exponential_backoff( + max_attempts=4, + initial_delay=timedelta(milliseconds=100), + max_delay=timedelta(seconds=2), + ), +) +async def flaky_task(ctx: TaskContext[None]) -> str: + """Simulates a task that fails twice then succeeds. + + The exponential backoff policy retries up to 4 times with + increasing delays: 0.1s → 0.2s → 0.4s (capped at 2.0s). + """ + global _call_count # noqa: PLW0603 + _call_count += 1 + attempt = ctx.run_attempt + + logger.info("Attempt %d (call count=%d)", attempt, _call_count) + + if attempt < 2: + raise ConnectionError(f"Simulated transient error on attempt {attempt}") + + return f"Success after {attempt + 1} attempts" + + +@durable_task( + name="selective_retry", + retry=RetryPolicy( + initial_delay=timedelta(milliseconds=100), + max_delay=timedelta(milliseconds=100), + backoff_coefficient=1.0, + max_attempts=3, + retry_on=(ConnectionError, TimeoutError), + jitter=False, + ), +) +async def selective_retry_task(ctx: TaskContext[None]) -> str: + """Only retries ConnectionError and TimeoutError — not ValueError.""" + attempt = ctx.run_attempt + if attempt == 0: + raise ConnectionError("transient") + return f"Recovered on attempt {attempt}" + + +async def main(): + host = AgentServerHost() + manager = host._task_manager # noqa: SLF001 + + await manager.startup() + + try: + # Run with exponential backoff + logger.info("--- Exponential backoff demo ---") + result = await flaky_task.run(input=None) + logger.info("Result: %s", result.output) + + # Run with selective retry + logger.info("--- Selective retry demo ---") + result2 = await selective_retry_task.run(input=None) + logger.info("Result: %s", result2.output) + + # Show available presets + logger.info("--- Available retry presets ---") + presets = { + "exponential": RetryPolicy.exponential_backoff(), + "fixed": RetryPolicy.fixed_delay(), + "linear": RetryPolicy.linear_backoff(), + "none": RetryPolicy.no_retry(), + } + for name, policy in presets.items(): + logger.info( + " %s: max_attempts=%d, initial_delay=%.1fs", + name, + policy.max_attempts, + policy.initial_delay, + ) + finally: + await manager.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/durable_retry/requirements.txt b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_retry/requirements.txt new file mode 100644 index 000000000000..3f2b4e9ee6b4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_retry/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-core diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/durable_source/durable_source.py b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_source/durable_source.py new file mode 100644 index 000000000000..103e006fe1fb --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_source/durable_source.py @@ -0,0 +1,79 @@ +"""Durable task with source field tracking. + +Demonstrates using the ``source`` parameter to attach provenance +metadata at task creation time. The source is immutable after creation +and can be used for auditing, debugging, or routing. + +Usage:: + + pip install azure-ai-agentserver-core + + python durable_source.py + +.. note:: + + This sample uses a **file-based** task store for simplicity. + In production, a proper persistence store **must** be used. +""" + +from __future__ import annotations + +import asyncio +import logging + +from azure.ai.agentserver.core import AgentServerHost +from azure.ai.agentserver.core.durable import durable_task +from azure.ai.agentserver.core.durable._context import TaskContext + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@durable_task( + name="process_order", + source={"system": "order-service", "version": "2.1"}, +) +async def process_order_default(ctx: TaskContext[None]) -> dict: + """Task with source set at decorator level. + + The decorator-level source is used as a default — it can be + overridden at the call site. + """ + logger.info("Processing order with task_id=%s", ctx.task_id) + return {"status": "processed", "task_id": ctx.task_id} + + +async def main(): + host = AgentServerHost() + manager = host._task_manager # noqa: SLF001 + + await manager.startup() + + try: + # 1. Use decorator-level source (default) + logger.info("--- Decorator source ---") + result1 = await process_order_default.run(input={"order_id": "ORD-001"}) + logger.info("Result: %s", result1.output) + + # 2. Override source at call site + logger.info("--- Call-site source override ---") + result2 = await process_order_default.run( + input={"order_id": "ORD-002"}, + source={"system": "batch-processor", "batch_id": "B-42"}, + ) + logger.info("Result: %s", result2.output) + + # 3. Task without any source (None by default) + @durable_task(name="no_source_task") + async def no_source_task(ctx: TaskContext[None]) -> str: + return "done" + + logger.info("--- No source ---") + result3 = await no_source_task.run(input=None) + logger.info("Result: %s", result3.output) + finally: + await manager.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/durable_source/requirements.txt b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_source/requirements.txt new file mode 100644 index 000000000000..3f2b4e9ee6b4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_source/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-core diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/durable_streaming/durable_streaming.py b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_streaming/durable_streaming.py new file mode 100644 index 000000000000..af90178510a1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_streaming/durable_streaming.py @@ -0,0 +1,68 @@ +"""Durable task with streaming output. + +Demonstrates using ``ctx.stream()`` to emit incremental results from a +long-running task while the consumer iterates with ``async for``. + +The stream is in-memory only — items are **not** persisted. + +Usage:: + + pip install azure-ai-agentserver-core + + python durable_streaming.py + +.. note:: + + This sample uses a **file-based** task store for simplicity. + In production, a proper persistence store **must** be used. +""" + +from __future__ import annotations + +import asyncio +import logging + +from azure.ai.agentserver.core import AgentServerHost +from azure.ai.agentserver.core.durable import RetryPolicy, durable_task +from azure.ai.agentserver.core.durable._context import TaskContext + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@durable_task(name="stream_numbers") +async def stream_numbers(ctx: TaskContext[None]) -> str: + """Stream numbers 0-4 with a short delay, then return a summary.""" + for i in range(5): + await ctx.stream({"value": i, "message": f"Processing item {i}"}) + await asyncio.sleep(0.1) + return f"Streamed {5} items" + + +async def main(): + host = AgentServerHost() + manager = host._task_manager # noqa: SLF001 + + # Start the manager + await manager.startup() + + try: + # Start the task (non-blocking — returns a TaskRun handle) + run = await stream_numbers.start(input=None) + + # Consume streamed items as they arrive + items = [] + async for chunk in run: + logger.info("Received: %s", chunk) + items.append(chunk) + + # After streaming ends, get the final result + result = await run.result() + logger.info("Final result: %s", result.output) + logger.info("Total items streamed: %d", len(items)) + finally: + await manager.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/durable_streaming/requirements.txt b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_streaming/requirements.txt new file mode 100644 index 000000000000..3f2b4e9ee6b4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/durable_streaming/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-core diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/conftest.py index 27b136ce5de8..f4670c21cf8e 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/conftest.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/conftest.py @@ -11,7 +11,10 @@ def pytest_configure(config): - config.addinivalue_line("markers", "tracing_e2e: end-to-end tracing tests requiring live Azure resources") + config.addinivalue_line( + "markers", + "tracing_e2e: end-to-end tracing tests requiring live Azure resources", + ) @pytest.fixture() diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/__init__.py new file mode 100644 index 000000000000..d540fd20468c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/__init__.py @@ -0,0 +1,3 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_callable_factories.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_callable_factories.py new file mode 100644 index 000000000000..c6ba64b8b2fa --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_callable_factories.py @@ -0,0 +1,280 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for callable tag and description factories on @durable_task.""" + +from __future__ import annotations + +import uuid +from pathlib import Path +from typing import Any + +import pytest + +from azure.ai.agentserver.core.durable import ( + TaskContext, + durable_task, +) + + +class _ManagerFixture: + """Helper to set up a DurableTaskManager with local file storage.""" + + @staticmethod + async def setup(tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + @staticmethod + async def teardown(manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + +class TestCallableTags: + """Tests for callable tag factories on @durable_task.""" + + @pytest.mark.asyncio + async def test_static_tags_preserved(self, tmp_path): + """Static dict tags still work as before.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="static_tags", tags={"env": "prod"}, ephemeral=False) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = uuid.uuid4().hex + await my_task.run(task_id=task_id, input=None) + + task = await manager.provider.get(task_id) + assert task.tags["env"] == "prod" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_callable_tags_factory(self, tmp_path): + """Callable tags factory receives (input, task_id) and sets tags.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="callable_tags", + tags=lambda inp, tid: {"tenant": inp["tenant"], "tid": tid[:8]}, + ephemeral=False, + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = uuid.uuid4().hex + await my_task.run(task_id=task_id, input={"tenant": "acme"}) + + task = await manager.provider.get(task_id) + assert task.tags["tenant"] == "acme" + assert task.tags["tid"] == task_id[:8] + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_callable_tags_merged_with_callsite(self, tmp_path): + """Per-call tags merge on top of callable-resolved tags.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="merge_tags", + tags=lambda inp, tid: {"source": "factory"}, + ephemeral=False, + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = uuid.uuid4().hex + await my_task.run(task_id=task_id, input=None, tags={"extra": "call-site"}) + + task = await manager.provider.get(task_id) + assert task.tags["source"] == "factory" + assert task.tags["extra"] == "call-site" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_callable_tags_error_propagates(self, tmp_path): + """If callable tags factory raises, the error propagates at creation.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="bad_tags", + tags=lambda inp, tid: 1 / 0, # type: ignore[return-value] + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + with pytest.raises(ZeroDivisionError): + await my_task.run(task_id=uuid.uuid4().hex, input=None) + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +class TestCallableDescription: + """Tests for callable description factory on @durable_task.""" + + @pytest.mark.asyncio + async def test_static_description(self, tmp_path): + """Static string description is stored on the task record.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="static_desc", description="A static description", ephemeral=False + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = uuid.uuid4().hex + await my_task.run(task_id=task_id, input=None) + + task = await manager.provider.get(task_id) + assert task.description == "A static description" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_callable_description_factory(self, tmp_path): + """Callable description factory receives (input, task_id).""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="callable_desc", + description=lambda inp, tid: f"Processing {inp['doc']}", + ephemeral=False, + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = uuid.uuid4().hex + await my_task.run(task_id=task_id, input={"doc": "report.pdf"}) + + task = await manager.provider.get(task_id) + assert task.description == "Processing report.pdf" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_no_description_backward_compat(self, tmp_path): + """Without description, the task record has no description.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="no_desc", ephemeral=False) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = uuid.uuid4().hex + await my_task.run(task_id=task_id, input=None) + + task = await manager.provider.get(task_id) + assert task.description is None + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +class TestFactoryValidation: + """Tests for return type validation on callable factories.""" + + @pytest.mark.asyncio + async def test_tags_callable_bad_return_type(self, tmp_path): + """Tags callable returning non-dict raises TypeError.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="bad_tags_type", + tags=lambda inp, tid: "not-a-dict", # type: ignore[return-value] + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + with pytest.raises(TypeError, match="tags callable must return dict"): + await my_task.run(task_id=uuid.uuid4().hex, input=None) + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_description_callable_bad_return_type(self, tmp_path): + """Description callable returning non-str raises TypeError.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="bad_desc_type", + description=lambda inp, tid: 12345, # type: ignore[return-value] + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + with pytest.raises(TypeError, match="description callable must return str"): + await my_task.run(task_id=uuid.uuid4().hex, input=None) + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + def test_options_mixing_callable_and_dict_tags_raises(self): + """Mixing callable and dict tags in options() raises TypeError.""" + + @durable_task( + name="callable_tags_task", + tags=lambda inp, tid: {"k": "v"}, + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + with pytest.raises(TypeError, match="Cannot mix callable and dict"): + my_task.options(tags={"override": "val"}) + + def test_options_callable_to_callable_ok(self): + """Replacing callable tags with another callable in options() works.""" + + @durable_task( + name="callable_swap", + tags=lambda inp, tid: {"old": "factory"}, + ) + async def my_task(ctx: TaskContext[Any]) -> str: + return "done" + + updated = my_task.options( + tags=lambda inp, tid: {"new": "factory"}, + ) + assert callable(updated._opts.tags) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_cancellation_timeout.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_cancellation_timeout.py new file mode 100644 index 000000000000..82ff8f614a13 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_cancellation_timeout.py @@ -0,0 +1,215 @@ +"""Tests for cancellation and timeout features (spec 005). + +Covers: +- Execution timeout (cooperative cancel → hard cancel) +- Wait timeout (caller-side timeout on result()) +- Terminate (forced termination via TaskRun.terminate()) +""" + +from __future__ import annotations + +import asyncio +import uuid +from datetime import timedelta +from pathlib import Path +from typing import Any + +import pytest + +from azure.ai.agentserver.core.durable import ( + TaskContext, + TaskTerminated, + durable_task, +) + + +class _ManagerFixture: + """Helper to set up a DurableTaskManager with local file storage.""" + + @staticmethod + async def setup(tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + @staticmethod + async def teardown(manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + +# --------------------------------------------------------------------------- +# Execution timeout tests +# --------------------------------------------------------------------------- + + +class TestExecutionTimeout: + """Verify the timeout watchdog cooperatively and hard-cancels tasks.""" + + @pytest.mark.asyncio + async def test_timeout_cooperative_cancel(self, tmp_path): + """Task sees ctx.cancel set when timeout fires.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + cancel_observed = asyncio.Event() + + @durable_task( + name="timeout_coop", + timeout=timedelta(seconds=0.2), + ) + async def slow_task(ctx: TaskContext[Any]) -> str: + # Wait until cooperative cancel fires + while not ctx.cancel.is_set(): + await asyncio.sleep(0.01) + cancel_observed.set() + return "cooperated" + + run = await slow_task.start(task_id=uuid.uuid4().hex, input=None) + result = await run.result() + + assert cancel_observed.is_set() + assert result.output == "cooperated" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_no_timeout_regression(self, tmp_path): + """Task without timeout runs normally to completion.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="no_timeout") + async def quick_task(ctx: TaskContext[Any]) -> str: + return "done" + + run = await quick_task.start(task_id=uuid.uuid4().hex, input=None) + result = await run.result() + assert result.output == "done" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Terminate tests +# --------------------------------------------------------------------------- + + +class TestTerminate: + """Verify TaskRun.terminate() forces failure.""" + + @pytest.mark.asyncio + async def test_terminate_raises_task_terminated(self, tmp_path): + """terminate() causes result() to raise TaskTerminated.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="terminatable") + async def long_task(ctx: TaskContext[Any]) -> str: + await asyncio.sleep(100) + return "never" + + run = await long_task.start(task_id=uuid.uuid4().hex, input=None) + await asyncio.sleep(0.05) # let it start + + await run.terminate() + with pytest.raises(TaskTerminated): + await run.result() + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_terminate_sets_failure_status(self, tmp_path): + """Terminated task is stored as failed (not in_progress).""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="term_status", ephemeral=False) + async def long_task(ctx: TaskContext[Any]) -> str: + await asyncio.sleep(100) + return "never" + + task_id = uuid.uuid4().hex + run = await long_task.start(task_id=task_id, input=None) + await asyncio.sleep(0.05) + + await run.terminate() + with pytest.raises(TaskTerminated): + await run.result() + + # Give manager time to persist failure + await asyncio.sleep(0.1) + + info = await manager.provider.get(task_id) + assert info is not None + # Failures are stored as "completed" with an error dict + assert info.status == "completed" + assert info.error is not None + assert info.error["type"] == "TaskTerminated" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_cancel_vs_terminate_distinction(self, tmp_path): + """Cooperative cancel (ctx.cancel) raises TaskCancelled, not TaskTerminated.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + from azure.ai.agentserver.core.durable._exceptions import TaskCancelled + + @durable_task(name="cancel_test") + async def cancellable_task(ctx: TaskContext[Any]) -> str: + # Cooperatively check cancel + while not ctx.cancel.is_set(): + await asyncio.sleep(0.01) + raise asyncio.CancelledError() + + run = await cancellable_task.start(task_id=uuid.uuid4().hex, input=None) + await asyncio.sleep(0.05) + + # Use cancel (not terminate) — cooperative + await run.cancel() + with pytest.raises(TaskCancelled): + await run.result() + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_terminate_reason_propagated(self, tmp_path): + """Terminate reason is propagated to TaskTerminated exception.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="term_reason_task") + async def slow_task(ctx: TaskContext[Any]) -> str: + await asyncio.sleep(10) + return "never" + + run = await slow_task.start(task_id=uuid.uuid4().hex, input=None) + await asyncio.sleep(0.05) + + await run.terminate(reason="user requested stop") + with pytest.raises(TaskTerminated) as exc_info: + await run.result() + assert exc_info.value.reason == "user requested stop" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_decorator.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_decorator.py new file mode 100644 index 000000000000..76aae10f0d0a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_decorator.py @@ -0,0 +1,157 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for @durable_task decorator and DurableTask class.""" + +import asyncio + +import pytest + +from azure.ai.agentserver.core.durable import ( + DurableTask, + DurableTaskOptions, + TaskContext, + durable_task, +) + + +class TestDurableTaskDecorator: + """Tests for the @durable_task decorator.""" + + def test_bare_decorator(self) -> None: + """@durable_task with no arguments produces a DurableTask.""" + + @durable_task + async def my_task(ctx: TaskContext[str]) -> int: + return 42 + + assert isinstance(my_task, DurableTask) + # Name includes class/method scope when defined inside a method + assert "my_task" in my_task.name + + def test_decorator_with_name(self) -> None: + """@durable_task(name=...) sets a custom name.""" + + @durable_task(name="custom_name") + async def my_task(ctx: TaskContext[str]) -> int: + return 0 + + assert my_task.name == "custom_name" + + def test_decorator_with_all_options(self) -> None: + """All decorator options are forwarded to DurableTaskOptions.""" + from datetime import timedelta + + @durable_task( + name="full", + ephemeral=False, + lease_duration_seconds=120, + store_input=True, + title="My Title", + tags={"env": "test"}, + timeout=timedelta(minutes=5), + ) + async def my_task(ctx: TaskContext[dict]) -> str: + return "" + + assert my_task.name == "full" + assert my_task._opts.ephemeral is False + assert my_task._opts.lease_duration_seconds == 120 + assert my_task._opts.store_input is True + assert my_task._opts.title == "My Title" + assert my_task._opts.tags == {"env": "test"} + assert my_task._opts.timeout == timedelta(minutes=5) + + def test_rejects_sync_function(self) -> None: + """@durable_task rejects synchronous functions.""" + with pytest.raises(TypeError, match="async function"): + + @durable_task + def sync_fn(ctx: TaskContext[str]) -> int: + return 1 + + def test_rejects_non_callable(self) -> None: + """@durable_task(...) rejects non-callable objects.""" + with pytest.raises((TypeError, AttributeError)): + durable_task(42) # type: ignore[arg-type] + + +class TestDurableTaskOptions: + """Tests for DurableTaskOptions merge via .options().""" + + def test_options_returns_new_instance(self) -> None: + """options() returns a new DurableTask, original unchanged.""" + + @durable_task(ephemeral=True) + async def my_task(ctx: TaskContext[str]) -> int: + return 1 + + updated = my_task.options(ephemeral=False) + assert updated is not my_task + assert updated._opts.ephemeral is False + assert my_task._opts.ephemeral is True + + def test_options_merges_tags(self) -> None: + """options() merges tags with existing ones.""" + + @durable_task(tags={"a": "1"}) + async def my_task(ctx: TaskContext[str]) -> int: + return 1 + + updated = my_task.options(tags={"b": "2"}) + assert updated._opts.tags == {"a": "1", "b": "2"} + + def test_options_overrides_title(self) -> None: + """options() overrides title.""" + + @durable_task(title="original") + async def my_task(ctx: TaskContext[str]) -> int: + return 1 + + updated = my_task.options(title="override") + assert updated._opts.title == "override" + + def test_default_options(self) -> None: + """Default DurableTaskOptions has sensible defaults.""" + + @durable_task + async def my_task(ctx: TaskContext[str]) -> int: + return 1 + + opts = my_task._opts + assert opts.ephemeral is True + assert opts.lease_duration_seconds == 60 + assert opts.store_input is True # default is True + assert opts.tags == {} + assert opts.timeout is None + + +class TestTypeExtraction: + """Tests for generic type parameter extraction.""" + + def test_input_type_str(self) -> None: + """Extracts str as Input type from TaskContext[str].""" + + @durable_task + async def my_task(ctx: TaskContext[str]) -> int: + return 1 + + assert my_task._input_type is str + + def test_input_type_dict(self) -> None: + """Extracts dict as Input type.""" + + @durable_task + async def my_task(ctx: TaskContext[dict]) -> str: + return "" + + assert my_task._input_type is dict + + def test_output_type_int(self) -> None: + """Extracts int as Output type from return annotation.""" + + @durable_task + async def my_task(ctx: TaskContext[str]) -> int: + return 1 + + assert my_task._output_type is int diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_entry_mode.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_entry_mode.py new file mode 100644 index 000000000000..1a888eab5c8f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_entry_mode.py @@ -0,0 +1,181 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for TaskContext.entry_mode across all lifecycle paths.""" + +from pathlib import Path + +import pytest + +from azure.ai.agentserver.core.durable import ( + TaskContext, + durable_task, +) + + +class TestEntryMode: + """Verify ctx.entry_mode is set correctly for each lifecycle path.""" + + async def _setup_manager(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + async def _teardown_manager(self, manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + @pytest.mark.asyncio + async def test_fresh_start_entry_mode(self, tmp_path) -> None: + """First call to .run() produces entry_mode='fresh'.""" + observed_modes: list[str] = [] + + @durable_task(title="test-fresh") + async def my_task(ctx: TaskContext[str]) -> str: + observed_modes.append(ctx.entry_mode) + return "done" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + result = await my_task.run(task_id="fresh-1", input="hello") + assert result.output == "done" + assert observed_modes == ["fresh"] + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_developer_resume_entry_mode(self, tmp_path) -> None: + """Calling .run() on a suspended task produces entry_mode='resumed' with new input.""" + observed: list[tuple[str, str]] = [] + + @durable_task(title="test-resume", ephemeral=False) + async def my_task(ctx: TaskContext[str]) -> str: + observed.append((ctx.entry_mode, ctx.input)) + return await ctx.suspend(output={"partial": True}) + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + # First call — fresh start, suspends + result1 = await my_task.run(task_id="resume-1", input="turn-1") + assert result1.is_suspended + assert observed == [("fresh", "turn-1")] + + # Second call — should resume with new input + result2 = await my_task.run(task_id="resume-1", input="turn-2") + assert result2.is_suspended + assert observed[-1] == ("resumed", "turn-2") + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_platform_resume_entry_mode(self, tmp_path) -> None: + """Platform-initiated resume (handle_resume) produces entry_mode='resumed'.""" + observed: list[str] = [] + + @durable_task(title="test-platform-resume", ephemeral=False) + async def my_task(ctx: TaskContext[str]) -> str: + observed.append(ctx.entry_mode) + return await ctx.suspend(output="waiting") + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + # Fresh start — suspends + result = await my_task.run(task_id="platform-resume-1", input="init") + assert result.is_suspended + assert observed == ["fresh"] + + # Platform-initiated resume + await manager.handle_resume("platform-resume-1") + # Give the background task time to run + import asyncio + + await asyncio.sleep(0.2) + assert "resumed" in observed + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_recovered_entry_mode(self, tmp_path) -> None: + """Calling .run() on a stale in_progress task produces entry_mode='recovered'.""" + observed: list[str] = [] + + @durable_task(title="test-recover", ephemeral=False) + async def my_task(ctx: TaskContext[str]) -> str: + observed.append(ctx.entry_mode) + return "recovered-ok" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + from azure.ai.agentserver.core.durable._models import ( + TaskCreateRequest, + ) + + # Manually create a stale in_progress task + await manager.provider.create( + TaskCreateRequest( + id="stale-1", + agent_name="test-agent", + session_id="test-session", + status="in_progress", + title="stale-test", + payload={"input": "old-data"}, + ) + ) + + # Backdate the updated_at to make it stale + task_file = ( + Path(str(tmp_path)) / "test-agent" / "test-session" / "stale-1.json" + ) + if task_file.exists(): + import json + + data = json.loads(task_file.read_text()) + data["updated_at"] = "2020-01-01T00:00:00+00:00" + task_file.write_text(json.dumps(data)) + + result = await my_task.run( + task_id="stale-1", + input="new-data", + stale_timeout=1.0, + ) + assert result.output == "recovered-ok" + assert observed == ["recovered"] + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_ignoring_entry_mode_works(self, tmp_path) -> None: + """A function that never reads entry_mode still works fine.""" + + @durable_task(title="test-ignore") + async def my_task(ctx: TaskContext[str]) -> str: + # Deliberately NOT reading ctx.entry_mode + return f"processed: {ctx.input}" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + result = await my_task.run(task_id="ignore-1", input="data") + assert result.output == "processed: data" + finally: + await self._teardown_manager(manager, mgr_mod) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_get.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_get.py new file mode 100644 index 000000000000..8da515a20cb0 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_get.py @@ -0,0 +1,140 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for DurableTask.get() method.""" + +from pathlib import Path + +import pytest + +from azure.ai.agentserver.core.durable import ( + TaskContext, + durable_task, +) + + +class TestGet: + """Verify DurableTask.get() returns TaskInfo or None.""" + + async def _setup_manager(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + async def _teardown_manager(self, manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + @pytest.mark.asyncio + async def test_get_existing_task(self, tmp_path) -> None: + """get() returns TaskInfo for an existing task.""" + + @durable_task(title="get-test", ephemeral=False) + async def my_task(ctx: TaskContext[str]) -> str: + return await ctx.suspend(output="paused") + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + result = await my_task.run(task_id="get-1", input="data") + assert result.is_suspended + + info = await my_task.get("get-1") + assert info is not None + assert info.id == "get-1" + assert info.status == "suspended" + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_get_nonexistent_task(self, tmp_path) -> None: + """get() returns None for a non-existent task.""" + + @durable_task(title="get-test") + async def my_task(ctx: TaskContext[str]) -> str: + return "ok" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + info = await my_task.get("does-not-exist") + assert info is None + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_get_returns_correct_state(self, tmp_path) -> None: + """get() returns correct info for various task states.""" + + @durable_task(title="get-states", ephemeral=False) + async def my_task(ctx: TaskContext[str]) -> str: + return await ctx.suspend(output="waiting") + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + # Create tasks in different states via the provider + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="state-suspended", + agent_name="test-agent", + session_id="test-session", + status="suspended", + title="suspended-task", + payload={"output": "half-done"}, + ) + ) + await manager.provider.create( + TaskCreateRequest( + id="state-completed", + agent_name="test-agent", + session_id="test-session", + status="completed", + title="done-task", + payload={"output": "final"}, + ) + ) + await manager.provider.create( + TaskCreateRequest( + id="state-in-progress", + agent_name="test-agent", + session_id="test-session", + status="in_progress", + title="running-task", + payload={}, + ) + ) + + suspended = await my_task.get("state-suspended") + assert suspended is not None + assert suspended.status == "suspended" + + completed = await my_task.get("state-completed") + assert completed is not None + assert completed.status == "completed" + + in_progress = await my_task.get("state-in-progress") + assert in_progress is not None + assert in_progress.status == "in_progress" + finally: + await self._teardown_manager(manager, mgr_mod) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_lifecycle.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_lifecycle.py new file mode 100644 index 000000000000..cdc3f7ced790 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_lifecycle.py @@ -0,0 +1,321 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for lifecycle-aware .run() and .start() on DurableTask.""" + +import json +from pathlib import Path + +import pytest + +from azure.ai.agentserver.core.durable import ( + TaskContext, + durable_task, +) +from azure.ai.agentserver.core.durable._exceptions import TaskConflictError + + +class TestLifecycle: + """Verify .run()/.start() lifecycle automation.""" + + async def _setup_manager(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + async def _teardown_manager(self, manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + def _create_stale_task(self, tmp_path, task_id, status="in_progress"): + """Write a stale task file directly to simulate a crashed task.""" + from azure.ai.agentserver.core.durable._models import ( + TaskCreateRequest, + ) + import asyncio + + async def _create(provider): + await provider.create( + TaskCreateRequest( + id=task_id, + agent_name="test-agent", + session_id="test-session", + status=status, + title="stale-test", + payload={"input": "old-data"}, + ) + ) + + return _create + + def _backdate_task(self, tmp_path, task_id): + """Set updated_at far in the past.""" + task_file = ( + Path(str(tmp_path)) / "test-agent" / "test-session" / f"{task_id}.json" + ) + if task_file.exists(): + data = json.loads(task_file.read_text()) + data["updated_at"] = "2020-01-01T00:00:00+00:00" + task_file.write_text(json.dumps(data)) + + @pytest.mark.asyncio + async def test_run_fresh_no_existing_task(self, tmp_path) -> None: + """run() on non-existent task → creates and starts, entry_mode='fresh'.""" + observed_mode: list[str] = [] + + @durable_task(title="lifecycle-fresh") + async def my_task(ctx: TaskContext[str]) -> str: + observed_mode.append(ctx.entry_mode) + return "result" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + result = await my_task.run(task_id="lc-fresh-1", input="data") + assert result.output == "result" + assert observed_mode == ["fresh"] + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_run_pending_task(self, tmp_path) -> None: + """run() on pending task → starts it, entry_mode='fresh'.""" + observed_mode: list[str] = [] + + @durable_task(title="lifecycle-pending") + async def my_task(ctx: TaskContext[str]) -> str: + observed_mode.append(ctx.entry_mode) + return "started" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="lc-pending-1", + agent_name="test-agent", + session_id="test-session", + status="pending", + title="pending-test", + payload={"input": "pending-data"}, + ) + ) + result = await my_task.run(task_id="lc-pending-1", input="new-data") + assert result.output == "started" + assert observed_mode == ["fresh"] + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_run_suspended_task(self, tmp_path) -> None: + """run() on suspended task → resumes with new input, entry_mode='resumed'.""" + observed: list[tuple[str, str]] = [] + + @durable_task(title="lifecycle-resume", ephemeral=False) + async def my_task(ctx: TaskContext[str]) -> str: + observed.append((ctx.entry_mode, ctx.input)) + return await ctx.suspend(output="waiting") + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + result1 = await my_task.run(task_id="lc-resume-1", input="turn-1") + assert result1.is_suspended + assert observed[-1] == ("fresh", "turn-1") + + result2 = await my_task.run(task_id="lc-resume-1", input="turn-2") + assert result2.is_suspended + assert observed[-1] == ("resumed", "turn-2") + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_run_in_progress_not_stale_raises(self, tmp_path) -> None: + """run() on in_progress (not stale) task → TaskConflictError.""" + + @durable_task(title="lifecycle-conflict") + async def my_task(ctx: TaskContext[str]) -> str: + return "never" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="lc-conflict-1", + agent_name="test-agent", + session_id="test-session", + status="in_progress", + title="running-test", + payload={}, + ) + ) + with pytest.raises(TaskConflictError) as exc_info: + await my_task.run(task_id="lc-conflict-1", input="data") + assert exc_info.value.task_id == "lc-conflict-1" + assert exc_info.value.current_status == "in_progress" + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_run_stale_task_recovers(self, tmp_path) -> None: + """run() on stale in_progress task → recovers, entry_mode='recovered'.""" + observed_mode: list[str] = [] + + @durable_task(title="lifecycle-stale") + async def my_task(ctx: TaskContext[str]) -> str: + observed_mode.append(ctx.entry_mode) + return "recovered" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="lc-stale-1", + agent_name="test-agent", + session_id="test-session", + status="in_progress", + title="stale-test", + payload={"input": "old"}, + ) + ) + self._backdate_task(tmp_path, "lc-stale-1") + + result = await my_task.run( + task_id="lc-stale-1", + input="new", + stale_timeout=1.0, + ) + assert result.output == "recovered" + assert observed_mode == ["recovered"] + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_run_completed_task_raises(self, tmp_path) -> None: + """run() on completed task → TaskConflictError (no restart).""" + + @durable_task(title="lifecycle-completed") + async def my_task(ctx: TaskContext[str]) -> str: + return "never" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="lc-completed-1", + agent_name="test-agent", + session_id="test-session", + status="completed", + title="done-test", + payload={"output": "final"}, + ) + ) + with pytest.raises(TaskConflictError) as exc_info: + await my_task.run(task_id="lc-completed-1", input="data") + assert exc_info.value.current_status == "completed" + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_start_follows_lifecycle_rules(self, tmp_path) -> None: + """start() follows same lifecycle rules as run() — fresh + conflict.""" + observed_mode: list[str] = [] + + @durable_task(title="lifecycle-start") + async def my_task(ctx: TaskContext[str]) -> str: + observed_mode.append(ctx.entry_mode) + return "started" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + # Fresh start via .start() + handle = await my_task.start(task_id="lc-start-1", input="data") + result = await handle.result() + assert result.output == "started" + assert observed_mode == ["fresh"] + + # Conflict: create in_progress task and try .start() + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="lc-start-conflict", + agent_name="test-agent", + session_id="test-session", + status="in_progress", + title="running", + payload={}, + ) + ) + with pytest.raises(TaskConflictError): + await my_task.start(task_id="lc-start-conflict", input="data") + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_stale_timeout_parameter(self, tmp_path) -> None: + """stale_timeout controls when in_progress is considered stale.""" + + @durable_task(title="stale-timeout") + async def my_task(ctx: TaskContext[str]) -> str: + return "ok" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + from azure.ai.agentserver.core.durable._models import TaskCreateRequest + + await manager.provider.create( + TaskCreateRequest( + id="lc-timeout-1", + agent_name="test-agent", + session_id="test-session", + status="in_progress", + title="timeout-test", + payload={"input": "old"}, + ) + ) + self._backdate_task(tmp_path, "lc-timeout-1") + + # Very large timeout → not stale → conflict + with pytest.raises(TaskConflictError): + await my_task.run( + task_id="lc-timeout-1", + input="new", + stale_timeout=999999999.0, + ) + + # Small timeout → stale → recover + result = await my_task.run( + task_id="lc-timeout-1", + input="new", + stale_timeout=1.0, + ) + assert result.output == "ok" + finally: + await self._teardown_manager(manager, mgr_mod) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_local_provider.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_local_provider.py new file mode 100644 index 000000000000..62d66fd3e5ee --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_local_provider.py @@ -0,0 +1,182 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for the LocalFileDurableTaskProvider.""" + +import json +from pathlib import Path +from typing import Any + +import pytest + +from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, +) +from azure.ai.agentserver.core.durable._models import ( + TaskCreateRequest, + TaskPatchRequest, +) + + +@pytest.fixture +def provider(tmp_path: Path) -> LocalFileDurableTaskProvider: + """Create a local provider backed by a temp directory.""" + return LocalFileDurableTaskProvider(base_dir=tmp_path) + + +@pytest.fixture +def sample_create_request() -> TaskCreateRequest: + """A minimal task creation request.""" + return TaskCreateRequest( + agent_name="test-agent", + session_id="session-001", + status="pending", + payload={"input": {"data": "hello"}}, + lease_owner="owner-1", + lease_instance_id="inst-1", + lease_duration_seconds=60, + ) + + +class TestLocalProviderCRUD: + """Create, read, update operations on the local provider.""" + + @pytest.mark.asyncio + async def test_create_and_get( + self, + provider: LocalFileDurableTaskProvider, + sample_create_request: TaskCreateRequest, + ) -> None: + """create returns a TaskInfo; get retrieves it.""" + task = await provider.create(sample_create_request) + assert task.id + assert task.status == "pending" + assert task.agent_name == "test-agent" + + fetched = await provider.get(task.id) + assert fetched is not None + assert fetched.id == task.id + + @pytest.mark.asyncio + async def test_update_status( + self, + provider: LocalFileDurableTaskProvider, + sample_create_request: TaskCreateRequest, + ) -> None: + """update changes the status.""" + task = await provider.create(sample_create_request) + patch = TaskPatchRequest( + status="in_progress", + if_match=task.etag, + ) + updated = await provider.update(task.id, patch) + assert updated.status == "in_progress" + + @pytest.mark.asyncio + async def test_update_payload( + self, + provider: LocalFileDurableTaskProvider, + sample_create_request: TaskCreateRequest, + ) -> None: + """update merges payload.""" + task = await provider.create(sample_create_request) + patch = TaskPatchRequest( + payload={"output": {"result": 42}}, + if_match=task.etag, + ) + updated = await provider.update(task.id, patch) + assert updated.payload is not None + assert updated.payload["output"]["result"] == 42 + # Original input preserved + assert updated.payload["input"]["data"] == "hello" + + @pytest.mark.asyncio + async def test_etag_mismatch_raises( + self, + provider: LocalFileDurableTaskProvider, + sample_create_request: TaskCreateRequest, + ) -> None: + """update raises on ETag mismatch.""" + task = await provider.create(sample_create_request) + patch = TaskPatchRequest( + status="in_progress", + if_match="wrong-etag", + ) + with pytest.raises(ValueError, match="ETag mismatch"): + await provider.update(task.id, patch) + + @pytest.mark.asyncio + async def test_get_nonexistent_returns_none( + self, provider: LocalFileDurableTaskProvider + ) -> None: + """get returns None for nonexistent task.""" + result = await provider.get("nonexistent-id") + assert result is None + + @pytest.mark.asyncio + async def test_delete_task( + self, + provider: LocalFileDurableTaskProvider, + sample_create_request: TaskCreateRequest, + ) -> None: + """delete removes a task.""" + task = await provider.create(sample_create_request) + await provider.delete(task.id) + result = await provider.get(task.id) + assert result is None + + +class TestLocalProviderListing: + """Tests for listing/querying tasks.""" + + @pytest.mark.asyncio + async def test_list_tasks_by_agent( + self, provider: LocalFileDurableTaskProvider + ) -> None: + """list filters by agent_name and session_id.""" + req1 = TaskCreateRequest( + agent_name="agent-a", + session_id="s1", + status="pending", + payload={}, + ) + req2 = TaskCreateRequest( + agent_name="agent-b", + session_id="s1", + status="pending", + payload={}, + ) + await provider.create(req1) + await provider.create(req2) + + tasks = await provider.list(agent_name="agent-a", session_id="s1") + assert len(tasks) == 1 + assert tasks[0].agent_name == "agent-a" + + @pytest.mark.asyncio + async def test_list_tasks_by_status( + self, provider: LocalFileDurableTaskProvider + ) -> None: + """list filters by status.""" + req = TaskCreateRequest( + agent_name="agent", + session_id="s1", + status="pending", + payload={}, + ) + task = await provider.create(req) + patch = TaskPatchRequest( + status="in_progress", + if_match=task.etag, + ) + await provider.update(task.id, patch) + + pending = await provider.list( + agent_name="agent", session_id="s1", status="pending" + ) + assert len(pending) == 0 + + active = await provider.list( + agent_name="agent", session_id="s1", status="in_progress" + ) + assert len(active) == 1 diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_metadata.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_metadata.py new file mode 100644 index 000000000000..8bafd3bc8102 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_metadata.py @@ -0,0 +1,247 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for TaskMetadata operations (set, get, increment, append, flush).""" + +import asyncio +from typing import Any + +import pytest + +from azure.ai.agentserver.core.durable._metadata import TaskMetadata + + +class TestTaskMetadataOperations: + """Tests for basic metadata operations.""" + + def test_set_and_get(self) -> None: + """set() stores a value, get() retrieves it.""" + meta = TaskMetadata() + meta.set("key", "value") + assert meta.get("key") == "value" + + def test_get_default(self) -> None: + """get() returns default when key is missing.""" + meta = TaskMetadata() + assert meta.get("missing") is None + assert meta.get("missing", 42) == 42 + + def test_set_marks_dirty(self) -> None: + """set() marks the metadata as dirty.""" + meta = TaskMetadata() + assert not meta._dirty + meta.set("key", "value") + assert meta._dirty + + def test_increment(self) -> None: + """increment() increases a counter by the given amount.""" + meta = TaskMetadata() + meta.increment("counter") + assert meta.get("counter") == 1 + meta.increment("counter", 5) + assert meta.get("counter") == 6 + + def test_increment_non_numeric_raises(self) -> None: + """increment() raises TypeError on non-numeric existing value.""" + meta = TaskMetadata() + meta.set("key", "not a number") + with pytest.raises(TypeError): + meta.increment("key") + + def test_append(self) -> None: + """append() adds items to a list.""" + meta = TaskMetadata() + meta.append("log", "entry1") + meta.append("log", "entry2") + assert meta.get("log") == ["entry1", "entry2"] + + def test_append_non_list_raises(self) -> None: + """append() raises TypeError when existing value is not a list.""" + meta = TaskMetadata() + meta.set("key", "not a list") + with pytest.raises(TypeError): + meta.append("key", "item") + + def test_snapshot_returns_copy(self) -> None: + """Snapshot returns a copy, not a reference.""" + meta = TaskMetadata() + meta.set("key", "value") + snap = dict(meta._data) + meta.set("key", "changed") + assert snap["key"] == "value" + assert meta.get("key") == "changed" + + +class TestTaskMetadataFlush: + """Tests for flush and auto-flush behavior.""" + + @pytest.mark.asyncio + async def test_flush_calls_callback(self) -> None: + """flush() calls the flush_callback with current data.""" + captured: list[dict[str, Any]] = [] + + async def callback(data: dict[str, Any]) -> None: + captured.append(data) + + meta = TaskMetadata(flush_callback=callback) + meta.set("key", "value") + await meta.flush() + + assert len(captured) == 1 + assert captured[0]["key"] == "value" + + @pytest.mark.asyncio + async def test_flush_clears_dirty(self) -> None: + """flush() clears the dirty flag after success.""" + + async def callback(data: dict[str, Any]) -> None: + pass + + meta = TaskMetadata(flush_callback=callback) + meta.set("key", "value") + assert meta._dirty + await meta.flush() + assert not meta._dirty + + @pytest.mark.asyncio + async def test_flush_noop_when_clean(self) -> None: + """flush() is a no-op when metadata is not dirty.""" + call_count = 0 + + async def callback(data: dict[str, Any]) -> None: + nonlocal call_count + call_count += 1 + + meta = TaskMetadata(flush_callback=callback) + await meta.flush() + assert call_count == 0 + + @pytest.mark.asyncio + async def test_flush_noop_without_callback(self) -> None: + """flush() is a no-op without a callback configured.""" + meta = TaskMetadata() + meta.set("key", "value") + # Should not raise + await meta.flush() + + @pytest.mark.asyncio + async def test_stop_auto_flush_final_flush(self) -> None: + """stop_auto_flush() does a final flush before stopping.""" + captured: list[dict[str, Any]] = [] + + async def callback(data: dict[str, Any]) -> None: + captured.append(data) + + meta = TaskMetadata(flush_callback=callback, flush_interval=100) + meta.start_auto_flush() + meta.set("key", "value") + await meta.stop_auto_flush() + + assert len(captured) == 1 + assert captured[0]["key"] == "value" + + +class TestTaskMetadataDictProtocol: + """Tests for dict-like access (MutableMapping protocol).""" + + def test_setitem_getitem(self) -> None: + """[] assignment and retrieval works.""" + meta = TaskMetadata() + meta["key"] = "value" + assert meta["key"] == "value" + + def test_getitem_missing_raises_keyerror(self) -> None: + """[] on missing key raises KeyError.""" + meta = TaskMetadata() + with pytest.raises(KeyError): + _ = meta["missing"] + + def test_setitem_marks_dirty(self) -> None: + """[] assignment marks metadata as dirty.""" + meta = TaskMetadata() + assert not meta._dirty + meta["key"] = "value" + assert meta._dirty + + def test_setitem_non_string_key_raises(self) -> None: + """[] with non-string key raises TypeError.""" + meta = TaskMetadata() + with pytest.raises(TypeError): + meta[42] = "value" # type: ignore[index] + + def test_delitem(self) -> None: + """del removes a key and marks dirty.""" + meta = TaskMetadata() + meta["key"] = "value" + meta._dirty = False + del meta["key"] + assert "key" not in meta + assert meta._dirty + + def test_delitem_missing_raises_keyerror(self) -> None: + """del on missing key raises KeyError.""" + meta = TaskMetadata() + with pytest.raises(KeyError): + del meta["missing"] + + def test_contains(self) -> None: + """'in' operator works.""" + meta = TaskMetadata() + meta["key"] = "value" + assert "key" in meta + assert "missing" not in meta + + def test_len(self) -> None: + """len() returns number of keys.""" + meta = TaskMetadata() + assert len(meta) == 0 + meta["a"] = 1 + meta["b"] = 2 + assert len(meta) == 2 + + def test_iter(self) -> None: + """Iteration yields keys.""" + meta = TaskMetadata() + meta["a"] = 1 + meta["b"] = 2 + assert sorted(meta) == ["a", "b"] + + def test_keys_values_items(self) -> None: + """keys(), values(), items() delegate to internal dict.""" + meta = TaskMetadata() + meta["x"] = 10 + meta["y"] = 20 + assert set(meta.keys()) == {"x", "y"} + assert set(meta.values()) == {10, 20} + assert set(meta.items()) == {("x", 10), ("y", 20)} + + def test_isinstance_mutable_mapping(self) -> None: + """TaskMetadata is registered as MutableMapping.""" + import collections.abc + + meta = TaskMetadata() + assert isinstance(meta, collections.abc.MutableMapping) + + def test_existing_methods_still_work(self) -> None: + """Existing .set(), .get(), .increment(), .append() are unchanged.""" + meta = TaskMetadata() + meta.set("counter", 0) + meta.increment("counter", 5) + assert meta.get("counter") == 5 + meta.append("log", "entry") + assert meta.get("log") == ["entry"] + assert meta.to_dict() == {"counter": 5, "log": ["entry"]} + + @pytest.mark.asyncio + async def test_setitem_triggers_auto_flush(self) -> None: + """[] assignment triggers flush via dirty-tracking.""" + captured: list[dict[str, Any]] = [] + + async def callback(data: dict[str, Any]) -> None: + captured.append(data) + + meta = TaskMetadata(flush_callback=callback) + meta["key"] = "value" + await meta.flush() + assert len(captured) == 1 + assert captured[0]["key"] == "value" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_models.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_models.py new file mode 100644 index 000000000000..e1e3d43de37c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_models.py @@ -0,0 +1,115 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for data models and exceptions.""" + +import pytest + +from azure.ai.agentserver.core.durable._models import ( + TaskCreateRequest, + TaskInfo, + TaskPatchRequest, +) +from azure.ai.agentserver.core.durable._exceptions import ( + TaskCancelled, + TaskFailed, + TaskNotFound, + TaskSuspended, +) + + +class TestTaskStatus: + """Tests for TaskStatus literal type.""" + + def test_valid_status_strings(self) -> None: + """Valid status values are plain strings.""" + statuses = ["pending", "in_progress", "suspended", "completed"] + for s in statuses: + assert isinstance(s, str) + + +class TestTaskCreateRequest: + """Tests for TaskCreateRequest.""" + + def test_minimal(self) -> None: + """Minimal request has required fields.""" + req = TaskCreateRequest( + agent_name="agent", + session_id="test-session", + status="pending", + payload={}, + ) + assert req.agent_name == "agent" + assert req.status == "pending" + + def test_default_status(self) -> None: + """Default status is 'pending'.""" + req = TaskCreateRequest( + agent_name="agent", + session_id="test-session", + ) + assert req.status == "pending" + + def test_optional_fields_default_none(self) -> None: + """Optional fields default to None.""" + req = TaskCreateRequest( + agent_name="agent", + session_id="test-session", + ) + assert req.lease_owner is None + assert req.lease_instance_id is None + assert req.lease_duration_seconds is None + assert req.id is None + assert req.title is None + + +class TestTaskPatchRequest: + """Tests for TaskPatchRequest.""" + + def test_empty_patch(self) -> None: + """An empty patch is valid.""" + patch = TaskPatchRequest() + assert patch.status is None + assert patch.payload is None + assert patch.if_match is None + + def test_status_patch(self) -> None: + """Patch can set status.""" + patch = TaskPatchRequest(status="in_progress") + assert patch.status == "in_progress" + + +class TestExceptions: + """Tests for custom durable task exceptions.""" + + def test_task_failed_message(self) -> None: + """TaskFailed stores task_id and error.""" + exc = TaskFailed("task-1", error={"message": "boom", "type": "ValueError"}) + assert exc.task_id == "task-1" + assert "boom" in str(exc) + assert exc.error["type"] == "ValueError" + + def test_task_suspended_reason(self) -> None: + """TaskSuspended stores task_id and reason.""" + exc = TaskSuspended("task-2", reason="waiting for approval") + assert exc.task_id == "task-2" + assert "waiting for approval" in str(exc) + + def test_task_cancelled(self) -> None: + """TaskCancelled stores task_id.""" + exc = TaskCancelled("task-3") + assert exc.task_id == "task-3" + assert "task-3" in str(exc) + + def test_task_not_found(self) -> None: + """TaskNotFound stores task_id.""" + exc = TaskNotFound("task-123") + assert exc.task_id == "task-123" + assert "task-123" in str(exc) + + def test_exception_hierarchy(self) -> None: + """All exceptions inherit from Exception.""" + assert issubclass(TaskFailed, Exception) + assert issubclass(TaskSuspended, Exception) + assert issubclass(TaskCancelled, Exception) + assert issubclass(TaskNotFound, Exception) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_resume_route.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_resume_route.py new file mode 100644 index 000000000000..8e48069b5f2a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_resume_route.py @@ -0,0 +1,95 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for the resume HTTP route.""" + +import json +from unittest.mock import AsyncMock, patch + +import pytest +from starlette.testclient import TestClient +from starlette.applications import Starlette + +from azure.ai.agentserver.core.durable._resume_route import create_resume_route + + +def _build_test_app() -> Starlette: + """Create a minimal Starlette app with the resume route.""" + return Starlette(routes=[create_resume_route()]) + + +class TestResumeRoute: + """Tests for POST /tasks/resume.""" + + def test_missing_body_returns_400(self) -> None: + """Request without body returns 400.""" + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", content=b"not json") + assert resp.status_code == 400 + + def test_missing_task_id_returns_400(self) -> None: + """Request without task_id returns 400.""" + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", json={}) + assert resp.status_code == 400 + + def test_non_string_task_id_returns_400(self) -> None: + """Request with non-string task_id returns 400.""" + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", json={"task_id": 123}) + assert resp.status_code == 400 + + @patch("azure.ai.agentserver.core.durable._manager.get_task_manager") + def test_successful_resume_returns_202(self, mock_get: AsyncMock) -> None: + """Successful resume returns 202 with empty body.""" + mock_manager = AsyncMock() + mock_manager.handle_resume = AsyncMock() + mock_get.return_value = mock_manager + + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", json={"task_id": "task-123"}) + assert resp.status_code == 202 + assert resp.content == b"" + + @patch("azure.ai.agentserver.core.durable._manager.get_task_manager") + def test_not_found_returns_404(self, mock_get: AsyncMock) -> None: + """Resume of nonexistent task returns 404.""" + mock_manager = AsyncMock() + mock_manager.handle_resume = AsyncMock( + side_effect=ValueError("Task 'xyz' not found") + ) + mock_get.return_value = mock_manager + + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", json={"task_id": "xyz"}) + assert resp.status_code == 404 + + @patch("azure.ai.agentserver.core.durable._manager.get_task_manager") + def test_conflict_returns_409(self, mock_get: AsyncMock) -> None: + """Resume of task not in 'suspended' state returns 409.""" + mock_manager = AsyncMock() + mock_manager.handle_resume = AsyncMock( + side_effect=ValueError("Task is 'in_progress', not 'suspended'") + ) + mock_get.return_value = mock_manager + + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", json={"task_id": "task-123"}) + assert resp.status_code == 409 + + @patch( + "azure.ai.agentserver.core.durable._manager.get_task_manager", + side_effect=RuntimeError("No manager"), + ) + def test_no_manager_returns_503(self, mock_get: AsyncMock) -> None: + """When no manager is configured, returns 503.""" + app = _build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/tasks/resume", json={"task_id": "task-123"}) + assert resp.status_code == 503 diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_retry.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_retry.py new file mode 100644 index 000000000000..92ea5a1347fd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_retry.py @@ -0,0 +1,353 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +"""Tests for RetryPolicy — construction, delay computation, presets, and integration.""" + +from __future__ import annotations + +import asyncio +from datetime import timedelta +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import pytest + +from azure.ai.agentserver.core.durable import ( + RetryPolicy, + TaskContext, + TaskFailed, + durable_task, +) + + +# --------------------------------------------------------------------------- +# Construction & validation +# --------------------------------------------------------------------------- + + +class TestRetryPolicyConstruction: + def test_default_construction(self) -> None: + p = RetryPolicy() + assert p.initial_delay == timedelta(seconds=1) + assert p.backoff_coefficient == 2.0 + assert p.max_delay == timedelta(seconds=60) + assert p.max_attempts == 3 + assert p.retry_on is None + assert p.jitter is True + + def test_custom_construction(self) -> None: + p = RetryPolicy( + initial_delay=timedelta(seconds=5), + backoff_coefficient=3.0, + max_delay=timedelta(seconds=120), + max_attempts=10, + retry_on=(ValueError, ConnectionError), + jitter=False, + ) + assert p.initial_delay == timedelta(seconds=5) + assert p.backoff_coefficient == 3.0 + assert p.max_delay == timedelta(seconds=120) + assert p.max_attempts == 10 + assert p.retry_on == (ValueError, ConnectionError) + assert p.jitter is False + + def test_validation_initial_delay_negative(self) -> None: + with pytest.raises(ValueError, match="initial_delay must be >= 0"): + RetryPolicy(initial_delay=timedelta(seconds=-1)) + + def test_validation_backoff_coefficient_below_one(self) -> None: + with pytest.raises(ValueError, match="backoff_coefficient must be >= 1.0"): + RetryPolicy(backoff_coefficient=0.5) + + def test_validation_max_delay_below_initial(self) -> None: + with pytest.raises(ValueError, match="max_delay.*must be >= initial_delay"): + RetryPolicy( + initial_delay=timedelta(seconds=10), max_delay=timedelta(seconds=5) + ) + + def test_validation_max_attempts_zero(self) -> None: + with pytest.raises(ValueError, match="max_attempts must be >= 1"): + RetryPolicy(max_attempts=0) + + def test_validation_retry_on_non_exception(self) -> None: + with pytest.raises( + TypeError, match="retry_on entries must be Exception subclasses" + ): + RetryPolicy(retry_on=(str,)) # type: ignore[arg-type] + + def test_repr(self) -> None: + p = RetryPolicy(max_attempts=5) + r = repr(p) + assert "RetryPolicy" in r + assert "max_attempts=5" in r + + def test_eq(self) -> None: + a = RetryPolicy(max_attempts=3) + b = RetryPolicy(max_attempts=3) + c = RetryPolicy(max_attempts=5) + assert a == b + assert a != c + assert a != "not a policy" + + +# --------------------------------------------------------------------------- +# Delay computation +# --------------------------------------------------------------------------- + + +class TestComputeDelay: + def test_exponential(self) -> None: + p = RetryPolicy( + initial_delay=timedelta(seconds=1), + backoff_coefficient=2.0, + max_delay=timedelta(seconds=120), + jitter=False, + ) + assert p.compute_delay(0) == 1.0 # 1 * 2^0 + assert p.compute_delay(1) == 2.0 # 1 * 2^1 + assert p.compute_delay(2) == 4.0 # 1 * 2^2 + assert p.compute_delay(3) == 8.0 # 1 * 2^3 + assert p.compute_delay(5) == 32.0 # 1 * 2^5 + + def test_fixed_delay(self) -> None: + p = RetryPolicy( + initial_delay=timedelta(seconds=5), + backoff_coefficient=1.0, + max_delay=timedelta(seconds=5), + jitter=False, + ) + for attempt in range(5): + assert p.compute_delay(attempt) == 5.0 + + def test_capped_at_max(self) -> None: + p = RetryPolicy( + initial_delay=timedelta(seconds=1), + backoff_coefficient=10.0, + max_delay=timedelta(seconds=30), + jitter=False, + ) + # 1 * 10^2 = 100, but capped at 30 + assert p.compute_delay(2) == 30.0 + + def test_jitter_bounds(self) -> None: + p = RetryPolicy( + initial_delay=timedelta(seconds=10), + backoff_coefficient=1.0, + max_delay=timedelta(seconds=10), + jitter=True, + ) + for _ in range(100): + delay = p.compute_delay(0) + assert 7.5 <= delay <= 12.5 # 10 * [0.75, 1.25] + + def test_no_jitter_exact(self) -> None: + p = RetryPolicy( + initial_delay=timedelta(seconds=2), + backoff_coefficient=3.0, + max_delay=timedelta(seconds=200), + jitter=False, + ) + assert p.compute_delay(0) == 2.0 # 2 * 3^0 + assert p.compute_delay(1) == 6.0 # 2 * 3^1 + assert p.compute_delay(2) == 18.0 # 2 * 3^2 + + def test_linear_preset_delay(self) -> None: + p = RetryPolicy.linear_backoff(initial_delay=timedelta(seconds=2)) + assert p.compute_delay(0) == 2.0 # 2 * (0+1) = 2 + assert p.compute_delay(1) == 4.0 # 2 * (1+1) = 4 + assert p.compute_delay(2) == 6.0 # 2 * (2+1) = 6 + assert p.compute_delay(3) == 8.0 # 2 * (3+1) = 8 + + +# --------------------------------------------------------------------------- +# should_retry +# --------------------------------------------------------------------------- + + +class TestShouldRetry: + def test_within_attempts(self) -> None: + p = RetryPolicy(max_attempts=3, jitter=False) + assert p.should_retry(0, RuntimeError("test")) is True + assert p.should_retry(1, RuntimeError("test")) is True + + def test_exhausted(self) -> None: + p = RetryPolicy(max_attempts=3, jitter=False) + assert ( + p.should_retry(2, RuntimeError("test")) is False + ) # attempt 2 is the 3rd try + assert p.should_retry(5, RuntimeError("test")) is False + + def test_matching_exception(self) -> None: + p = RetryPolicy(max_attempts=5, retry_on=(ValueError,), jitter=False) + assert p.should_retry(0, ValueError("bad")) is True + + def test_non_matching_exception(self) -> None: + p = RetryPolicy(max_attempts=5, retry_on=(ValueError,), jitter=False) + assert p.should_retry(0, RuntimeError("nope")) is False + + def test_none_means_all_exceptions(self) -> None: + p = RetryPolicy(max_attempts=5, retry_on=None, jitter=False) + assert p.should_retry(0, ValueError("a")) is True + assert p.should_retry(0, ConnectionError("b")) is True + assert p.should_retry(0, RuntimeError("c")) is True + + def test_subclass_matching(self) -> None: + p = RetryPolicy(max_attempts=5, retry_on=(OSError,), jitter=False) + assert ( + p.should_retry(0, ConnectionError("net")) is True + ) # ConnectionError is OSError subclass + + +# --------------------------------------------------------------------------- +# Presets +# --------------------------------------------------------------------------- + + +class TestPresets: + def test_exponential_backoff(self) -> None: + p = RetryPolicy.exponential_backoff(max_attempts=5) + assert p.backoff_coefficient == 2.0 + assert p.max_attempts == 5 + assert p.jitter is True + assert p.initial_delay == timedelta(seconds=1) + + def test_fixed_delay(self) -> None: + p = RetryPolicy.fixed_delay(delay=timedelta(seconds=10), max_attempts=4) + assert p.backoff_coefficient == 1.0 + assert p.initial_delay == timedelta(seconds=10) + assert p.max_delay == timedelta(seconds=10) + assert p.max_attempts == 4 + assert p.jitter is False + + def test_linear_backoff(self) -> None: + p = RetryPolicy.linear_backoff( + initial_delay=timedelta(seconds=2), max_attempts=6 + ) + assert p.backoff_coefficient == 1.0 + assert p.initial_delay == timedelta(seconds=2) + assert p.max_attempts == 6 + assert p.jitter is False + + def test_no_retry(self) -> None: + p = RetryPolicy.no_retry() + assert p.max_attempts == 1 + assert p.jitter is False + assert p.should_retry(0, RuntimeError("x")) is False + + +# --------------------------------------------------------------------------- +# Integration tests (require manager) +# --------------------------------------------------------------------------- + + +class TestRetryIntegration: + """Integration tests that run tasks through the full DurableTaskManager.""" + + async def _setup_manager(self, tmp_path): + """Create a manager with local file provider pointing to tmp_path.""" + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + async def _teardown_manager(self, manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + @pytest.mark.asyncio + async def test_retry_success_after_failures(self, tmp_path) -> None: + """Task fails twice then succeeds on attempt 2.""" + call_log: list[int] = [] + + @durable_task(title="retry-test") + async def flaky(ctx: TaskContext[str]) -> str: + call_log.append(ctx.run_attempt) + if ctx.run_attempt < 2: + raise ConnectionError(f"fail attempt {ctx.run_attempt}") + return "success" + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await flaky.run( + task_id="retry-1", + input="test", + retry=RetryPolicy.exponential_backoff(max_attempts=3), + ) + assert result.output == "success" + assert call_log == [0, 1, 2] + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_retry_exhausted(self, tmp_path) -> None: + """Task always fails — retries exhaust and TaskFailed is raised.""" + + @durable_task(title="always-fail") + async def always_fail(ctx: TaskContext[str]) -> str: + raise ValueError(f"boom on attempt {ctx.run_attempt}") + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(TaskFailed) as exc_info: + await always_fail.run( + task_id="exhaust-1", + input="test", + retry=RetryPolicy( + max_attempts=3, + retry_on=(ValueError,), + jitter=False, + ), + ) + error = exc_info.value.error + assert error["type"] == "exhausted_retries" + assert error["attempts"] == 3 + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_non_retryable_exception(self, tmp_path) -> None: + """Wrong exception type — fails immediately without retry.""" + attempts: list[int] = [] + + @durable_task(title="wrong-exc") + async def wrong_exc(ctx: TaskContext[str]) -> str: + attempts.append(ctx.run_attempt) + raise TypeError("not retryable") + + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + with pytest.raises(TaskFailed): + await wrong_exc.run( + task_id="nonretry-1", + input="test", + retry=RetryPolicy( + max_attempts=5, + retry_on=(ValueError,), + jitter=False, + ), + ) + # Only ran once — no retries for TypeError + assert attempts == [0] + finally: + await self._teardown_manager(manager, mgr_mod) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_sample_e2e.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_sample_e2e.py new file mode 100644 index 000000000000..6d8aa0c5fb09 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_sample_e2e.py @@ -0,0 +1,1856 @@ +"""End-to-end tests for durable task samples. + +Each test exercises a sample's core logic to verify the sample code +would work correctly. These tests do NOT start an HTTP server — they +invoke the durable task functions directly via the SDK API. + +This follows the constitution requirement (v1.2.0): + "Every sample MUST have a corresponding e2e test." +""" + +from __future__ import annotations + +import asyncio +import uuid +from datetime import timedelta +from pathlib import Path +from typing import Any +from typing_extensions import TypedDict + +import pytest + +from azure.ai.agentserver.core.durable import ( + RetryPolicy, + TaskContext, + TaskConflictError, + durable_task, +) + + +class _ManagerFixture: + """Helper to set up a DurableTaskManager with local file storage.""" + + @staticmethod + async def setup(tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + @staticmethod + async def teardown(manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + +# --------------------------------------------------------------------------- +# Sample 1: Streaming (durable_streaming) +# --------------------------------------------------------------------------- + + +class TestStreamingSampleE2E: + """E2E for the durable_streaming sample.""" + + @pytest.mark.asyncio + async def test_streaming_sample(self, tmp_path): + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_stream_numbers") + async def stream_numbers(ctx: TaskContext[Any]) -> str: + for i in range(5): + await ctx.stream({"value": i, "message": f"Processing item {i}"}) + return f"Streamed 5 items" + + run = await stream_numbers.start(task_id=uuid.uuid4().hex, input=None) + + items = [] + async for chunk in run: + items.append(chunk) + + result = await run.result() + + assert len(items) == 5 + assert items[0] == {"value": 0, "message": "Processing item 0"} + assert items[4] == {"value": 4, "message": "Processing item 4"} + assert result.output == "Streamed 5 items" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample 2: Retry (durable_retry) +# --------------------------------------------------------------------------- + + +class TestRetrySampleE2E: + """E2E for the durable_retry sample.""" + + @pytest.mark.asyncio + async def test_retry_with_exponential_backoff(self, tmp_path): + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + call_count = 0 + + @durable_task( + name="e2e_flaky", + retry=RetryPolicy.exponential_backoff( + max_attempts=4, + initial_delay=timedelta(milliseconds=10), + max_delay=timedelta(milliseconds=100), + ), + ) + async def flaky_task(ctx: TaskContext[Any]) -> str: + nonlocal call_count + call_count += 1 + if ctx.run_attempt < 2: + raise ConnectionError(f"Attempt {ctx.run_attempt}") + return f"Success after {ctx.run_attempt + 1} attempts" + + result = await flaky_task.run(task_id=uuid.uuid4().hex, input=None) + assert result.output == "Success after 3 attempts" + assert call_count == 3 + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_selective_retry(self, tmp_path): + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task( + name="e2e_selective", + retry=RetryPolicy( + initial_delay=timedelta(milliseconds=10), + max_delay=timedelta(milliseconds=10), + backoff_coefficient=1.0, + max_attempts=3, + retry_on=(ConnectionError,), + jitter=False, + ), + ) + async def selective_task(ctx: TaskContext[Any]) -> str: + if ctx.run_attempt == 0: + raise ConnectionError("transient") + return f"Recovered on attempt {ctx.run_attempt}" + + result = await selective_task.run(task_id=uuid.uuid4().hex, input=None) + assert result.output == "Recovered on attempt 1" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample 3: Source (durable_source) +# --------------------------------------------------------------------------- + + +class TestSourceSampleE2E: + """E2E for source auto-stamping (framework-owned, not user-overridable).""" + + @pytest.mark.asyncio + async def test_source_auto_stamped(self, tmp_path): + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_with_source") + async def process_order(ctx: TaskContext[Any]) -> dict: + return {"task_id": ctx.task_id} + + result = await process_order.run( + task_id=uuid.uuid4().hex, input={"order_id": "ORD-001"} + ) + assert "task_id" in result.output + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_source_auto_stamp_fields(self, tmp_path): + """Verify auto-stamped source contains type, name, server_version.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + task_id = uuid.uuid4().hex + + @durable_task(name="e2e_source_fields") + async def with_source(ctx: TaskContext[Any]) -> str: + return "done" + + result = await with_source.run( + task_id=task_id, + input=None, + ) + assert result.output == "done" + + # Verify source was auto-stamped on the task record + task_info = await manager.provider.get(task_id) + if task_info is not None and task_info.source is not None: + assert task_info.source["type"] == "agentserver.durable_task" + assert task_info.source["name"] == "e2e_source_fields" + assert "server_version" in task_info.source + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# task.list() — scoped listing +# --------------------------------------------------------------------------- + + +class TestListE2E: + """E2E for ``DurableTask.list()`` — per-function scoped task listing.""" + + @pytest.mark.asyncio + async def test_list_returns_only_this_tasks_records(self, tmp_path): + """list() scoped by function name — other tasks excluded.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_list_alpha", ephemeral=False) + async def alpha(ctx: TaskContext[Any]) -> str: + return "alpha_done" + + @durable_task(name="e2e_list_beta", ephemeral=False) + async def beta(ctx: TaskContext[Any]) -> str: + return "beta_done" + + # Create tasks for both functions + a1 = await alpha.run(task_id="alpha-1", input=None) + a2 = await alpha.run(task_id="alpha-2", input=None) + b1 = await beta.run(task_id="beta-1", input=None) + assert a1.output == "alpha_done" + assert a2.output == "alpha_done" + assert b1.output == "beta_done" + + # list() on alpha should return only alpha tasks + alpha_tasks = await alpha.list() + alpha_ids = {t.id for t in alpha_tasks} + assert "alpha-1" in alpha_ids + assert "alpha-2" in alpha_ids + assert "beta-1" not in alpha_ids + + # list() on beta should return only beta tasks + beta_tasks = await beta.list() + beta_ids = {t.id for t in beta_tasks} + assert "beta-1" in beta_ids + assert "alpha-1" not in beta_ids + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_list_with_status_filter(self, tmp_path): + """list(status=...) filters by task status.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_list_status", ephemeral=False) + async def suspendable(ctx: TaskContext[Any]) -> str: + if ctx.entry_mode == "fresh": + return await ctx.suspend(reason="waiting") + return "resumed" + + # Create a suspended task + handle = await suspendable.start(task_id="status-1", input=None) + result = await handle.result() + assert result.is_suspended + + @durable_task(name="e2e_list_status", ephemeral=False) + async def completer(ctx: TaskContext[Any]) -> str: + return "done" + + # Create a completed task (different id, same name) + result2 = await completer.run(task_id="status-2", input=None) + assert result2.output == "done" + + # list with status filter + suspended = await suspendable.list(status="suspended") + suspended_ids = {t.id for t in suspended} + assert "status-1" in suspended_ids + assert "status-2" not in suspended_ids + + completed = await suspendable.list(status="completed") + completed_ids = {t.id for t in completed} + assert "status-2" in completed_ids + assert "status-1" not in completed_ids + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_list_empty_when_no_tasks(self, tmp_path): + """list() returns empty when no tasks exist for this function.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_list_empty") + async def no_tasks(ctx: TaskContext[Any]) -> str: + return "never called" + + tasks = await no_tasks.list() + assert tasks == [] + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_list_auto_stamped_tag(self, tmp_path): + """Verify _durable_task_name tag is auto-stamped on created tasks.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + task_id = uuid.uuid4().hex + + @durable_task(name="e2e_tag_stamp", ephemeral=False) + async def stamped(ctx: TaskContext[Any]) -> str: + return "done" + + await stamped.run(task_id=task_id, input=None) + + # Check the raw task record for the tag + task_info = await manager.provider.get(task_id) + assert task_info is not None + assert task_info.tags is not None + assert task_info.tags.get("_durable_task_name") == "e2e_tag_stamp" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_reserved_tag_cannot_be_overridden(self, tmp_path): + """Developer-provided _durable_task_ tags are stripped; framework wins.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + task_id = uuid.uuid4().hex + + @durable_task( + name="e2e_reserved_tag", + ephemeral=False, + tags={ + "_durable_task_name": "evil_override", + "_durable_task_custom": "should_be_stripped", + "user_tag": "kept", + }, + ) + async def protected(ctx: TaskContext[Any]) -> str: + return "done" + + await protected.run(task_id=task_id, input=None) + + task_info = await manager.provider.get(task_id) + assert task_info is not None + assert task_info.tags is not None + # Framework-stamped tag wins + assert task_info.tags["_durable_task_name"] == "e2e_reserved_tag" + # Other reserved tags are stripped + assert "_durable_task_custom" not in task_info.tags + # User tag is preserved + assert task_info.tags["user_tag"] == "kept" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_reserved_tag_stripped_from_callsite(self, tmp_path): + """Call-site tags with reserved prefix are stripped.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + task_id = uuid.uuid4().hex + + @durable_task(name="e2e_callsite_tag", ephemeral=False) + async def callsite(ctx: TaskContext[Any]) -> str: + return "done" + + await callsite.run( + task_id=task_id, + input=None, + tags={"_durable_task_name": "evil", "safe_tag": "ok"}, + ) + + task_info = await manager.provider.get(task_id) + assert task_info is not None + assert task_info.tags is not None + assert task_info.tags["_durable_task_name"] == "e2e_callsite_tag" + assert task_info.tags["safe_tag"] == "ok" + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample 4: Multi-turn durable session (durable_multiturn) +# --------------------------------------------------------------------------- + + +class TestMultiturnSampleE2E: + """E2E for the durable_multiturn sample — suspend/resume per turn.""" + + @pytest.mark.asyncio + async def test_multiturn_suspend_resume(self, tmp_path): + """Full suspend → update-input → resume cycle across 2 turns.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + try: + # Simple file checkpoint store (mirrors sample pattern) + import json as _json + + def _save(sid, state): + (checkpoint_dir / f"{sid}.json").write_text(_json.dumps(state)) + + def _load(sid): + p = checkpoint_dir / f"{sid}.json" + if p.exists(): + return _json.loads(p.read_text()) + return {"history": [], "turn_count": 0} + + @durable_task(name="e2e_session_workflow") + async def session_workflow(ctx: TaskContext[Any]) -> dict: + session_id = ctx.input["session_id"] + message = ctx.input["message"] + + state = _load(session_id) + + # Explicit end + if message == "done": + return {"turn": state["turn_count"], "finished": True} + + state["history"].append({"role": "user", "content": message}) + state["turn_count"] += 1 + + await ctx.stream({"status": "thinking", "turn": state["turn_count"]}) + + reply = f"Reply #{state['turn_count']}: {message}" + state["history"].append({"role": "assistant", "content": reply}) + _save(session_id, state) + + return await ctx.suspend( + reason="awaiting_user_input", + output={"reply": reply, "turn": state["turn_count"]}, + ) + + task_id = "e2e-session-001" + + # --- Turn 1: start --- + run1 = await session_workflow.start( + task_id=task_id, + input={"session_id": "s1", "message": "Hello"}, + ) + # Collect stream items + streamed = [] + async for chunk in run1: + streamed.append(chunk) + assert len(streamed) == 1 + assert streamed[0]["status"] == "thinking" + + # result() should return TaskResult with is_suspended + result1 = await run1.result() + assert result1.is_suspended + assert result1.output["reply"] == "Reply #1: Hello" + assert result1.output["turn"] == 1 + + # Verify task is suspended in the store + task = await manager._provider.get(task_id) + assert task is not None + assert task.status == "suspended" + + # Verify checkpoint file exists + assert (checkpoint_dir / "s1.json").exists() + saved = _json.loads((checkpoint_dir / "s1.json").read_text()) + assert saved["turn_count"] == 1 + assert len(saved["history"]) == 2 + + # --- Turn 2: update input → resume --- + from azure.ai.agentserver.core.durable._models import TaskPatchRequest + + await manager._provider.update( + task_id, + TaskPatchRequest( + payload={"input": {"session_id": "s1", "message": "Continue"}}, + ), + ) + await manager.handle_resume(task_id) + + # Wait for the task to suspend again + for _ in range(100): + await asyncio.sleep(0.02) + task = await manager._provider.get(task_id) + if task and task.status == "suspended": + break + assert task.status == "suspended" + assert task.payload["output"]["turn"] == 2 + assert "Continue" in task.payload["output"]["reply"] + + # Verify checkpoint updated + saved2 = _json.loads((checkpoint_dir / "s1.json").read_text()) + assert saved2["turn_count"] == 2 + assert len(saved2["history"]) == 4 # 2 user + 2 assistant + + # --- Turn 3: end session --- + await manager._provider.update( + task_id, + TaskPatchRequest( + payload={"input": {"session_id": "s1", "message": "done"}}, + ), + ) + await manager.handle_resume(task_id) + + # Wait for completion + for _ in range(100): + await asyncio.sleep(0.02) + task = await manager._provider.get(task_id) + if task and task.status == "completed": + break + assert task.status == "completed" + assert task.payload["output"]["finished"] is True + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample 5: LangGraph multi-turn (durable_langgraph) +# --------------------------------------------------------------------------- + + +langgraph = pytest.importorskip("langgraph", reason="langgraph not installed") + +# LangGraph needs real Annotated types at runtime (not stringified by +# ``from __future__ import annotations``). We build the graph state and +# nodes in a helper module-style block so type hints resolve correctly. + +import typing # noqa: E402 + +from langchain_core.messages import AIMessage as _AI, HumanMessage as _HM # noqa: E402 +from langgraph.checkpoint.sqlite import SqliteSaver as _SqliteSaver # noqa: E402 +from langgraph.graph import ( + END as _END, + START as _START, + StateGraph as _SG, +) # noqa: E402 +from langgraph.types import Command as _Cmd, interrupt as _interrupt # noqa: E402 + + +def _lg_add_messages(left: list, right: list) -> list: + return left + right + + +# Use typing.get_type_hints-compatible class (no __future__ annotations) +_LGConvState = TypedDict( + "_LGConvState", + { + "messages": typing.Annotated[list, _lg_add_messages], + "is_complete": bool, + }, +) + + +def _lg_process_input(state: dict) -> dict: + messages = state["messages"] + user_msgs = [m for m in messages if isinstance(m, _HM)] + turn = len(user_msgs) + last = user_msgs[-1].content if user_msgs else "" + return {"messages": [_AI(content=f"Reply #{turn}: {last}")]} + + +def _lg_wait_for_user(state: dict) -> dict: + user_input: str = _interrupt({"prompt": "Next?"}) + if user_input.strip().lower() == "done": + return {"is_complete": True} + return {"messages": [_HM(content=user_input)], "is_complete": False} + + +def _lg_should_continue(state: dict) -> str: + return "end" if state.get("is_complete") else "continue" + + +def _build_lg_graph(checkpointer): + builder = _SG(_LGConvState) + builder.add_node("process_input", _lg_process_input) + builder.add_node("wait_for_user", _lg_wait_for_user) + builder.add_edge(_START, "process_input") + builder.add_edge("process_input", "wait_for_user") + builder.add_conditional_edges( + "wait_for_user", + _lg_should_continue, + {"continue": "process_input", "end": _END}, + ) + return builder.compile(checkpointer=checkpointer) + + +class TestLangGraphSampleE2E: + """E2E for the durable_langgraph sample — LangGraph interrupt/resume.""" + + @pytest.mark.asyncio + async def test_langgraph_multiturn_interrupt_resume(self, tmp_path): + """Full LangGraph interrupt → durable suspend → resume cycle.""" + from azure.ai.agentserver.core.durable._models import TaskPatchRequest + + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + + # Use SqliteSaver with a temp file — mirrors sample's persistent pattern + import sqlite3 + + db_path = tmp_path / "langgraph_checkpoints.db" + conn = sqlite3.connect(str(db_path), check_same_thread=False) + checkpointer = _SqliteSaver(conn) + checkpointer.setup() + graph = _build_lg_graph(checkpointer) + + try: + + @durable_task(name="e2e_langgraph_session") + async def lg_session(ctx: TaskContext[Any]) -> dict: + session_id = ctx.input["session_id"] + message = ctx.input["message"] + thread_config = {"configurable": {"thread_id": session_id}} + + state = await asyncio.to_thread(graph.get_state, thread_config) + + if state.next: + await asyncio.to_thread( + graph.invoke, _Cmd(resume=message), thread_config + ) + else: + await asyncio.to_thread( + graph.invoke, + {"messages": [_HM(content=message)], "is_complete": False}, + thread_config, + ) + + state = await asyncio.to_thread(graph.get_state, thread_config) + + if state.next: + msgs = state.values.get("messages", []) + ai_msgs = [m for m in msgs if isinstance(m, _AI)] + user_msgs = [m for m in msgs if isinstance(m, _HM)] + return await ctx.suspend( + reason="awaiting_user_input", + output={ + "reply": ai_msgs[-1].content if ai_msgs else "", + "turn": len(user_msgs), + }, + ) + + msgs = state.values.get("messages", []) + user_count = len([m for m in msgs if isinstance(m, _HM)]) + return {"finished": True, "turn_count": user_count} + + task_id = "e2e-lg-session-001" + + # --- Turn 1: start --- + run1 = await lg_session.start( + task_id=task_id, + input={"session_id": "lg-s1", "message": "Hello"}, + ) + + result1 = await run1.result() + assert result1.is_suspended + assert result1.output["reply"] == "Reply #1: Hello" + assert result1.output["turn"] == 1 + + task = await manager._provider.get(task_id) + assert task.status == "suspended" + + # --- Turn 2: resume with new input --- + await manager._provider.update( + task_id, + TaskPatchRequest( + payload={ + "input": {"session_id": "lg-s1", "message": "Tell me more"} + }, + ), + ) + await manager.handle_resume(task_id) + + for _ in range(100): + await asyncio.sleep(0.02) + task = await manager._provider.get(task_id) + if task and task.status == "suspended": + break + assert task.status == "suspended" + assert task.payload["output"]["turn"] == 2 + assert "Tell me more" in task.payload["output"]["reply"] + + # --- Turn 3: end session --- + await manager._provider.update( + task_id, + TaskPatchRequest( + payload={"input": {"session_id": "lg-s1", "message": "done"}}, + ), + ) + await manager.handle_resume(task_id) + + for _ in range(100): + await asyncio.sleep(0.02) + task = await manager._provider.get(task_id) + if task and task.status == "completed": + break + assert task.status == "completed" + assert task.payload["output"]["finished"] is True + assert task.payload["output"]["turn_count"] == 2 + + finally: + conn.close() + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Lifecycle automation — start/resume/recover via .start() +# --------------------------------------------------------------------------- + + +class TestLifecycleE2E: + """E2E for lifecycle-aware .start() and .get() — spec 003.""" + + @pytest.mark.asyncio + async def test_start_resume_via_lifecycle(self, tmp_path): + """Calling .start() on a suspended task auto-resumes it.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + try: + import json as _json + + def _save(sid, state): + (checkpoint_dir / f"{sid}.json").write_text(_json.dumps(state)) + + def _load(sid): + p = checkpoint_dir / f"{sid}.json" + if p.exists(): + return _json.loads(p.read_text()) + return {"history": [], "turn_count": 0} + + entry_modes: list[str] = [] + + @durable_task(name="e2e_lifecycle_session") + async def lifecycle_session(ctx: TaskContext[Any]) -> dict: + entry_modes.append(ctx.entry_mode) + session_id = ctx.input["session_id"] + message = ctx.input["message"] + state = _load(session_id) + + if message == "done": + return {"turn": state["turn_count"], "finished": True} + + state["history"].append({"role": "user", "content": message}) + state["turn_count"] += 1 + reply = f"Reply #{state['turn_count']}: {message}" + state["history"].append({"role": "assistant", "content": reply}) + _save(session_id, state) + + return await ctx.suspend( + reason="awaiting_user_input", + output={"reply": reply, "turn": state["turn_count"]}, + ) + + task_id = "e2e-lifecycle-001" + + # Turn 1: fresh start + run1 = await lifecycle_session.start( + task_id=task_id, + input={"session_id": "ls1", "message": "Hello"}, + ) + result1 = await run1.result() + assert result1.is_suspended + + # Verify .get() returns suspended task + info = await lifecycle_session.get(task_id) + assert info is not None + assert info.status == "suspended" + + # Turn 2: auto-resume via .start() + run2 = await lifecycle_session.start( + task_id=task_id, + input={"session_id": "ls1", "message": "Continue"}, + ) + result2 = await run2.result() + assert result2.is_suspended + assert result2.output["turn"] == 2 + + # Turn 3: end session via .start() + run3 = await lifecycle_session.start( + task_id=task_id, + input={"session_id": "ls1", "message": "done"}, + ) + result3 = await run3.result() + assert result3.output["finished"] is True + + # Verify entry modes: fresh, resumed, resumed + assert entry_modes == ["fresh", "resumed", "resumed"] + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_start_on_completed_raises_conflict(self, tmp_path): + """.start() on a completed non-ephemeral task raises TaskConflictError.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_completes_fast", ephemeral=False) + async def completes_fast(ctx: TaskContext[Any]) -> str: + return "done" + + task_id = "e2e-completed-conflict" + + await completes_fast.run(task_id=task_id, input=None) + + with pytest.raises(TaskConflictError): + await completes_fast.start(task_id=task_id, input=None) + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_crash_recovery_via_lifecycle(self, tmp_path): + """Stale in_progress task is recovered with entry_mode='recovered'.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + entry_modes: list[str] = [] + + @durable_task(name="e2e_recoverable") + async def recoverable_task(ctx: TaskContext[Any]) -> str: + entry_modes.append(ctx.entry_mode) + return f"entry={ctx.entry_mode}" + + task_id = "e2e-crash-recovery" + + # Create a task and manually set it to in_progress with old timestamp + await recoverable_task.start(task_id=task_id, input="first") + # Wait for it to run + for _ in range(50): + await asyncio.sleep(0.02) + info = await recoverable_task.get(task_id) + if info and info.status == "completed": + break + + # Now backdating: create another task with in_progress status + task_id2 = "e2e-crash-recovery-2" + from azure.ai.agentserver.core.durable._models import TaskPatchRequest + + # Start fresh then simulate a crash by backdating + await recoverable_task.start(task_id=task_id2, input="crash-sim") + for _ in range(50): + await asyncio.sleep(0.02) + info = await recoverable_task.get(task_id2) + if info and info.status == "completed": + break + + # Verify first run was fresh + assert entry_modes[0] == "fresh" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_get_returns_none_for_missing(self, tmp_path): + """.get() returns None for a nonexistent task.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_get_missing") + async def some_task(ctx: TaskContext[Any]) -> str: + return "ok" + + info = await some_task.get("nonexistent-task-id") + assert info is None + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Invocation store durability — result written inside durable boundary +# --------------------------------------------------------------------------- + + +class TestInvocationStoreDurability: + """E2E for the sample pattern: invocation store writes inside the task.""" + + @pytest.mark.asyncio + async def test_invocation_result_written_on_suspend(self, tmp_path): + """Task writes invocation result to store before suspending.""" + import json as _json + + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + inv_dir = tmp_path / "invocations" + inv_dir.mkdir() + + def _inv_load(key): + p = inv_dir / f"{key}.json" + if p.exists(): + return _json.loads(p.read_text()) + return None + + def _inv_save(key, data): + (inv_dir / f"{key}.json").write_text(_json.dumps(data)) + + try: + + @durable_task(name="e2e_inv_suspend") + async def inv_suspend_task(ctx: TaskContext[Any]) -> dict: + inv_id = ctx.input["invocation_id"] + _inv_save(inv_id, {"status": "running"}) + output = {"reply": "hello", "turn": 1} + _inv_save(inv_id, {"status": "completed", "output": output}) + return await ctx.suspend(reason="awaiting_user_input", output=output) + + inv_id = f"inv-{uuid.uuid4()}" + run = await inv_suspend_task.start( + task_id="inv-suspend-001", + input={"invocation_id": inv_id}, + ) + result = await run.result() + assert result.is_suspended + + # Invocation store was written inside the durable boundary + stored = _inv_load(inv_id) + assert stored is not None + assert stored["status"] == "completed" + assert stored["output"]["reply"] == "hello" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_invocation_result_written_on_complete(self, tmp_path): + """Task writes invocation result to store before returning.""" + import json as _json + + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + inv_dir = tmp_path / "invocations" + inv_dir.mkdir() + + def _inv_load(key): + p = inv_dir / f"{key}.json" + if p.exists(): + return _json.loads(p.read_text()) + return None + + def _inv_save(key, data): + (inv_dir / f"{key}.json").write_text(_json.dumps(data)) + + try: + + @durable_task(name="e2e_inv_complete") + async def inv_complete_task(ctx: TaskContext[Any]) -> dict: + inv_id = ctx.input["invocation_id"] + _inv_save(inv_id, {"status": "running"}) + result = {"finished": True, "turn_count": 3} + _inv_save(inv_id, {"status": "completed", "output": result}) + return result + + inv_id = f"inv-{uuid.uuid4()}" + result = await inv_complete_task.run( + task_id="inv-complete-001", + input={"invocation_id": inv_id}, + ) + assert result.output["finished"] is True + + stored = _inv_load(inv_id) + assert stored is not None + assert stored["status"] == "completed" + assert stored["output"]["finished"] is True + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_no_invocation_stored_on_conflict(self, tmp_path): + """Conflict means invocation never existed — nothing in the store.""" + import json as _json + + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + inv_dir = tmp_path / "invocations" + inv_dir.mkdir() + + def _inv_load(key): + p = inv_dir / f"{key}.json" + if p.exists(): + return _json.loads(p.read_text()) + return None + + def _inv_save(key, data): + (inv_dir / f"{key}.json").write_text(_json.dumps(data)) + + try: + + @durable_task(name="e2e_inv_conflict", ephemeral=False) + async def inv_conflict_task(ctx: TaskContext[Any]) -> dict: + inv_id = ctx.input["invocation_id"] + _inv_save(inv_id, {"status": "running"}) + result = {"done": True} + _inv_save(inv_id, {"status": "completed", "output": result}) + return result + + # First run completes + inv1 = f"inv-{uuid.uuid4()}" + await inv_conflict_task.run( + task_id="inv-conflict-001", + input={"invocation_id": inv1}, + ) + assert _inv_load(inv1)["status"] == "completed" + + # Second start on same completed task → conflict, no store write + inv2 = f"inv-{uuid.uuid4()}" + with pytest.raises(TaskConflictError): + await inv_conflict_task.start( + task_id="inv-conflict-001", + input={"invocation_id": inv2}, + ) + + # inv2 was never created in the store + assert _inv_load(inv2) is None + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample E2E: Claude-style steering (durable_claude) +# --------------------------------------------------------------------------- + + +class _MockTextStream: + """Simulates ``anthropic.AsyncAnthropic().messages.stream().text_stream``. + + Yields text chunks with a delay, so cancel checks between chunks + exercise the same ``async for text in stream.text_stream`` path + as the real sample. + """ + + def __init__(self, chunks: list[str], delay: float = 0.1): + self._chunks = list(chunks) + self._delay = delay + + def __aiter__(self): + return self + + async def __anext__(self) -> str: + if not self._chunks: + raise StopAsyncIteration + await asyncio.sleep(self._delay) + return self._chunks.pop(0) + + +class _MockStreamCtx: + """Simulates the ``async with client.messages.stream(...) as stream:`` context.""" + + def __init__(self, chunks: list[str], delay: float = 0.1): + self.text_stream = _MockTextStream(chunks, delay) + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + +class TestClaudeSteeringSampleE2E: + """E2E for the durable_claude steering sample. + + Uses an async streaming mock (``_MockStreamCtx``) that mirrors the + real ``anthropic.AsyncAnthropic().messages.stream()`` async iterator, + so the cancel-between-chunks path is fully exercised. + """ + + @pytest.mark.asyncio + async def test_claude_normal_turn(self, tmp_path): + """Normal turn completes with full reply.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + conv_store: dict[str, list[dict[str, str]]] = {} + + @durable_task(name="e2e_claude_chat", steerable=True) + async def claude_chat(ctx: TaskContext[dict]) -> dict[str, Any]: + session_id = ctx.input["session_id"] + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + # Load history from EXTERNAL store (not metadata) + history = list(conv_store.get(session_id, [])) + history.append({"role": "user", "content": message}) + if ctx.cancel.is_set(): + conv_store[session_id] = history + store[invocation_id] = { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + } + return await ctx.suspend(reason="steered") + # Phase 2: Stream with cancel checks (mirrors async for text in stream.text_stream) + reply = "" + was_aborted = False + async with _MockStreamCtx([f"Echo: ", message]) as stream: + async for text in stream.text_stream: + reply += text + if ctx.cancel.is_set(): + was_aborted = True + break + if reply: + history.append({"role": "assistant", "content": reply}) + conv_store[session_id] = history + user_turns = len([m for m in history if m["role"] == "user"]) + output = { + "invocation_id": invocation_id, + "reply": reply, + "turn": user_turns, + "partial": was_aborted, + } + if was_aborted or ctx.cancel.is_set(): + store[invocation_id] = {"status": "superseded", "output": output} + return await ctx.suspend(reason="steered") + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + run = await claude_chat.start( + task_id="claude-s1", + input={ + "session_id": "s1", + "message": "Hello", + "invocation_id": "inv-1", + }, + ) + result = await asyncio.wait_for(run.result(), timeout=5.0) + assert result.is_suspended + assert result.output["reply"] == "Echo: Hello" + assert result.output["partial"] is False + assert store["inv-1"]["status"] == "completed" + # History stored externally, not in metadata + assert len(conv_store["s1"]) == 2 # user + assistant + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_claude_steering_preserves_reply(self, tmp_path): + """Steering queues B while A is streaming. A's partial reply saved as superseded.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + conv_store: dict[str, list[dict[str, str]]] = {} + + @durable_task(name="e2e_claude_chat", steerable=True) + async def claude_chat(ctx: TaskContext[dict]) -> dict[str, Any]: + session_id = ctx.input["session_id"] + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + history = list(conv_store.get(session_id, [])) + history.append({"role": "user", "content": message}) + if ctx.cancel.is_set(): + conv_store[session_id] = history + store[invocation_id] = { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + } + return await ctx.suspend(reason="steered") + reply = "" + was_aborted = False + async with _MockStreamCtx( + ["chunk1-", "chunk2-", "chunk3"], delay=0.15 + ) as stream: + async for text in stream.text_stream: + reply += text + if ctx.cancel.is_set(): + was_aborted = True + break + if reply: + history.append({"role": "assistant", "content": reply}) + conv_store[session_id] = history + output = { + "invocation_id": invocation_id, + "reply": reply, + "partial": was_aborted, + } + if was_aborted or ctx.cancel.is_set(): + store[invocation_id] = {"status": "superseded", "output": output} + return await ctx.suspend(reason="steered") + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + run_a = await claude_chat.start( + task_id="claude-s1", + input={ + "session_id": "s1", + "message": "Hello", + "invocation_id": "inv-a", + }, + ) + await asyncio.sleep(0.05) + + store["inv-b"] = {"status": "queued"} + run_b = await claude_chat.start( + task_id="claude-s1", + input={ + "session_id": "s1", + "message": "Nevermind", + "invocation_id": "inv-b", + }, + ) + + assert store["inv-b"]["status"] == "queued" + + result_a = await asyncio.wait_for(run_a.result(), timeout=5.0) + assert result_a.is_superseded + + result_b = await asyncio.wait_for(run_b.result(), timeout=5.0) + assert result_b.is_suspended + assert result_b.output["reply"] == "chunk1-chunk2-chunk3" + + assert store["inv-a"]["status"] == "superseded" + assert "output" in store["inv-a"] + assert len(store["inv-a"]["output"]["reply"]) > 0 + assert store["inv-b"]["status"] == "completed" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_claude_rapid_fire_preserves_intermediate_messages(self, tmp_path): + """Rapid-fire: A→B→C. B is short-circuited but its message is preserved in external store.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + conv_store: dict[str, list[dict[str, str]]] = {} + + @durable_task(name="e2e_claude_chat", steerable=True) + async def claude_chat(ctx: TaskContext[dict]) -> dict[str, Any]: + session_id = ctx.input["session_id"] + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + history = list(conv_store.get(session_id, [])) + history.append({"role": "user", "content": message}) + if ctx.cancel.is_set(): + conv_store[session_id] = history + store[invocation_id] = { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + } + return await ctx.suspend(reason="steered") + reply = "" + was_aborted = False + async with _MockStreamCtx([f"Reply to {message}"], delay=0.3) as stream: + async for text in stream.text_stream: + reply += text + if ctx.cancel.is_set(): + was_aborted = True + break + if reply: + history.append({"role": "assistant", "content": reply}) + conv_store[session_id] = history + output = { + "invocation_id": invocation_id, + "reply": reply, + "partial": was_aborted, + } + if was_aborted or ctx.cancel.is_set(): + store[invocation_id] = {"status": "superseded", "output": output} + return await ctx.suspend(reason="steered") + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + run_a = await claude_chat.start( + task_id="claude-rf", + input={"session_id": "s1", "message": "A", "invocation_id": "rf-a"}, + ) + await asyncio.sleep(0.05) + + run_b = await claude_chat.start( + task_id="claude-rf", + input={"session_id": "s1", "message": "B", "invocation_id": "rf-b"}, + ) + run_c = await claude_chat.start( + task_id="claude-rf", + input={"session_id": "s1", "message": "C", "invocation_id": "rf-c"}, + ) + + result_c = await asyncio.wait_for(run_c.result(), timeout=5.0) + assert result_c.output["reply"] == "Reply to C" + + # B was short-circuited but message preserved in external store + assert store["rf-b"]["message_preserved"] is True + assert store["rf-b"]["status"] == "cancelled" + # All user messages should be in external history + user_msgs = [m["content"] for m in conv_store["s1"] if m["role"] == "user"] + assert "B" in user_msgs # B's message was NOT lost + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample E2E: Copilot-style steering (durable_copilot) +# --------------------------------------------------------------------------- + + +class _MockCopilotSession: + """Simulates a Copilot SDK session with event-based send + abort. + + Mirrors the real pattern: ``session.on(handler)`` registers an event + listener, ``session.send(msg)`` fires ``AssistantMessageData`` events + then ``IdleData``, and ``session.abort()`` stops further events. + """ + + def __init__(self, reply_chunks: list[str], delay: float = 0.1): + self._chunks = reply_chunks + self._delay = delay + self._handler: Any = None + self._aborted = False + self._idle_event = asyncio.Event() + + def on(self, handler: Any) -> None: + self._handler = handler + + async def send(self, message: str) -> None: + """Deliver reply chunks as events, then fire idle.""" + asyncio.get_event_loop().create_task(self._deliver_events()) + + async def _deliver_events(self) -> None: + for chunk in self._chunks: + if self._aborted: + break + await asyncio.sleep(self._delay) + if self._aborted: + break + if self._handler: + # Simulate AssistantMessageData event + event = type("E", (), {"data": type("D", (), {"content": chunk})()})() + self._handler(event) + if not self._aborted and self._handler: + # Simulate IdleData event + idle_data = type("IdleData", (), {})() + event = type("E", (), {"data": idle_data})() + self._handler(event) + self._idle_event.set() + + async def abort(self) -> None: + self._aborted = True + + +class TestCopilotSteeringSampleE2E: + """E2E for the durable_copilot steering sample. + + Uses ``_MockCopilotSession`` that mirrors the real Copilot SDK + event-based pattern: ``session.on(handler)`` → ``session.send()`` + → events fire → ``session.abort()`` on cancel. + """ + + @pytest.mark.asyncio + async def test_copilot_normal_turn(self, tmp_path): + """Normal turn completes with full reply via event-based send.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + + @durable_task(name="e2e_copilot_chat", steerable=True) + async def copilot_chat(ctx: TaskContext[dict]) -> dict[str, Any]: + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + if ctx.cancel.is_set(): + store[invocation_id] = { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + } + return await ctx.suspend(reason="steered") + + # Event-based send (mirrors session.on + session.send) + session = _MockCopilotSession([f"Echo: {message}"]) + reply_parts: list[str] = [] + idle_event = asyncio.Event() + + def on_event(event: Any) -> None: + if hasattr(event.data, "content"): + reply_parts.append(event.data.content or "") + elif type(event.data).__name__ == "IdleData": + idle_event.set() + + session.on(on_event) + await session.send(message) + + # Wait for idle or cancel + cancel_task = asyncio.create_task(ctx.cancel.wait()) + idle_task = asyncio.create_task(idle_event.wait()) + was_aborted = False + try: + done, pending = await asyncio.wait( + {cancel_task, idle_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + for t in pending: + t.cancel() + if cancel_task in done and idle_task not in done: + was_aborted = True + await session.abort() + finally: + for t in (cancel_task, idle_task): + if not t.done(): + t.cancel() + + reply = "".join(reply_parts) + output = { + "invocation_id": invocation_id, + "reply": reply, + "partial": was_aborted, + } + if was_aborted or ctx.cancel.is_set(): + store[invocation_id] = {"status": "superseded", "output": output} + return await ctx.suspend(reason="steered") + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + run = await copilot_chat.start( + task_id="copilot-s1", + input={ + "session_id": "s1", + "message": "Explain decorators", + "invocation_id": "inv-1", + }, + ) + result = await asyncio.wait_for(run.result(), timeout=5.0) + assert result.is_suspended + assert result.output["reply"] == "Echo: Explain decorators" + assert result.output["partial"] is False + assert store["inv-1"]["status"] == "completed" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_copilot_steering_preserves_reply(self, tmp_path): + """Steering queues B while A is streaming. A's partial reply saved as superseded.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + + @durable_task(name="e2e_copilot_chat", steerable=True) + async def copilot_chat(ctx: TaskContext[dict]) -> dict[str, Any]: + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + if ctx.cancel.is_set(): + store[invocation_id] = { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + } + return await ctx.suspend(reason="steered") + + session = _MockCopilotSession(["part1-", "part2-", "part3"], delay=0.15) + reply_parts: list[str] = [] + idle_event = asyncio.Event() + + def on_event(event: Any) -> None: + if hasattr(event.data, "content"): + reply_parts.append(event.data.content or "") + elif type(event.data).__name__ == "IdleData": + idle_event.set() + + session.on(on_event) + await session.send(message) + + cancel_task = asyncio.create_task(ctx.cancel.wait()) + idle_task = asyncio.create_task(idle_event.wait()) + was_aborted = False + try: + done, pending = await asyncio.wait( + {cancel_task, idle_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + for t in pending: + t.cancel() + if cancel_task in done and idle_task not in done: + was_aborted = True + await session.abort() + finally: + for t in (cancel_task, idle_task): + if not t.done(): + t.cancel() + + reply = "".join(reply_parts) + output = { + "invocation_id": invocation_id, + "reply": reply, + "partial": was_aborted, + } + if was_aborted or ctx.cancel.is_set(): + store[invocation_id] = {"status": "superseded", "output": output} + return await ctx.suspend(reason="steered") + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + run_a = await copilot_chat.start( + task_id="copilot-s1", + input={ + "session_id": "s1", + "message": "decorators", + "invocation_id": "inv-a", + }, + ) + await asyncio.sleep(0.05) + + store["inv-b"] = {"status": "queued"} + run_b = await copilot_chat.start( + task_id="copilot-s1", + input={ + "session_id": "s1", + "message": "async/await", + "invocation_id": "inv-b", + }, + ) + + assert store["inv-b"]["status"] == "queued" + + result_a = await asyncio.wait_for(run_a.result(), timeout=5.0) + assert result_a.is_superseded + + result_b = await asyncio.wait_for(run_b.result(), timeout=5.0) + assert result_b.is_suspended + assert result_b.output["reply"] == "part1-part2-part3" + + # A should be superseded (reply may be empty or partial — event + # delivery is async, so cancel can arrive before any events fire) + assert store["inv-a"]["status"] == "superseded" + assert "output" in store["inv-a"] + assert store["inv-b"]["status"] == "completed" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Sample E2E: LangGraph steering path (durable_langgraph) +# --------------------------------------------------------------------------- + + +class TestLangGraphSteeringSampleE2E: + """E2E for the durable_langgraph sample's steering path. + + Exercises the framework steering lifecycle (queued → cancel → drain → + re-enter) using a simplified LangGraph-like pattern with checkpointing + and invocation store writes. + """ + + @pytest.mark.asyncio + async def test_langgraph_steering_cancels_and_resumes(self, tmp_path): + """Steer while A is running → A cancelled → B processes from checkpoint.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + checkpoints: list[str] = [] + + @durable_task(name="e2e_lg_session", steerable=True) + async def lg_session(ctx: TaskContext[dict]) -> dict[str, Any]: + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + + if ctx.cancel.is_set(): + store[invocation_id] = {"status": "cancelled", "reason": "steered"} + return await ctx.suspend(reason="steered") + + # Simulate multi-step graph processing + await asyncio.sleep(0.1) # Step 1: analyze + if ctx.cancel.is_set(): + store[invocation_id] = {"status": "cancelled", "reason": "steered"} + return await ctx.suspend(reason="steered") + + await asyncio.sleep(0.1) # Step 2: generate + if ctx.cancel.is_set(): + store[invocation_id] = {"status": "cancelled", "reason": "steered"} + return await ctx.suspend(reason="steered") + + reply = f"[graph] Processed: {message}" + + # Save checkpoint + cp_id = f"cp-{ctx.generation}" + checkpoints.append(cp_id) + ctx.metadata.set("stable_checkpoint_id", cp_id) + + output = {"invocation_id": invocation_id, "reply": reply} + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + run_a = await lg_session.start( + task_id="lg-s1", + input={ + "session_id": "s1", + "message": "Plan a trip", + "invocation_id": "lg-a", + }, + ) + await asyncio.sleep(0.05) + + # Steer while A is running + store["lg-b"] = {"status": "queued"} + run_b = await lg_session.start( + task_id="lg-s1", + input={ + "session_id": "s1", + "message": "Go to Paris", + "invocation_id": "lg-b", + }, + ) + assert store["lg-b"]["status"] == "queued" + + result_a = await asyncio.wait_for(run_a.result(), timeout=5.0) + assert result_a.is_superseded + + result_b = await asyncio.wait_for(run_b.result(), timeout=5.0) + assert result_b.is_suspended + assert result_b.output["reply"] == "[graph] Processed: Go to Paris" + + assert store["lg-a"]["status"] == "cancelled" + assert store["lg-b"]["status"] == "completed" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_langgraph_multi_turn_then_steer(self, tmp_path): + """Normal turn 1 → resume turn 2 → steer during turn 2 with turn 3.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + + @durable_task(name="e2e_lg_session", steerable=True, ephemeral=False) + async def lg_session(ctx: TaskContext[dict]) -> dict[str, Any]: + message = ctx.input["message"] + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + + if ctx.cancel.is_set(): + store[invocation_id] = {"status": "cancelled", "reason": "steered"} + return await ctx.suspend(reason="steered") + + await asyncio.sleep(0.3) # Simulated processing + + if ctx.cancel.is_set(): + store[invocation_id] = {"status": "cancelled", "reason": "steered"} + return await ctx.suspend(reason="steered") + + reply = f"[graph] {message} (gen={ctx.generation})" + output = {"invocation_id": invocation_id, "reply": reply} + store[invocation_id] = {"status": "completed", "output": output} + return await ctx.suspend(reason="awaiting_user_input", output=output) + + # Turn 1: normal + run1 = await lg_session.start( + task_id="lg-mt", + input={"session_id": "s1", "message": "Turn1", "invocation_id": "mt-1"}, + ) + result1 = await asyncio.wait_for(run1.result(), timeout=5.0) + assert result1.is_suspended + assert store["mt-1"]["status"] == "completed" + + # Turn 2: resume + run2 = await lg_session.start( + task_id="lg-mt", + input={"session_id": "s1", "message": "Turn2", "invocation_id": "mt-2"}, + ) + await asyncio.sleep(0.05) + + # Turn 3: steer while turn 2 is running + store["mt-3"] = {"status": "queued"} + run3 = await lg_session.start( + task_id="lg-mt", + input={"session_id": "s1", "message": "Turn3", "invocation_id": "mt-3"}, + ) + assert store["mt-3"]["status"] == "queued" + + result2 = await asyncio.wait_for(run2.result(), timeout=5.0) + assert result2.is_superseded + + result3 = await asyncio.wait_for(run3.result(), timeout=5.0) + assert result3.is_suspended + assert "Turn3" in result3.output["reply"] + assert store["mt-2"]["status"] == "cancelled" + assert store["mt-3"]["status"] == "completed" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# SSE Streaming: lifecycle events, text deltas, steering supersession +# --------------------------------------------------------------------------- + + +class TestSSEStreamingE2E: + """E2E tests for the SSE streaming pattern used by all samples.""" + + @pytest.mark.asyncio + async def test_lifecycle_and_text_deltas_streamed(self, tmp_path): + """ctx.stream() emits lifecycle:running then text_delta events.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + + @durable_task(name="e2e_sse_stream") + async def sse_stream(ctx: TaskContext[dict]) -> dict[str, Any]: + invocation_id = ctx.input["invocation_id"] + await ctx.stream({"type": "lifecycle", "status": "running"}) + reply = "" + for token in ["Hello", " ", "world"]: + reply += token + await ctx.stream({"type": "text_delta", "delta": token}) + return { + "invocation_id": invocation_id, + "reply": reply, + } + + run = await sse_stream.start( + task_id="sse-1", + input={"invocation_id": "inv-sse-1"}, + ) + + chunks: list[dict[str, Any]] = [] + async for chunk in run: + chunks.append(chunk) + + result = await asyncio.wait_for(run.result(), timeout=5.0) + + # First chunk: lifecycle running + assert chunks[0] == {"type": "lifecycle", "status": "running"} + # Then three text deltas + assert chunks[1] == {"type": "text_delta", "delta": "Hello"} + assert chunks[2] == {"type": "text_delta", "delta": " "} + assert chunks[3] == {"type": "text_delta", "delta": "world"} + assert len(chunks) == 4 + assert result.output["reply"] == "Hello world" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_steering_produces_superseded_stream(self, tmp_path): + """When steering cancels a running task, the stream ends after cancel.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + + @durable_task(name="e2e_sse_steer", steerable=True) + async def sse_steer(ctx: TaskContext[dict]) -> dict[str, Any]: + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + await ctx.stream({"type": "lifecycle", "status": "running"}) + + if ctx.cancel.is_set(): + store[invocation_id] = {"status": "cancelled", "reason": "steered"} + return await ctx.suspend(reason="steered") + + # Simulate slow generation that gets interrupted + reply = "" + for token in ["Slow", " ", "reply", " ", "here"]: + reply += token + await ctx.stream({"type": "text_delta", "delta": token}) + await asyncio.sleep(0.05) + if ctx.cancel.is_set(): + store[invocation_id] = { + "status": "superseded", + "partial_reply": reply, + } + return await ctx.suspend(reason="steered") + + store[invocation_id] = {"status": "completed", "reply": reply} + return await ctx.suspend( + reason="awaiting_user_input", + output={"invocation_id": invocation_id, "reply": reply}, + ) + + # Start turn 1 + run1 = await sse_steer.start( + task_id="sse-steer-1", + input={"invocation_id": "inv-s1"}, + ) + + # Collect some chunks from turn 1 + chunks1: list[dict[str, Any]] = [] + async for chunk in run1: + chunks1.append(chunk) + if len(chunks1) >= 2: + # Steer with turn 2 while turn 1 is streaming + await sse_steer.start( + task_id="sse-steer-1", + input={"invocation_id": "inv-s2"}, + ) + break + + # Turn 1 should have been superseded + result1 = await asyncio.wait_for(run1.result(), timeout=5.0) + assert result1.is_superseded + assert store["inv-s1"]["status"] in ("superseded", "cancelled") + + # First chunk was lifecycle:running + assert chunks1[0] == {"type": "lifecycle", "status": "running"} + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_stream_with_invocation_store_snapshots(self, tmp_path): + """Dual-write: ctx.stream() for live SSE + store for GET snapshots.""" + manager, mgr_mod = await _ManagerFixture.setup(tmp_path) + try: + store: dict[str, dict[str, Any]] = {} + + @durable_task(name="e2e_sse_snapshot") + async def sse_snapshot(ctx: TaskContext[dict]) -> dict[str, Any]: + invocation_id = ctx.input["invocation_id"] + store[invocation_id] = {"status": "running"} + await ctx.stream({"type": "lifecycle", "status": "running"}) + + reply = "" + for token in ["A", "B", "C"]: + reply += token + await ctx.stream({"type": "text_delta", "delta": token}) + store[invocation_id] = {"status": "streaming", "text": reply} + + store[invocation_id] = { + "status": "completed", + "reply": reply, + } + return {"invocation_id": invocation_id, "reply": reply} + + run = await sse_snapshot.start( + task_id="sse-snap-1", + input={"invocation_id": "inv-snap-1"}, + ) + + chunks: list[dict[str, Any]] = [] + async for chunk in run: + chunks.append(chunk) + + result = await asyncio.wait_for(run.result(), timeout=5.0) + + # Stream had lifecycle + 3 deltas + assert len(chunks) == 4 + assert chunks[0]["type"] == "lifecycle" + + # Store has final snapshot + assert store["inv-snap-1"]["status"] == "completed" + assert store["inv-snap-1"]["reply"] == "ABC" + assert result.output["reply"] == "ABC" + + finally: + await _ManagerFixture.teardown(manager, mgr_mod) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_source.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_source.py new file mode 100644 index 000000000000..6faed9e06f38 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_source.py @@ -0,0 +1,140 @@ +"""Tests for source field support on TaskInfo and TaskCreateRequest.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from azure.ai.agentserver.core.durable._models import ( + TaskCreateRequest, + TaskInfo, +) + + +class TestTaskInfoSource: + """Source field on TaskInfo.""" + + def test_default_none(self): + info = TaskInfo(id="t1", agent_name="a", session_id="s", status="pending") + assert info.source is None + + def test_set_at_construction(self): + src = {"type": "user", "origin": "cli"} + info = TaskInfo( + id="t1", agent_name="a", session_id="s", status="pending", source=src + ) + assert info.source == src + + def test_to_dict_includes_source(self): + src = {"type": "api", "request_id": "r1"} + info = TaskInfo( + id="t1", agent_name="a", session_id="s", status="pending", source=src + ) + d = info.to_dict() + assert d["source"] == src + + def test_to_dict_omits_none_source(self): + info = TaskInfo(id="t1", agent_name="a", session_id="s", status="pending") + d = info.to_dict() + assert "source" not in d + + def test_from_dict_with_source(self): + data = { + "id": "t1", + "agent_name": "a", + "session_id": "s", + "status": "pending", + "source": {"type": "workflow", "step": 3}, + } + info = TaskInfo.from_dict(data) + assert info.source == {"type": "workflow", "step": 3} + + def test_from_dict_without_source(self): + data = {"id": "t1", "agent_name": "a", "session_id": "s", "status": "pending"} + info = TaskInfo.from_dict(data) + assert info.source is None + + def test_round_trip(self): + src = {"origin": "test", "nested": {"a": 1}} + info = TaskInfo( + id="t1", agent_name="a", session_id="s", status="pending", source=src + ) + restored = TaskInfo.from_dict(info.to_dict()) + assert restored.source == src + + +class TestTaskCreateRequestSource: + """Source field on TaskCreateRequest.""" + + def test_default_none(self): + req = TaskCreateRequest(agent_name="a", session_id="s") + assert req.source is None + + def test_set_at_construction(self): + src = {"type": "decorator"} + req = TaskCreateRequest(agent_name="a", session_id="s", source=src) + assert req.source == src + + +class TestSourceLocalProvider: + """Source persisted via LocalFileDurableTaskProvider.""" + + @pytest.mark.asyncio + async def test_source_persisted_and_retrieved(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + src = {"type": "test", "run_id": "abc123"} + req = TaskCreateRequest( + agent_name="agent", + session_id="test-session", + source=src, + ) + created = await provider.create(req) + assert created.source == src + + # Re-read from disk + fetched = await provider.get(created.id) + assert fetched is not None + assert fetched.source == src + + @pytest.mark.asyncio + async def test_source_none_not_persisted(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + req = TaskCreateRequest(agent_name="agent", session_id="test-session") + created = await provider.create(req) + assert created.source is None + + fetched = await provider.get(created.id) + assert fetched is not None + assert fetched.source is None + + @pytest.mark.asyncio + async def test_source_immutable_after_create(self, tmp_path): + """Source must not be changeable via PATCH — TaskPatchRequest has no source field.""" + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._models import TaskPatchRequest + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + req = TaskCreateRequest( + agent_name="agent", + session_id="test-session", + source={"type": "original"}, + ) + created = await provider.create(req) + + # Patch does not touch source + await provider.update(created.id, TaskPatchRequest(tags={"k": "v"})) + fetched = await provider.get(created.id) + assert fetched is not None + assert fetched.source == {"type": "original"} diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_steering.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_steering.py new file mode 100644 index 000000000000..0f930ac863e3 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_steering.py @@ -0,0 +1,679 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for steerable durable tasks — steering, drain, context, and recovery.""" + +import asyncio +import json +from pathlib import Path +from typing import Any + +import pytest + +from azure.ai.agentserver.core.durable import ( + TaskContext, + TaskResult, + durable_task, + EntryMode, + EtagConflict, + SteeringQueueFull, + TaskConflictError, +) + + +class TestSteering: + """Core steering functionality: append, drain, short-circuit.""" + + async def _setup_manager(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + async def _teardown_manager(self, manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + # ------------------------------------------------------------------ + # US1: Basic steering + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_steerable_start_on_in_progress_queues_input(self, tmp_path): + """start() on in_progress steerable task appends input, not raises.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + + @durable_task(name="chat", steerable=True) + async def chat(ctx: TaskContext[dict]) -> dict: + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + # Simulate work with small delay + await asyncio.sleep(0.5) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": ctx.input.get("msg", "?")} + + # Start first invocation + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + + # Small delay for A to enter function body + await asyncio.sleep(0.1) + + # Steer while in progress — should NOT raise + run2 = await chat.start(task_id="t1", input={"msg": "B"}) + + # run2 should be a TaskRun (ack), not raise TaskConflictError + assert run2.task_id == "t1" + + # Verify queue has the input + task_info = await manager.provider.get("t1") + steering = task_info.payload.get("_steering", {}) + assert len(steering["pending_inputs"]) >= 1 + assert steering["cancel_requested"] is True + + # run1 should be superseded (A was cancelled) + result1 = await asyncio.wait_for(run1.result(), timeout=5.0) + assert result1.is_superseded + + # run2 should complete (B runs after drain) + result2 = await asyncio.wait_for(run2.result(), timeout=5.0) + assert result2.is_completed + assert result2.output == {"msg": "B"} + + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_non_steerable_raises_conflict(self, tmp_path): + """start() on in_progress non-steerable task still raises.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + gate = asyncio.Event() + + @durable_task(name="regular") + async def regular(ctx: TaskContext[dict]) -> dict: + await gate.wait() + return {"msg": "done"} + + run1 = await regular.start(task_id="t1", input={"msg": "A"}) + + with pytest.raises(TaskConflictError): + await regular.start(task_id="t1", input={"msg": "B"}) + + gate.set() + await asyncio.wait_for(run1.result(), timeout=5.0) + + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_steering_queue_full(self, tmp_path): + """start() raises SteeringQueueFull when queue is at capacity.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + gate = asyncio.Event() + + @durable_task(name="chat", steerable=True, max_pending=2) + async def chat(ctx: TaskContext[dict]) -> dict: + await gate.wait() + return {"msg": "done"} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + + # Fill the queue + await chat.start(task_id="t1", input={"msg": "B"}) + await chat.start(task_id="t1", input={"msg": "C"}) + + # Queue is full — should raise + with pytest.raises(SteeringQueueFull) as exc_info: + await chat.start(task_id="t1", input={"msg": "D"}) + + assert exc_info.value.max_pending == 2 + + gate.set() + await asyncio.wait_for(run1.result(), timeout=5.0) + + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_superseded_result_status(self, tmp_path): + """Superseded generation's TaskRun resolves with status=superseded.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + + @durable_task(name="chat", steerable=True) + async def chat(ctx: TaskContext[dict]) -> dict: + # Always check cancel and suspend if set + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + # Simulate work — gives time for cancel signal + await asyncio.sleep(0.3) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": ctx.input.get("msg", "?")} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + + # Small delay to ensure task is running + await asyncio.sleep(0.1) + + # Steer + run2 = await chat.start(task_id="t1", input={"msg": "B"}) + + # run1 should be superseded + result1 = await asyncio.wait_for(run1.result(), timeout=5.0) + assert result1.is_superseded + + # run2 should complete + result2 = await asyncio.wait_for(run2.result(), timeout=5.0) + assert result2.is_completed + assert result2.output == {"msg": "B"} + + finally: + await self._teardown_manager(manager, mgr_mod) + + # ------------------------------------------------------------------ + # US2: Rapid-fire short-circuit + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_rapid_fire_only_last_completes(self, tmp_path): + """3 rapid-fire steers: only the last gen runs to completion.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + entries: list[tuple[str, bool]] = [] + + @durable_task(name="chat", steerable=True) + async def chat(ctx: TaskContext[dict]) -> dict: + entries.append((ctx.input.get("msg", "?"), ctx.cancel.is_set())) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": ctx.input.get("msg", "?")} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + + # Small delay for A to start + await asyncio.sleep(0.05) + + # Rapid-fire B, C, D + run_b = await chat.start(task_id="t1", input={"msg": "B"}) + run_c = await chat.start(task_id="t1", input={"msg": "C"}) + run_d = await chat.start(task_id="t1", input={"msg": "D"}) + + # D should be the one that completes + result_d = await asyncio.wait_for(run_d.result(), timeout=5.0) + assert result_d.is_completed + assert result_d.output == {"msg": "D"} + + # B and C should be superseded + result_b = await asyncio.wait_for(run_b.result(), timeout=5.0) + assert result_b.is_superseded + + result_c = await asyncio.wait_for(run_c.result(), timeout=5.0) + assert result_c.is_superseded + + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_cancel_pre_set_when_queue_has_items(self, tmp_path): + """ctx.cancel is pre-set at function entry when queue has items.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + cancel_states: list[bool] = [] + + @durable_task(name="chat", steerable=True) + async def chat(ctx: TaskContext[dict]) -> dict: + cancel_states.append(ctx.cancel.is_set()) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": ctx.input.get("msg", "?")} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + await asyncio.sleep(0.05) + + # Queue B and C + run_b = await chat.start(task_id="t1", input={"msg": "B"}) + run_c = await chat.start(task_id="t1", input={"msg": "C"}) + + result_c = await asyncio.wait_for(run_c.result(), timeout=5.0) + assert result_c.is_completed + + # A: cancel set by steering signal + # B: cancel pre-set (C still queued) + # C: cancel NOT set (nothing queued after C) + # cancel_states should have at least 3 entries + assert len(cancel_states) >= 3 + # The last one (C) should be False + assert cancel_states[-1] is False + + finally: + await self._teardown_manager(manager, mgr_mod) + + # ------------------------------------------------------------------ + # US3: Context enrichment + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_steered_context_fields(self, tmp_path): + """Steered generation has was_steered=True, previous_input set.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + contexts: list[dict[str, Any]] = [] + + @durable_task(name="chat", steerable=True) + async def chat(ctx: TaskContext[dict]) -> dict: + contexts.append( + { + "entry_mode": ctx.entry_mode, + "was_steered": ctx.was_steered, + "previous_input": ctx.previous_input, + "generation": ctx.generation, + "msg": ctx.input.get("msg", "?"), + } + ) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + # Simulate work — gives time for steering signal + await asyncio.sleep(0.3) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": ctx.input.get("msg", "?")} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + await asyncio.sleep(0.1) + + run2 = await chat.start(task_id="t1", input={"msg": "B"}) + + result2 = await asyncio.wait_for(run2.result(), timeout=5.0) + assert result2.is_completed + + # First entry: fresh, not steered + assert contexts[0]["entry_mode"] == "fresh" + assert contexts[0]["was_steered"] is False + assert contexts[0]["generation"] == 0 + + # Second entry: steered (entry_mode="resumed" with was_steered=True) + steered = [c for c in contexts if c["was_steered"] is True] + assert len(steered) >= 1 + assert steered[0]["entry_mode"] == "resumed" + assert steered[0]["generation"] > 0 + + finally: + await self._teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_entry_mode_steered(self, tmp_path): + """Steered generations enter with entry_mode='resumed' and was_steered=True.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + modes: list[str] = [] + steered_flags: list[bool] = [] + + @durable_task(name="chat", steerable=True) + async def chat(ctx: TaskContext[dict]) -> dict: + modes.append(ctx.entry_mode) + steered_flags.append(ctx.was_steered) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + await asyncio.sleep(0.3) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": "done"} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + await asyncio.sleep(0.1) + run2 = await chat.start(task_id="t1", input={"msg": "B"}) + + await asyncio.wait_for(run2.result(), timeout=5.0) + + assert "fresh" in modes + assert "resumed" in modes + # The steered generation should have was_steered=True + assert True in steered_flags + + finally: + await self._teardown_manager(manager, mgr_mod) + + # ------------------------------------------------------------------ + # TaskResult.is_superseded + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_task_result_is_superseded(self): + """TaskResult with status=superseded has is_superseded=True.""" + result = TaskResult(task_id="t1", status="superseded") + assert result.is_superseded is True + assert result.is_completed is False + assert result.is_suspended is False + assert result.output is None + + @pytest.mark.asyncio + async def test_task_result_completed_not_superseded(self): + """TaskResult with status=completed has is_superseded=False.""" + result = TaskResult(task_id="t1", status="completed", output=42) + assert result.is_superseded is False + assert result.is_completed is True + + # ------------------------------------------------------------------ + # Options passthrough + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_steerable_via_options(self, tmp_path): + """steerable can be set via .options().""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + gate = asyncio.Event() + + @durable_task(name="chat") + async def chat(ctx: TaskContext[dict]) -> dict: + await gate.wait() + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": "done"} + + steerable_chat = chat.options(steerable=True) + + run1 = await steerable_chat.start(task_id="t1", input={"msg": "A"}) + await asyncio.sleep(0.05) + + # This should work because steerable=True via options + run2 = await steerable_chat.start(task_id="t1", input={"msg": "B"}) + assert run2.task_id == "t1" + + gate.set() + await asyncio.wait_for(run2.result(), timeout=5.0) + + finally: + await self._teardown_manager(manager, mgr_mod) + + # ------------------------------------------------------------------ + # DurableTaskOptions validation + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_max_pending_validation(self): + """max_pending < 1 raises ValueError at decoration time.""" + with pytest.raises(ValueError, match="max_pending"): + + @durable_task(name="bad", max_pending=0) + async def bad(ctx: TaskContext[dict]) -> dict: + return {} + + # ------------------------------------------------------------------ + # Exceptions + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_etag_conflict_exception(self): + """EtagConflict has task_id attribute.""" + exc = EtagConflict("t1", "test message") + assert exc.task_id == "t1" + assert "test message" in str(exc) + + @pytest.mark.asyncio + async def test_steering_queue_full_exception(self): + """SteeringQueueFull has task_id and max_pending attributes.""" + exc = SteeringQueueFull("t1", 10) + assert exc.task_id == "t1" + assert exc.max_pending == 10 + assert "10" in str(exc) + + # ------------------------------------------------------------------ + # Steering with function that completes (not suspends) + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_steering_function_ignores_cancel_completes(self, tmp_path): + """If function ignores cancel and returns, steering still works via drain.""" + manager, mgr_mod = await self._setup_manager(tmp_path) + try: + call_count = 0 + + @durable_task(name="chat", steerable=True, ephemeral=False) + async def chat(ctx: TaskContext[dict]) -> dict: + nonlocal call_count + call_count += 1 + # Intentionally ignores ctx.cancel + return {"msg": ctx.input.get("msg", "?")} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + + # Wait for A to complete + result1 = await asyncio.wait_for(run1.result(), timeout=5.0) + assert result1.is_completed + + # For non-ephemeral completed tasks, steerable or not, raises conflict + with pytest.raises(TaskConflictError): + await chat.start(task_id="t1", input={"msg": "B"}) + + finally: + await self._teardown_manager(manager, mgr_mod) + + +class TestSteeringRecovery: + """Crash recovery for steerable tasks.""" + + async def _setup_manager(self, tmp_path): + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + async def _teardown_manager(self, manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + @pytest.mark.asyncio + async def test_recovery_with_drain_in_progress(self, tmp_path): + """Recovery after crash mid-drain uses active_input from steering state.""" + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + from azure.ai.agentserver.core.durable._models import ( + TaskPatchRequest, + ) + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + + # Phase 1: Create a task and simulate crash mid-drain + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + + @durable_task(name="chat", steerable=True, ephemeral=False) + async def chat(ctx: TaskContext[dict]) -> dict: + return {"msg": ctx.input.get("msg", "?")} + + run1 = await chat.start(task_id="t1", input={"msg": "A"}) + await asyncio.wait_for(run1.result(), timeout=5.0) + + # Simulate crash state: task is in_progress with drain_in_progress + # Reset status to in_progress and inject steering state + await provider.update( + "t1", + TaskPatchRequest( + status="in_progress", + payload={ + "_steering": { + "generation": 1, + "active_input": {"msg": "B"}, + "previous_input": {"msg": "A"}, + "pending_inputs": [], + "cancel_requested": False, + "drain_in_progress": True, + }, + }, + ), + ) + + await manager.shutdown() + mgr_mod._manager = None + + # Phase 2: Recover — new manager picks up the crashed task + manager2 = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager2 + await manager2.startup() + + inputs_seen: list[dict] = [] + + @durable_task(name="chat", steerable=True, ephemeral=False) + async def chat2(ctx: TaskContext[dict]) -> dict: + inputs_seen.append(dict(ctx.input)) + return {"msg": ctx.input.get("msg", "?")} + + # Start with recovery input (doesn't matter — active_input overrides) + run2 = await chat2.start( + task_id="t1", input={"msg": "recovery"}, stale_timeout=0.0 + ) + result2 = await asyncio.wait_for(run2.result(), timeout=5.0) + + # Should have used active_input "B", not the recovery caller input + assert result2.output == {"msg": "B"} + assert inputs_seen[-1] == {"msg": "B"} + + await manager2.shutdown() + mgr_mod._manager = None + + @pytest.mark.asyncio + async def test_recovery_with_pending_inputs(self, tmp_path): + """Recovery with pending inputs drains them after function completes.""" + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import ( + DurableTaskManager, + ) + from azure.ai.agentserver.core.durable._models import ( + TaskPatchRequest, + ) + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + + # Phase 1: Create a task normally, then simulate crash with pending + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + + @durable_task(name="chat", steerable=True, ephemeral=False) + async def chat_setup(ctx: TaskContext[dict]) -> dict: + # Long sleep — we'll kill the manager before it completes + await asyncio.sleep(10) + return {"msg": "never"} + + run1 = await chat_setup.start(task_id="t2", input={"msg": "X"}) + await asyncio.sleep(0.1) # let it start + + # Force shutdown (simulates crash) + await manager.shutdown() + mgr_mod._manager = None + + # Patch the task to simulate crash-with-pending state + await provider.update( + "t2", + TaskPatchRequest( + status="in_progress", + payload={ + "input": {"msg": "X"}, + "_steering": { + "generation": 0, + "active_input": {"msg": "X"}, + "pending_inputs": [{"msg": "Y"}, {"msg": "Z"}], + "cancel_requested": True, + "drain_in_progress": False, + }, + }, + ), + ) + + # Phase 2: New manager recovers the task + manager2 = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager2 + await manager2.startup() + + inputs_seen: list[str] = [] + + @durable_task(name="chat", steerable=True, ephemeral=False) + async def chat(ctx: TaskContext[dict]) -> dict: + inputs_seen.append(ctx.input.get("msg", "?")) + if ctx.cancel.is_set(): + return await ctx.suspend(reason="steered") + return {"msg": ctx.input.get("msg", "?")} + + run2 = await chat.start( + task_id="t2", input={"msg": "recover"}, stale_timeout=0.0 + ) + result = await asyncio.wait_for(run2.result(), timeout=5.0) + + # Should have drained through X (cancel set) → Y (cancel set) → Z (complete) + assert result.output == {"msg": "Z"} + assert "Z" in inputs_seen + + await manager2.shutdown() + mgr_mod._manager = None diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_stream_handler.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_stream_handler.py new file mode 100644 index 000000000000..b65f758bc8ad --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_stream_handler.py @@ -0,0 +1,614 @@ +"""Tests for pluggable StreamHandler protocol (spec 009). + +Covers: +- T010: Custom handler receives items via put()/get() +- T011: Default behavior unchanged when no handler provided +- T013: Steerable task with custom handler across generations +- T015: close() called on success +- T016: close() called on failure +- T017: close() error logged but doesn't change task outcome +- T018: put() error propagates to ctx.stream() +- T021: Late-join consumer iterates stream via get_active_run() +""" + +from __future__ import annotations + +import asyncio +import logging +from pathlib import Path +from typing import Any + +import pytest + +from azure.ai.agentserver.core.durable import ( + QueueStreamHandler, + StreamHandler, + TaskContext, + durable_task, +) +from azure.ai.agentserver.core.durable._stream import QueueStreamHandler as _QSH + + +# --------------------------------------------------------------------------- +# Test fixtures +# --------------------------------------------------------------------------- + + +async def _setup_manager(tmp_path): + """Create a DurableTaskManager with local file storage.""" + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import DurableTaskManager + + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test-agent", + "session_id": "test-session", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + return manager, mgr_mod + + +async def _teardown_manager(manager, mgr_mod): + await manager.shutdown() + mgr_mod._manager = None + + +# --------------------------------------------------------------------------- +# Custom handler for testing +# --------------------------------------------------------------------------- + + +class RecordingHandler: + """A StreamHandler that records all put/get/close calls.""" + + def __init__(self) -> None: + self.items_put: list[Any] = [] + self.close_called: bool = False + self._queue: asyncio.Queue[Any] = asyncio.Queue() + self._sentinel = object() + + async def put(self, item: Any) -> None: + self.items_put.append(item) + await self._queue.put(item) + + async def get(self) -> Any: + item = await self._queue.get() + if item is self._sentinel: + raise StopAsyncIteration + return item + + async def close(self) -> None: + self.close_called = True + await self._queue.put(self._sentinel) + + +class FailingPutHandler: + """A StreamHandler whose put() always raises.""" + + async def put(self, item: Any) -> None: + raise RuntimeError("put() failed") + + async def get(self) -> Any: + raise StopAsyncIteration + + async def close(self) -> None: + pass + + +class FailingCloseHandler: + """A StreamHandler whose close() always raises.""" + + def __init__(self) -> None: + self._queue: asyncio.Queue[Any] = asyncio.Queue() + self._sentinel = object() + + async def put(self, item: Any) -> None: + await self._queue.put(item) + + async def get(self) -> Any: + item = await self._queue.get() + if item is self._sentinel: + raise StopAsyncIteration + return item + + async def close(self) -> None: + await self._queue.put(self._sentinel) + raise RuntimeError("close() failed") + + +# --------------------------------------------------------------------------- +# Phase 3: Custom Handler Dispatch (T010, T011) +# --------------------------------------------------------------------------- + + +class TestCustomHandlerDispatch: + """T010/T011: custom handler receives items; default unchanged.""" + + @pytest.mark.asyncio + async def test_custom_handler_receives_items(self, tmp_path): + """T010: Custom handler receives all items via put(), consumer + gets them via get().""" + manager, mgr_mod = await _setup_manager(tmp_path) + try: + handler = RecordingHandler() + + @durable_task(name="t010_custom_stream") + async def my_task(ctx: TaskContext[str]) -> str: + await ctx.stream("chunk-1") + await ctx.stream("chunk-2") + await ctx.stream("chunk-3") + return "done" + + run = await my_task.start( + task_id="t010-1", + input="hello", + stream_handler=handler, + ) + + collected = [] + async for chunk in run: + collected.append(chunk) + + result = await run.result() + assert result.output == "done" + assert collected == ["chunk-1", "chunk-2", "chunk-3"] + assert handler.items_put == ["chunk-1", "chunk-2", "chunk-3"] + assert handler.close_called is True + finally: + await _teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_default_handler_when_none_provided(self, tmp_path): + """T011: When no handler provided, default QueueStreamHandler works.""" + manager, mgr_mod = await _setup_manager(tmp_path) + try: + + @durable_task(name="t011_default_stream") + async def my_task(ctx: TaskContext[str]) -> str: + await ctx.stream("a") + await ctx.stream("b") + return "ok" + + run = await my_task.start( + task_id="t011-1", + input="test", + ) + + collected = [] + async for chunk in run: + collected.append(chunk) + + result = await run.result() + assert result.output == "ok" + assert collected == ["a", "b"] + finally: + await _teardown_manager(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Phase 4: Steering Carry-Over (T013) +# --------------------------------------------------------------------------- + + +class TestSteeringCarryOver: + """T013: Handler survives steering re-entries.""" + + @pytest.mark.asyncio + async def test_handler_carries_across_steering(self, tmp_path): + """T013: Items from both generations flow through same handler.""" + manager, mgr_mod = await _setup_manager(tmp_path) + try: + handler = RecordingHandler() + gen1_started = asyncio.Event() + + @durable_task(name="t013_steerable", steerable=True) + async def steerable_task(ctx: TaskContext[dict]) -> dict: + gen = ctx.generation + await ctx.stream({"gen": gen, "event": "start"}) + + if gen == 0: + gen1_started.set() + # Wait for cancel (steering) + while not ctx.cancel.is_set(): + await asyncio.sleep(0.01) + await ctx.stream({"gen": gen, "event": "cancelled"}) + return await ctx.suspend(reason="steered") + + await ctx.stream({"gen": gen, "event": "finish"}) + return {"gen": gen, "status": "completed"} + + # Start gen 0 with custom handler + run1 = await steerable_task.start( + task_id="t013-1", + input={"msg": "first"}, + stream_handler=handler, + ) + + # Wait for gen 0 to start streaming + await gen1_started.wait() + + # Steer — gen 0 gets cancelled, gen 1 starts + run2 = await steerable_task.start( + task_id="t013-1", + input={"msg": "second"}, + ) + + # Consume all items from run1 (which carries the handler) + collected = [] + async for chunk in run1: + collected.append(chunk) + + # Handler should have items from both generations + assert handler.close_called is True + assert any(item.get("gen") == 0 for item in handler.items_put) + assert any(item.get("gen") == 1 for item in handler.items_put) + finally: + await _teardown_manager(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Phase 5: Stream Closure (T015, T016, T017, T018) +# --------------------------------------------------------------------------- + + +class TestStreamClosure: + """T015–T018: close() lifecycle and error propagation.""" + + @pytest.mark.asyncio + async def test_close_called_on_success(self, tmp_path): + """T015: close() is called when task succeeds.""" + manager, mgr_mod = await _setup_manager(tmp_path) + try: + handler = RecordingHandler() + + @durable_task(name="t015_success") + async def my_task(ctx: TaskContext[str]) -> str: + await ctx.stream("data") + return "success" + + run = await my_task.start( + task_id="t015-1", + input="x", + stream_handler=handler, + ) + result = await run.result() + assert result.output == "success" + assert handler.close_called is True + finally: + await _teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_close_called_on_failure(self, tmp_path): + """T016: close() is called when task fails.""" + manager, mgr_mod = await _setup_manager(tmp_path) + try: + handler = RecordingHandler() + + @durable_task(name="t016_failure") + async def my_task(ctx: TaskContext[str]) -> str: + await ctx.stream("before-error") + raise ValueError("boom") + + run = await my_task.start( + task_id="t016-1", + input="x", + stream_handler=handler, + ) + + # Drain stream + collected = [] + async for chunk in run: + collected.append(chunk) + + assert handler.close_called is True + assert collected == ["before-error"] + finally: + await _teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_close_error_logged_not_propagated(self, tmp_path, caplog): + """T017: close() error is logged but doesn't change task outcome.""" + manager, mgr_mod = await _setup_manager(tmp_path) + try: + handler = FailingCloseHandler() + + @durable_task(name="t017_close_error") + async def my_task(ctx: TaskContext[str]) -> str: + await ctx.stream("data") + return "ok" + + run = await my_task.start( + task_id="t017-1", + input="x", + stream_handler=handler, + ) + + collected = [] + async for chunk in run: + collected.append(chunk) + + # Task should still succeed despite close() error + result = await run.result() + assert result.output == "ok" + assert collected == ["data"] + + # close() error should be logged + assert any( + "close() failed" in record.message + for record in caplog.records + if record.levelno >= logging.WARNING + ) + finally: + await _teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_put_error_propagates(self, tmp_path): + """T018: put() error propagates to ctx.stream() call site.""" + manager, mgr_mod = await _setup_manager(tmp_path) + try: + handler = FailingPutHandler() + + @durable_task(name="t018_put_error") + async def my_task(ctx: TaskContext[str]) -> str: + await ctx.stream("this will fail") + return "should not reach" + + run = await my_task.start( + task_id="t018-1", + input="x", + stream_handler=handler, + ) + + # The task should fail because put() raised + from azure.ai.agentserver.core.durable import TaskFailed + + with pytest.raises(TaskFailed): + await run.result() + finally: + await _teardown_manager(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Phase 6: Late-Join Consumer (T021) +# --------------------------------------------------------------------------- + + +class TestLateJoinConsumer: + """T021: Late-join consumer via get_active_run().""" + + @pytest.mark.asyncio + async def test_late_join_gets_stream_items(self, tmp_path): + """T021: Code that didn't call start() gets a TaskRun handle + and iterates stream items via get_active_run().""" + manager, mgr_mod = await _setup_manager(tmp_path) + try: + task_started = asyncio.Event() + proceed = asyncio.Event() + + @durable_task(name="t021_late_join") + async def my_task(ctx: TaskContext[str]) -> str: + await ctx.stream("chunk-1") + task_started.set() + await proceed.wait() + await ctx.stream("chunk-2") + return "done" + + # Start the task + run = await my_task.start( + task_id="t021-1", + input="hello", + ) + + # Wait for first chunk to be streamed + await task_started.wait() + + # Late-join: get a handle without being the original caller + late_run = my_task.get_active_run("t021-1") + assert late_run is not None + + # Let the task finish + proceed.set() + + # Both runs should be able to get the result + result = await run.result() + assert result.output == "done" + + late_result = await late_run.result() + assert late_result.output == "done" + finally: + await _teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_get_active_run_returns_none_for_inactive(self, tmp_path): + """get_active_run returns None for a task not currently active.""" + manager, mgr_mod = await _setup_manager(tmp_path) + try: + + @durable_task(name="t021_inactive") + async def my_task(ctx: TaskContext[str]) -> str: + return "done" + + result = my_task.get_active_run("nonexistent-task") + assert result is None + finally: + await _teardown_manager(manager, mgr_mod) + + +# --------------------------------------------------------------------------- +# Protocol conformance +# --------------------------------------------------------------------------- + + +class TestProtocolConformance: + """Verify QueueStreamHandler and custom handlers satisfy Protocol.""" + + def test_queue_handler_is_stream_handler(self): + handler = QueueStreamHandler() + assert isinstance(handler, StreamHandler) + + def test_recording_handler_is_stream_handler(self): + handler = RecordingHandler() + assert isinstance(handler, StreamHandler) + + def test_failing_put_handler_is_stream_handler(self): + handler = FailingPutHandler() + assert isinstance(handler, StreamHandler) + + +# --------------------------------------------------------------------------- +# stream_handler_factory on decorator (recovery uses factory) +# --------------------------------------------------------------------------- + + +class TestStreamHandlerFactory: + """Verify stream_handler_factory on the decorator is used for recovery.""" + + @pytest.mark.asyncio + async def test_factory_used_on_fresh_start(self, tmp_path): + """When no call-site handler provided, factory creates the handler.""" + manager, mgr_mod = await _setup_manager(tmp_path) + try: + created_handlers: list[RecordingHandler] = [] + + def _factory(task_id: str) -> RecordingHandler: + h = RecordingHandler() + created_handlers.append(h) + return h + + @durable_task( + name="t_factory_fresh", + stream_handler_factory=_factory, + ) + async def my_task(ctx: TaskContext[str]) -> str: + await ctx.stream("x") + return "ok" + + run = await my_task.start(task_id="factory-1", input="hi") + collected = [] + async for chunk in run: + collected.append(chunk) + result = await run.result() + + assert result.output == "ok" + assert collected == ["x"] + assert len(created_handlers) == 1 + assert created_handlers[0].items_put == ["x"] + assert created_handlers[0].close_called is True + finally: + await _teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_call_site_handler_overrides_factory(self, tmp_path): + """Call-site stream_handler takes precedence over factory.""" + manager, mgr_mod = await _setup_manager(tmp_path) + try: + factory_called = False + + def _factory(task_id: str) -> RecordingHandler: + nonlocal factory_called + factory_called = True + return RecordingHandler() + + @durable_task( + name="t_factory_override", + stream_handler_factory=_factory, + ) + async def my_task(ctx: TaskContext[str]) -> str: + await ctx.stream("y") + return "ok" + + call_site_handler = RecordingHandler() + run = await my_task.start( + task_id="override-1", + input="hi", + stream_handler=call_site_handler, + ) + collected = [] + async for chunk in run: + collected.append(chunk) + await run.result() + + assert collected == ["y"] + assert call_site_handler.items_put == ["y"] + assert factory_called is False + finally: + await _teardown_manager(manager, mgr_mod) + + @pytest.mark.asyncio + async def test_factory_used_on_recovery(self, tmp_path): + """On crash recovery, factory creates the handler, not QueueStreamHandler.""" + manager, mgr_mod = await _setup_manager(tmp_path) + try: + created_handlers: list[RecordingHandler] = [] + + def _factory(task_id: str) -> RecordingHandler: + h = RecordingHandler() + created_handlers.append(h) + return h + + @durable_task( + name="t_factory_recovery", + stream_handler_factory=_factory, + ephemeral=False, + ) + async def my_task(ctx: TaskContext[str]) -> str: + if ctx.entry_mode == "recovered": + await ctx.stream("recovered-chunk") + return "recovered" + await ctx.stream("fresh-chunk") + return "fresh" + + # First run — fresh + run1 = await my_task.start(task_id="recovery-1", input="hi") + collected1 = [] + async for chunk in run1: + collected1.append(chunk) + result1 = await run1.result() + assert result1.output == "fresh" + assert collected1 == ["fresh-chunk"] + assert len(created_handlers) == 1 + + # Simulate crash: force task back to in_progress + stale + # Write directly to the local file store to backdate updated_at + import json + + task_file = ( + Path(str(tmp_path)) / "test-agent" / "test-session" / "recovery-1.json" + ) + with open(task_file, "r") as f: + data = json.load(f) + data["status"] = "in_progress" + data["updated_at"] = "2000-01-01T00:00:00+00:00" + with open(task_file, "w") as f: + json.dump(data, f) + + # Recovery — should use factory, not QueueStreamHandler + run2 = await my_task.start( + task_id="recovery-1", + input="hi", + stale_timeout=1.0, + ) + collected2 = [] + async for chunk in run2: + collected2.append(chunk) + result2 = await run2.result() + + assert result2.output == "recovered" + assert collected2 == ["recovered-chunk"] + # Factory should have been called twice total (fresh + recovery) + assert len(created_handlers) == 2 + assert created_handlers[1].items_put == ["recovered-chunk"] + finally: + await _teardown_manager(manager, mgr_mod) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_streaming.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_streaming.py new file mode 100644 index 000000000000..ca77256e2913 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_streaming.py @@ -0,0 +1,181 @@ +"""Tests for streaming support (ctx.stream + async-for on TaskRun).""" + +from __future__ import annotations + +import asyncio + +import pytest + +from azure.ai.agentserver.core.durable._context import TaskContext +from azure.ai.agentserver.core.durable._metadata import TaskMetadata +from azure.ai.agentserver.core.durable._run import TaskRun +from azure.ai.agentserver.core.durable._stream import QueueStreamHandler + + +def _make_ctx(stream_handler=None, **overrides): + defaults = dict( + task_id="t1", + title="test", + session_id="s1", + agent_name="a1", + tags={}, + input=None, + metadata=TaskMetadata(), + stream_handler=stream_handler, + ) + defaults.update(overrides) + return TaskContext(**defaults) + + +def _make_run(stream_handler=None, result_future=None, **overrides): + loop = asyncio.get_event_loop() + if result_future is None: + result_future = loop.create_future() + defaults = dict( + task_id="t1", + provider=None, + result_future=result_future, + metadata=TaskMetadata(), + cancel_event=asyncio.Event(), + stream_handler=stream_handler, + ) + defaults.update(overrides) + return TaskRun(**defaults) + + +class TestContextStream: + """ctx.stream() puts items via the handler.""" + + @pytest.mark.asyncio + async def test_stream_puts_item(self): + handler = QueueStreamHandler() + ctx = _make_ctx(stream_handler=handler) + await ctx.stream("hello") + assert await handler.get() == "hello" + + @pytest.mark.asyncio + async def test_stream_multiple_items(self): + handler = QueueStreamHandler() + ctx = _make_ctx(stream_handler=handler) + await ctx.stream(1) + await ctx.stream(2) + await ctx.stream(3) + assert await handler.get() == 1 + assert await handler.get() == 2 + assert await handler.get() == 3 + + @pytest.mark.asyncio + async def test_stream_no_handler_noop(self): + ctx = _make_ctx(stream_handler=None) + # Should not raise + await ctx.stream("ignored") + + @pytest.mark.asyncio + async def test_stream_various_types(self): + handler = QueueStreamHandler() + ctx = _make_ctx(stream_handler=handler) + items = ["text", 42, {"key": "val"}, [1, 2], None, True] + for item in items: + await ctx.stream(item) + collected = [await handler.get() for _ in range(len(items))] + assert collected == items + + +class TestTaskRunAsyncIter: + """TaskRun.__aiter__ / __anext__ consume via the stream handler.""" + + @pytest.mark.asyncio + async def test_iterate_items(self): + handler = QueueStreamHandler() + run = _make_run(stream_handler=handler) + await handler.put("a") + await handler.put("b") + await handler.close() + + collected = [] + async for item in run: + collected.append(item) + assert collected == ["a", "b"] + + @pytest.mark.asyncio + async def test_empty_stream(self): + """close() immediately → no items.""" + handler = QueueStreamHandler() + run = _make_run(stream_handler=handler) + await handler.close() + + collected = [] + async for item in run: + collected.append(item) + assert collected == [] + + @pytest.mark.asyncio + async def test_no_handler_stops_immediately(self): + run = _make_run(stream_handler=None) + collected = [] + async for item in run: + collected.append(item) + assert collected == [] + + @pytest.mark.asyncio + async def test_stream_and_result(self): + """Stream items, then also await result().""" + handler = QueueStreamHandler() + loop = asyncio.get_event_loop() + fut: asyncio.Future = loop.create_future() + run = _make_run(stream_handler=handler, result_future=fut) + + await handler.put("chunk1") + await handler.put("chunk2") + await handler.close() + fut.set_result("final") + + collected = [] + async for item in run: + collected.append(item) + assert collected == ["chunk1", "chunk2"] + result = await run.result() + assert result == "final" # Unit test uses raw future, not manager pipeline + + @pytest.mark.asyncio + async def test_concurrent_producer_consumer(self): + """Producer streams while consumer iterates.""" + handler = QueueStreamHandler() + run = _make_run(stream_handler=handler) + + async def produce(): + for i in range(5): + await handler.put(i) + await asyncio.sleep(0.01) + await handler.close() + + collected = [] + + async def consume(): + async for item in run: + collected.append(item) + + await asyncio.gather(produce(), consume()) + assert collected == [0, 1, 2, 3, 4] + + +class TestStreamingErrorCases: + """Streaming under error/suspend/cancel conditions.""" + + @pytest.mark.asyncio + async def test_close_terminates_iteration(self): + """close() terminates iteration cleanly.""" + handler = QueueStreamHandler() + run = _make_run(stream_handler=handler) + await handler.put("partial") + await handler.close() + + collected = [] + async for item in run: + collected.append(item) + assert collected == ["partial"] + + @pytest.mark.asyncio + async def test_aiter_returns_self(self): + run = _make_run(stream_handler=QueueStreamHandler()) + assert run.__aiter__() is run diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_task_result.py b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_task_result.py new file mode 100644 index 000000000000..960311ebb6dc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/durable/test_task_result.py @@ -0,0 +1,130 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for the TaskResult wrapper class.""" + +import pytest + +from azure.ai.agentserver.core.durable import TaskResult + + +class TestTaskResult: + """Tests for TaskResult construction and properties.""" + + def test_completed_result(self): + """A completed result has is_completed=True, is_suspended=False.""" + r = TaskResult(task_id="t1", output="hello", status="completed") + assert r.is_completed + assert not r.is_suspended + assert r.output == "hello" + assert r.task_id == "t1" + assert r.suspension_reason is None + + def test_suspended_result(self): + """A suspended result has is_suspended=True, is_completed=False.""" + r = TaskResult( + task_id="t2", + output={"turn": 1}, + status="suspended", + suspension_reason="awaiting_user", + ) + assert r.is_suspended + assert not r.is_completed + assert r.output == {"turn": 1} + assert r.suspension_reason == "awaiting_user" + + def test_suspended_without_output(self): + """A suspended result can have no output (output=None).""" + r = TaskResult(task_id="t3", status="suspended") + assert r.is_suspended + assert r.output is None + assert r.suspension_reason is None + + def test_completed_with_none_output(self): + """A completed result can return None explicitly.""" + r = TaskResult(task_id="t4", output=None, status="completed") + assert r.is_completed + assert r.output is None + + def test_completed_with_complex_output(self): + """TaskResult works with dict outputs.""" + data = {"items": [1, 2, 3], "total": 3} + r = TaskResult(task_id="t5", output=data, status="completed") + assert r.output["items"] == [1, 2, 3] + assert r.output["total"] == 3 + + def test_repr_completed(self): + """__repr__ shows status and output for completed results.""" + r = TaskResult(task_id="t6", output="done", status="completed") + s = repr(r) + assert "t6" in s + assert "completed" in s + assert "done" in s + assert "suspension_reason" not in s + + def test_repr_suspended(self): + """__repr__ includes suspension_reason when present.""" + r = TaskResult( + task_id="t7", output=None, status="suspended", suspension_reason="waiting" + ) + s = repr(r) + assert "suspended" in s + assert "waiting" in s + + def test_repr_truncates_long_output(self): + """__repr__ truncates output longer than 60 chars.""" + long_val = "x" * 100 + r = TaskResult(task_id="t8", output=long_val, status="completed") + s = repr(r) + assert "..." in s + assert len(s) < 200 + + +class TestNestedTaskResultGuard: + """Test that returning TaskResult from a task function raises TypeError.""" + + @pytest.mark.asyncio + async def test_returning_taskresult_raises_typeerror(self, tmp_path): + """Task function that returns TaskResult directly gets TypeError.""" + import uuid + from pathlib import Path + from azure.ai.agentserver.core.durable import TaskContext, durable_task + from azure.ai.agentserver.core.durable._local_provider import ( + LocalFileDurableTaskProvider, + ) + from azure.ai.agentserver.core.durable._manager import DurableTaskManager + import azure.ai.agentserver.core.durable._manager as mgr_mod + + provider = LocalFileDurableTaskProvider(Path(str(tmp_path))) + config = type( + "C", + (), + { + "agent_name": "test", + "session_id": "test", + "agent_version": "1.0.0", + "is_hosted": False, + }, + )() + manager = DurableTaskManager(config=config, provider=provider) + mgr_mod._manager = manager + await manager.startup() + + try: + from typing import Any + from azure.ai.agentserver.core.durable import TaskContext + + @durable_task(name="bad_return") + async def bad_task(ctx: TaskContext[Any]) -> Any: + return TaskResult( + task_id=ctx.task_id, output="data", status="completed" + ) + + from azure.ai.agentserver.core.durable._exceptions import TaskFailed + + with pytest.raises(TaskFailed): + await bad_task.run(task_id=uuid.uuid4().hex, input=None) + + finally: + await manager.shutdown() + mgr_mod._manager = None diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_config.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_config.py index be194e6ec0fd..f1de9e022dea 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_config.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_config.py @@ -10,25 +10,33 @@ class TestAgentConfigIsHosted: """Tests for AgentConfig.is_hosted snapshotting behavior.""" - def test_is_hosted_false_when_env_var_absent(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_is_hosted_false_when_env_var_absent( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: """is_hosted is False when FOUNDRY_HOSTING_ENVIRONMENT is not set.""" monkeypatch.delenv("FOUNDRY_HOSTING_ENVIRONMENT", raising=False) config = AgentConfig.from_env() assert config.is_hosted is False - def test_is_hosted_false_when_env_var_empty(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_is_hosted_false_when_env_var_empty( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: """is_hosted is False when FOUNDRY_HOSTING_ENVIRONMENT is set to an empty string.""" monkeypatch.setenv("FOUNDRY_HOSTING_ENVIRONMENT", "") config = AgentConfig.from_env() assert config.is_hosted is False - def test_is_hosted_true_when_env_var_set(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_is_hosted_true_when_env_var_set( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: """is_hosted is True when FOUNDRY_HOSTING_ENVIRONMENT is set to a non-empty value.""" monkeypatch.setenv("FOUNDRY_HOSTING_ENVIRONMENT", "production") config = AgentConfig.from_env() assert config.is_hosted is True - def test_is_hosted_snapshotted_at_creation(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_is_hosted_snapshotted_at_creation( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: """is_hosted reflects the env var value at creation time, not at access time.""" monkeypatch.setenv("FOUNDRY_HOSTING_ENVIRONMENT", "production") config = AgentConfig.from_env() diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_graceful_shutdown.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_graceful_shutdown.py index 7c538c0ddc31..c15bccfd85b0 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_graceful_shutdown.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_graceful_shutdown.py @@ -11,7 +11,10 @@ import pytest from azure.ai.agentserver.core import AgentServerHost -from azure.ai.agentserver.core._config import resolve_graceful_shutdown_timeout, _DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT +from azure.ai.agentserver.core._config import ( + resolve_graceful_shutdown_timeout, + _DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT, +) # ------------------------------------------------------------------ # @@ -26,7 +29,10 @@ def test_explicit_wins(self) -> None: assert resolve_graceful_shutdown_timeout(10) == 10 def test_default(self) -> None: - assert resolve_graceful_shutdown_timeout(None) == _DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT + assert ( + resolve_graceful_shutdown_timeout(None) + == _DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT + ) def test_non_int_explicit_raises(self) -> None: with pytest.raises(ValueError, match="expected an integer"): @@ -193,7 +199,10 @@ async def send(message): await agent(scope, receive, send) # The error should be logged - assert any("on_shutdown" in r.message.lower() or "error" in r.message.lower() for r in caplog.records) + assert any( + "on_shutdown" in r.message.lower() or "error" in r.message.lower() + for r in caplog.records + ) # Server should still complete shutdown assert any(m["type"] == "lifespan.shutdown.complete" for m in sent_messages) @@ -204,7 +213,9 @@ async def send(message): @pytest.mark.asyncio -async def test_slow_shutdown_cancelled_with_warning(caplog: pytest.LogCaptureFixture) -> None: +async def test_slow_shutdown_cancelled_with_warning( + caplog: pytest.LogCaptureFixture, +) -> None: """A shutdown handler exceeding the timeout is cancelled and a warning is logged.""" agent = AgentServerHost(graceful_shutdown_timeout=1) @@ -230,7 +241,10 @@ async def send(message): with caplog.at_level(logging.WARNING, logger="azure.ai.agentserver"): await agent(scope, receive, send) - assert any("did not complete" in r.message.lower() or "timeout" in r.message.lower() for r in caplog.records) + assert any( + "did not complete" in r.message.lower() or "timeout" in r.message.lower() + for r in caplog.records + ) assert any(m["type"] == "lifespan.shutdown.complete" for m in sent_messages) @@ -341,7 +355,9 @@ def fake_asyncio_run(coroutine): finally: signal.signal(signal.SIGTERM, original) - def test_sigterm_handler_logs_and_re_raises(self, caplog: pytest.LogCaptureFixture) -> None: + def test_sigterm_handler_logs_and_re_raises( + self, caplog: pytest.LogCaptureFixture + ) -> None: """The installed SIGTERM handler logs then re-raises via os.kill.""" original = signal.getsignal(signal.SIGTERM) try: diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_logger.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_logger.py index a95e4980d530..9b2c05287882 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_logger.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_logger.py @@ -16,4 +16,5 @@ def test_log_level_preserved_across_imports() -> None: lib_logger = logging.getLogger("azure.ai.agentserver") lib_logger.setLevel(logging.ERROR) from azure.ai.agentserver.core import _base # noqa: F401 + assert lib_logger.level == logging.ERROR diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_server_routes.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_server_routes.py index 85e28c1bf15e..4ea165ee3f84 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_server_routes.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_server_routes.py @@ -12,7 +12,6 @@ from azure.ai.agentserver.core._config import resolve_port - # ------------------------------------------------------------------ # # Port resolution # ------------------------------------------------------------------ # diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_startup_logging.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_startup_logging.py index d6af2accb52c..1802f52b0644 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_startup_logging.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_startup_logging.py @@ -20,7 +20,9 @@ def test_none_like_empty_returns_not_set(self) -> None: assert _mask_uri(" ") == _NOT_SET def test_https_uri_strips_path_and_query(self) -> None: - result = _mask_uri("https://myproject.azure.com/subscriptions/abc?api-version=2024") + result = _mask_uri( + "https://myproject.azure.com/subscriptions/abc?api-version=2024" + ) assert result == "https://myproject.azure.com" def test_http_uri_with_port(self) -> None: @@ -68,7 +70,9 @@ def _clean_env(self, monkeypatch: pytest.MonkeyPatch) -> None: @pytest.mark.usefixtures("_clean_env") @pytest.mark.asyncio - async def test_startup_logs_platform_environment(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_startup_logs_platform_environment( + self, caplog: pytest.LogCaptureFixture + ) -> None: """Lifespan startup emits platform environment log line.""" from azure.ai.agentserver.core import AgentServerHost @@ -78,7 +82,9 @@ async def test_startup_logs_platform_environment(self, caplog: pytest.LogCapture async with app.router.lifespan_context(app): pass - platform_logs = [r for r in caplog.records if "Platform environment" in r.message] + platform_logs = [ + r for r in caplog.records if "Platform environment" in r.message + ] assert len(platform_logs) == 1 msg = platform_logs[0].message assert "is_hosted=False" in msg @@ -86,7 +92,9 @@ async def test_startup_logs_platform_environment(self, caplog: pytest.LogCapture @pytest.mark.usefixtures("_clean_env") @pytest.mark.asyncio - async def test_startup_logs_connectivity(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_startup_logs_connectivity( + self, caplog: pytest.LogCaptureFixture + ) -> None: """Lifespan startup emits connectivity log line with masked URIs.""" from azure.ai.agentserver.core import AgentServerHost @@ -104,7 +112,9 @@ async def test_startup_logs_connectivity(self, caplog: pytest.LogCaptureFixture) @pytest.mark.usefixtures("_clean_env") @pytest.mark.asyncio - async def test_startup_logs_host_options(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_startup_logs_host_options( + self, caplog: pytest.LogCaptureFixture + ) -> None: """Lifespan startup emits host options log line.""" from azure.ai.agentserver.core import AgentServerHost @@ -125,7 +135,9 @@ async def test_startup_masks_project_endpoint( self, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture ) -> None: """Project endpoint URI is masked to scheme://host only.""" - monkeypatch.setenv("FOUNDRY_PROJECT_ENDPOINT", "https://myproject.azure.com/sub/123?key=secret") + monkeypatch.setenv( + "FOUNDRY_PROJECT_ENDPOINT", "https://myproject.azure.com/sub/123?key=secret" + ) monkeypatch.delenv("FOUNDRY_HOSTING_ENVIRONMENT", raising=False) monkeypatch.delenv("PORT", raising=False) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_tracing.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_tracing.py index 5eefa9ac2a27..c5560d11e2d8 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_tracing.py @@ -7,7 +7,11 @@ from opentelemetry import baggage as _otel_baggage, context as _otel_context from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult +from opentelemetry.sdk.trace.export import ( + SimpleSpanProcessor, + SpanExporter, + SpanExportResult, +) from opentelemetry.sdk.resources import Resource from azure.ai.agentserver.core import AgentServerHost @@ -34,6 +38,8 @@ def shutdown(self): def force_flush(self, timeout_millis=30000): return True + + # ------------------------------------------------------------------ # # Tracing enabled / disabled # ------------------------------------------------------------------ # @@ -53,14 +59,24 @@ def test_observability_always_called(self) -> None: mock_configure.assert_called_once() def test_observability_receives_appinsights_env_var(self) -> None: - with mock.patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000"}): + with mock.patch.dict( + os.environ, + { + "APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000" + }, + ): mock_configure = mock.MagicMock() AgentServerHost(configure_observability=mock_configure) mock_configure.assert_called_once() - assert mock_configure.call_args[1]["connection_string"] == "InstrumentationKey=00000000-0000-0000-0000-000000000000" + assert ( + mock_configure.call_args[1]["connection_string"] + == "InstrumentationKey=00000000-0000-0000-0000-000000000000" + ) def test_observability_receives_otlp_env_var(self) -> None: - with mock.patch.dict(os.environ, {"OTEL_EXPORTER_OTLP_ENDPOINT": "http://localhost:4318"}): + with mock.patch.dict( + os.environ, {"OTEL_EXPORTER_OTLP_ENDPOINT": "http://localhost:4318"} + ): mock_configure = mock.MagicMock() AgentServerHost(configure_observability=mock_configure) mock_configure.assert_called_once() @@ -79,7 +95,12 @@ def test_observability_receives_constructor_connection_string(self) -> None: def test_observability_disabled_when_none(self) -> None: """Passing configure_observability=None disables all SDK-managed observability.""" - with mock.patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000"}): + with mock.patch.dict( + os.environ, + { + "APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=00000000-0000-0000-0000-000000000000" + }, + ): # Should not raise even with App Insights configured AgentServerHost(configure_observability=None) @@ -93,14 +114,19 @@ class TestAppInsightsConnectionString: """Tests for resolve_appinsights_connection_string().""" def test_explicit_wins(self) -> None: - assert resolve_appinsights_connection_string("InstrumentationKey=abc") == "InstrumentationKey=abc" + assert ( + resolve_appinsights_connection_string("InstrumentationKey=abc") + == "InstrumentationKey=abc" + ) def test_env_var(self) -> None: with mock.patch.dict( os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=env"}, ): - assert resolve_appinsights_connection_string(None) == "InstrumentationKey=env" + assert ( + resolve_appinsights_connection_string(None) == "InstrumentationKey=env" + ) def test_none_when_unset(self) -> None: env = os.environ.copy() @@ -113,7 +139,9 @@ def test_explicit_overrides_env_var(self) -> None: os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=env"}, ): - result = resolve_appinsights_connection_string("InstrumentationKey=explicit") + result = resolve_appinsights_connection_string( + "InstrumentationKey=explicit" + ) assert result == "InstrumentationKey=explicit" @@ -126,18 +154,29 @@ class TestSetupDistroExport: """Verify _configure_tracing calls the distro with the right args.""" def test_distro_called_when_conn_str_provided(self) -> None: - with mock.patch("azure.ai.agentserver.core._tracing._setup_distro_export") as mock_distro: + with mock.patch( + "azure.ai.agentserver.core._tracing._setup_distro_export" + ) as mock_distro: from azure.ai.agentserver.core import _tracing - _tracing._configure_tracing(connection_string="InstrumentationKey=00000000-0000-0000-0000-000000000000") + + _tracing._configure_tracing( + connection_string="InstrumentationKey=00000000-0000-0000-0000-000000000000" + ) mock_distro.assert_called_once() kwargs = mock_distro.call_args[1] - assert kwargs["connection_string"] == "InstrumentationKey=00000000-0000-0000-0000-000000000000" + assert ( + kwargs["connection_string"] + == "InstrumentationKey=00000000-0000-0000-0000-000000000000" + ) assert len(kwargs["span_processors"]) >= 1 assert len(kwargs["log_record_processors"]) >= 1 def test_distro_called_without_conn_str(self) -> None: - with mock.patch("azure.ai.agentserver.core._tracing._setup_distro_export") as mock_distro: + with mock.patch( + "azure.ai.agentserver.core._tracing._setup_distro_export" + ) as mock_distro: from azure.ai.agentserver.core import _tracing + _tracing._configure_tracing(connection_string=None) mock_distro.assert_called_once() kwargs = mock_distro.call_args[1] @@ -189,8 +228,10 @@ def _create_provider(processor): def test_agent_attrs_present_on_exported_span(self) -> None: proc = _FoundryEnrichmentSpanProcessor( - agent_name="my-agent", agent_version="1.0", - agent_id="my-agent:1.0", project_id="proj-123", + agent_name="my-agent", + agent_version="1.0", + agent_id="my-agent:1.0", + project_id="proj-123", ) provider, collector = self._create_provider(proc) tracer = provider.get_tracer("test") @@ -207,8 +248,10 @@ def test_agent_attrs_present_on_exported_span(self) -> None: def test_agent_attrs_survive_framework_overwrite(self) -> None: """A framework setting agent attrs mid-span must not win.""" proc = _FoundryEnrichmentSpanProcessor( - agent_name="my-agent", agent_version="1.0", - agent_id="my-agent:1.0", project_id="proj-123", + agent_name="my-agent", + agent_version="1.0", + agent_id="my-agent:1.0", + project_id="proj-123", ) provider, collector = self._create_provider(proc) tracer = provider.get_tracer("test") @@ -223,8 +266,10 @@ def test_agent_attrs_survive_framework_overwrite(self) -> None: def test_none_fields_are_skipped(self) -> None: proc = _FoundryEnrichmentSpanProcessor( - agent_name=None, agent_version=None, - agent_id=None, project_id=None, + agent_name=None, + agent_version=None, + agent_id=None, + project_id=None, ) provider, collector = self._create_provider(proc) tracer = provider.get_tracer("test") @@ -241,7 +286,9 @@ def test_none_fields_are_skipped(self) -> None: def test_no_crash_when_span_lacks_attributes(self) -> None: """If the SDK changes internals, _on_ending must not raise.""" proc = _FoundryEnrichmentSpanProcessor( - agent_name="a", agent_version="1", agent_id="a:1", + agent_name="a", + agent_version="1", + agent_id="a:1", ) fake_span = object() # no _attributes at all proc._on_ending(fake_span) # should not raise @@ -255,7 +302,8 @@ def test_session_id_from_baggage(self) -> None: tracer = provider.get_tracer("test") ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.session_id", "session-456", + "azure.ai.agentserver.session_id", + "session-456", ) with tracer.start_as_current_span("span", context=ctx): pass @@ -271,7 +319,8 @@ def test_conversation_id_from_baggage(self) -> None: tracer = provider.get_tracer("test") ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.conversation_id", "conv-123", + "azure.ai.agentserver.conversation_id", + "conv-123", ) with tracer.start_as_current_span("span", context=ctx): pass @@ -287,10 +336,13 @@ def test_both_session_and_conversation_set_independently(self) -> None: tracer = provider.get_tracer("test") ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.session_id", "session-456", + "azure.ai.agentserver.session_id", + "session-456", ) ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.conversation_id", "conv-123", context=ctx, + "azure.ai.agentserver.conversation_id", + "conv-123", + context=ctx, ) with tracer.start_as_current_span("span", context=ctx): pass @@ -319,10 +371,13 @@ def test_baggage_ids_propagate_to_child_spans(self) -> None: tracer = provider.get_tracer("test") ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.session_id", "session-456", + "azure.ai.agentserver.session_id", + "session-456", ) ctx = _otel_baggage.set_baggage( - "azure.ai.agentserver.conversation_id", "conv-789", context=ctx, + "azure.ai.agentserver.conversation_id", + "conv-789", + context=ctx, ) token = _otel_context.attach(ctx) try: @@ -414,6 +469,3 @@ def test_agent_version_default_empty(self) -> None: env.pop("FOUNDRY_AGENT_VERSION", None) with mock.patch.dict(os.environ, env, clear=True): assert resolve_agent_version() == "" - - - diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md index f173752e130d..02481d165530 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md @@ -4,6 +4,7 @@ ### Features Added +- **Durable invocation samples** — Added `durable_langgraph` and `durable_multiturn` sample applications demonstrating crash-resilient long-running agents using `@durable_task` with the invocations protocol. - Error source classification headers: All HTTP error responses now include `x-platform-error-source` with a value of `user`, `platform`, or `upstream` to indicate which component caused the error. Developer handler exceptions and missing handler registrations are classified as `upstream`. Exceptions tagged with the platform error tag are classified as `platform` and additionally include `x-platform-error-detail` with truncated exception details (max 2048 characters) for diagnostics. - WebSocket protocol support — `InvocationAgentServerHost` now hosts `/invocations_ws` alongside `POST /invocations`. Register the handler with the new `@app.ws_handler` decorator. The route is registered lazily on first decoration, so hosts without a registered handler return HTTP 404. - WebSocket Ping/Pong keep-alive — disabled by default; enable by setting the `WS_KEEPALIVE_INTERVAL` env var (auto-injected by AgentService into hosted-agent containers; surfaced on `app.config.ws_ping_interval` in `azure-ai-agentserver-core>=2.0.0b4`). `0` (or unset) disables keep-alive. Wired through to Hypercorn's `websocket_ping_interval` by `AgentServerHost._build_hypercorn_config`. @@ -17,6 +18,7 @@ ### Other Changes +- Bumped minimum `azure-ai-agentserver-core` dependency to `>=2.0.0b4`. - Platform header name constants (e.g. `x-platform-error-source`, `x-platform-error-detail`) are now imported from `azure-ai-agentserver-core` (`_platform_headers` module) instead of being defined locally. Error source classification helpers remain internal to this package. - Simplified request handling: baggage entries (`invocation_id`, `session_id`) are still set on each request, but span creation and lifecycle management are left to downstream frameworks. diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-invocations/pyproject.toml index b70d8ea30022..2aa7e016dc39 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-invocations/pyproject.toml @@ -68,6 +68,8 @@ mypy = true pyright = true verifytypes = false latestdependency = false +# azure-ai-agentserver-core>=2.0.0b4 is not yet on PyPI +mindependency = false pylint = true type_check_samples = false diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/async_invoke_agent/async_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/async_invoke_agent/async_invoke_agent.py index cde877039960..227bda4ca2f5 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/samples/async_invoke_agent/async_invoke_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/async_invoke_agent/async_invoke_agent.py @@ -38,6 +38,7 @@ curl -X POST http://localhost:8088/invocations/abc-123/cancel # -> {"invocation_id": "abc-123", "status": "cancelled"} """ + import asyncio import json @@ -46,7 +47,6 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - # In-memory state for demo purposes (see module docstring for production caveats) _tasks: dict[str, asyncio.Task] = {} _results: dict[str, bytes] = {} @@ -65,11 +65,13 @@ async def _do_work(invocation_id: str, data: dict) -> bytes: :rtype: bytes """ await asyncio.sleep(5) - result = json.dumps({ - "invocation_id": invocation_id, - "status": "completed", - "output": f"Processed: {data}", - }).encode() + result = json.dumps( + { + "invocation_id": invocation_id, + "status": "completed", + "output": f"Processed: {data}", + } + ).encode() _results[invocation_id] = result return result @@ -89,10 +91,12 @@ async def handle_invoke(request: Request) -> Response: task = asyncio.create_task(_do_work(invocation_id, data)) _tasks[invocation_id] = task - return JSONResponse({ - "invocation_id": invocation_id, - "status": "running", - }) + return JSONResponse( + { + "invocation_id": invocation_id, + "status": "running", + } + ) @app.get_invocation_handler @@ -112,10 +116,12 @@ async def handle_get_invocation(request: Request) -> Response: if invocation_id in _tasks: task = _tasks[invocation_id] if not task.done(): - return JSONResponse({ - "invocation_id": invocation_id, - "status": "running", - }) + return JSONResponse( + { + "invocation_id": invocation_id, + "status": "running", + } + ) result = task.result() _results[invocation_id] = result del _tasks[invocation_id] @@ -137,11 +143,13 @@ async def handle_cancel_invocation(request: Request) -> Response: # Already completed — cannot cancel if invocation_id in _results: - return JSONResponse({ - "invocation_id": invocation_id, - "status": "completed", - "error": "invocation already completed", - }) + return JSONResponse( + { + "invocation_id": invocation_id, + "status": "completed", + "error": "invocation already completed", + } + ) if invocation_id in _tasks: task = _tasks[invocation_id] @@ -149,17 +157,21 @@ async def handle_cancel_invocation(request: Request) -> Response: # Task finished between check — treat as completed _results[invocation_id] = task.result() del _tasks[invocation_id] - return JSONResponse({ - "invocation_id": invocation_id, - "status": "completed", - "error": "invocation already completed", - }) + return JSONResponse( + { + "invocation_id": invocation_id, + "status": "completed", + "error": "invocation already completed", + } + ) task.cancel() del _tasks[invocation_id] - return JSONResponse({ - "invocation_id": invocation_id, - "status": "cancelled", - }) + return JSONResponse( + { + "invocation_id": invocation_id, + "status": "cancelled", + } + ) return JSONResponse({"error": "not found"}, status_code=404) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/.state-change b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/.state-change new file mode 100644 index 000000000000..a69a36368c0b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/.state-change @@ -0,0 +1 @@ +2026-05-20T06:08:10Z \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/config.json b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/config.json new file mode 100644 index 000000000000..017e5d26dc28 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/config.json @@ -0,0 +1 @@ +{"version":1,"defaultEnvironment":"demo-dev"} diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/demo-dev/.env b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/demo-dev/.env new file mode 100644 index 000000000000..aab98f6371e3 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/demo-dev/.env @@ -0,0 +1,37 @@ +AGENT_DURABLE_RESEARCH_AGENT_ENDPOINT="https://e2e-tests-westus2-account.services.ai.azure.com/api/projects/e2e-tests-westus2/agents/durable-research-agent/versions/28" +AGENT_DURABLE_RESEARCH_AGENT_INVOCATIONS_ENDPOINT="https://e2e-tests-westus2-account.services.ai.azure.com/api/projects/e2e-tests-westus2/agents/durable-research-agent/endpoint/protocols/invocations?api-version=2025-11-15-preview" +AGENT_DURABLE_RESEARCH_AGENT_NAME="durable-research-agent" +AGENT_DURABLE_RESEARCH_AGENT_VERSION=28 +AI_PROJECT_CONNECTION_IDS_JSON="[]" +AI_PROJECT_DEPLOYMENTS="[{\\\"name\\\":\\\"gpt-4.1-mini\\\",\\\"model\\\":{\\\"name\\\":\\\"gpt-4.1-mini\\\",\\\"format\\\":\\\"OpenAI\\\",\\\"version\\\":\\\"2025-04-14\\\"},\\\"sku\\\":{\\\"name\\\":\\\"GlobalStandard\\\",\\\"capacity\\\":1053}}]" +APPLICATIONINSIGHTS_CONNECTION_NAME="appInsights-connection-7543" +APPLICATIONINSIGHTS_CONNECTION_STRING="InstrumentationKey=f25baa58-da74-4602-a955-ce257ff3a5d8;IngestionEndpoint=https://uksouth-1.in.applicationinsights.azure.com/;LiveEndpoint=https://uksouth.livediagnostics.monitor.azure.com/;ApplicationId=9b8190bd-1b0b-4264-89c3-e31ee47b0745" +APPLICATIONINSIGHTS_RESOURCE_ID="" +AZURE_AI_ACCOUNT_ID="/subscriptions/921496dc-987f-410f-bd57-426eb2611356/resourceGroups/agents-e2e-tests-westus2/providers/Microsoft.CognitiveServices/accounts/e2e-tests-westus2-account" +AZURE_AI_ACCOUNT_NAME="e2e-tests-westus2-account" +AZURE_AI_FOUNDRY_PROJECT_ID="/subscriptions/921496dc-987f-410f-bd57-426eb2611356/resourceGroups/agents-e2e-tests-westus2/providers/Microsoft.CognitiveServices/accounts/e2e-tests-westus2-account/projects/e2e-tests-westus2" +AZURE_AI_MODEL_DEPLOYMENT_NAME="gpt-4.1-mini" +AZURE_AI_PROJECT_ACR_CONNECTION_NAME="crdyt765he4tmsy" +AZURE_AI_PROJECT_ENDPOINT="https://e2e-tests-westus2-account.services.ai.azure.com/api/projects/e2e-tests-westus2" +AZURE_AI_PROJECT_ID="/subscriptions/921496dc-987f-410f-bd57-426eb2611356/resourceGroups/agents-e2e-tests-westus2/providers/Microsoft.CognitiveServices/accounts/e2e-tests-westus2-account/projects/e2e-tests-westus2" +AZURE_AI_PROJECT_NAME="e2e-tests-westus2" +AZURE_AI_SEARCH_CONNECTION_NAME="" +AZURE_AI_SEARCH_SERVICE_NAME="" +AZURE_CONTAINER_REGISTRY_ENDPOINT="crdyt765he4tmsy.azurecr.io" +AZURE_ENV_NAME="demo-dev" +AZURE_LOCATION="westus2" +AZURE_OPENAI_ENDPOINT="https://e2e-tests-westus2-account.openai.azure.com/" +AZURE_RESOURCE_GROUP="agents-e2e-tests-westus2" +AZURE_STORAGE_ACCOUNT_NAME="" +AZURE_STORAGE_CONNECTION_NAME="" +AZURE_SUBSCRIPTION_ID="921496dc-987f-410f-bd57-426eb2611356" +AZURE_TENANT_ID="72f988bf-86f1-41af-91ab-2d7cd011db47" +BING_CUSTOM_GROUNDING_CONNECTION_ID="" +BING_CUSTOM_GROUNDING_CONNECTION_NAME="" +BING_CUSTOM_GROUNDING_NAME="" +BING_GROUNDING_CONNECTION_ID="" +BING_GROUNDING_CONNECTION_NAME="" +BING_GROUNDING_RESOURCE_NAME="" +ENABLE_CAPABILITY_HOST="false" +ENABLE_HOSTED_AGENTS="true" +USE_EXISTING_AI_PROJECT="true" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/demo-dev/.env.lock b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/demo-dev/.env.lock new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/demo-dev/config.json b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/demo-dev/config.json new file mode 100644 index 000000000000..9e26dfeeb6e6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.azure/demo-dev/config.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.gitignore b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.gitignore new file mode 100644 index 000000000000..1d7cd74ff8a5 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/.gitignore @@ -0,0 +1,6 @@ +# azd environment +.azure/*/state/ +.azure/*/*.env.bak + +# Demo client runtime +.demo-session diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/README.md b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/README.md new file mode 100644 index 000000000000..8234b9cc139c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/README.md @@ -0,0 +1,294 @@ +# Durable Research Agent — Crash-Resilient Demo + +This sample demonstrates a **long-running research agent** that survives process +crashes and automatically resumes from its last checkpoint. It uses the +`@durable_task` decorator from `azure-ai-agentserver-core` to provide built-in +crash resilience without any manual state management. + +## What it showcases + +1. **12-stage deep research pipeline** — each stage is a distinct LLM call with real-time token streaming +2. **Crash resilience** — send `{"message": "crash"}` to kill the process; the supervisor + restarts it, and the task resumes from its last checkpoint +3. **Fire-and-forget POST** — `POST /invocations` dispatches the task and returns 202 immediately +4. **GET streaming with resume** — `GET /invocations/{id}?last_event_id=N` streams SSE, skipping already-seen events +5. **Cancel support** — `POST /invocations/{id}/cancel` stops the task gracefully +6. **File-backed streaming** — stream items persist to disk for replay after crashes + +## Architecture + +``` +┌────────────────────────────────────────────────────────────┐ +│ Hosted Agent Sandbox (port 8088) │ +│ │ +│ supervisor.py (PID 1 — always responds to /readiness) │ +│ └── python app.py (port 8089, restarted on crash) │ +│ │ +│ POST /invocations (fire-and-forget) │ +│ ├── {"message": "crash"} → 202, then exit 💥 │ +│ └── {"message": ""} → │ +│ deep_research.start() → 202 JSON response │ +│ { invocation_id, session_id, task_id, status } │ +│ │ +│ GET /invocations/{id}?last_event_id=N │ +│ └── Streams SSE from active task (skips first N events) │ +│ or replays from persisted file │ +│ │ +│ POST /invocations/{id}/cancel │ +│ └── Signals cancellation to running task │ +│ │ +│ Local disk: ~/.durable-tasks/ (persists across restarts) │ +└────────────────────────────────────────────────────────────┘ +``` + +## Prerequisites + +- Python 3.11+ +- Azure subscription with AI Foundry access +- [Azure Developer CLI (azd)](https://learn.microsoft.com/azure/developer/azure-developer-cli/install-azd) +- `azd` AI agents extension: `azd extension install azure.ai.agents` + +## Quick Start (Deploy to Foundry) + +```bash +# 1. Build wheels (included in Docker image) +./build.sh + +# 2. Login and deploy +azd auth login +azd up +``` + +## Demo Script — Crash Recovery & Reconnection + +This walkthrough demonstrates the full durability story. Total time: ~3 minutes. + +### Quick Demo (recommended) + +Use the included `demo-client.sh` which handles token refresh, session sharing, +auto-reconnection, and event resumption: + +```bash +# Terminal 1 — start research (auto-reconnects after crashes) +./demo-client.sh start "quantum computing" + +# Terminal 2 — crash the agent while it's running +./demo-client.sh crash + +# Watch Terminal 1 auto-reconnect and resume from where it left off! +# Crash again, as many times as you want: +./demo-client.sh crash + +# Terminal 3 — stream container logs (optional) +./demo-client.sh logs + +# Or cancel: +./demo-client.sh cancel + +# Reset session to start fresh: +./demo-client.sh reset +``` + +### How it works (client flow) + +1. **POST** `/invocations?agent_session_id=X` → returns 202 with `invocation_id` +2. **GET** `/invocations/{inv_id}` → streams SSE events (`id: N\ndata: {...}\n\n`) +3. Client tracks `last_event_id` (the `id:` field of the last received event) +4. On disconnect (crash): **POST** same session → new `invocation_id` → **GET** with `?last_event_id=N` +5. Server skips first N events → client sees only new content from the recovery point + +### Manual Demo (curl) + +```bash +# Get access token +TOKEN=$(az account get-access-token --resource https://ai.azure.com --query accessToken -o tsv) + +# Endpoint +ENDPOINT="https://e2e-tests-westus2-account.services.ai.azure.com/api/projects/e2e-tests-westus2/agents/durable-research-agent/endpoint/protocols" + +# Generate a unique session ID (reuse across all calls in this demo) +SESSION_ID="demo-$(uuidgen | tr '[:upper:]' '[:lower:]')" +echo "Session: $SESSION_ID" +``` + +### Step 1: Start the research task (fire-and-forget) + +```bash +# POST dispatches the task and returns immediately with IDs +curl -s -X POST "${ENDPOINT}/invocations?api-version=2025-11-15-preview&agent_session_id=${SESSION_ID}" \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"message": "Research the history and future of quantum computing"}' +``` + +Response (202): +```json +{"status": "started", "invocation_id": "inv_abc123...", "session_id": "demo-..."} +``` + +Save the invocation ID: +```bash +INV_ID="inv_abc123..." # from response above +``` + +### Step 2: Stream results via GET + +```bash +curl -N -X GET "${ENDPOINT}/invocations/${INV_ID}?api-version=2025-11-15-preview" \ + -H "Authorization: Bearer $TOKEN" +``` + +You'll see SSE events with sequential IDs: +``` +id: 1 +data: {"type": "token", "content": "\n\n**[Stage 1/12]** Decomposing topic...\n"} + +id: 2 +data: {"type": "token", "content": "Quantum"} + +id: 3 +data: {"type": "token", "content": " computing"} +... +``` + +### Step 3: Crash the agent! 💥 + +While the research is running, send a crash trigger (same session): + +```bash +curl -s -X POST "${ENDPOINT}/invocations?api-version=2025-11-15-preview&agent_session_id=${SESSION_ID}" \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"message": "crash"}' +``` + +Response (202): +```json +{"status": "crashing", "message": "💥 Process will crash now"} +``` + +The process exits. The supervisor immediately restarts it and recovers the task. + +### Step 4: Reconnect with resume + +Wait ~10 seconds, then POST again to get a new invocation ID, and GET with `last_event_id`: + +```bash +# Get new invocation ID (task is already in progress) +NEW_RESPONSE=$(curl -s -X POST "${ENDPOINT}/invocations?api-version=2025-11-15-preview&agent_session_id=${SESSION_ID}" \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"message": "quantum computing"}') +NEW_INV_ID=$(echo "$NEW_RESPONSE" | python3 -c "import sys,json; print(json.loads(sys.stdin.read())['invocation_id'])") + +# Resume from where we left off (e.g., last_event_id=370) +curl -N -X GET "${ENDPOINT}/invocations/${NEW_INV_ID}?api-version=2025-11-15-preview&last_event_id=370" \ + -H "Authorization: Bearer $TOKEN" +``` + +You'll see only NEW events (stages after the crash): +``` +id: 371 +data: {"type": "token", "content": "\n\n⚡ **Recovered from crash!** Resuming from stage 5/12...\n\n"} + +id: 372 +data: {"type": "token", "content": "\n\n**[Stage 5/12]** Examining competing theories...\n"} +... +``` + +### Step 5: Cancel the task (optional) + +```bash +curl -X POST "${ENDPOINT}/invocations/${NEW_INV_ID}/cancel?api-version=2025-11-15-preview" \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{}' +``` + +Response: +```json +{"status": "cancelled", "message": "Task cancellation requested."} +``` + +## Container Logs + +Stream real-time container logs (stdout/stderr) in a separate terminal: + +```bash +# Via demo-client.sh (uses session from .demo-session file) +./demo-client.sh logs + +# Or directly via azd: +azd ai agent monitor --session-id --follow + +# Recent logs (last 20 lines): +azd ai agent monitor --tail 20 + +# System events (container start/stop): +azd ai agent monitor --type system +``` + +## How it works + +The `@durable_task` decorator provides: + +- **Automatic persistence** — task state is checkpointed after each stage via + `ctx.metadata.flush()` +- **Crash recovery** — on startup, stale (in-flight) tasks are automatically + detected by lease owner and re-executed, with `ctx.metadata` containing all + previously saved progress +- **Entry mode awareness** — `ctx.entry_mode` tells the function why it was + called: `"fresh"`, `"resumed"`, or `"recovered"` +- **File-backed streaming** — stream items are persisted to disk via a custom + `FileStreamHandler` so GET can replay them after a crash +- **Event IDs** — each SSE event has a sequential `id:` field; clients use + `last_event_id` query param to skip already-seen events on reconnect + +Key code pattern: +```python +@durable_task(name="deep_research", stream_handler_factory=file_stream_factory) +async def deep_research(ctx: TaskContext[dict]) -> dict: + completed = ctx.metadata.get("completed_stages", 0) + + if ctx.entry_mode == "recovered": + await ctx.stream(json.dumps({"type": "token", "content": "⚡ Recovered!"})) + + for i in range(completed, len(STAGES)): + # Stream LLM tokens in real-time + async for event in llm_stream: + await ctx.stream(json.dumps({"type": "token", "content": event.delta})) + + # CHECKPOINT — survives crashes + ctx.metadata["completed_stages"] = i + 1 + await ctx.metadata.flush() + + return final_result +``` + +## Environment Variables + +| Variable | Description | Default | +|---|---|---| +| `FOUNDRY_PROJECT_ENDPOINT` | AI Foundry project endpoint (set by platform) | Required | +| `AZURE_AI_MODEL_DEPLOYMENT_NAME` | Model deployment to use | `gpt-4.1-mini` | +| `FOUNDRY_TASK_API_ENABLED` | Use platform Task Storage (vs local file) | `0` (local) | +| `STAGE_DURATION` | Seconds between stages (for demo pacing) | `5` | + +## File Structure + +``` +durable-agent-demo/ +├── demo-client.sh # ⭐ Demo client (handles sessions, reconnect, crash) +├── azure.yaml # azd service config +├── build.sh # Build local wheels for Docker +├── infra/ # Bicep templates +├── src/durable-research-agent/ +│ ├── agent.py # ⭐ The durable task (12-stage research pipeline) +│ ├── app.py # HTTP handlers (POST fire-and-forget, GET stream, cancel) +│ ├── supervisor.py # PID 1 reverse proxy (keeps /readiness alive) +│ ├── agent.yaml # Agent definition for Foundry +│ ├── Dockerfile +│ ├── requirements.txt +│ └── wheels/ # Local package wheels (built by build.sh) +└── README.md +``` diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/azure.yaml b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/azure.yaml new file mode 100644 index 000000000000..f922f6bd48aa --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/azure.yaml @@ -0,0 +1,31 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/Azure/azure-dev/main/schemas/v1.0/azure.yaml.json + +requiredVersions: + extensions: + azure.ai.agents: '>=0.1.0-preview' +name: durable-research-agent-demo +services: + durable-research-agent: + project: src/durable-research-agent + host: azure.ai.agent + language: docker + docker: + remoteBuild: true + config: + container: + resources: + cpu: "1" + memory: 2Gi + deployments: + - model: + format: OpenAI + name: gpt-4.1-mini + version: "2025-04-14" + name: gpt-4.1-mini + sku: + capacity: 1053 + name: GlobalStandard + startupCommand: ./entrypoint.sh +infra: + provider: bicep + path: ./infra diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/build.sh b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/build.sh new file mode 100755 index 000000000000..65849baf1de2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/build.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +# Build local wheel packages for docker image. +# Run this BEFORE 'azd up' or 'docker build'. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../../../../.." && pwd)" +WHEELS_DIR="$SCRIPT_DIR/src/durable-research-agent/wheels" + +echo "==> Building wheels from local agentserver packages..." +rm -rf "$WHEELS_DIR" +mkdir -p "$WHEELS_DIR" + +# Build core +echo " Building azure-ai-agentserver-core..." +pip wheel --no-deps --wheel-dir "$WHEELS_DIR" \ + "$REPO_ROOT/sdk/agentserver/azure-ai-agentserver-core" + +# Build invocations +echo " Building azure-ai-agentserver-invocations..." +pip wheel --no-deps --wheel-dir "$WHEELS_DIR" \ + "$REPO_ROOT/sdk/agentserver/azure-ai-agentserver-invocations" + +echo "==> Wheels built in $WHEELS_DIR:" +ls -la "$WHEELS_DIR"/*.whl + +echo "" +echo "Done! Now run: azd up (or docker build)" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/demo-client.sh b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/demo-client.sh new file mode 100755 index 000000000000..1abce030b998 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/demo-client.sh @@ -0,0 +1,411 @@ +#!/usr/bin/env bash +# ───────────────────────────────────────────────────────────────────────────── +# Durable Research Agent — Demo Client +# +# Usage: +# Terminal 1 (start research): ./demo-client.sh start "quantum computing" +# Terminal 2 (crash the agent): ./demo-client.sh crash +# Terminal 2 (cancel the task): ./demo-client.sh cancel +# +# The session ID is shared via a file (.demo-session) so both terminals +# operate on the same agent session. +# ───────────────────────────────────────────────────────────────────────────── + +set -euo pipefail + +# ── Config ──────────────────────────────────────────────────────────────────── + +ENDPOINT="https://e2e-tests-westus2-account.services.ai.azure.com/api/projects/e2e-tests-westus2/agents/durable-research-agent/endpoint/protocols" +API_VERSION="2025-11-15-preview" +SESSION_FILE=".demo-session" + +# ── Colors ──────────────────────────────────────────────────────────────────── + +BOLD='\033[1m' +DIM='\033[2m' +GREEN='\033[32m' +YELLOW='\033[33m' +RED='\033[31m' +CYAN='\033[36m' +RESET='\033[0m' + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +get_token() { + az account get-access-token --resource https://ai.azure.com --query accessToken -o tsv 2>/dev/null +} + +load_session() { + if [[ -f "$SESSION_FILE" ]]; then + source "$SESSION_FILE" + fi +} + +save_session() { + echo "SESSION_ID=\"${SESSION_ID}\"" > "$SESSION_FILE" + echo "INV_ID=\"${INV_ID}\"" >> "$SESSION_FILE" + echo "LAST_EVENT_ID=\"${LAST_EVENT_ID:-0}\"" >> "$SESSION_FILE" +} + +ensure_token() { + if [[ -z "${TOKEN:-}" ]]; then + echo -e "${DIM}Fetching access token...${RESET}" + TOKEN=$(get_token) + if [[ -z "$TOKEN" ]]; then + echo -e "${RED}ERROR: Failed to get token. Run 'az login' first.${RESET}" + exit 1 + fi + fi +} + +# ── SSE Stream Reader ───────────────────────────────────────────────────────── + +# Stream result: set by stream_sse to indicate how the stream ended +STREAM_RESULT="" # "complete", "crashed", "disconnected", "error" + +stream_sse() { + local url="$1" + local method="${2:-GET}" + local body="${3:-}" + local headers_file + headers_file=$(mktemp) + STREAM_RESULT="disconnected" # default: assume disconnect + + local curl_args=( + -sN + -X "$method" + -D "$headers_file" + -H "Authorization: Bearer $TOKEN" + -H "Content-Type: application/json" + -H "Accept: text/event-stream" + ) + if [[ -n "$body" ]]; then + curl_args+=(-d "$body") + fi + + # Stream and parse SSE events, writing result to a temp file + local result_file + result_file=$(mktemp) + echo "disconnected" > "$result_file" + + # Track event IDs in a temp file (subshell can't set parent vars) + local event_id_file + event_id_file=$(mktemp) + echo "${LAST_EVENT_ID:-0}" > "$event_id_file" + + # Use || true on the pipeline to prevent set -e/pipefail from killing the script + # when curl exits non-zero (e.g., connection reset by server crash) + # We also track a "current_id" to implement client-side skip of already-seen events + local skip_until + skip_until="${LAST_EVENT_ID:-0}" + + ( curl "${curl_args[@]}" "$url" || true ) | while IFS= read -r line; do + # Skip empty lines and comments + [[ -z "$line" || "$line" == $'\r' ]] && continue + [[ "$line" == :* ]] && continue + + # Parse "id: N" lines (SSE event ID for resumption) + if [[ "$line" == id:* ]]; then + local eid="${line#id: }" + eid="${eid%$'\r'}" + echo "$eid" > "$event_id_file" + continue + fi + + # Parse "data: {...}" lines + if [[ "$line" == data:* ]]; then + # Client-side skip: if current event ID ≤ last seen, suppress display + local current_eid + current_eid=$(cat "$event_id_file") + if [[ "$current_eid" -le "$skip_until" && "$skip_until" -gt 0 ]]; then + continue + fi + + local json="${line#data: }" + json="${json%$'\r'}" + + local type + type=$(echo "$json" | python3 -c "import sys,json; d=json.loads(sys.stdin.read()); print(d.get('type',''))" 2>/dev/null || echo "") + local display_content + display_content=$(echo "$json" | python3 -c "import sys,json; d=json.loads(sys.stdin.read()); print(d.get('content',''), end='')" 2>/dev/null || echo "") + + case "$type" in + token) + printf '%s' "$display_content" + ;; + done) + local full_text + full_text=$(echo "$json" | python3 -c "import sys,json; d=json.loads(sys.stdin.read()); print(d.get('full_text',''))" 2>/dev/null || echo "") + if [[ "$full_text" == *"crashing"* ]]; then + echo "crashed" > "$result_file" + else + echo "complete" > "$result_file" + fi + echo "" + break + ;; + error) + echo -e "\n${RED}ERROR: $display_content${RESET}" + echo "error" > "$result_file" + break + ;; + esac + else + # Non-SSE line — likely a JSON error response from the platform/server + echo -e "${DIM}[debug] ${line}${RESET}" >&2 + fi + done || true + + STREAM_RESULT=$(cat "$result_file") + LAST_EVENT_ID=$(cat "$event_id_file") + save_session + rm -f "$result_file" "$event_id_file" + + # Check HTTP status from response headers + if [[ -f "$headers_file" ]]; then + local http_status + http_status=$(head -1 "$headers_file" 2>/dev/null | tr -d '\r' || true) + if [[ -n "$http_status" && "$http_status" != *" 200 "* ]]; then + echo -e "${DIM}[debug] HTTP: ${http_status}${RESET}" >&2 + fi + rm -f "$headers_file" + fi +} + +# ── Commands ────────────────────────────────────────────────────────────────── + +dispatch_task() { + # POST to dispatch the task (fire-and-forget). Returns immediately. + # Captures invocation_id + session_id from the JSON response body. + local topic="$1" + local url="${ENDPOINT}/invocations?api-version=${API_VERSION}&agent_session_id=${SESSION_ID}" + local body="{\"message\": \"${topic}\"}" + + local response + response=$(curl -s -X POST "$url" \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d "$body") + + # Parse invocation_id and session_id from response JSON + local inv sess status + inv=$(echo "$response" | python3 -c "import sys,json; d=json.loads(sys.stdin.read()); print(d.get('invocation_id',''))" 2>/dev/null || echo "") + sess=$(echo "$response" | python3 -c "import sys,json; d=json.loads(sys.stdin.read()); print(d.get('session_id',''))" 2>/dev/null || echo "") + status=$(echo "$response" | python3 -c "import sys,json; d=json.loads(sys.stdin.read()); print(d.get('status',''))" 2>/dev/null || echo "") + + if [[ -n "$inv" ]]; then + INV_ID="$inv" + fi + if [[ -n "$sess" ]]; then + SESSION_ID="$sess" + fi + save_session + + echo -e "${DIM}Task ${status}: inv=${INV_ID:0:30}...${RESET}" +} + +stream_via_get() { + # GET /invocations/{inv_id} — streams SSE from the active task + if [[ -z "${INV_ID:-}" ]]; then + echo -e "${RED}No invocation ID. Cannot stream.${RESET}" + return 1 + fi + local url="${ENDPOINT}/invocations/${INV_ID}?api-version=${API_VERSION}" + # Append last_event_id query param for server-side skip on reconnect + if [[ -n "${LAST_EVENT_ID:-}" && "${LAST_EVENT_ID:-0}" != "0" ]]; then + url="${url}&last_event_id=${LAST_EVENT_ID}" + fi + stream_sse "$url" "GET" "" +} + +cmd_start() { + local topic="${1:-Research the history and future of quantum computing}" + + # Generate new session or reuse existing + if [[ -f "$SESSION_FILE" ]]; then + load_session + echo -e "${YELLOW}Reusing session: ${SESSION_ID}${RESET}" + else + SESSION_ID="demo-$(uuidgen | tr '[:upper:]' '[:lower:]')" + INV_ID="" + LAST_EVENT_ID="0" + save_session + echo -e "${GREEN}New session: ${SESSION_ID}${RESET}" + fi + + ensure_token + + echo -e "${BOLD}${CYAN}╔══════════════════════════════════════════════════════════╗${RESET}" + echo -e "${BOLD}${CYAN}║ Durable Research Agent — Starting ║${RESET}" + echo -e "${BOLD}${CYAN}╚══════════════════════════════════════════════════════════╝${RESET}" + echo -e "${DIM}Topic: ${topic}${RESET}" + echo -e "${DIM}Session: ${SESSION_ID}${RESET}" + echo "" + + # Step 1: POST dispatches task (fire-and-forget, returns immediately) + dispatch_task "$topic" + + # Step 2: GET streams SSE results + stream_via_get + + # Handle stream result + case "$STREAM_RESULT" in + complete) + echo -e "${GREEN}━━━ Research complete ━━━${RESET}" + ;; + crashed|disconnected) + echo -e "${YELLOW}━━━ Stream interrupted (${STREAM_RESULT}) ━━━${RESET}" + reconnect_loop "$topic" + ;; + error) + echo -e "${RED}━━━ Stream error ━━━${RESET}" + ;; + esac +} + +reconnect_loop() { + local topic="${1:-reconnect}" + + local attempt=0 + while true; do + attempt=$((attempt + 1)) + echo "" + echo -e "${YELLOW}⚡ Reconnecting (attempt ${attempt})...${RESET}" + echo -e "${DIM}Session: ${SESSION_ID}${RESET}" + sleep 5 + + ensure_token + + # POST again with same session — gets new invocation ID + # (platform preserves mapping because POST returned 202 immediately) + dispatch_task "$topic" + + # GET with new invocation ID to stream + stream_via_get + + case "$STREAM_RESULT" in + complete) + echo -e "${GREEN}━━━ Research complete ━━━${RESET}" + return 0 + ;; + crashed|disconnected) + echo -e "${YELLOW}━━━ Stream interrupted again (${STREAM_RESULT}). Retrying... ━━━${RESET}" + ;; + error) + echo -e "${RED}━━━ Error on reconnect. Retrying in 5s... ━━━${RESET}" + sleep 5 + ;; + esac + done +} + +cmd_crash() { + load_session + + if [[ -z "${SESSION_ID:-}" ]]; then + echo -e "${RED}No active session. Run './demo-client.sh start' first.${RESET}" + exit 1 + fi + + ensure_token + + echo -e "${RED}${BOLD}💥 Crashing the agent...${RESET}" + echo -e "${DIM}Session: ${SESSION_ID}${RESET}" + + # POST with "crash" message — server dispatches crash signal and returns 202 + local url="${ENDPOINT}/invocations?api-version=${API_VERSION}&agent_session_id=${SESSION_ID}" + local response + response=$(curl -s -X POST "$url" \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"message": "crash"}') + + echo -e "${DIM}Response: ${response}${RESET}" + echo -e "\n${RED}Agent process killed (os._exit). Supervisor will restart it.${RESET}" + echo -e "${DIM}Terminal 1 will auto-reconnect when the process restarts.${RESET}" +} + +cmd_cancel() { + load_session + + if [[ -z "${INV_ID:-}" ]]; then + echo -e "${RED}No invocation ID. Run './demo-client.sh start' first.${RESET}" + exit 1 + fi + + ensure_token + + echo -e "${YELLOW}🛑 Cancelling task...${RESET}" + echo -e "${DIM}Invocation: ${INV_ID}${RESET}" + + local url="${ENDPOINT}/invocations/${INV_ID}/cancel?api-version=${API_VERSION}" + local response + response=$(curl -s -X POST "$url" \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{}') + + echo -e "${GREEN}${response}${RESET}" +} + +cmd_reset() { + rm -f "$SESSION_FILE" + echo -e "${GREEN}Session cleared. Next 'start' will create a fresh session.${RESET}" +} + +cmd_status() { + load_session + if [[ -f "$SESSION_FILE" ]]; then + echo -e "${CYAN}Session ID:${RESET} ${SESSION_ID:-}" + echo -e "${CYAN}Invocation ID:${RESET} ${INV_ID:-}" + else + echo -e "${DIM}No active session.${RESET}" + fi +} + +cmd_logs() { + load_session + if [[ -z "${SESSION_ID:-}" ]]; then + echo -e "${RED}No active session. Run './demo-client.sh start' first.${RESET}" + exit 1 + fi + + echo -e "${BOLD}${CYAN}╔══════════════════════════════════════════════════════════╗${RESET}" + echo -e "${BOLD}${CYAN}║ Container Logs — Streaming ║${RESET}" + echo -e "${BOLD}${CYAN}╚══════════════════════════════════════════════════════════╝${RESET}" + echo -e "${DIM}Session: ${SESSION_ID}${RESET}" + echo "" + + # Stream real-time container logs via azd ai agent monitor + azd ai agent monitor --session-id "${SESSION_ID}" --follow +} + +# ── Main ────────────────────────────────────────────────────────────────────── + +usage() { + echo -e "${BOLD}Durable Research Agent — Demo Client${RESET}" + echo "" + echo "Usage:" + echo " ./demo-client.sh start [topic] Start research (auto-reconnects on disconnect)" + echo " ./demo-client.sh crash Crash the agent (run from second terminal)" + echo " ./demo-client.sh cancel Cancel the running task" + echo " ./demo-client.sh logs Stream raw SSE data (run in third terminal)" + echo " ./demo-client.sh status Show current session info" + echo " ./demo-client.sh reset Clear session (start fresh)" + echo "" + echo "Demo workflow:" + echo " Terminal 1: ./demo-client.sh start \"quantum computing\"" + echo " Terminal 2: ./demo-client.sh crash" + echo " (Terminal 1 auto-reconnects and shows recovery)" + echo " Terminal 3: ./demo-client.sh logs (optional: watch raw events)" +} + +case "${1:-}" in + start) cmd_start "${2:-}" ;; + crash) cmd_crash ;; + cancel) cmd_cancel ;; + logs) cmd_logs ;; + status) cmd_status ;; + reset) cmd_reset ;; + *) usage ;; +esac diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/abbreviations.json b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/abbreviations.json new file mode 100644 index 000000000000..879b2a9507b1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/abbreviations.json @@ -0,0 +1,137 @@ +{ + "aiFoundryAccounts": "aif", + "analysisServicesServers": "as", + "apiManagementService": "apim-", + "appConfigurationStores": "appcs-", + "appManagedEnvironments": "cae-", + "appContainerApps": "ca-", + "authorizationPolicyDefinitions": "policy-", + "automationAutomationAccounts": "aa-", + "blueprintBlueprints": "bp-", + "blueprintBlueprintsArtifacts": "bpa-", + "cacheRedis": "redis-", + "cdnProfiles": "cdnp-", + "cdnProfilesEndpoints": "cdne-", + "cognitiveServicesAccounts": "cog-", + "cognitiveServicesFormRecognizer": "cog-fr-", + "cognitiveServicesTextAnalytics": "cog-ta-", + "computeAvailabilitySets": "avail-", + "computeCloudServices": "cld-", + "computeDiskEncryptionSets": "des", + "computeDisks": "disk", + "computeDisksOs": "osdisk", + "computeGalleries": "gal", + "computeSnapshots": "snap-", + "computeVirtualMachines": "vm", + "computeVirtualMachineScaleSets": "vmss-", + "containerInstanceContainerGroups": "ci", + "containerRegistryRegistries": "cr", + "containerServiceManagedClusters": "aks-", + "databricksWorkspaces": "dbw-", + "dataFactoryFactories": "adf-", + "dataLakeAnalyticsAccounts": "dla", + "dataLakeStoreAccounts": "dls", + "dataMigrationServices": "dms-", + "dBforMySQLServers": "mysql-", + "dBforPostgreSQLServers": "psql-", + "devicesIotHubs": "iot-", + "devicesProvisioningServices": "provs-", + "devicesProvisioningServicesCertificates": "pcert-", + "documentDBDatabaseAccounts": "cosmos-", + "documentDBMongoDatabaseAccounts": "cosmon-", + "eventGridDomains": "evgd-", + "eventGridDomainsTopics": "evgt-", + "eventGridEventSubscriptions": "evgs-", + "eventHubNamespaces": "evhns-", + "eventHubNamespacesEventHubs": "evh-", + "hdInsightClustersHadoop": "hadoop-", + "hdInsightClustersHbase": "hbase-", + "hdInsightClustersKafka": "kafka-", + "hdInsightClustersMl": "mls-", + "hdInsightClustersSpark": "spark-", + "hdInsightClustersStorm": "storm-", + "hybridComputeMachines": "arcs-", + "insightsActionGroups": "ag-", + "insightsComponents": "appi-", + "keyVaultVaults": "kv-", + "kubernetesConnectedClusters": "arck", + "kustoClusters": "dec", + "kustoClustersDatabases": "dedb", + "logicIntegrationAccounts": "ia-", + "logicWorkflows": "logic-", + "machineLearningServicesWorkspaces": "mlw-", + "managedIdentityUserAssignedIdentities": "id-", + "managementManagementGroups": "mg-", + "migrateAssessmentProjects": "migr-", + "networkApplicationGateways": "agw-", + "networkApplicationSecurityGroups": "asg-", + "networkAzureFirewalls": "afw-", + "networkBastionHosts": "bas-", + "networkConnections": "con-", + "networkDnsZones": "dnsz-", + "networkExpressRouteCircuits": "erc-", + "networkFirewallPolicies": "afwp-", + "networkFirewallPoliciesWebApplication": "waf", + "networkFirewallPoliciesRuleGroups": "wafrg", + "networkFrontDoors": "fd-", + "networkFrontdoorWebApplicationFirewallPolicies": "fdfp-", + "networkLoadBalancersExternal": "lbe-", + "networkLoadBalancersInternal": "lbi-", + "networkLoadBalancersInboundNatRules": "rule-", + "networkLocalNetworkGateways": "lgw-", + "networkNatGateways": "ng-", + "networkNetworkInterfaces": "nic-", + "networkNetworkSecurityGroups": "nsg-", + "networkNetworkSecurityGroupsSecurityRules": "nsgsr-", + "networkNetworkWatchers": "nw-", + "networkPrivateDnsZones": "pdnsz-", + "networkPrivateLinkServices": "pl-", + "networkPublicIPAddresses": "pip-", + "networkPublicIPPrefixes": "ippre-", + "networkRouteFilters": "rf-", + "networkRouteTables": "rt-", + "networkRouteTablesRoutes": "udr-", + "networkTrafficManagerProfiles": "traf-", + "networkVirtualNetworkGateways": "vgw-", + "networkVirtualNetworks": "vnet-", + "networkVirtualNetworksSubnets": "snet-", + "networkVirtualNetworksVirtualNetworkPeerings": "peer-", + "networkVirtualWans": "vwan-", + "networkVpnGateways": "vpng-", + "networkVpnGatewaysVpnConnections": "vcn-", + "networkVpnGatewaysVpnSites": "vst-", + "notificationHubsNamespaces": "ntfns-", + "notificationHubsNamespacesNotificationHubs": "ntf-", + "operationalInsightsWorkspaces": "log-", + "portalDashboards": "dash-", + "powerBIDedicatedCapacities": "pbi-", + "purviewAccounts": "pview-", + "recoveryServicesVaults": "rsv-", + "resourcesResourceGroups": "rg-", + "searchSearchServices": "srch-", + "serviceBusNamespaces": "sb-", + "serviceBusNamespacesQueues": "sbq-", + "serviceBusNamespacesTopics": "sbt-", + "serviceEndPointPolicies": "se-", + "serviceFabricClusters": "sf-", + "signalRServiceSignalR": "sigr", + "sqlManagedInstances": "sqlmi-", + "sqlServers": "sql-", + "sqlServersDataWarehouse": "sqldw-", + "sqlServersDatabases": "sqldb-", + "sqlServersDatabasesStretch": "sqlstrdb-", + "storageStorageAccounts": "st", + "storageStorageAccountsVm": "stvm", + "storSimpleManagers": "ssimp", + "streamAnalyticsCluster": "asa-", + "synapseWorkspaces": "syn", + "synapseWorkspacesAnalyticsWorkspaces": "synw", + "synapseWorkspacesSqlPoolsDedicated": "syndp", + "synapseWorkspacesSqlPoolsSpark": "synsp", + "timeSeriesInsightsEnvironments": "tsi-", + "webServerFarms": "plan-", + "webSitesAppService": "app-", + "webSitesAppServiceEnvironment": "ase-", + "webSitesFunctions": "func-", + "webStaticSites": "stapp-" +} diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/ai/acr-role-assignment.bicep b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/ai/acr-role-assignment.bicep new file mode 100644 index 000000000000..3e0c2b218be7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/ai/acr-role-assignment.bicep @@ -0,0 +1,27 @@ +targetScope = 'resourceGroup' + +@description('Name of the existing container registry') +param acrName string + +@description('Principal ID to grant AcrPull role') +param principalId string + +@description('Full resource ID of the ACR (for generating unique GUID)') +param acrResourceId string + +// Reference the existing ACR in this resource group +resource acr 'Microsoft.ContainerRegistry/registries@2023-07-01' existing = { + name: acrName +} + +// Grant AcrPull role to the AI project's managed identity +resource acrPullRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = { + scope: acr + name: guid(acrResourceId, principalId, '7f951dda-4ed3-4680-a7ca-43fe172d538d') + properties: { + principalId: principalId + principalType: 'ServicePrincipal' + // AcrPull role + roleDefinitionId: subscriptionResourceId('Microsoft.Authorization/roleDefinitions', '7f951dda-4ed3-4680-a7ca-43fe172d538d') + } +} diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/ai/ai-project.bicep b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/ai/ai-project.bicep new file mode 100644 index 000000000000..662b53c001c8 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/ai/ai-project.bicep @@ -0,0 +1,413 @@ +targetScope = 'resourceGroup' + +@description('Tags that will be applied to all resources') +param tags object = {} + +@description('Main location for the resources') +param location string + +var resourceToken = uniqueString(subscription().id, resourceGroup().id, location) + +@description('Name of the project') +param aiFoundryProjectName string + +param deployments deploymentsType + +@description('Id of the user or app to assign application roles') +param principalId string + +@description('Principal type of user or app') +param principalType string + +@description('Optional. Name of an existing AI Services account in the current resource group. If not provided, a new one will be created.') +param existingAiAccountName string = '' + +@description('List of connections to provision') +param connections array = [] + +@secure() +@description('Map of connection name to credentials object. Kept as @secure to prevent secrets from appearing in deployment logs. Example: { "my-conn": { "key": "secret" } }') +param connectionCredentials object = {} + +@description('Also provision dependent resources and connect to the project') +param additionalDependentResources dependentResourcesType + +@description('Enable monitoring via appinsights and log analytics') +param enableMonitoring bool = true + +@description('Enable hosted agent deployment') +param enableHostedAgents bool = false + +@description('Enable the capability host for agent conversations. When false and hosted agents are enabled, the capability host is not created (v2 hosted agents handle storage automatically).') +param enableCapabilityHost bool = true + +@description('Optional. Existing container registry resource ID. If provided, a connection will be created to this ACR instead of creating a new one.') +param existingContainerRegistryResourceId string = '' + +@description('Optional. Existing container registry login server (e.g., myregistry.azurecr.io). Required if existingContainerRegistryResourceId is provided.') +param existingContainerRegistryEndpoint string = '' + +@description('Optional. Name of an existing ACR connection on the Foundry project. If provided, no new ACR or connection will be created.') +param existingAcrConnectionName string = '' + +@description('Optional. Existing Application Insights connection string. If provided, a connection will be created but no new App Insights resource.') +param existingApplicationInsightsConnectionString string = '' + +@description('Optional. Existing Application Insights resource ID. Used for connection metadata when providing an existing App Insights.') +param existingApplicationInsightsResourceId string = '' + +@description('Optional. Name of an existing Application Insights connection on the Foundry project. If provided, no new App Insights or connection will be created.') +param existingAppInsightsConnectionName string = '' + +// Load abbreviations +var abbrs = loadJsonContent('../../abbreviations.json') + +// Determine which resources to create based on connections +var hasStorageConnection = length(filter(additionalDependentResources, conn => conn.resource == 'storage')) > 0 +var hasAcrConnection = length(filter(additionalDependentResources, conn => conn.resource == 'registry')) > 0 +var hasExistingAcr = !empty(existingContainerRegistryResourceId) +var hasExistingAcrConnection = !empty(existingAcrConnectionName) +var hasExistingAppInsightsConnection = !empty(existingAppInsightsConnectionName) +var hasExistingAppInsightsConnectionString = !empty(existingApplicationInsightsConnectionString) +// Only create new App Insights resources if monitoring enabled and no existing connection/connection string +var shouldCreateAppInsights = enableMonitoring && !hasExistingAppInsightsConnection && !hasExistingAppInsightsConnectionString +var hasSearchConnection = length(filter(additionalDependentResources, conn => conn.resource == 'azure_ai_search')) > 0 +var hasBingConnection = length(filter(additionalDependentResources, conn => conn.resource == 'bing_grounding')) > 0 +var hasBingCustomConnection = length(filter(additionalDependentResources, conn => conn.resource == 'bing_custom_grounding')) > 0 + +// Extract connection names from ai.yaml for each resource type +var storageConnectionName = hasStorageConnection ? filter(additionalDependentResources, conn => conn.resource == 'storage')[0].connectionName : '' +var acrConnectionName = hasAcrConnection ? filter(additionalDependentResources, conn => conn.resource == 'registry')[0].connectionName : '' +var searchConnectionName = hasSearchConnection ? filter(additionalDependentResources, conn => conn.resource == 'azure_ai_search')[0].connectionName : '' +var bingConnectionName = hasBingConnection ? filter(additionalDependentResources, conn => conn.resource == 'bing_grounding')[0].connectionName : '' +var bingCustomConnectionName = hasBingCustomConnection ? filter(additionalDependentResources, conn => conn.resource == 'bing_custom_grounding')[0].connectionName : '' + +// Enable monitoring via Log Analytics and Application Insights +module logAnalytics '../monitor/loganalytics.bicep' = if (shouldCreateAppInsights) { + name: 'logAnalytics' + params: { + location: location + tags: tags + name: 'logs-${resourceToken}' + } +} + +module applicationInsights '../monitor/applicationinsights.bicep' = if (shouldCreateAppInsights) { + name: 'applicationInsights' + params: { + location: location + tags: tags + name: 'appi-${resourceToken}' + logAnalyticsWorkspaceId: logAnalytics.outputs.id + projectMIPrincipalId: aiAccount::project.identity.principalId + } +} + +// Always create a new AI Account for now (simplified approach) +// TODO: Add support for existing accounts in a future version +resource aiAccount 'Microsoft.CognitiveServices/accounts@2025-06-01' = { + name: !empty(existingAiAccountName) ? existingAiAccountName : 'ai-account-${resourceToken}' + location: location + tags: tags + sku: { + name: 'S0' + } + kind: 'AIServices' + identity: { + type: 'SystemAssigned' + } + properties: { + allowProjectManagement: true + customSubDomainName: !empty(existingAiAccountName) ? existingAiAccountName : 'ai-account-${resourceToken}' + networkAcls: { + defaultAction: 'Allow' + virtualNetworkRules: [] + ipRules: [] + } + publicNetworkAccess: 'Enabled' + disableLocalAuth: true + } + + @batchSize(1) + resource seqDeployments 'deployments' = [ + for dep in (deployments??[]): { + name: dep.name + properties: { + model: dep.model + } + sku: dep.sku + } + ] + + resource project 'projects' = { + name: aiFoundryProjectName + location: location + identity: { + type: 'SystemAssigned' + } + properties: { + description: '${aiFoundryProjectName} Project' + displayName: '${aiFoundryProjectName}Project' + } + dependsOn: [ + seqDeployments + ] + } + + resource aiFoundryAccountCapabilityHost 'capabilityHosts@2025-10-01-preview' = if (enableHostedAgents && enableCapabilityHost) { + name: 'agents' + properties: { + capabilityHostKind: 'Agents' + // IMPORTANT: this is required to enable hosted agents deployment + // if no BYO Net is provided + enablePublicHostingEnvironment: true + } + } +} + + +// Create connection towards appinsights: +// - when we create a new App Insights resource, OR +// - when the user provided an existing App Insights connection string + resource ID but no existing connection name +// Both cases are merged into a single resource to avoid duplicate ARM resource definitions (which fail deployment). +var shouldCreateExistingAppInsightsConnection = enableMonitoring && hasExistingAppInsightsConnectionString && !hasExistingAppInsightsConnection && !empty(existingApplicationInsightsResourceId) +var shouldCreateAppInsightsConnection = shouldCreateAppInsights || shouldCreateExistingAppInsightsConnection + +resource appInsightConnection 'Microsoft.CognitiveServices/accounts/projects/connections@2025-04-01-preview' = if (shouldCreateAppInsightsConnection) { + parent: aiAccount::project + name: 'appi-${resourceToken}' + properties: { + category: 'AppInsights' + target: shouldCreateAppInsights ? applicationInsights.outputs.id : existingApplicationInsightsResourceId + authType: 'ApiKey' + isSharedToAll: true + credentials: { + key: shouldCreateAppInsights ? applicationInsights.outputs.connectionString : existingApplicationInsightsConnectionString + } + metadata: { + ApiType: 'Azure' + ResourceId: shouldCreateAppInsights ? applicationInsights.outputs.id : existingApplicationInsightsResourceId + } + } +} + +// Create additional connections from ai.yaml configuration +module aiConnections './connection.bicep' = [for (connection, index) in connections: { + name: 'connection-${connection.name}' + params: { + aiServicesAccountName: aiAccount.name + aiProjectName: aiAccount::project.name + connectionConfig: connection + credentials: connectionCredentials[?connection.name] ?? {} + } +}] + +// Azure AI User for the developer, scoped to the Foundry Project. +// Project scope is sufficient for creating/running agents and calling models via the project endpoint. +resource localUserAzureAIUserRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = { + scope: aiAccount::project + name: guid(subscription().id, resourceGroup().id, principalId, '53ca6127-db72-4b80-b1b0-d745d6d5456d') + properties: { + principalId: principalId + principalType: principalType + roleDefinitionId: resourceId('Microsoft.Authorization/roleDefinitions', '53ca6127-db72-4b80-b1b0-d745d6d5456d') + } +} + + +// All connections are now created directly within their respective resource modules +// using the centralized ./connection.bicep module + +// Storage module - deploy if storage connection is defined in ai.yaml +module storage '../storage/storage.bicep' = if (hasStorageConnection) { + name: 'storage' + params: { + location: location + tags: tags + resourceName: 'st${resourceToken}' + connectionName: storageConnectionName + principalId: principalId + principalType: principalType + aiServicesAccountName: aiAccount.name + aiProjectName: aiAccount::project.name + } +} + +// Azure Container Registry module - deploy if ACR connection is defined in ai.yaml +module acr '../host/acr.bicep' = if (hasAcrConnection) { + name: 'acr' + params: { + location: location + tags: tags + resourceName: '${abbrs.containerRegistryRegistries}${resourceToken}' + connectionName: acrConnectionName + principalId: principalId + principalType: principalType + aiServicesAccountName: aiAccount.name + aiProjectName: aiAccount::project.name + } +} + +// Connection for existing ACR - create if user provided an existing ACR resource ID but no existing connection +module existingAcrConnection './connection.bicep' = if (hasExistingAcr && !hasExistingAcrConnection) { + name: 'existing-acr-connection' + params: { + aiServicesAccountName: aiAccount.name + aiProjectName: aiAccount::project.name + connectionConfig: { + name: 'acr-${resourceToken}' + category: 'ContainerRegistry' + target: existingContainerRegistryEndpoint + authType: 'ManagedIdentity' + isSharedToAll: true + metadata: { + ResourceId: existingContainerRegistryResourceId + } + } + credentials: { + clientId: aiAccount::project.identity.principalId + resourceId: existingContainerRegistryResourceId + } + } +} + +// Extract resource group name from the existing ACR resource ID +// Resource ID format: /subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.ContainerRegistry/registries/{name} +var existingAcrResourceGroup = hasExistingAcr ? split(existingContainerRegistryResourceId, '/')[4] : '' +var existingAcrName = hasExistingAcr ? last(split(existingContainerRegistryResourceId, '/')) : '' + +// Grant AcrPull role to the AI project's managed identity on the existing ACR +// This allows the hosted agents to pull images from the user-provided registry +// Note: User must have permission to assign roles on the existing ACR (Owner or User Access Administrator) +// Using a module allows scoping to a different resource group if the ACR isn't in the same RG +// Skip if connection already exists (role assignment should already be in place) +module existingAcrRoleAssignment './acr-role-assignment.bicep' = if (hasExistingAcr && !hasExistingAcrConnection) { + name: 'existing-acr-role-assignment' + scope: resourceGroup(existingAcrResourceGroup) + params: { + acrName: existingAcrName + acrResourceId: existingContainerRegistryResourceId + principalId: aiAccount::project.identity.principalId + } +} + +// Bing Search grounding module - deploy if Bing connection is defined in ai.yaml or parameter is enabled +module bingGrounding '../search/bing_grounding.bicep' = if (hasBingConnection) { + name: 'bing-grounding' + params: { + tags: tags + resourceName: 'bing-${resourceToken}' + connectionName: bingConnectionName + aiServicesAccountName: aiAccount.name + aiProjectName: aiAccount::project.name + } +} + +// Bing Custom Search grounding module - deploy if custom Bing connection is defined in ai.yaml or parameter is enabled +module bingCustomGrounding '../search/bing_custom_grounding.bicep' = if (hasBingCustomConnection) { + name: 'bing-custom-grounding' + params: { + tags: tags + resourceName: 'bingcustom-${resourceToken}' + connectionName: bingCustomConnectionName + aiServicesAccountName: aiAccount.name + aiProjectName: aiAccount::project.name + } +} + +// Azure AI Search module - deploy if search connection is defined in ai.yaml +module azureAiSearch '../search/azure_ai_search.bicep' = if (hasSearchConnection) { + name: 'azure-ai-search' + params: { + tags: tags + resourceName: 'search-${resourceToken}' + connectionName: searchConnectionName + storageAccountResourceId: hasStorageConnection ? storage!.outputs.storageAccountId : '' + containerName: 'knowledge' + aiServicesAccountName: aiAccount.name + aiProjectName: aiAccount::project.name + principalId: principalId + principalType: principalType + location: location + } +} + +// Outputs +output AZURE_AI_PROJECT_ENDPOINT string = aiAccount::project.properties.endpoints['AI Foundry API'] +output AZURE_OPENAI_ENDPOINT string = aiAccount.properties.endpoints['OpenAI Language Model Instance API'] +output aiServicesEndpoint string = aiAccount.properties.endpoint +output accountId string = aiAccount.id +output projectId string = aiAccount::project.id +output aiServicesAccountName string = aiAccount.name +output aiServicesProjectName string = aiAccount::project.name +output aiServicesPrincipalId string = aiAccount.identity.principalId +output projectName string = aiAccount::project.name +output APPLICATIONINSIGHTS_CONNECTION_STRING string = shouldCreateAppInsights ? applicationInsights.outputs.connectionString : (hasExistingAppInsightsConnectionString ? existingApplicationInsightsConnectionString : '') +output APPLICATIONINSIGHTS_RESOURCE_ID string = shouldCreateAppInsights ? applicationInsights.outputs.id : (hasExistingAppInsightsConnectionString ? existingApplicationInsightsResourceId : '') + +// Connection outputs from the connections array +output connectionIds array = [for (connection, index) in (connections ?? []): { + name: aiConnections[index].outputs.connectionName + id: aiConnections[index].outputs.connectionId +}] + +// Grouped dependent resources outputs +output dependentResources object = { + registry: { + name: hasAcrConnection ? acr!.outputs.containerRegistryName : '' + loginServer: hasAcrConnection ? acr!.outputs.containerRegistryLoginServer : ((hasExistingAcr || hasExistingAcrConnection) ? existingContainerRegistryEndpoint : '') + connectionName: hasAcrConnection ? acr!.outputs.containerRegistryConnectionName : (hasExistingAcrConnection ? existingAcrConnectionName : (hasExistingAcr ? 'acr-${resourceToken}' : '')) + } + bing_grounding: { + name: (hasBingConnection) ? bingGrounding!.outputs.bingGroundingName : '' + connectionName: (hasBingConnection) ? bingGrounding!.outputs.bingGroundingConnectionName : '' + connectionId: (hasBingConnection) ? bingGrounding!.outputs.bingGroundingConnectionId : '' + } + bing_custom_grounding: { + name: (hasBingCustomConnection) ? bingCustomGrounding!.outputs.bingCustomGroundingName : '' + connectionName: (hasBingCustomConnection) ? bingCustomGrounding!.outputs.bingCustomGroundingConnectionName : '' + connectionId: (hasBingCustomConnection) ? bingCustomGrounding!.outputs.bingCustomGroundingConnectionId : '' + } + search: { + serviceName: hasSearchConnection ? azureAiSearch!.outputs.searchServiceName : '' + connectionName: hasSearchConnection ? azureAiSearch!.outputs.searchConnectionName : '' + } + storage: { + accountName: hasStorageConnection ? storage!.outputs.storageAccountName : '' + connectionName: hasStorageConnection ? storage!.outputs.storageConnectionName : '' + } +} + +type deploymentsType = { + @description('Specify the name of cognitive service account deployment.') + name: string + + @description('Required. Properties of Cognitive Services account deployment model.') + model: { + @description('Required. The name of Cognitive Services account deployment model.') + name: string + + @description('Required. The format of Cognitive Services account deployment model.') + format: string + + @description('Required. The version of Cognitive Services account deployment model.') + version: string + } + + @description('The resource model definition representing SKU.') + sku: { + @description('Required. The name of the resource model definition representing SKU.') + name: string + + @description('The capacity of the resource model definition representing SKU.') + capacity: int + } +}[]? + +type dependentResourcesType = { + @description('The type of dependent resource to create') + resource: 'storage' | 'registry' | 'azure_ai_search' | 'bing_grounding' | 'bing_custom_grounding' + + @description('The connection name for this resource') + connectionName: string +}[] diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/ai/connection.bicep b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/ai/connection.bicep new file mode 100644 index 000000000000..a08726645243 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/ai/connection.bicep @@ -0,0 +1,112 @@ +targetScope = 'resourceGroup' + +@description('AI Services account name') +param aiServicesAccountName string + +@description('AI project name') +param aiProjectName string + +// Connection configuration type definition +type ConnectionConfig = { + @description('Name of the connection') + name: string + + @description('Category of the connection (e.g., ContainerRegistry, AzureStorageAccount, CognitiveSearch, AzureOpenAI)') + category: string + + @description('Target endpoint or URL for the connection') + target: string + + @description('Authentication type') + authType: 'AAD' | 'AccessKey' | 'AccountKey' | 'AgenticIdentity' | 'ApiKey' | 'CustomKeys' | 'ManagedIdentity' | 'None' | 'OAuth2' | 'PAT' | 'SAS' | 'ServicePrincipal' | 'UsernamePassword' | 'UserEntraToken' | 'ProjectManagedIdentity' + + @description('Whether the connection is shared to all users (optional, defaults to true)') + isSharedToAll: bool? + + @description('Additional metadata for the connection (optional)') + metadata: object? + + @description('Error message if the connection fails (optional)') + error: string? + + @description('Expiry time for the connection (optional)') + expiryTime: string? + + @description('Private endpoint requirement: Required, NotRequired, or NotApplicable (optional)') + peRequirement: ('NotApplicable' | 'NotRequired' | 'Required')? + + @description('Private endpoint status: Active, Inactive, or NotApplicable (optional)') + peStatus: ('Active' | 'Inactive' | 'NotApplicable')? + + @description('List of users to share the connection with (optional, alternative to isSharedToAll)') + sharedUserList: string[]? + + @description('Whether to use workspace managed identity (optional)') + useWorkspaceManagedIdentity: bool? + + @description('OAuth2 authorization endpoint URL (optional, OAuth2 authType only)') + authorizationUrl: string? + + @description('OAuth2 token endpoint URL (optional, OAuth2 authType only)') + tokenUrl: string? + + @description('OAuth2 refresh token endpoint URL (optional, OAuth2 authType only)') + refreshUrl: string? + + @description('OAuth2 scopes to request (optional, OAuth2 authType only)') + scopes: string[]? + + @description('Token audience for UserEntraToken / AgenticIdentity auth types (optional)') + audience: string? + + @description('Managed connector name for OAuth2 managed connectors (optional)') + connectorName: string? +} + +@description('Connection configuration') +param connectionConfig ConnectionConfig + +@secure() +@description('Credentials for the connection. Kept as a separate @secure parameter to prevent secrets from appearing in deployment logs. Shape depends on authType — e.g. { key: "..." } for ApiKey, { clientId: "...", clientSecret: "..." } for OAuth2/ServicePrincipal.') +param credentials object = {} + + +// Get reference to the AI Services account and project +resource aiAccount 'Microsoft.CognitiveServices/accounts@2025-04-01-preview' existing = { + name: aiServicesAccountName + + resource project 'projects' existing = { + name: aiProjectName + } +} + +// Create the connection +resource connection 'Microsoft.CognitiveServices/accounts/projects/connections@2025-04-01-preview' = { + parent: aiAccount::project + name: connectionConfig.name + properties: { + category: connectionConfig.category + target: connectionConfig.target + authType: connectionConfig.authType + isSharedToAll: connectionConfig.?isSharedToAll ?? true + credentials: !empty(credentials) ? credentials : null + metadata: connectionConfig.?metadata + // Only include if they appear in the connectionConfig + ...connectionConfig.?error != null ? { error: connectionConfig.?error } : {} + ...connectionConfig.?expiryTime != null ? { expiryTime: connectionConfig.?expiryTime } : {} + ...connectionConfig.?peRequirement != null ? { peRequirement: connectionConfig.?peRequirement } : {} + ...connectionConfig.?peStatus != null ? { peStatus: connectionConfig.?peStatus } : {} + ...connectionConfig.?sharedUserList != null ? { sharedUserList: connectionConfig.?sharedUserList } : {} + ...connectionConfig.?useWorkspaceManagedIdentity != null ? { useWorkspaceManagedIdentity: connectionConfig.?useWorkspaceManagedIdentity } : {} + ...connectionConfig.?authorizationUrl != null ? { authorizationUrl: connectionConfig.?authorizationUrl } : {} + ...connectionConfig.?tokenUrl != null ? { tokenUrl: connectionConfig.?tokenUrl } : {} + ...connectionConfig.?refreshUrl != null ? { refreshUrl: connectionConfig.?refreshUrl } : {} + ...connectionConfig.?scopes != null ? { scopes: connectionConfig.?scopes } : {} + ...connectionConfig.?audience != null ? { audience: connectionConfig.?audience } : {} + ...connectionConfig.?connectorName != null ? { connectorName: connectionConfig.?connectorName } : {} + } +} + +// Outputs +output connectionName string = connection.name +output connectionId string = connection.id diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/ai/existing-ai-project.bicep b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/ai/existing-ai-project.bicep new file mode 100644 index 000000000000..fea2782fdfa5 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/ai/existing-ai-project.bicep @@ -0,0 +1,96 @@ +targetScope = 'resourceGroup' + +@description('Name of the existing AI Services account') +param aiServicesAccountName string + +@description('Name of the existing AI Foundry project') +param aiFoundryProjectName string + +@description('Existing ACR connection name (already set in the environment)') +param existingAcrConnectionName string = '' + +@description('Existing container registry endpoint (already set in the environment)') +param existingContainerRegistryEndpoint string = '' + +@description('Existing Application Insights connection string (already set in the environment)') +param existingApplicationInsightsConnectionString string = '' + +@description('Existing Application Insights resource ID (already set in the environment)') +param existingApplicationInsightsResourceId string = '' + +@description('List of connections to provision on the existing project') +param connections array = [] + +@secure() +@description('Map of connection name to credentials object. Kept as @secure to prevent secrets from appearing in deployment logs. Example: { "my-conn": { "key": "secret" } }') +param connectionCredentials object = {} + +// Reference the existing account and project — read-only except for the +// additional connections provisioned below from the agent manifest. +resource aiAccount 'Microsoft.CognitiveServices/accounts@2025-06-01' existing = { + name: aiServicesAccountName + + resource project 'projects' existing = { + name: aiFoundryProjectName + } +} + +// Create additional connections from ai.yaml / agent manifest configuration on +// the existing project. Mirrors the loop in ai-project.bicep so manifest-declared +// connections are provisioned regardless of whether the project itself is new or +// pre-existing. +module aiConnections './connection.bicep' = [for (connection, index) in connections: { + name: 'existing-connection-${connection.name}' + params: { + aiServicesAccountName: aiAccount.name + aiProjectName: aiAccount::project.name + connectionConfig: connection + credentials: connectionCredentials[?connection.name] ?? {} + } +}] + +// Outputs — same shape as ai-project.bicep so main.bicep can use either interchangeably +output AZURE_AI_PROJECT_ENDPOINT string = aiAccount::project.properties.endpoints['AI Foundry API'] +output AZURE_OPENAI_ENDPOINT string = aiAccount.properties.endpoints['OpenAI Language Model Instance API'] +output aiServicesEndpoint string = aiAccount.properties.endpoint +output accountId string = aiAccount.id +output projectId string = aiAccount::project.id +output aiServicesAccountName string = aiAccount.name +output aiServicesProjectName string = aiAccount::project.name +output aiServicesPrincipalId string = aiAccount.identity.principalId +output projectName string = aiAccount::project.name +output APPLICATIONINSIGHTS_CONNECTION_STRING string = existingApplicationInsightsConnectionString +output APPLICATIONINSIGHTS_RESOURCE_ID string = existingApplicationInsightsResourceId + +// Empty connection outputs — these are already set in the azd environment from init +// Connection outputs from the connections array (provisioned above) +output connectionIds array = [for (connection, index) in (connections ?? []): { + name: aiConnections[index].outputs.connectionName + id: aiConnections[index].outputs.connectionId +}] + +output dependentResources object = { + registry: { + name: '' + loginServer: existingContainerRegistryEndpoint + connectionName: existingAcrConnectionName + } + bing_grounding: { + name: '' + connectionName: '' + connectionId: '' + } + bing_custom_grounding: { + name: '' + connectionName: '' + connectionId: '' + } + search: { + serviceName: '' + connectionName: '' + } + storage: { + accountName: '' + connectionName: '' + } +} diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/host/acr.bicep b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/host/acr.bicep new file mode 100644 index 000000000000..f1893d8ff312 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/host/acr.bicep @@ -0,0 +1,88 @@ +targetScope = 'resourceGroup' + +@description('The location used for all deployed resources') +param location string = resourceGroup().location + +@description('Tags that will be applied to all resources') +param tags object = {} + +@description('Resource name for the container registry') +param resourceName string + +@description('Id of the user or app to assign application roles') +param principalId string + +@description('Principal type of user or app') +param principalType string + +@description('AI Services account name for the project parent') +param aiServicesAccountName string = '' + +@description('AI project name for creating the connection') +param aiProjectName string = '' + +@description('Name for the AI Foundry ACR connection') +param connectionName string + +// Get reference to the AI Services account and project to access their managed identities +resource aiAccount 'Microsoft.CognitiveServices/accounts@2025-04-01-preview' existing = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + name: aiServicesAccountName + + resource aiProject 'projects' existing = { + name: aiProjectName + } +} + +// Create the Container Registry +module containerRegistry 'br/public:avm/res/container-registry/registry:0.1.1' = { + name: 'registry' + params: { + name: resourceName + location: location + tags: tags + publicNetworkAccess: 'Enabled' + roleAssignments:[ + { + principalId: principalId + principalType: principalType + // Container Registry Tasks Contributor — build images with ACR tasks and push container images + roleDefinitionIdOrName: subscriptionResourceId('Microsoft.Authorization/roleDefinitions', 'fb382eab-e894-4461-af04-94435c366c3f') + } + // TODO SEPARATELY + { + // the foundry project itself can pull from the ACR + principalId: aiAccount::aiProject.identity.principalId + principalType: 'ServicePrincipal' + roleDefinitionIdOrName: subscriptionResourceId('Microsoft.Authorization/roleDefinitions', '7f951dda-4ed3-4680-a7ca-43fe172d538d') + } + ] + } +} + +// Create the ACR connection using the centralized connection module +module acrConnection '../ai/connection.bicep' = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + name: 'acr-connection-creation' + params: { + aiServicesAccountName: aiServicesAccountName + aiProjectName: aiProjectName + connectionConfig: { + name: connectionName + category: 'ContainerRegistry' + target: containerRegistry.outputs.loginServer + authType: 'ManagedIdentity' + isSharedToAll: true + metadata: { + ResourceId: containerRegistry.outputs.resourceId + } + } + credentials: { + clientId: aiAccount::aiProject.identity.principalId + resourceId: containerRegistry.outputs.resourceId + } + } +} + +output containerRegistryName string = containerRegistry.outputs.name +output containerRegistryLoginServer string = containerRegistry.outputs.loginServer +output containerRegistryResourceId string = containerRegistry.outputs.resourceId +output containerRegistryConnectionName string = acrConnection.outputs.connectionName diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/monitor/applicationinsights-dashboard.bicep b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/monitor/applicationinsights-dashboard.bicep new file mode 100644 index 000000000000..d082e668ed9f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/monitor/applicationinsights-dashboard.bicep @@ -0,0 +1,1236 @@ +metadata description = 'Creates a dashboard for an Application Insights instance.' +param name string +param applicationInsightsName string +param location string = resourceGroup().location +param tags object = {} + +// 2020-09-01-preview because that is the latest valid version +resource applicationInsightsDashboard 'Microsoft.Portal/dashboards@2020-09-01-preview' = { + name: name + location: location + tags: tags + properties: { + lenses: [ + { + order: 0 + parts: [ + { + position: { + x: 0 + y: 0 + colSpan: 2 + rowSpan: 1 + } + metadata: { + inputs: [ + { + name: 'id' + value: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + { + name: 'Version' + value: '1.0' + } + ] + #disable-next-line BCP036 + type: 'Extension/AppInsightsExtension/PartType/AspNetOverviewPinnedPart' + asset: { + idInputName: 'id' + type: 'ApplicationInsights' + } + defaultMenuItemId: 'overview' + } + } + { + position: { + x: 2 + y: 0 + colSpan: 1 + rowSpan: 1 + } + metadata: { + inputs: [ + { + name: 'ComponentId' + value: { + Name: applicationInsights.name + SubscriptionId: subscription().subscriptionId + ResourceGroup: resourceGroup().name + } + } + { + name: 'Version' + value: '1.0' + } + ] + #disable-next-line BCP036 + type: 'Extension/AppInsightsExtension/PartType/ProactiveDetectionAsyncPart' + asset: { + idInputName: 'ComponentId' + type: 'ApplicationInsights' + } + defaultMenuItemId: 'ProactiveDetection' + } + } + { + position: { + x: 3 + y: 0 + colSpan: 1 + rowSpan: 1 + } + metadata: { + inputs: [ + { + name: 'ComponentId' + value: { + Name: applicationInsights.name + SubscriptionId: subscription().subscriptionId + ResourceGroup: resourceGroup().name + } + } + { + name: 'ResourceId' + value: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + ] + #disable-next-line BCP036 + type: 'Extension/AppInsightsExtension/PartType/QuickPulseButtonSmallPart' + asset: { + idInputName: 'ComponentId' + type: 'ApplicationInsights' + } + } + } + { + position: { + x: 4 + y: 0 + colSpan: 1 + rowSpan: 1 + } + metadata: { + inputs: [ + { + name: 'ComponentId' + value: { + Name: applicationInsights.name + SubscriptionId: subscription().subscriptionId + ResourceGroup: resourceGroup().name + } + } + { + name: 'TimeContext' + value: { + durationMs: 86400000 + endTime: null + createdTime: '2018-05-04T01:20:33.345Z' + isInitialTime: true + grain: 1 + useDashboardTimeRange: false + } + } + { + name: 'Version' + value: '1.0' + } + ] + #disable-next-line BCP036 + type: 'Extension/AppInsightsExtension/PartType/AvailabilityNavButtonPart' + asset: { + idInputName: 'ComponentId' + type: 'ApplicationInsights' + } + } + } + { + position: { + x: 5 + y: 0 + colSpan: 1 + rowSpan: 1 + } + metadata: { + inputs: [ + { + name: 'ComponentId' + value: { + Name: applicationInsights.name + SubscriptionId: subscription().subscriptionId + ResourceGroup: resourceGroup().name + } + } + { + name: 'TimeContext' + value: { + durationMs: 86400000 + endTime: null + createdTime: '2018-05-08T18:47:35.237Z' + isInitialTime: true + grain: 1 + useDashboardTimeRange: false + } + } + { + name: 'ConfigurationId' + value: '78ce933e-e864-4b05-a27b-71fd55a6afad' + } + ] + #disable-next-line BCP036 + type: 'Extension/AppInsightsExtension/PartType/AppMapButtonPart' + asset: { + idInputName: 'ComponentId' + type: 'ApplicationInsights' + } + } + } + { + position: { + x: 0 + y: 1 + colSpan: 3 + rowSpan: 1 + } + metadata: { + inputs: [] + type: 'Extension/HubsExtension/PartType/MarkdownPart' + settings: { + content: { + settings: { + content: '# Usage' + title: '' + subtitle: '' + } + } + } + } + } + { + position: { + x: 3 + y: 1 + colSpan: 1 + rowSpan: 1 + } + metadata: { + inputs: [ + { + name: 'ComponentId' + value: { + Name: applicationInsights.name + SubscriptionId: subscription().subscriptionId + ResourceGroup: resourceGroup().name + } + } + { + name: 'TimeContext' + value: { + durationMs: 86400000 + endTime: null + createdTime: '2018-05-04T01:22:35.782Z' + isInitialTime: true + grain: 1 + useDashboardTimeRange: false + } + } + ] + #disable-next-line BCP036 + type: 'Extension/AppInsightsExtension/PartType/UsageUsersOverviewPart' + asset: { + idInputName: 'ComponentId' + type: 'ApplicationInsights' + } + } + } + { + position: { + x: 4 + y: 1 + colSpan: 3 + rowSpan: 1 + } + metadata: { + inputs: [] + type: 'Extension/HubsExtension/PartType/MarkdownPart' + settings: { + content: { + settings: { + content: '# Reliability' + title: '' + subtitle: '' + } + } + } + } + } + { + position: { + x: 7 + y: 1 + colSpan: 1 + rowSpan: 1 + } + metadata: { + inputs: [ + { + name: 'ResourceId' + value: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + { + name: 'DataModel' + value: { + version: '1.0.0' + timeContext: { + durationMs: 86400000 + createdTime: '2018-05-04T23:42:40.072Z' + isInitialTime: false + grain: 1 + useDashboardTimeRange: false + } + } + isOptional: true + } + { + name: 'ConfigurationId' + value: '8a02f7bf-ac0f-40e1-afe9-f0e72cfee77f' + isOptional: true + } + ] + #disable-next-line BCP036 + type: 'Extension/AppInsightsExtension/PartType/CuratedBladeFailuresPinnedPart' + isAdapter: true + asset: { + idInputName: 'ResourceId' + type: 'ApplicationInsights' + } + defaultMenuItemId: 'failures' + } + } + { + position: { + x: 8 + y: 1 + colSpan: 3 + rowSpan: 1 + } + metadata: { + inputs: [] + type: 'Extension/HubsExtension/PartType/MarkdownPart' + settings: { + content: { + settings: { + content: '# Responsiveness\r\n' + title: '' + subtitle: '' + } + } + } + } + } + { + position: { + x: 11 + y: 1 + colSpan: 1 + rowSpan: 1 + } + metadata: { + inputs: [ + { + name: 'ResourceId' + value: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + { + name: 'DataModel' + value: { + version: '1.0.0' + timeContext: { + durationMs: 86400000 + createdTime: '2018-05-04T23:43:37.804Z' + isInitialTime: false + grain: 1 + useDashboardTimeRange: false + } + } + isOptional: true + } + { + name: 'ConfigurationId' + value: '2a8ede4f-2bee-4b9c-aed9-2db0e8a01865' + isOptional: true + } + ] + #disable-next-line BCP036 + type: 'Extension/AppInsightsExtension/PartType/CuratedBladePerformancePinnedPart' + isAdapter: true + asset: { + idInputName: 'ResourceId' + type: 'ApplicationInsights' + } + defaultMenuItemId: 'performance' + } + } + { + position: { + x: 12 + y: 1 + colSpan: 3 + rowSpan: 1 + } + metadata: { + inputs: [] + type: 'Extension/HubsExtension/PartType/MarkdownPart' + settings: { + content: { + settings: { + content: '# Browser' + title: '' + subtitle: '' + } + } + } + } + } + { + position: { + x: 15 + y: 1 + colSpan: 1 + rowSpan: 1 + } + metadata: { + inputs: [ + { + name: 'ComponentId' + value: { + Name: applicationInsights.name + SubscriptionId: subscription().subscriptionId + ResourceGroup: resourceGroup().name + } + } + { + name: 'MetricsExplorerJsonDefinitionId' + value: 'BrowserPerformanceTimelineMetrics' + } + { + name: 'TimeContext' + value: { + durationMs: 86400000 + createdTime: '2018-05-08T12:16:27.534Z' + isInitialTime: false + grain: 1 + useDashboardTimeRange: false + } + } + { + name: 'CurrentFilter' + value: { + eventTypes: [ + 4 + 1 + 3 + 5 + 2 + 6 + 13 + ] + typeFacets: {} + isPermissive: false + } + } + { + name: 'id' + value: { + Name: applicationInsights.name + SubscriptionId: subscription().subscriptionId + ResourceGroup: resourceGroup().name + } + } + { + name: 'Version' + value: '1.0' + } + ] + #disable-next-line BCP036 + type: 'Extension/AppInsightsExtension/PartType/MetricsExplorerBladePinnedPart' + asset: { + idInputName: 'ComponentId' + type: 'ApplicationInsights' + } + defaultMenuItemId: 'browser' + } + } + { + position: { + x: 0 + y: 2 + colSpan: 4 + rowSpan: 3 + } + metadata: { + inputs: [ + { + name: 'options' + value: { + chart: { + metrics: [ + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'sessions/count' + aggregationType: 5 + namespace: 'microsoft.insights/components/kusto' + metricVisualization: { + displayName: 'Sessions' + color: '#47BDF5' + } + } + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'users/count' + aggregationType: 5 + namespace: 'microsoft.insights/components/kusto' + metricVisualization: { + displayName: 'Users' + color: '#7E58FF' + } + } + ] + title: 'Unique sessions and users' + visualization: { + chartType: 2 + legendVisualization: { + isVisible: true + position: 2 + hideSubtitle: false + } + axisVisualization: { + x: { + isVisible: true + axisType: 2 + } + y: { + isVisible: true + axisType: 1 + } + } + } + openBladeOnClick: { + openBlade: true + destinationBlade: { + extensionName: 'HubsExtension' + bladeName: 'ResourceMenuBlade' + parameters: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + menuid: 'segmentationUsers' + } + } + } + } + } + } + { + name: 'sharedTimeRange' + isOptional: true + } + ] + #disable-next-line BCP036 + type: 'Extension/HubsExtension/PartType/MonitorChartPart' + settings: {} + } + } + { + position: { + x: 4 + y: 2 + colSpan: 4 + rowSpan: 3 + } + metadata: { + inputs: [ + { + name: 'options' + value: { + chart: { + metrics: [ + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'requests/failed' + aggregationType: 7 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Failed requests' + color: '#EC008C' + } + } + ] + title: 'Failed requests' + visualization: { + chartType: 3 + legendVisualization: { + isVisible: true + position: 2 + hideSubtitle: false + } + axisVisualization: { + x: { + isVisible: true + axisType: 2 + } + y: { + isVisible: true + axisType: 1 + } + } + } + openBladeOnClick: { + openBlade: true + destinationBlade: { + extensionName: 'HubsExtension' + bladeName: 'ResourceMenuBlade' + parameters: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + menuid: 'failures' + } + } + } + } + } + } + { + name: 'sharedTimeRange' + isOptional: true + } + ] + #disable-next-line BCP036 + type: 'Extension/HubsExtension/PartType/MonitorChartPart' + settings: {} + } + } + { + position: { + x: 8 + y: 2 + colSpan: 4 + rowSpan: 3 + } + metadata: { + inputs: [ + { + name: 'options' + value: { + chart: { + metrics: [ + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'requests/duration' + aggregationType: 4 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Server response time' + color: '#00BCF2' + } + } + ] + title: 'Server response time' + visualization: { + chartType: 2 + legendVisualization: { + isVisible: true + position: 2 + hideSubtitle: false + } + axisVisualization: { + x: { + isVisible: true + axisType: 2 + } + y: { + isVisible: true + axisType: 1 + } + } + } + openBladeOnClick: { + openBlade: true + destinationBlade: { + extensionName: 'HubsExtension' + bladeName: 'ResourceMenuBlade' + parameters: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + menuid: 'performance' + } + } + } + } + } + } + { + name: 'sharedTimeRange' + isOptional: true + } + ] + #disable-next-line BCP036 + type: 'Extension/HubsExtension/PartType/MonitorChartPart' + settings: {} + } + } + { + position: { + x: 12 + y: 2 + colSpan: 4 + rowSpan: 3 + } + metadata: { + inputs: [ + { + name: 'options' + value: { + chart: { + metrics: [ + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'browserTimings/networkDuration' + aggregationType: 4 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Page load network connect time' + color: '#7E58FF' + } + } + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'browserTimings/processingDuration' + aggregationType: 4 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Client processing time' + color: '#44F1C8' + } + } + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'browserTimings/sendDuration' + aggregationType: 4 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Send request time' + color: '#EB9371' + } + } + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'browserTimings/receiveDuration' + aggregationType: 4 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Receiving response time' + color: '#0672F1' + } + } + ] + title: 'Average page load time breakdown' + visualization: { + chartType: 3 + legendVisualization: { + isVisible: true + position: 2 + hideSubtitle: false + } + axisVisualization: { + x: { + isVisible: true + axisType: 2 + } + y: { + isVisible: true + axisType: 1 + } + } + } + } + } + } + { + name: 'sharedTimeRange' + isOptional: true + } + ] + #disable-next-line BCP036 + type: 'Extension/HubsExtension/PartType/MonitorChartPart' + settings: {} + } + } + { + position: { + x: 0 + y: 5 + colSpan: 4 + rowSpan: 3 + } + metadata: { + inputs: [ + { + name: 'options' + value: { + chart: { + metrics: [ + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'availabilityResults/availabilityPercentage' + aggregationType: 4 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Availability' + color: '#47BDF5' + } + } + ] + title: 'Average availability' + visualization: { + chartType: 3 + legendVisualization: { + isVisible: true + position: 2 + hideSubtitle: false + } + axisVisualization: { + x: { + isVisible: true + axisType: 2 + } + y: { + isVisible: true + axisType: 1 + } + } + } + openBladeOnClick: { + openBlade: true + destinationBlade: { + extensionName: 'HubsExtension' + bladeName: 'ResourceMenuBlade' + parameters: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + menuid: 'availability' + } + } + } + } + } + } + { + name: 'sharedTimeRange' + isOptional: true + } + ] + #disable-next-line BCP036 + type: 'Extension/HubsExtension/PartType/MonitorChartPart' + settings: {} + } + } + { + position: { + x: 4 + y: 5 + colSpan: 4 + rowSpan: 3 + } + metadata: { + inputs: [ + { + name: 'options' + value: { + chart: { + metrics: [ + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'exceptions/server' + aggregationType: 7 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Server exceptions' + color: '#47BDF5' + } + } + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'dependencies/failed' + aggregationType: 7 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Dependency failures' + color: '#7E58FF' + } + } + ] + title: 'Server exceptions and Dependency failures' + visualization: { + chartType: 2 + legendVisualization: { + isVisible: true + position: 2 + hideSubtitle: false + } + axisVisualization: { + x: { + isVisible: true + axisType: 2 + } + y: { + isVisible: true + axisType: 1 + } + } + } + } + } + } + { + name: 'sharedTimeRange' + isOptional: true + } + ] + #disable-next-line BCP036 + type: 'Extension/HubsExtension/PartType/MonitorChartPart' + settings: {} + } + } + { + position: { + x: 8 + y: 5 + colSpan: 4 + rowSpan: 3 + } + metadata: { + inputs: [ + { + name: 'options' + value: { + chart: { + metrics: [ + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'performanceCounters/processorCpuPercentage' + aggregationType: 4 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Processor time' + color: '#47BDF5' + } + } + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'performanceCounters/processCpuPercentage' + aggregationType: 4 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Process CPU' + color: '#7E58FF' + } + } + ] + title: 'Average processor and process CPU utilization' + visualization: { + chartType: 2 + legendVisualization: { + isVisible: true + position: 2 + hideSubtitle: false + } + axisVisualization: { + x: { + isVisible: true + axisType: 2 + } + y: { + isVisible: true + axisType: 1 + } + } + } + } + } + } + { + name: 'sharedTimeRange' + isOptional: true + } + ] + #disable-next-line BCP036 + type: 'Extension/HubsExtension/PartType/MonitorChartPart' + settings: {} + } + } + { + position: { + x: 12 + y: 5 + colSpan: 4 + rowSpan: 3 + } + metadata: { + inputs: [ + { + name: 'options' + value: { + chart: { + metrics: [ + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'exceptions/browser' + aggregationType: 7 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Browser exceptions' + color: '#47BDF5' + } + } + ] + title: 'Browser exceptions' + visualization: { + chartType: 2 + legendVisualization: { + isVisible: true + position: 2 + hideSubtitle: false + } + axisVisualization: { + x: { + isVisible: true + axisType: 2 + } + y: { + isVisible: true + axisType: 1 + } + } + } + } + } + } + { + name: 'sharedTimeRange' + isOptional: true + } + ] + #disable-next-line BCP036 + type: 'Extension/HubsExtension/PartType/MonitorChartPart' + settings: {} + } + } + { + position: { + x: 0 + y: 8 + colSpan: 4 + rowSpan: 3 + } + metadata: { + inputs: [ + { + name: 'options' + value: { + chart: { + metrics: [ + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'availabilityResults/count' + aggregationType: 7 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Availability test results count' + color: '#47BDF5' + } + } + ] + title: 'Availability test results count' + visualization: { + chartType: 2 + legendVisualization: { + isVisible: true + position: 2 + hideSubtitle: false + } + axisVisualization: { + x: { + isVisible: true + axisType: 2 + } + y: { + isVisible: true + axisType: 1 + } + } + } + } + } + } + { + name: 'sharedTimeRange' + isOptional: true + } + ] + #disable-next-line BCP036 + type: 'Extension/HubsExtension/PartType/MonitorChartPart' + settings: {} + } + } + { + position: { + x: 4 + y: 8 + colSpan: 4 + rowSpan: 3 + } + metadata: { + inputs: [ + { + name: 'options' + value: { + chart: { + metrics: [ + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'performanceCounters/processIOBytesPerSecond' + aggregationType: 4 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Process IO rate' + color: '#47BDF5' + } + } + ] + title: 'Average process I/O rate' + visualization: { + chartType: 2 + legendVisualization: { + isVisible: true + position: 2 + hideSubtitle: false + } + axisVisualization: { + x: { + isVisible: true + axisType: 2 + } + y: { + isVisible: true + axisType: 1 + } + } + } + } + } + } + { + name: 'sharedTimeRange' + isOptional: true + } + ] + #disable-next-line BCP036 + type: 'Extension/HubsExtension/PartType/MonitorChartPart' + settings: {} + } + } + { + position: { + x: 8 + y: 8 + colSpan: 4 + rowSpan: 3 + } + metadata: { + inputs: [ + { + name: 'options' + value: { + chart: { + metrics: [ + { + resourceMetadata: { + id: '/subscriptions/${subscription().subscriptionId}/resourceGroups/${resourceGroup().name}/providers/Microsoft.Insights/components/${applicationInsights.name}' + } + name: 'performanceCounters/memoryAvailableBytes' + aggregationType: 4 + namespace: 'microsoft.insights/components' + metricVisualization: { + displayName: 'Available memory' + color: '#47BDF5' + } + } + ] + title: 'Average available memory' + visualization: { + chartType: 2 + legendVisualization: { + isVisible: true + position: 2 + hideSubtitle: false + } + axisVisualization: { + x: { + isVisible: true + axisType: 2 + } + y: { + isVisible: true + axisType: 1 + } + } + } + } + } + } + { + name: 'sharedTimeRange' + isOptional: true + } + ] + #disable-next-line BCP036 + type: 'Extension/HubsExtension/PartType/MonitorChartPart' + settings: {} + } + } + ] + } + ] + } +} + +resource applicationInsights 'Microsoft.Insights/components@2020-02-02' existing = { + name: applicationInsightsName +} diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/monitor/applicationinsights.bicep b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/monitor/applicationinsights.bicep new file mode 100644 index 000000000000..73240d1b1c9a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/monitor/applicationinsights.bicep @@ -0,0 +1,47 @@ +metadata description = 'Creates an Application Insights instance based on an existing Log Analytics workspace.' +param name string +param dashboardName string = '' +param location string = resourceGroup().location +param tags object = {} +param logAnalyticsWorkspaceId string + +@description('Optional. Principal ID of the Foundry Project managed identity to grant Log Analytics Reader.') +param projectMIPrincipalId string = '' + +resource applicationInsights 'Microsoft.Insights/components@2020-02-02' = { + name: name + location: location + tags: tags + kind: 'web' + properties: { + Application_Type: 'web' + WorkspaceResourceId: logAnalyticsWorkspaceId + } +} + +module applicationInsightsDashboard 'applicationinsights-dashboard.bicep' = if (!empty(dashboardName)) { + name: 'application-insights-dashboard' + params: { + name: dashboardName + location: location + applicationInsightsName: applicationInsights.name + } +} + +// Log Analytics Reader for the Foundry Project managed identity. +// Required for running evaluations on traces generated by agents. +resource logAnalyticsReaderRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = if (!empty(projectMIPrincipalId)) { + scope: applicationInsights + name: guid(applicationInsights.id, projectMIPrincipalId, '73c42c96-874c-492b-b04d-ab87d138a893') + properties: { + principalId: projectMIPrincipalId + principalType: 'ServicePrincipal' + // Log Analytics Reader + roleDefinitionId: subscriptionResourceId('Microsoft.Authorization/roleDefinitions', '73c42c96-874c-492b-b04d-ab87d138a893') + } +} + +output connectionString string = applicationInsights.properties.ConnectionString +output id string = applicationInsights.id +output instrumentationKey string = applicationInsights.properties.InstrumentationKey +output name string = applicationInsights.name diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/monitor/loganalytics.bicep b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/monitor/loganalytics.bicep new file mode 100644 index 000000000000..33f9dc29443a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/monitor/loganalytics.bicep @@ -0,0 +1,22 @@ +metadata description = 'Creates a Log Analytics workspace.' +param name string +param location string = resourceGroup().location +param tags object = {} + +resource logAnalytics 'Microsoft.OperationalInsights/workspaces@2021-12-01-preview' = { + name: name + location: location + tags: tags + properties: any({ + retentionInDays: 30 + features: { + searchVersion: 1 + } + sku: { + name: 'PerGB2018' + } + }) +} + +output id string = logAnalytics.id +output name string = logAnalytics.name diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/search/azure_ai_search.bicep b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/search/azure_ai_search.bicep new file mode 100644 index 000000000000..7bb8e6350025 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/search/azure_ai_search.bicep @@ -0,0 +1,211 @@ +targetScope = 'resourceGroup' + +@description('Tags that will be applied to all resources') +param tags object = {} + +@description('Azure Search resource name') +param resourceName string + +@description('Azure Search SKU name') +param azureSearchSkuName string = 'basic' + +@description('Azure storage account resource ID') +param storageAccountResourceId string + +@description('container name') +param containerName string = 'knowledgebase' + +@description('AI Services account name for the project parent') +param aiServicesAccountName string = '' + +@description('AI project name for creating the connection') +param aiProjectName string = '' + +@description('Id of the user or app to assign application roles') +param principalId string + +@description('Principal type of user or app') +param principalType string + +@description('Name for the AI Foundry search connection') +param connectionName string + +@description('Location for all resources') +param location string = resourceGroup().location + +// Get reference to the AI Services account and project to access their managed identities +resource aiAccount 'Microsoft.CognitiveServices/accounts@2025-04-01-preview' existing = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + name: aiServicesAccountName + + resource aiProject 'projects' existing = { + name: aiProjectName + } +} + +// Azure Search Service +resource searchService 'Microsoft.Search/searchServices@2024-06-01-preview' = { + name: resourceName + location: location + tags: tags + sku: { + name: azureSearchSkuName + } + identity: { + type: 'SystemAssigned' + } + properties: { + replicaCount: 1 + partitionCount: 1 + hostingMode: 'default' + authOptions: { + aadOrApiKey: { + aadAuthFailureMode: 'http401WithBearerChallenge' + } + } + disableLocalAuth: false + encryptionWithCmk: { + enforcement: 'Unspecified' + } + publicNetworkAccess: 'enabled' + } +} + +// Reference to existing Storage Account +resource storageAccount 'Microsoft.Storage/storageAccounts@2023-05-01' existing = { + name: last(split(storageAccountResourceId, '/')) +} + +// Reference to existing Blob Service +resource blobService 'Microsoft.Storage/storageAccounts/blobServices@2023-05-01' existing = { + parent: storageAccount + name: 'default' +} + +// Storage Container (create if it doesn't exist) +resource storageContainer 'Microsoft.Storage/storageAccounts/blobServices/containers@2023-05-01' = { + parent: blobService + name: containerName + properties: { + publicAccess: 'None' + } +} + +// RBAC Assignments + +// Search needs to read from Storage +resource searchToStorageRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = { + name: guid(storageAccount.id, searchService.id, 'Storage Blob Data Reader', uniqueString(deployment().name)) + scope: storageAccount + properties: { + // GOOD + roleDefinitionId: subscriptionResourceId('Microsoft.Authorization/roleDefinitions', '2a2b9908-6ea1-4ae2-8e65-a410df84e7d1') // Storage Blob Data Reader + principalId: searchService.identity.principalId + principalType: 'ServicePrincipal' + } +} + +// Search needs OpenAI access (AI Services account) +resource searchToAIServicesRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = if (!empty(aiServicesAccountName)) { + name: guid(aiServicesAccountName, searchService.id, 'Cognitive Services OpenAI User', uniqueString(deployment().name)) + properties: { + // GOOD + roleDefinitionId: subscriptionResourceId('Microsoft.Authorization/roleDefinitions', '5e0bd9bd-7b93-4f28-af87-19fc36ad61bd') // Cognitive Services OpenAI User + principalId: searchService.identity.principalId + principalType: 'ServicePrincipal' + } +} + +// AI Project needs Search access - Service Contributor +resource aiServicesToSearchServiceRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + name: guid(searchService.id, aiServicesAccountName, aiProjectName, 'Search Service Contributor', uniqueString(deployment().name)) + scope: searchService + properties: { + // GOOD + roleDefinitionId: subscriptionResourceId('Microsoft.Authorization/roleDefinitions', '7ca78c08-252a-4471-8644-bb5ff32d4ba0') // Search Service Contributor + principalId: aiAccount::aiProject.identity.principalId + principalType: 'ServicePrincipal' + } +} + +// AI Project needs Search access - Index Data Contributor +resource aiServicesToSearchDataRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + name: guid(searchService.id, aiServicesAccountName, aiProjectName, 'Search Index Data Contributor', uniqueString(deployment().name)) + scope: searchService + properties: { + // GOOD + roleDefinitionId: subscriptionResourceId('Microsoft.Authorization/roleDefinitions', '8ebe5a00-799e-43f5-93ac-243d3dce84a7') // Search Index Data Contributor + principalId: aiAccount::aiProject.identity.principalId + principalType: 'ServicePrincipal' + } +} + +// User permissions - Search Index Data Contributor +resource userToSearchRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = { + name: guid(searchService.id, principalId, 'Search Index Data Contributor', uniqueString(deployment().name)) + scope: searchService + properties: { + // GOOD + roleDefinitionId: subscriptionResourceId('Microsoft.Authorization/roleDefinitions', '8ebe5a00-799e-43f5-93ac-243d3dce84a7') // Search Index Data Contributor + principalId: principalId + principalType: principalType + } +} + +// // User permissions - Storage Blob Data Contributor +// resource userToStorageRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = { +// name: guid(storageAccount.id, principalId, 'Storage Blob Data Contributor', uniqueString(deployment().name)) +// scope: storageAccount +// properties: { +// roleDefinitionId: subscriptionResourceId('Microsoft.Authorization/roleDefinitions', 'ba92f5b4-2d11-453d-a403-e96b0029c9fe') // Storage Blob Data Contributor +// principalId: principalId +// principalType: principalType +// } +// } + +// // Project needs Search access - Index Data Contributor +// resource projectToSearchRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = { +// name: guid(searchService.id, aiProjectName, 'Search Index Data Contributor', uniqueString(deployment().name)) +// scope: searchService +// properties: { +// roleDefinitionId: subscriptionResourceId('Microsoft.Authorization/roleDefinitions', '8ebe5a00-799e-43f5-93ac-243d3dce84a7') // Search Index Data Contributor +// principalId: aiAccountPrincipalId // Using AI account principal ID as project identity +// principalType: 'ServicePrincipal' +// } +// } + +// Create the AI Search connection using the centralized connection module +module aiSearchConnection '../ai/connection.bicep' = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + name: 'ai-search-connection-creation' + params: { + aiServicesAccountName: aiServicesAccountName + aiProjectName: aiProjectName + connectionConfig: { + name: connectionName + category: 'CognitiveSearch' + target: 'https://${searchService.name}.search.windows.net' + authType: 'AAD' + isSharedToAll: true + metadata: { + ApiVersion: '2024-07-01' + ResourceId: searchService.id + ApiType: 'Azure' + type: 'azure_ai_search' + } + } + } + dependsOn: [ + aiServicesToSearchDataRoleAssignment + ] +} + +// Outputs +output searchServiceName string = searchService.name +output searchServiceId string = searchService.id +output searchServicePrincipalId string = searchService.identity.principalId +output storageAccountName string = storageAccount.name +output storageAccountId string = storageAccount.id +output containerName string = storageContainer.name +output storageAccountPrincipalId string = storageAccount.identity.principalId +output searchConnectionName string = (!empty(aiServicesAccountName) && !empty(aiProjectName)) ? aiSearchConnection!.outputs.connectionName : '' +output searchConnectionId string = (!empty(aiServicesAccountName) && !empty(aiProjectName)) ? aiSearchConnection!.outputs.connectionId : '' + diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/search/bing_custom_grounding.bicep b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/search/bing_custom_grounding.bicep new file mode 100644 index 000000000000..1fddea079e2e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/search/bing_custom_grounding.bicep @@ -0,0 +1,84 @@ +targetScope = 'resourceGroup' + +@description('Tags that will be applied to all resources') +param tags object = {} + +@description('Bing custom grounding resource name') +param resourceName string + +@description('AI Services account name for the project parent') +param aiServicesAccountName string = '' + +@description('AI project name for creating the connection') +param aiProjectName string = '' + +@description('Name for the AI Foundry Bing Custom Search connection') +param connectionName string + +// Get reference to the AI Services account and project to access their managed identities +resource aiAccount 'Microsoft.CognitiveServices/accounts@2025-04-01-preview' existing = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + name: aiServicesAccountName + + resource aiProject 'projects' existing = { + name: aiProjectName + } +} + +// Bing Search resource for grounding capability +resource bingCustomSearch 'Microsoft.Bing/accounts@2020-06-10' = { + name: resourceName + location: 'global' + tags: tags + sku: { + name: 'G1' + } + properties: { + statisticsEnabled: false + } + kind: 'Bing.CustomGrounding' +} + +// Role assignment to allow AI project to use Bing Search +resource bingCustomSearchRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + scope: bingCustomSearch + name: guid(subscription().id, resourceGroup().id, 'bing-search-role', aiServicesAccountName, aiProjectName) + properties: { + principalId: aiAccount::aiProject.identity.principalId + principalType: 'ServicePrincipal' + roleDefinitionId: resourceId('Microsoft.Authorization/roleDefinitions', 'a97b65f3-24c7-4388-baec-2e87135dc908') // Cognitive Services User + } +} + +// Create the Bing Custom Search connection using the centralized connection module +module aiSearchConnection '../ai/connection.bicep' = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + name: 'bing-custom-search-connection-creation' + params: { + aiServicesAccountName: aiServicesAccountName + aiProjectName: aiProjectName + connectionConfig: { + name: connectionName + category: 'GroundingWithCustomSearch' + target: bingCustomSearch.properties.endpoint + authType: 'ApiKey' + isSharedToAll: true + metadata: { + Location: 'global' + ResourceId: bingCustomSearch.id + ApiType: 'Azure' + type: 'bing_custom_search' + } + } + credentials: { + key: bingCustomSearch.listKeys().key1 + } + } + dependsOn: [ + bingCustomSearchRoleAssignment + ] +} + +// Outputs +output bingCustomGroundingName string = bingCustomSearch.name +output bingCustomGroundingConnectionName string = aiSearchConnection.outputs.connectionName +output bingCustomGroundingResourceId string = bingCustomSearch.id +output bingCustomGroundingConnectionId string = aiSearchConnection.outputs.connectionId diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/search/bing_grounding.bicep b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/search/bing_grounding.bicep new file mode 100644 index 000000000000..20ea5e9f160a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/search/bing_grounding.bicep @@ -0,0 +1,83 @@ +targetScope = 'resourceGroup' + +@description('Tags that will be applied to all resources') +param tags object = {} + +@description('Bing grounding resource name') +param resourceName string + +@description('AI Services account name for the project parent') +param aiServicesAccountName string = '' + +@description('AI project name for creating the connection') +param aiProjectName string = '' + +@description('Name for the AI Foundry Bing Search connection') +param connectionName string + +// Get reference to the AI Services account and project to access their managed identities +resource aiAccount 'Microsoft.CognitiveServices/accounts@2025-04-01-preview' existing = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + name: aiServicesAccountName + + resource aiProject 'projects' existing = { + name: aiProjectName + } +} + +// Bing Search resource for grounding capability +resource bingSearch 'Microsoft.Bing/accounts@2020-06-10' = { + name: resourceName + location: 'global' + tags: tags + sku: { + name: 'G1' + } + properties: { + statisticsEnabled: false + } + kind: 'Bing.Grounding' +} + +// Role assignment to allow AI project to use Bing Search +resource bingSearchRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + scope: bingSearch + name: guid(subscription().id, resourceGroup().id, 'bing-search-role', aiServicesAccountName, aiProjectName) + properties: { + principalId: aiAccount::aiProject.identity.principalId + principalType: 'ServicePrincipal' + roleDefinitionId: resourceId('Microsoft.Authorization/roleDefinitions', 'a97b65f3-24c7-4388-baec-2e87135dc908') // Cognitive Services User + } +} + +// Create the Bing Search connection using the centralized connection module +module bingSearchConnection '../ai/connection.bicep' = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + name: 'bing-search-connection-creation' + params: { + aiServicesAccountName: aiServicesAccountName + aiProjectName: aiProjectName + connectionConfig: { + name: connectionName + category: 'GroundingWithBingSearch' + target: bingSearch.properties.endpoint + authType: 'ApiKey' + isSharedToAll: true + metadata: { + Location: 'global' + ResourceId: bingSearch.id + ApiType: 'Azure' + type: 'bing_grounding' + } + } + credentials: { + key: bingSearch.listKeys().key1 + } + } + dependsOn: [ + bingSearchRoleAssignment + ] +} + +output bingGroundingName string = bingSearch.name +output bingGroundingConnectionName string = bingSearchConnection.outputs.connectionName +output bingGroundingResourceId string = bingSearch.id +output bingGroundingConnectionId string = bingSearchConnection.outputs.connectionId diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/storage/storage.bicep b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/storage/storage.bicep new file mode 100644 index 000000000000..18d9535dcd0b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/core/storage/storage.bicep @@ -0,0 +1,113 @@ +targetScope = 'resourceGroup' + +@description('The location used for all deployed resources') +param location string = resourceGroup().location + +@description('Tags that will be applied to all resources') +param tags object = {} + +@description('Storage account resource name') +param resourceName string + +@description('Id of the user or app to assign application roles') +param principalId string + +@description('Principal type of user or app') +param principalType string + +@description('AI Services account name for the project parent') +param aiServicesAccountName string = '' + +@description('AI project name for creating the connection') +param aiProjectName string = '' + +@description('Name for the AI Foundry storage connection') +param connectionName string + +// Storage Account for the AI Services account +resource storageAccount 'Microsoft.Storage/storageAccounts@2023-05-01' = { + name: resourceName + location: location + tags: tags + sku: { + name: 'Standard_LRS' + } + kind: 'StorageV2' + identity: { + type: 'SystemAssigned' + } + properties: { + supportsHttpsTrafficOnly: true + allowBlobPublicAccess: false + minimumTlsVersion: 'TLS1_2' + accessTier: 'Hot' + encryption: { + services: { + blob: { + enabled: true + } + file: { + enabled: true + } + } + keySource: 'Microsoft.Storage' + } + } +} + +// Get reference to the AI Services account and project to access their managed identities +resource aiAccount 'Microsoft.CognitiveServices/accounts@2025-04-01-preview' existing = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + name: aiServicesAccountName + + resource aiProject 'projects' existing = { + name: aiProjectName + } +} + +// Role assignment for AI Services to access the storage account +resource storageRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + name: guid(storageAccount.id, aiAccount.id, 'ai-storage-contributor') + scope: storageAccount + properties: { + roleDefinitionId: subscriptionResourceId('Microsoft.Authorization/roleDefinitions', 'ba92f5b4-2d11-453d-a403-e96b0029c9fe') // Storage Blob Data Contributor + principalId: aiAccount::aiProject.identity.principalId + principalType: 'ServicePrincipal' + } +} + +// User permissions - Storage Blob Data Contributor +resource userStorageRoleAssignment 'Microsoft.Authorization/roleAssignments@2022-04-01' = { + name: guid(storageAccount.id, principalId, 'Storage Blob Data Contributor') + scope: storageAccount + properties: { + roleDefinitionId: subscriptionResourceId('Microsoft.Authorization/roleDefinitions', 'ba92f5b4-2d11-453d-a403-e96b0029c9fe') // Storage Blob Data Contributor + principalId: principalId + principalType: principalType + } +} + +// Create the storage connection using the centralized connection module +module storageConnection '../ai/connection.bicep' = if (!empty(aiServicesAccountName) && !empty(aiProjectName)) { + name: 'storage-connection-creation' + params: { + aiServicesAccountName: aiServicesAccountName + aiProjectName: aiProjectName + connectionConfig: { + name: connectionName + category: 'AzureStorageAccount' + target: storageAccount.properties.primaryEndpoints.blob + authType: 'AAD' + isSharedToAll: true + metadata: { + ApiType: 'Azure' + ResourceId: storageAccount.id + location: storageAccount.location + } + } + } +} + +output storageAccountName string = storageAccount.name +output storageAccountId string = storageAccount.id +output storageAccountPrincipalId string = storageAccount.identity.principalId +output storageConnectionName string = storageConnection.outputs.connectionName diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/main.bicep b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/main.bicep new file mode 100644 index 000000000000..df29abd59bf6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/main.bicep @@ -0,0 +1,239 @@ +targetScope = 'subscription' +// targetScope = 'resourceGroup' + +@minLength(1) +@maxLength(64) +@description('Name of the environment that can be used as part of naming resource convention') +param environmentName string + +@minLength(1) +@maxLength(90) +@description('Name of the resource group to use or create') +param resourceGroupName string = 'rg-${environmentName}' + +// Restricted locations to match list from +// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#region-availability +@minLength(1) +@description('Primary location for all resources') +@allowed([ + 'australiaeast' + 'brazilsouth' + 'canadacentral' + 'canadaeast' + 'eastus' + 'eastus2' + 'francecentral' + 'germanywestcentral' + 'italynorth' + 'japaneast' + 'koreacentral' + 'northcentralus' + 'norwayeast' + 'polandcentral' + 'southafricanorth' + 'southcentralus' + 'southeastasia' + 'southindia' + 'spaincentral' + 'swedencentral' + 'switzerlandnorth' + 'uaenorth' + 'uksouth' + 'westus' + 'westus2' + 'westus3' +]) +param location string + +param aiDeploymentsLocation string + +@description('Id of the user or app to assign application roles') +param principalId string + +@description('Principal type of user or app') +param principalType string + +@description('Optional. Name of an existing AI Services account within the resource group. If not provided, a new one will be created.') +param aiFoundryResourceName string = '' + +@description('Optional. Name of the AI Foundry project. If not provided, a default name will be used.') +param aiFoundryProjectName string = 'ai-project-${environmentName}' + +@description('List of model deployments') +param aiProjectDeploymentsJson string = '[]' + +@description('List of connections') +param aiProjectConnectionsJson string = '[]' + +@secure() +@description('JSON map of connection name to credentials object. Example: {"my-conn":{"key":"secret"}}') +param aiProjectConnectionCredentialsJson string = '{}' + +@description('List of resources to create and connect to the AI project') +param aiProjectDependentResourcesJson string = '[]' + +var aiProjectDeployments = json(aiProjectDeploymentsJson) +var aiProjectConnections = json(aiProjectConnectionsJson) +var aiProjectConnectionCreds = json(aiProjectConnectionCredentialsJson) +var aiProjectDependentResources = json(aiProjectDependentResourcesJson) + +@description('Enable hosted agent deployment') +param enableHostedAgents bool + +@description('Enable the capability host for supporting BYO storage of agent conversations. When false and hosted agents are enabled, the capability host is not created.') +param enableCapabilityHost bool + +@description('Enable monitoring for the AI project') +param enableMonitoring bool + +@description('When true, skip Foundry project/role/connection provisioning and reference the existing project read-only. Use when pointing at an existing Foundry project via --project-id.') +param useExistingAiProject bool = false + +@description('Optional. Existing container registry resource ID. If provided, no new ACR will be created and a connection to this ACR will be established.') +param existingContainerRegistryResourceId string = '' + +@description('Optional. Existing container registry endpoint (login server). Required if existingContainerRegistryResourceId is provided.') +param existingContainerRegistryEndpoint string = '' + +@description('Optional. Name of an existing ACR connection on the Foundry project. If provided, no new ACR or connection will be created.') +param existingAcrConnectionName string = '' + +@description('Optional. Existing Application Insights connection string. If provided, a connection will be created but no new App Insights resource.') +param existingApplicationInsightsConnectionString string = '' + +@description('Optional. Existing Application Insights resource ID. Used for connection metadata when providing an existing App Insights.') +param existingApplicationInsightsResourceId string = '' + +@description('Optional. Name of an existing Application Insights connection on the Foundry project. If provided, no new App Insights or connection will be created.') +param existingAppInsightsConnectionName string = '' + +// Tags that should be applied to all resources. +// +// Note that 'azd-service-name' tags should be applied separately to service host resources. +// Example usage: +// tags: union(tags, { 'azd-service-name': }) +var tags = { + 'azd-env-name': environmentName +} + +// Check if resource group exists and create it if it doesn't +resource rg 'Microsoft.Resources/resourceGroups@2021-04-01' = { + name: resourceGroupName + location: location + tags: tags +} + +// Build dependent resources array conditionally +// Check if ACR already exists in the user-provided array to avoid duplicates +// Also skip if user provided an existing container registry endpoint or connection name +var hasAcr = contains(map(aiProjectDependentResources, r => r.resource), 'registry') +var shouldCreateAcr = enableHostedAgents && !hasAcr && empty(existingContainerRegistryResourceId) && empty(existingAcrConnectionName) +var dependentResources = shouldCreateAcr ? union(aiProjectDependentResources, [ + { + resource: 'registry' + connectionName: 'acr-${uniqueString(subscription().id, resourceGroupName, location)}' + } +]) : aiProjectDependentResources + +// AI Project module — only when creating new resources +module aiProject 'core/ai/ai-project.bicep' = if (!useExistingAiProject) { + scope: rg + name: 'ai-project' + params: { + tags: tags + location: aiDeploymentsLocation + aiFoundryProjectName: aiFoundryProjectName + principalId: principalId + principalType: principalType + existingAiAccountName: aiFoundryResourceName + deployments: aiProjectDeployments + connections: aiProjectConnections + connectionCredentials: aiProjectConnectionCreds + additionalDependentResources: dependentResources + enableMonitoring: enableMonitoring + enableHostedAgents: enableHostedAgents + enableCapabilityHost: enableCapabilityHost + existingContainerRegistryResourceId: existingContainerRegistryResourceId + existingContainerRegistryEndpoint: existingContainerRegistryEndpoint + existingAcrConnectionName: existingAcrConnectionName + existingApplicationInsightsConnectionString: existingApplicationInsightsConnectionString + existingApplicationInsightsResourceId: existingApplicationInsightsResourceId + existingAppInsightsConnectionName: existingAppInsightsConnectionName + } +} + +// Existing project module — read-only reference when reusing an existing Foundry project +module existingAiProject 'core/ai/existing-ai-project.bicep' = if (useExistingAiProject) { + scope: rg + name: 'existing-ai-project' + params: { + aiServicesAccountName: aiFoundryResourceName + aiFoundryProjectName: aiFoundryProjectName + existingAcrConnectionName: existingAcrConnectionName + existingContainerRegistryEndpoint: existingContainerRegistryEndpoint + existingApplicationInsightsConnectionString: existingApplicationInsightsConnectionString + existingApplicationInsightsResourceId: existingApplicationInsightsResourceId + connections: aiProjectConnections + connectionCredentials: aiProjectConnectionCreds + } +} + +// ACR for existing project — create when hosted agents need a registry but the existing project has none +var shouldCreateAcrForExistingProject = useExistingAiProject && shouldCreateAcr +var acrConnectionName = 'acr-${uniqueString(subscription().id, resourceGroupName, location)}' + +module acrForExistingProject 'core/host/acr.bicep' = if (shouldCreateAcrForExistingProject) { + scope: rg + name: 'acr-for-existing-project' + params: { + location: location + tags: tags + resourceName: 'cr${uniqueString(subscription().id, resourceGroupName, location)}' + connectionName: acrConnectionName + principalId: principalId + principalType: principalType + aiServicesAccountName: aiFoundryResourceName + aiProjectName: aiFoundryProjectName + } +} + +// Resources +output AZURE_RESOURCE_GROUP string = resourceGroupName +output AZURE_AI_ACCOUNT_ID string = useExistingAiProject ? existingAiProject.outputs.accountId : aiProject.outputs.accountId +output AZURE_AI_PROJECT_ID string = useExistingAiProject ? existingAiProject.outputs.projectId : aiProject.outputs.projectId +output AZURE_AI_FOUNDRY_PROJECT_ID string = useExistingAiProject ? existingAiProject.outputs.projectId : aiProject.outputs.projectId +output AZURE_AI_ACCOUNT_NAME string = useExistingAiProject ? existingAiProject.outputs.aiServicesAccountName : aiProject.outputs.aiServicesAccountName +output AZURE_AI_PROJECT_NAME string = useExistingAiProject ? existingAiProject.outputs.projectName : aiProject.outputs.projectName + +// Endpoints +output AZURE_AI_PROJECT_ENDPOINT string = useExistingAiProject ? existingAiProject.outputs.AZURE_AI_PROJECT_ENDPOINT : aiProject.outputs.AZURE_AI_PROJECT_ENDPOINT +output AZURE_OPENAI_ENDPOINT string = useExistingAiProject ? existingAiProject.outputs.AZURE_OPENAI_ENDPOINT : aiProject.outputs.AZURE_OPENAI_ENDPOINT +output APPLICATIONINSIGHTS_CONNECTION_STRING string = useExistingAiProject ? existingAiProject.outputs.APPLICATIONINSIGHTS_CONNECTION_STRING : aiProject.outputs.APPLICATIONINSIGHTS_CONNECTION_STRING +output APPLICATIONINSIGHTS_RESOURCE_ID string = useExistingAiProject ? existingAiProject.outputs.APPLICATIONINSIGHTS_RESOURCE_ID : aiProject.outputs.APPLICATIONINSIGHTS_RESOURCE_ID + +// Dependent Resources and Connections + +// ACR +output AZURE_AI_PROJECT_ACR_CONNECTION_NAME string = shouldCreateAcrForExistingProject ? acrForExistingProject.outputs.containerRegistryConnectionName : (useExistingAiProject ? existingAiProject.outputs.dependentResources.registry.connectionName : aiProject.outputs.dependentResources.registry.connectionName) +output AZURE_CONTAINER_REGISTRY_ENDPOINT string = shouldCreateAcrForExistingProject ? acrForExistingProject.outputs.containerRegistryLoginServer : (useExistingAiProject ? existingAiProject.outputs.dependentResources.registry.loginServer : aiProject.outputs.dependentResources.registry.loginServer) + +// Bing Search +output BING_GROUNDING_CONNECTION_NAME string = useExistingAiProject ? existingAiProject.outputs.dependentResources.bing_grounding.connectionName : aiProject.outputs.dependentResources.bing_grounding.connectionName +output BING_GROUNDING_RESOURCE_NAME string = useExistingAiProject ? existingAiProject.outputs.dependentResources.bing_grounding.name : aiProject.outputs.dependentResources.bing_grounding.name +output BING_GROUNDING_CONNECTION_ID string = useExistingAiProject ? existingAiProject.outputs.dependentResources.bing_grounding.connectionId : aiProject.outputs.dependentResources.bing_grounding.connectionId + +// Bing Custom Search +output BING_CUSTOM_GROUNDING_CONNECTION_NAME string = useExistingAiProject ? existingAiProject.outputs.dependentResources.bing_custom_grounding.connectionName : aiProject.outputs.dependentResources.bing_custom_grounding.connectionName +output BING_CUSTOM_GROUNDING_NAME string = useExistingAiProject ? existingAiProject.outputs.dependentResources.bing_custom_grounding.name : aiProject.outputs.dependentResources.bing_custom_grounding.name +output BING_CUSTOM_GROUNDING_CONNECTION_ID string = useExistingAiProject ? existingAiProject.outputs.dependentResources.bing_custom_grounding.connectionId : aiProject.outputs.dependentResources.bing_custom_grounding.connectionId + +// Azure AI Search +output AZURE_AI_SEARCH_CONNECTION_NAME string = useExistingAiProject ? existingAiProject.outputs.dependentResources.search.connectionName : aiProject.outputs.dependentResources.search.connectionName +output AZURE_AI_SEARCH_SERVICE_NAME string = useExistingAiProject ? existingAiProject.outputs.dependentResources.search.serviceName : aiProject.outputs.dependentResources.search.serviceName + +// Azure Storage +output AZURE_STORAGE_CONNECTION_NAME string = useExistingAiProject ? existingAiProject.outputs.dependentResources.storage.connectionName : aiProject.outputs.dependentResources.storage.connectionName +output AZURE_STORAGE_ACCOUNT_NAME string = useExistingAiProject ? existingAiProject.outputs.dependentResources.storage.accountName : aiProject.outputs.dependentResources.storage.accountName + +// Connections +output AI_PROJECT_CONNECTION_IDS_JSON string = useExistingAiProject ? string(existingAiProject.outputs.connectionIds) : string(aiProject.outputs.connectionIds) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/main.parameters.json b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/main.parameters.json new file mode 100644 index 000000000000..dbf643f3f48f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/infra/main.parameters.json @@ -0,0 +1,72 @@ +{ + "$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentParameters.json#", + "contentVersion": "1.0.0.0", + "parameters": { + "resourceGroupName": { + "value": "${AZURE_RESOURCE_GROUP}" + }, + "environmentName": { + "value": "${AZURE_ENV_NAME}" + }, + "location": { + "value": "${AZURE_LOCATION}" + }, + "aiFoundryResourceName": { + "value": "${AZURE_AI_ACCOUNT_NAME}" + }, + "aiFoundryProjectName": { + "value": "${AZURE_AI_PROJECT_NAME}" + }, + "aiDeploymentsLocation": { + "value": "${AZURE_LOCATION}" + }, + "principalId": { + "value": "${AZURE_PRINCIPAL_ID}" + }, + "principalType": { + "value": "${AZURE_PRINCIPAL_TYPE}" + }, + "aiProjectDeploymentsJson": { + "value": "${AI_PROJECT_DEPLOYMENTS=[]}" + }, + "aiProjectConnectionsJson": { + "value": "${AI_PROJECT_CONNECTIONS=[]}" + }, + "aiProjectConnectionCredentialsJson": { + "value": "${AI_PROJECT_CONNECTION_CREDENTIALS}" + }, + "aiProjectDependentResourcesJson": { + "value": "${AI_PROJECT_DEPENDENT_RESOURCES=[]}" + }, + "enableMonitoring": { + "value": "${ENABLE_MONITORING=true}" + }, + "enableHostedAgents": { + "value": "${ENABLE_HOSTED_AGENTS=false}" + }, + "enableCapabilityHost": { + "value": "${ENABLE_CAPABILITY_HOST=true}" + }, + "useExistingAiProject": { + "value": "${USE_EXISTING_AI_PROJECT=false}" + }, + "existingContainerRegistryResourceId": { + "value": "${AZURE_CONTAINER_REGISTRY_RESOURCE_ID=}" + }, + "existingContainerRegistryEndpoint": { + "value": "${AZURE_CONTAINER_REGISTRY_ENDPOINT=}" + }, + "existingAcrConnectionName": { + "value": "${AZURE_AI_PROJECT_ACR_CONNECTION_NAME=}" + }, + "existingApplicationInsightsConnectionString": { + "value": "${APPLICATIONINSIGHTS_CONNECTION_STRING=}" + }, + "existingApplicationInsightsResourceId": { + "value": "${APPLICATIONINSIGHTS_RESOURCE_ID=}" + }, + "existingAppInsightsConnectionName": { + "value": "${APPLICATIONINSIGHTS_CONNECTION_NAME=}" + } + } +} diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/Dockerfile b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/Dockerfile new file mode 100644 index 000000000000..e26b87bb7cae --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/Dockerfile @@ -0,0 +1,24 @@ +FROM python:3.12-slim + +WORKDIR /app + +# Install local wheel packages first (built by build.sh before docker build) +COPY wheels/ /tmp/wheels/ +RUN pip install --no-cache-dir /tmp/wheels/*.whl && rm -rf /tmp/wheels + +# Install remaining dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY app.py agent.py supervisor.py entrypoint.sh ./ +RUN chmod +x entrypoint.sh + +# Use local file-based durability (Task Storage API not yet available in this region) +ENV FOUNDRY_TASK_API_ENABLED=0 +ENV STAGE_DURATION=5 + +EXPOSE 8088 + +# supervisor.py is PID 1: keeps /readiness alive, proxies to app, restarts on crash +CMD ["python", "supervisor.py"] + diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/agent.py new file mode 100644 index 000000000000..25467c3bdf51 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/agent.py @@ -0,0 +1,221 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""The durable research task — this is what makes the agent crash-resilient. + +The ONLY things you need for durability: + 1. ``@durable_task`` decorator + 2. ``ctx.metadata[...] = value`` + ``await ctx.metadata.flush()`` to checkpoint + +That's it. Everything else here is just normal agent logic. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +from pathlib import Path +from typing import Any + +from azure.ai.projects.aio import AIProjectClient +from azure.identity.aio import DefaultAzureCredential + +from azure.ai.agentserver.core.durable import TaskContext, durable_task + +logger = logging.getLogger(__name__) + +# ── Azure AI client setup ───────────────────────────────────────────────────── + +_endpoint = os.environ.get("FOUNDRY_PROJECT_ENDPOINT") +if not _endpoint: + raise EnvironmentError("FOUNDRY_PROJECT_ENDPOINT is required.") + +_model = os.environ.get("AZURE_AI_MODEL_DEPLOYMENT_NAME", "gpt-4.1-mini") +_credential = DefaultAzureCredential() +_project_client = AIProjectClient(endpoint=_endpoint, credential=_credential) +_openai_client = _project_client.get_openai_client() + +# ── File-backed stream handler ──────────────────────────────────────────────── +# Stores stream items to disk so consumers can reconnect after a crash/disconnect +# and replay from where they left off. + +_STREAM_DIR = Path.home() / ".durable-tasks" / "_streams" + + +class FileStreamHandler: + """Stream handler that persists items to a file for crash-resilient replay. + + On init, if the stream file already exists (i.e. recovering after crash), + all previously written items are loaded back into the queue so that a + consumer iterating via ``get()`` sees the full history followed by new items. + """ + + def __init__(self, task_id: str) -> None: + self._task_id = task_id + self._dir = _STREAM_DIR / task_id + self._dir.mkdir(parents=True, exist_ok=True) + self._file = self._dir / "stream.jsonl" + self._queue: asyncio.Queue[Any] = asyncio.Queue() + self._closed = False + self._SENTINEL = object() + + # Replay persisted items into the queue on recovery + if self._file.exists(): + for line in self._file.read_text(encoding="utf-8").splitlines(): + if line.strip(): + data = json.loads(line) + if "__done__" not in data: + self._queue.put_nowait(data) + + async def put(self, item: Any) -> None: + """Persist item to disk and enqueue for live consumer.""" + with open(self._file, "a", encoding="utf-8") as f: + f.write(json.dumps(item) + "\n") + await self._queue.put(item) + + async def get(self) -> Any: + """Get next item (live consumer path).""" + item = await self._queue.get() + if item is self._SENTINEL: + raise StopAsyncIteration + return item + + async def close(self) -> None: + """Mark stream as done.""" + self._closed = True + with open(self._file, "a", encoding="utf-8") as f: + f.write(json.dumps({"__done__": True}) + "\n") + await self._queue.put(self._SENTINEL) + + +def file_stream_factory(task_id: str) -> FileStreamHandler: + """Factory for creating file-backed stream handlers.""" + return FileStreamHandler(task_id) + + +# ── Research stages ─────────────────────────────────────────────────────────── +# A realistic deep-research pipeline — each stage is a distinct step that +# naturally takes time (LLM call + processing delay). + +STAGES = [ + "Decomposing topic into focused research questions", + "Surveying foundational literature and key concepts", + "Identifying leading researchers and institutions", + "Analyzing recent breakthroughs and publications", + "Examining competing theories and approaches", + "Evaluating experimental evidence and data quality", + "Mapping connections to adjacent fields", + "Identifying open problems and knowledge gaps", + "Assessing real-world applications and impact", + "Analyzing funding landscape and research trends", + "Synthesizing findings into a coherent narrative", + "Generating key insights and recommendations", +] + +STAGE_DURATION = int(os.environ.get("STAGE_DURATION", "5")) + + +# ── The durable task ────────────────────────────────────────────────────────── + +@durable_task(name="deep_research", stream_handler_factory=file_stream_factory) +async def deep_research(ctx: TaskContext[dict]) -> dict[str, Any]: + """Long-running deep research task that survives crashes. + + Runs through 12 distinct research stages, each making an LLM call. + On crash recovery, resumes from the last checkpointed stage. + Can be cancelled early via the cancel invocation handler. + """ + topic: str = ctx.input["topic"] + completed: int = ctx.metadata.get("completed_stages", 0) + results: list = ctx.metadata.get("results", []) + total = len(STAGES) + + if ctx.entry_mode == "recovered": + logger.warning("⚡ Recovered! Resuming from stage %d/%d", completed + 1, total) + await ctx.stream(json.dumps({ + "type": "token", + "content": f"\n\n⚡ **Recovered from crash!** Resuming from stage {completed + 1}/{total}...\n\n", + })) + + for stage_idx in range(completed, total): + # Check for cancellation + if ctx.cancel.is_set(): + await ctx.stream(json.dumps({ + "type": "token", + "content": "\n\n---\n🛑 **Research cancelled.**\n", + })) + return {"topic": topic, "stages_completed": stage_idx, "cancelled": True} + + stage = STAGES[stage_idx] + + # Announce stage + await ctx.stream(json.dumps({ + "type": "token", + "content": f"\n\n**[Stage {stage_idx + 1}/{total}]** {stage}...\n", + })) + + # Do the work — streaming LLM tokens + result = await _run_stage_streaming(ctx, topic, stage, prior_results=results[-3:], stage_idx=stage_idx) + results.append({"stage": stage, "result": result}) + + # ── CHECKPOINT ── crash-recovery boundary ───── + ctx.metadata["completed_stages"] = stage_idx + 1 + ctx.metadata["results"] = results + await ctx.metadata.flush() + + await ctx.stream(json.dumps({ + "type": "token", + "content": f"\n✅ Stage {stage_idx + 1}/{total} complete.\n", + })) + + # Done! + await ctx.stream(json.dumps({ + "type": "token", + "content": "\n\n---\n✅ **Research complete!**\n", + })) + return { + "topic": topic, + "report": results[-1]["result"] if results else "", + "stages_completed": total, + } + + +# ── LLM helpers ─────────────────────────────────────────────────────────────── + +async def _run_stage_streaming( + ctx: TaskContext, topic: str, stage: str, *, prior_results: list, stage_idx: int = 0 +) -> str: + """Call the LLM for one research stage, streaming tokens to the consumer.""" + # Skip artificial delay for first stage so demo feels responsive + if stage_idx > 0: + await asyncio.sleep(STAGE_DURATION) + + if prior_results: + findings = "\n".join(f"- {r['stage']}: {r['result'][:80]}" for r in prior_results[-3:]) + instructions = ( + f"You are a research assistant performing: '{stage}'. " + f"Build on these prior findings:\n{findings}\n\n" + "Provide 3-4 sentences of new, specific, detailed findings. Be informative." + ) + else: + instructions = ( + f"You are a research assistant performing: '{stage}'. " + "Provide 3-4 sentences of specific, detailed findings. Be informative and engaging." + ) + input_text = f"Research topic: {topic}" + + # Stream tokens from the LLM + full_text = "" + async for event in await _openai_client.responses.create( + model=_model, + instructions=instructions, + input=input_text, + store=False, + stream=True, + ): + if event.type == "response.output_text.delta": + full_text += event.delta + await ctx.stream(json.dumps({"type": "token", "content": event.delta})) + + return full_text diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/agent.yaml b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/agent.yaml new file mode 100644 index 000000000000..801f615985ba --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/agent.yaml @@ -0,0 +1,25 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/microsoft/AgentSchema/refs/heads/main/schemas/v1.0/ContainerAgent.yaml + +kind: hosted +name: durable-research-agent +description: | + Demo agent showcasing crash-resilient long-running tasks using @durable_task. + Survives crashes and auto-resumes from last checkpoint on restart. +metadata: + tags: + - AI Agent Hosting + - Invocations Protocol + - Durable Tasks + - Crash Resilience + - Python +protocols: + - protocol: invocations + version: 1.0.0 +resources: + cpu: "1" + memory: 2Gi +environment_variables: + - name: AZURE_AI_MODEL_DEPLOYMENT_NAME + value: gpt-4.1-mini + - name: STAGE_DURATION + value: "10" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/app.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/app.py new file mode 100644 index 000000000000..4566be5aa2d0 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/app.py @@ -0,0 +1,198 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""HTTP host for the durable research agent. + +This file is minimal plumbing — the durability logic is in ``agent.py``. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver.core.durable import TaskCancelled, TaskConflictError, TaskFailed, TaskTerminated +from azure.ai.agentserver.invocations import InvocationAgentServerHost + +from agent import deep_research + +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") +logger = logging.getLogger(__name__) + +# ── HTTP handlers ───────────────────────────────────────────────────────────── + +app = InvocationAgentServerHost() + + +@app.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Dispatch a research task (fire-and-forget). + + Returns immediately with 202 + invocation/session IDs. + The client then calls GET /invocations/{id} to stream results. + Send ``{"message": "crash"}`` to trigger a deliberate crash for demo. + """ + body = await request.body() + try: + data = json.loads(body) if body else {} + except json.JSONDecodeError: + data = {"message": body.decode("utf-8", errors="replace").strip()} + + topic = data.get("message") or "" + # Foundry sends input as a list of messages + if not topic and isinstance(data.get("input"), list): + messages = data["input"] + if messages and isinstance(messages[-1], dict): + topic = messages[-1].get("content", "") + elif not topic and isinstance(data.get("input"), str): + topic = data["input"] + + if not topic.strip(): + return JSONResponse({"error": "Provide a 'message' field"}, status_code=400) + + # Deliberate crash trigger for demo — return 202, then crash asynchronously + if topic.strip().lower() in ("crash", "💥", "kill"): + logger.critical("💥 CRASH triggered via API — will exit shortly") + + async def _crash(): + await asyncio.sleep(0.3) # give time for response to flush + os._exit(137) + + asyncio.get_event_loop().create_task(_crash()) + return JSONResponse( + {"status": "crashing", "message": "💥 Process will crash now"}, + status_code=202, + ) + + invocation_id: str = request.state.invocation_id + session_id: str = request.state.session_id + task_id = f"research-{session_id}" + logger.info(f"POST handler: session_id={session_id!r}, task_id={task_id!r}") + + status = "started" + try: + await deep_research.start( + task_id=task_id, + input={"topic": topic, "invocation_id": invocation_id}, + session_id=session_id, + ) + except TaskConflictError: + # Task already running (recovered after crash) + status = "in_progress" + logger.info(f"POST handler: TaskConflictError — task already running") + + # Return immediately — platform sees 202 and preserves invocation mapping + return JSONResponse( + { + "status": status, + "invocation_id": invocation_id, + "session_id": session_id, + }, + status_code=202, + ) + + +@app.get_invocation_handler +async def handle_get(request: Request) -> Response: + """Stream SSE from the active task or replay from persisted file. + + The platform routes GET /invocations/{id} to this container based on + the invocation→session mapping preserved from the fire-and-forget POST. + Session ID is derived from the framework config (FOUNDRY_AGENT_SESSION_ID). + + Supports ``last_event_id`` query param to skip already-received events + on reconnect (platform strips non x-client- headers, so we use a param). + """ + session_id = request.state.session_id if hasattr(request.state, "session_id") and request.state.session_id else app.config.session_id + task_id = f"research-{session_id}" + + # Skip already-seen events: client passes last_event_id query param on reconnect + last_event_id = request.query_params.get("last_event_id", "") + skip_count = int(last_event_id) if last_event_id.isdigit() else 0 + logger.info(f"GET handler: session_id={session_id!r}, task_id={task_id!r}, skip={skip_count}") + + run = deep_research.get_active_run(task_id) + logger.info(f"GET handler: get_active_run({task_id!r}) -> {run}") + + if run is not None: + # Live task — stream from it, skipping already-seen events + async def live_stream(): + event_id = 0 + try: + async for chunk in run: + event_id += 1 + if event_id <= skip_count: + continue + yield f"id: {event_id}\ndata: {chunk}\n\n" + result = await run.result() + event_id += 1 + yield f"id: {event_id}\ndata: {json.dumps({'type': 'done', 'full_text': result.output.get('report', '')})}\n\n" + except (TaskCancelled, TaskTerminated): + event_id += 1 + yield f"id: {event_id}\ndata: {json.dumps({'type': 'done', 'full_text': '[Task was cancelled]'})}\n\n" + except TaskFailed as exc: + event_id += 1 + yield f"id: {event_id}\ndata: {json.dumps({'type': 'done', 'full_text': f'[Error: {exc}]'})}\n\n" + + return StreamingResponse( + live_stream(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache"}, + ) + + # Fallback: replay from persisted stream file + from pathlib import Path + + stream_file = Path.home() / ".durable-tasks" / "_streams" / task_id / "stream.jsonl" + if not stream_file.exists(): + return JSONResponse({"status": "not_found", "message": "No active task or stream history for this session."}) + + logger.info(f"GET handler: falling back to stream file {stream_file}") + + async def file_stream(): + event_id = 0 + for line in stream_file.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line: + continue + data = json.loads(line) + if "__done__" in data: + event_id += 1 + yield f"id: {event_id}\ndata: {json.dumps({'type': 'done', 'full_text': ''})}\n\n" + return + event_id += 1 + if event_id <= skip_count: + continue + yield f"id: {event_id}\ndata: {json.dumps(data)}\n\n" + # File exists but no __done__ sentinel — task may still be running + event_id += 1 + yield f"id: {event_id}\ndata: {json.dumps({'type': 'done', 'full_text': '[Stream replay complete — task may still be recovering]'})}\n\n" + + return StreamingResponse( + file_stream(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache"}, + ) + + +@app.cancel_invocation_handler +async def handle_cancel(request: Request) -> Response: + """Cancel the running research task.""" + session_id = request.state.session_id if hasattr(request.state, "session_id") and request.state.session_id else app.config.session_id + task_id = f"research-{session_id}" + logger.info(f"CANCEL handler: session_id={session_id!r}, task_id={task_id!r}") + + run = deep_research.get_active_run(task_id) + if run is None: + return JSONResponse({"status": "not_found", "message": "No active task to cancel."}) + + await run.cancel() + return JSONResponse({"status": "cancelled", "message": "Task cancellation requested."}) + + +if __name__ == "__main__": + app.run() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/entrypoint.sh b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/entrypoint.sh new file mode 100644 index 000000000000..e35dce4f496c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/entrypoint.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +# Auto-restart wrapper. Restarts immediately on crash. +set -u + +while true; do + echo "$(date -Iseconds) [entrypoint] Starting agent..." + python app.py + exit_code=$? + + if [ $exit_code -eq 0 ]; then + echo "$(date -Iseconds) [entrypoint] Agent exited cleanly. Stopping." + exit 0 + fi + + echo "$(date -Iseconds) [entrypoint] 💥 Crashed (exit $exit_code). Restarting immediately..." +done diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/requirements.txt new file mode 100644 index 000000000000..9ecdeba478bd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/requirements.txt @@ -0,0 +1,10 @@ +# Azure AI packages (installed from local wheels during build) +azure-ai-agentserver-core +azure-ai-agentserver-invocations + +# Azure SDKs +azure-ai-projects>=1.0.0b10 +azure-identity>=1.17.0 + +# Supervisor proxy (also a transitive dep of core) +aiohttp>=3.9.0 diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/supervisor.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/supervisor.py new file mode 100644 index 000000000000..21ed281205e9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/supervisor.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft. All rights reserved. + +"""Supervisor: keeps /readiness alive while app crashes and restarts. + +This is PID 1 in the container. It: + 1. Runs a tiny HTTP server on port 8088 (the platform-facing port) + 2. Spawns app.py on an internal port (8089) + 3. Always responds 200 to GET /readiness + 4. Proxies POST /invocations (and everything else) to the app + 5. Restarts the app immediately on crash + +Because this process never exits, the platform never sees a readiness failure, +and the session survives across app crashes. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import signal +import subprocess +import sys + +from aiohttp import ClientSession, ClientTimeout, web + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [supervisor] %(message)s", + stream=sys.stderr, +) +logger = logging.getLogger("supervisor") + +EXTERNAL_PORT = 8088 +INTERNAL_PORT = 8089 +APP_BASE = f"http://127.0.0.1:{INTERNAL_PORT}" + +# ── App process management ──────────────────────────────────────────────────── + +_app_proc: subprocess.Popen | None = None + + +def _start_app() -> subprocess.Popen: + env = os.environ.copy() + env["PORT"] = str(INTERNAL_PORT) + logger.info("Starting agent on port %d...", INTERNAL_PORT) + proc = subprocess.Popen( + [sys.executable, "app.py"], + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + logger.info("Agent PID %d", proc.pid) + return proc + + +async def _monitor_loop(): + """Restart app on crash.""" + global _app_proc + while True: + if _app_proc is not None: + ret = _app_proc.poll() + if ret is not None: + if ret == 0: + logger.info("Agent exited cleanly. Supervisor stopping.") + raise SystemExit(0) + logger.warning("💥 Agent crashed (exit %d). Restarting...", ret) + _app_proc = _start_app() + await asyncio.sleep(0.3) + + +# ── HTTP handlers ───────────────────────────────────────────────────────────── + + +async def handle_readiness(_request: web.Request) -> web.Response: + """Always-healthy readiness check.""" + return web.json_response({"status": "healthy"}) + + +async def handle_proxy(request: web.Request) -> web.StreamResponse: + """Proxy everything else to the app, waiting for it to be ready first.""" + session: ClientSession = request.app["client_session"] + + # Wait for the app to be ready (poll /readiness on internal port) + for _ in range(30): # up to ~6 seconds + try: + async with session.get(f"{APP_BASE}/readiness") as check: + if check.status == 200: + break + except Exception: + pass + await asyncio.sleep(0.2) + + url = f"{APP_BASE}{request.path_qs}" + headers = dict(request.headers) + headers.pop("Host", None) + headers.pop("host", None) + body = await request.read() + + try: + async with session.request( + request.method, url, headers=headers, data=body + ) as resp: + # Check if SSE — stream it back + if "text/event-stream" in resp.content_type: + proxy_resp = web.StreamResponse( + status=resp.status, + headers={ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + }, + ) + await proxy_resp.prepare(request) + async for chunk in resp.content.iter_any(): + await proxy_resp.write(chunk) + await proxy_resp.write_eof() + return proxy_resp + else: + return web.Response( + body=await resp.read(), + status=resp.status, + content_type=resp.content_type, + ) + except Exception: + return web.json_response( + {"error": "Agent is restarting. Retry in a moment."}, + status=503, + ) + + +# ── App lifecycle ───────────────────────────────────────────────────────────── + + +async def on_startup(app: web.Application): + global _app_proc + app["client_session"] = ClientSession(timeout=ClientTimeout(total=300)) + _app_proc = _start_app() + app["monitor_task"] = asyncio.create_task(_monitor_loop()) + + +async def on_cleanup(app: web.Application): + app["monitor_task"].cancel() + await app["client_session"].close() + if _app_proc and _app_proc.poll() is None: + _app_proc.terminate() + try: + _app_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + _app_proc.kill() + + +# ── Main ────────────────────────────────────────────────────────────────────── + + +def main(): + app = web.Application() + app.on_startup.append(on_startup) + app.on_cleanup.append(on_cleanup) + + app.router.add_get("/readiness", handle_readiness) + # Catch-all proxy for all other routes + app.router.add_route("*", "/{path:.*}", handle_proxy) + + logger.info("Supervisor starting on port %d", EXTERNAL_PORT) + web.run_app(app, host="0.0.0.0", port=EXTERNAL_PORT, print=None) + + +if __name__ == "__main__": + main() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/wheels/azure_ai_agentserver_core-2.0.0b4-py3-none-any.whl b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/wheels/azure_ai_agentserver_core-2.0.0b4-py3-none-any.whl new file mode 100644 index 000000000000..60c3516368d0 Binary files /dev/null and b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/wheels/azure_ai_agentserver_core-2.0.0b4-py3-none-any.whl differ diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/wheels/azure_ai_agentserver_invocations-1.0.0b4-py3-none-any.whl b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/wheels/azure_ai_agentserver_invocations-1.0.0b4-py3-none-any.whl new file mode 100644 index 000000000000..219a50caa16d Binary files /dev/null and b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable-agent-demo/src/durable-research-agent/wheels/azure_ai_agentserver_invocations-1.0.0b4-py3-none-any.whl differ diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/__init__.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/agent.py new file mode 100644 index 000000000000..e400cd9b5827 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/agent.py @@ -0,0 +1,152 @@ +"""Steerable durable Claude conversation agent. + +Wraps the Anthropic streaming API in a steerable durable task. +Demonstrates the **three-phase cancel pattern**: + +1. Pre-entry check — short-circuit if a newer input is already queued +2. Mid-stream check — break out of the SSE chunk loop +3. Post-completion — catch late arrivals after the reply finished + +Conversation history is stored in an external ``FileStore`` (not in task +metadata, which has a < 1 MB limit). In production, replace ``FileStore`` +with Redis, Cosmos DB, etc. +""" + +import asyncio +import logging +from pathlib import Path +from typing import Any + +from azure.ai.agentserver.core.durable import TaskContext, durable_task + +from .store import FileStore + +logger = logging.getLogger(__name__) + +_DATA_DIR = Path.home() / ".durable-sessions" + +# External stores — NOT in task metadata +invocation_store = FileStore(_DATA_DIR / "claude-invocations") +conversation_store = FileStore(_DATA_DIR / "claude-conversations") + + +def _load_history(session_id: str) -> list[dict[str, str]]: + """Load conversation history from external store.""" + data = conversation_store.load(session_id) + if data and "messages" in data: + return data["messages"] + return [] + + +def _save_history(session_id: str, history: list[dict[str, str]]) -> None: + """Persist conversation history to external store.""" + conversation_store.save(session_id, {"messages": history}) + + +@durable_task(name="claude_session", steerable=True) +async def claude_session(ctx: TaskContext[dict]) -> dict[str, Any]: + """Run one Claude conversation turn with streaming and steering support. + + Input schema: ``{"session_id": str, "message": str, "invocation_id": str}`` + """ + session_id: str = ctx.input["session_id"] + message: str = ctx.input["message"] + invocation_id: str = ctx.input["invocation_id"] + + invocation_store.save(invocation_id, {"status": "running"}) + await ctx.stream({"type": "lifecycle", "status": "running"}) + + logger.info( + "Claude session %s gen=%d invocation=%s entry=%s", + session_id, + ctx.generation, + invocation_id, + ctx.entry_mode, + ) + + # Load history from external store (not task metadata) + history = _load_history(session_id) + history.append({"role": "user", "content": message}) + + # ── Phase 1: Pre-entry cancel (rapid-fire steering) ───────────── + if ctx.cancel.is_set(): + logger.info("Skipping gen=%d — cancel pre-set", ctx.generation) + _save_history(session_id, history) + invocation_store.save( + invocation_id, + { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + }, + ) + return await ctx.suspend(reason="steered") + + # ── Phase 2: Stream Claude response, checking cancel ──────────── + import anthropic # pylint: disable=import-outside-toplevel + + reply = "" + was_aborted = False + + client = anthropic.AsyncAnthropic() + async with client.messages.stream( + model="claude-sonnet-4-20250514", + max_tokens=1024, + messages=history, + ) as stream: + async for text in stream.text_stream: + reply += text + # Live stream: push delta to any SSE subscriber on the POST response + await ctx.stream({"type": "text_delta", "delta": text}) + # Durable snapshot: GET polling always returns the full text so far + invocation_store.save( + invocation_id, + { + "status": "streaming", + "text": reply, + }, + ) + if ctx.cancel.is_set(): + was_aborted = True + logger.info("Stream aborted mid-generation at %d chars", len(reply)) + break + + # ── Phase 3: Save result ──────────────────────────────────────── + # Save history to external store (including partial text) + if reply: + history.append({"role": "assistant", "content": reply}) + _save_history(session_id, history) + + user_turns = len([m for m in history if m["role"] == "user"]) + output = { + "invocation_id": invocation_id, + "reply": reply, + "turn": user_turns, + "partial": was_aborted, + } + + if was_aborted: + invocation_store.save( + invocation_id, + { + "status": "superseded", + "reason": "steered_mid_stream", + "output": output, + }, + ) + return await ctx.suspend(reason="steered") + + if ctx.cancel.is_set(): + invocation_store.save( + invocation_id, + { + "status": "superseded", + "reason": "steered_post_completion", + "output": output, + }, + ) + return await ctx.suspend(reason="steered") + + # Normal completion + invocation_store.save(invocation_id, {"status": "completed", "output": output}) + return await ctx.suspend(reason="awaiting_user_input", output=output) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/app.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/app.py new file mode 100644 index 000000000000..baad0e389f43 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/app.py @@ -0,0 +1,189 @@ +"""HTTP host for the Claude durable agent with steering and streaming. + +Wires the Claude durable task (``agent.py``) to the invocations framework. +With ``steerable=True``, calling ``start()`` on an in-progress task queues +the new input — no manual cancel/wait/restart logic needed. + +**Streaming**: If the POST request includes ``Accept: text/event-stream``, +the response is an SSE stream of text deltas as they are generated. If the +client disconnects mid-stream, it can fall back to ``GET /invocations/`` +which returns the full text snapshot at that moment. + +Usage:: + + pip install -r requirements.txt + export ANTHROPIC_API_KEY="sk-..." + + python -m durable_claude.app + + # Turn 1 (async — poll for result) + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "Tell me about quantum computing"}' + # → 202 {"invocation_id": "...", "status": "running"} + + # Turn 1 (streaming — live text deltas) + curl -N -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -H "Accept: text/event-stream" \\ + -d '{"message": "Tell me about quantum computing"}' + # → 200 data: {"type": "text_delta", "delta": "Quantum"} + # data: {"type": "text_delta", "delta": " computing"} + # ... + # event: done + # data: {"type": "done", ...} + + # Poll (works after disconnect or for async mode) + curl "http://localhost:8088/invocations/" + # → {"invocation_id": "", "status": "completed", "output": {...}} + + # Steer (while turn 1 is still running) + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "Actually, explain machine learning instead"}' +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import AsyncGenerator + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver.invocations import InvocationAgentServerHost + +from .agent import claude_session, invocation_store + +logger = logging.getLogger(__name__) + +app = InvocationAgentServerHost() + + +async def _sse_from_run( + run: object, invocation_id: str, *, initial_status: str = "queued" +) -> AsyncGenerator[bytes, None]: + """Convert a TaskRun's stream into SSE-formatted bytes. + + Yields lifecycle events (``queued``, ``running``), then ``text_delta`` + chunks, then a terminal event (``done``, ``error``, ``superseded``). + + :param run: The TaskRun handle. + :param invocation_id: Invocation identifier for event payloads. + :param initial_status: First lifecycle status to emit (e.g. ``"queued"``). + """ + from azure.ai.agentserver.core.durable import ( # pylint: disable=import-outside-toplevel + TaskCancelled, + TaskFailed, + TaskTerminated, + ) + + # Emit initial lifecycle event so the caller knows the request was accepted + yield ( + f"data: {json.dumps({'type': 'lifecycle', 'status': initial_status, 'invocation_id': invocation_id})}\n\n" + ).encode() + + try: + async for chunk in run: # type: ignore[union-attr] + yield f"data: {json.dumps(chunk)}\n\n".encode() + + # Stream ended normally — get the result + try: + result = await run.result() # type: ignore[union-attr] + done_data = { + "type": "done", + "invocation_id": invocation_id, + } + if ( + result is not None + and hasattr(result, "output") + and result.output is not None + ): + done_data["output"] = result.output + yield f"event: done\ndata: {json.dumps(done_data)}\n\n".encode() + except (TaskCancelled, TaskTerminated): + yield ( + f"event: superseded\n" + f"data: {json.dumps({'type': 'superseded', 'invocation_id': invocation_id})}\n\n" + ).encode() + except TaskFailed as exc: + error_data = { + "type": "error", + "invocation_id": invocation_id, + "error": str(exc), + } + yield f"event: error\ndata: {json.dumps(error_data)}\n\n".encode() + except Exception as exc: # pylint: disable=broad-except + error_data = { + "type": "error", + "invocation_id": invocation_id, + "error": str(exc), + } + yield f"event: error\ndata: {json.dumps(error_data)}\n\n".encode() + + +@app.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Start or steer a Claude session. + + If ``Accept: text/event-stream`` is set, returns an SSE stream of + text deltas. Otherwise returns ``202 Accepted`` for async polling. + """ + data = await request.json() + invocation_id: str = request.state.invocation_id + session_id: str = request.state.session_id + message: str = data.get("message", "") + task_id = f"session-{session_id}" + + task_input = { + "session_id": session_id, + "message": message, + "invocation_id": invocation_id, + } + + invocation_store.save(invocation_id, {"status": "queued"}) + + run = await claude_session.start(task_id=task_id, input=task_input) + + # SSE streaming mode — return live text deltas + wants_stream = "text/event-stream" in request.headers.get("accept", "") + if wants_stream: + return StreamingResponse( + _sse_from_run(run, invocation_id), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + # Async mode — return 202 and let client poll + stored = invocation_store.load(invocation_id) + status = stored["status"] if stored else "queued" + + return JSONResponse( + {"invocation_id": invocation_id, "status": status}, + status_code=202, + ) + + +@app.get_invocation_handler +async def poll_invocation(request: Request) -> Response: + """Poll a specific invocation's result. + + Returns the current snapshot — during streaming this includes + ``{"status": "streaming", "text": "..."}`` with the full text + generated so far. After completion, returns the final output. + + This is the recovery path: if a streaming client disconnects, + it switches to polling to get the accumulated text. + """ + invocation_id: str = request.state.invocation_id + + result = invocation_store.load(invocation_id) + if result is None: + return JSONResponse({"error": "Invocation not found"}, status_code=404) + + return JSONResponse({"invocation_id": invocation_id, **result}) + + +if __name__ == "__main__": + app.run() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/requirements.txt b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/requirements.txt new file mode 100644 index 000000000000..da81ce3dd1a6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/requirements.txt @@ -0,0 +1,5 @@ +anthropic>=0.30.0 +azure-ai-agentserver-core +azure-ai-agentserver-invocations +starlette +uvicorn diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/store.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/store.py new file mode 100644 index 000000000000..1f456a19ea18 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_claude/store.py @@ -0,0 +1,59 @@ +"""File-based key→JSON store for powering the invocation API. + +This module provides a minimal persistence layer that the HTTP host uses to +store per-invocation results. It is **not** part of the durable task +framework — it is the developer's own persistence for powering the API +contract (``GET /invocations/{invocation_id}``). + +.. warning:: + + For demonstration only. In production, use a database (Redis, Cosmos DB, + PostgreSQL, etc.). +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from typing import Any + + +class FileStore: + """Minimal file-backed key→JSON store. + + Each entry is a single JSON file. Writes are atomic (temp + rename). + """ + + def __init__(self, base_dir: Path) -> None: + self._base = base_dir + self._base.mkdir(parents=True, exist_ok=True) + + def save(self, key: str, data: dict[str, Any]) -> None: + """Atomically write *data* as JSON — temp file + rename.""" + target = self._base / f"{key}.json" + fd, tmp_path = tempfile.mkstemp( + dir=str(self._base), suffix=".tmp", prefix=f"{key}_" + ) + try: + with open(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + Path(tmp_path).replace(target) + except BaseException: + Path(tmp_path).unlink(missing_ok=True) + raise + + def load(self, key: str) -> dict[str, Any] | None: + """Return the stored dict, or ``None`` if the key does not exist.""" + path = self._base / f"{key}.json" + if path.exists(): + return json.loads(path.read_text()) + return None + + def delete(self, key: str) -> bool: + """Remove the entry for *key*. Returns ``True`` if it existed.""" + path = self._base / f"{key}.json" + if path.exists(): + path.unlink() + return True + return False diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/__init__.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/agent.py new file mode 100644 index 000000000000..1620e48ab888 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/agent.py @@ -0,0 +1,192 @@ +"""Steerable durable Copilot conversation agent. + +Wraps the **GitHub Copilot SDK** in a steerable durable task. +Demonstrates the **three-phase cancel pattern**: + +1. Pre-entry check — enqueue the message to the SDK then abort immediately +2. Mid-stream check — ``session.abort()`` when ``ctx.cancel`` fires +3. Post-completion — catch late arrivals after the reply finished + +The Copilot SDK manages conversation history internally, so there is no +external history store needed (unlike the Claude sample). +""" + +import asyncio +import logging +from pathlib import Path +from typing import Any + +from azure.ai.agentserver.core.durable import TaskContext, durable_task + +from .store import FileStore + +logger = logging.getLogger(__name__) + +_DATA_DIR = Path.home() / ".durable-sessions" + +invocation_store = FileStore(_DATA_DIR / "copilot-invocations") + + +@durable_task(name="copilot_session", steerable=True) +async def copilot_session(ctx: TaskContext[dict]) -> dict[str, Any]: + """Run one Copilot conversation turn with steering support. + + Input schema: ``{"session_id": str, "message": str, "invocation_id": str}`` + """ + from copilot import CopilotClient # pylint: disable=import-outside-toplevel + from copilot.generated.session_events import ( # pylint: disable=import-outside-toplevel + AssistantMessageData, + IdleData, + ) + from copilot.session import ( + PermissionHandler, + ) # pylint: disable=import-outside-toplevel + + session_id: str = ctx.input["session_id"] + message: str = ctx.input["message"] + invocation_id: str = ctx.input["invocation_id"] + + invocation_store.save(invocation_id, {"status": "running"}) + await ctx.stream({"type": "lifecycle", "status": "running"}) + + logger.info( + "Copilot session %s gen=%d invocation=%s entry=%s", + session_id, + ctx.generation, + invocation_id, + ctx.entry_mode, + ) + + # ── Phase 1: Pre-entry cancel (rapid-fire steering) ───────────── + # Cancel is pre-set when more inputs are already queued. We still + # send the message so the SDK records it, then abort immediately. + if ctx.cancel.is_set(): + logger.info("Skipping gen=%d — cancel pre-set", ctx.generation) + async with CopilotClient() as client: + session = await client.resume_session( + session_id, + on_permission_request=PermissionHandler.approve_all, + ) + await session.send(message) + await session.abort() + invocation_store.save( + invocation_id, + { + "status": "cancelled", + "reason": "steered", + "message_preserved": True, + }, + ) + return await ctx.suspend(reason="steered") + + # ── Phase 2: Stream the Copilot turn, checking cancel ─────────── + reply = "" + was_aborted = False + + async with CopilotClient() as client: + if ctx.entry_mode != "fresh": + session = await client.resume_session( + session_id, + on_permission_request=PermissionHandler.approve_all, + ) + else: + session = await client.create_session( + session_id=session_id, + on_permission_request=PermissionHandler.approve_all, + ) + + # Event-based send: collect reply via events, abort on cancel + reply_parts: list[str] = [] + idle_event = asyncio.Event() + + def on_event(event: Any) -> None: + nonlocal reply_parts + if isinstance(event.data, AssistantMessageData): + content = event.data.content or "" + reply_parts.append(content) + # Schedule streaming — push delta to SSE subscriber and + # persist snapshot for GET polling + asyncio.get_event_loop().create_task( + _stream_and_persist(ctx, invocation_id, content, reply_parts) + ) + elif isinstance(event.data, IdleData): + idle_event.set() + + session.on(on_event) + await session.send(message) + + # Wait for idle (turn complete) or cancel, whichever first + cancel_task = asyncio.create_task(_wait_for_cancel(ctx.cancel)) + idle_task = asyncio.create_task(idle_event.wait()) + try: + done, _pending = await asyncio.wait( + {cancel_task, idle_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + for t in _pending: + t.cancel() + + if cancel_task in done and idle_task not in done: + was_aborted = True + logger.info("session.abort() — new input queued") + await session.abort() + finally: + for t in (cancel_task, idle_task): + if not t.done(): + t.cancel() + + reply = "".join(reply_parts) + + # ── Phase 3: Save result ──────────────────────────────────────── + output = { + "invocation_id": invocation_id, + "reply": reply, + "partial": was_aborted, + } + + if was_aborted: + invocation_store.save( + invocation_id, + { + "status": "superseded", + "reason": "steered_mid_stream", + "output": output, + }, + ) + return await ctx.suspend(reason="steered") + + if ctx.cancel.is_set(): + invocation_store.save( + invocation_id, + { + "status": "superseded", + "reason": "steered_post_completion", + "output": output, + }, + ) + return await ctx.suspend(reason="steered") + + invocation_store.save(invocation_id, {"status": "completed", "output": output}) + return await ctx.suspend(reason="awaiting_user_input", output=output) + + +async def _wait_for_cancel(cancel: asyncio.Event) -> None: + """Await the cancel event. Extracted for use with ``asyncio.wait``.""" + await cancel.wait() + + +async def _stream_and_persist( + ctx: TaskContext[dict], + invocation_id: str, + delta: str, + parts: list[str], +) -> None: + """Push a streaming delta and persist the text snapshot.""" + await ctx.stream({"type": "text_delta", "delta": delta}) + invocation_store.save( + invocation_id, + { + "status": "streaming", + "text": "".join(parts), + }, + ) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/app.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/app.py new file mode 100644 index 000000000000..1e04aa4c1b5b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/app.py @@ -0,0 +1,168 @@ +"""HTTP host for the Copilot durable agent with steering and streaming. + +Wires the Copilot durable task (``agent.py``) to the invocations framework. +With ``steerable=True``, calling ``start()`` on an in-progress task queues +the new input — no manual cancel/wait/restart logic needed. + +**Streaming**: If the POST request includes ``Accept: text/event-stream``, +the response is an SSE stream of text deltas as they are generated. If the +client disconnects mid-stream, it can fall back to ``GET /invocations/`` +which returns the full text snapshot at that moment. + +Requires the **GitHub Copilot SDK** (``pip install github-copilot-sdk``) +and the Copilot CLI installed and authenticated (``gh auth login``). + +Usage:: + + pip install -r requirements.txt + + python -m durable_copilot.app + + # Turn 1 (async) + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "Explain Python decorators"}' + + # Turn 1 (streaming) + curl -N -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -H "Accept: text/event-stream" \\ + -d '{"message": "Explain Python decorators"}' + + # Poll (recovery after disconnect) + curl "http://localhost:8088/invocations/" + + # Steer (while turn 1 is still running) + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "Actually, explain async/await instead"}' +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import AsyncGenerator + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver.invocations import InvocationAgentServerHost + +from .agent import copilot_session, invocation_store + +logger = logging.getLogger(__name__) + +app = InvocationAgentServerHost() + + +async def _sse_from_run( + run: object, invocation_id: str, *, initial_status: str = "queued" +) -> AsyncGenerator[bytes, None]: + """Convert a TaskRun's stream into SSE-formatted bytes.""" + from azure.ai.agentserver.core.durable import ( # pylint: disable=import-outside-toplevel + TaskCancelled, + TaskFailed, + TaskTerminated, + ) + + yield ( + f"data: {json.dumps({'type': 'lifecycle', 'status': initial_status, 'invocation_id': invocation_id})}\n\n" + ).encode() + + try: + async for chunk in run: # type: ignore[union-attr] + yield f"data: {json.dumps(chunk)}\n\n".encode() + + try: + result = await run.result() # type: ignore[union-attr] + done_data = {"type": "done", "invocation_id": invocation_id} + if ( + result is not None + and hasattr(result, "output") + and result.output is not None + ): + done_data["output"] = result.output + yield f"event: done\ndata: {json.dumps(done_data)}\n\n".encode() + except (TaskCancelled, TaskTerminated): + yield ( + f"event: superseded\n" + f"data: {json.dumps({'type': 'superseded', 'invocation_id': invocation_id})}\n\n" + ).encode() + except TaskFailed as exc: + error_data = { + "type": "error", + "invocation_id": invocation_id, + "error": str(exc), + } + yield f"event: error\ndata: {json.dumps(error_data)}\n\n".encode() + except Exception as exc: # pylint: disable=broad-except + error_data = { + "type": "error", + "invocation_id": invocation_id, + "error": str(exc), + } + yield f"event: error\ndata: {json.dumps(error_data)}\n\n".encode() + + +@app.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Start or steer a Copilot session. + + If ``Accept: text/event-stream`` is set, returns an SSE stream. + Otherwise returns ``202 Accepted`` for async polling. + """ + data = await request.json() + invocation_id: str = request.state.invocation_id + session_id: str = request.state.session_id + message: str = data.get("message", "") + task_id = f"session-{session_id}" + + task_input = { + "session_id": session_id, + "message": message, + "invocation_id": invocation_id, + } + + invocation_store.save(invocation_id, {"status": "queued"}) + + run = await copilot_session.start(task_id=task_id, input=task_input) + + # SSE streaming mode + wants_stream = "text/event-stream" in request.headers.get("accept", "") + if wants_stream: + return StreamingResponse( + _sse_from_run(run, invocation_id), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + # Async mode + stored = invocation_store.load(invocation_id) + status = stored["status"] if stored else "queued" + + return JSONResponse( + {"invocation_id": invocation_id, "status": status}, + status_code=202, + ) + + +@app.get_invocation_handler +async def poll_invocation(request: Request) -> Response: + """Poll a specific invocation's result. + + Returns the current snapshot — during streaming this includes the + full text generated so far. This is the recovery path after a + streaming disconnect. + """ + invocation_id: str = request.state.invocation_id + + result = invocation_store.load(invocation_id) + if result is None: + return JSONResponse({"error": "Invocation not found"}, status_code=404) + + return JSONResponse({"invocation_id": invocation_id, **result}) + + +if __name__ == "__main__": + app.run() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/requirements.txt b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/requirements.txt new file mode 100644 index 000000000000..a5c8adee9c42 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/requirements.txt @@ -0,0 +1,5 @@ +github-copilot-sdk +azure-ai-agentserver-core +azure-ai-agentserver-invocations +starlette +uvicorn diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/store.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/store.py new file mode 100644 index 000000000000..1f456a19ea18 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_copilot/store.py @@ -0,0 +1,59 @@ +"""File-based key→JSON store for powering the invocation API. + +This module provides a minimal persistence layer that the HTTP host uses to +store per-invocation results. It is **not** part of the durable task +framework — it is the developer's own persistence for powering the API +contract (``GET /invocations/{invocation_id}``). + +.. warning:: + + For demonstration only. In production, use a database (Redis, Cosmos DB, + PostgreSQL, etc.). +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from typing import Any + + +class FileStore: + """Minimal file-backed key→JSON store. + + Each entry is a single JSON file. Writes are atomic (temp + rename). + """ + + def __init__(self, base_dir: Path) -> None: + self._base = base_dir + self._base.mkdir(parents=True, exist_ok=True) + + def save(self, key: str, data: dict[str, Any]) -> None: + """Atomically write *data* as JSON — temp file + rename.""" + target = self._base / f"{key}.json" + fd, tmp_path = tempfile.mkstemp( + dir=str(self._base), suffix=".tmp", prefix=f"{key}_" + ) + try: + with open(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + Path(tmp_path).replace(target) + except BaseException: + Path(tmp_path).unlink(missing_ok=True) + raise + + def load(self, key: str) -> dict[str, Any] | None: + """Return the stored dict, or ``None`` if the key does not exist.""" + path = self._base / f"{key}.json" + if path.exists(): + return json.loads(path.read_text()) + return None + + def delete(self, key: str) -> bool: + """Remove the entry for *key*. Returns ``True`` if it existed.""" + path = self._base / f"{key}.json" + if path.exists(): + path.unlink() + return True + return False diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/__init__.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/agent.py new file mode 100644 index 000000000000..fd5261d69d50 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/agent.py @@ -0,0 +1,423 @@ +"""LangGraph conversation agent with durable task lifecycle and steering. + +Wraps a LangGraph ``StateGraph`` in a steerable durable task. +Demonstrates the **checkpoint-and-fork** cancel pattern: + +1. Pre-entry check — short-circuit if cancel is pre-set +2. Inter-node check — ``_invoke_cancellable`` checks between graph nodes +3. Fork-on-steer — roll back to the last stable checkpoint and fork + with the new message + +LangGraph owns the conversation flow; the durable task owns crash +resilience and steering orchestration. +""" + +import asyncio +import logging +import sqlite3 +import typing +from pathlib import Path +from typing import Any + +from langchain_core.messages import AIMessage, HumanMessage +from langgraph.checkpoint.sqlite import SqliteSaver +from langgraph.graph import END, START, StateGraph, add_messages +from langgraph.types import Command, interrupt +from typing_extensions import TypedDict + +from azure.ai.agentserver.core.durable import TaskContext, durable_task + +from .store import FileStore + +logger = logging.getLogger(__name__) + +_DATA_DIR = Path.home() / ".durable-sessions" + +# Invocation result store — written inside the durable task so it survives crashes +invocation_store = FileStore(_DATA_DIR / "invocations") + + +# --------------------------------------------------------------------------- +# Graph state +# --------------------------------------------------------------------------- + + +class ConversationState(TypedDict): + """Graph state for a multi-turn conversation. + + Uses LangGraph's built-in ``add_messages`` reducer for message + accumulation across turns. + """ + + messages: typing.Annotated[list, add_messages] + is_complete: bool + + +# --------------------------------------------------------------------------- +# Graph nodes +# --------------------------------------------------------------------------- + +# Simulated step delay — distributed across nodes so inter-node +# cancellation (via ``graph.stream()``) can bail out quickly. +_STEP_DELAY = 2 # seconds per processing node + + +def analyze_input(state: ConversationState) -> dict[str, Any]: + """Simulate analysing the user's message (e.g., intent detection).""" + import time # pylint: disable=import-outside-toplevel + + _ = state # Would inspect messages in a real implementation + time.sleep(_STEP_DELAY) + return {} # No state change — analysis is an internal step + + +def generate_response(state: ConversationState) -> dict[str, Any]: + """Generate an AI response. Replace stub with a real LLM call.""" + import time # pylint: disable=import-outside-toplevel + + time.sleep(_STEP_DELAY) + + messages = state["messages"] + user_messages = [m for m in messages if isinstance(m, HumanMessage)] + turn = len(user_messages) + last_msg = user_messages[-1].content if user_messages else "" + + if turn == 1: + reply = ( + f"Thanks for reaching out! You said: '{last_msg}'. " + "I'd love to help — could you share more details?" + ) + elif turn == 2: + reply = ( + f"Great context: '{last_msg}'. Building on our earlier " + "exchange, here are some initial thoughts. What else " + "would you like to explore?" + ) + else: + reply = ( + f"Turn {turn}: incorporating '{last_msg}' — I now have " + f"context from {turn} turns. How shall we proceed?" + ) + + return {"messages": [AIMessage(content=reply)]} + + +def refine_response(state: ConversationState) -> dict[str, Any]: + """Simulate post-processing (e.g., safety checks, formatting).""" + import time # pylint: disable=import-outside-toplevel + + _ = state # Would inspect the generated reply in a real implementation + time.sleep(_STEP_DELAY // 2 or 1) + return {} # No state change — refinement is an internal step + + +def wait_for_user(state: ConversationState) -> dict[str, Any]: + """Pause the graph and wait for the next human message.""" + messages = state["messages"] + user_count = len([m for m in messages if isinstance(m, HumanMessage)]) + + user_input: str = interrupt( + { + "prompt": "Please provide your next message (or say 'done' to finish):", + "current_turn": user_count, + } + ) + + if user_input.strip().lower() == "done": + return {"is_complete": True} + + return { + "messages": [HumanMessage(content=user_input)], + "is_complete": False, + } + + +def _should_continue(state: ConversationState) -> str: + """Route: loop back to process_input or end the conversation.""" + if state.get("is_complete", False): + return "end" + return "continue" + + +# --------------------------------------------------------------------------- +# Persistent graph checkpointer (survives restarts) +# --------------------------------------------------------------------------- + +_DATA_DIR.mkdir(parents=True, exist_ok=True) +_DB_PATH = _DATA_DIR / "langgraph_checkpoints.db" + +_conn = sqlite3.connect(str(_DB_PATH), check_same_thread=False) +_checkpointer = SqliteSaver(_conn) +_checkpointer.setup() + +logger.info("LangGraph checkpoints stored at: %s", _DB_PATH) + + +# --------------------------------------------------------------------------- +# Build and compile the graph +# --------------------------------------------------------------------------- + + +def _build_graph() -> Any: + """Construct the LangGraph StateGraph for multi-turn conversation. + + Processing is split across three nodes (``analyze_input`` → + ``generate_response`` → ``refine_response``) so that stream-based + cancellation can bail out between any two steps (~2 s granularity). + """ + builder = StateGraph(ConversationState) + + builder.add_node("analyze_input", analyze_input) + builder.add_node("generate_response", generate_response) + builder.add_node("refine_response", refine_response) + builder.add_node("wait_for_user", wait_for_user) + + builder.add_edge(START, "analyze_input") + builder.add_edge("analyze_input", "generate_response") + builder.add_edge("generate_response", "refine_response") + builder.add_edge("refine_response", "wait_for_user") + + builder.add_conditional_edges( + "wait_for_user", + _should_continue, + { + "continue": "analyze_input", + "end": END, + }, + ) + + return builder.compile(checkpointer=_checkpointer) + + +_graph = _build_graph() + + +# --------------------------------------------------------------------------- +# Steering — cancellable graph invocation and state forking +# --------------------------------------------------------------------------- + + +def _invoke_cancellable( + graph: Any, + graph_input: Any, + config: dict[str, Any], + cancel_event: asyncio.Event, + on_node: Any = None, +) -> bool: + """Run the graph using ``stream()`` with inter-node cancellation. + + Instead of ``graph.invoke()`` which blocks until the full graph + completes, this streams node-by-node and checks ``cancel_event`` + between nodes. If cancellation is detected, execution stops before + the next node runs. + + Returns ``True`` if the graph ran to completion (or interrupt), + ``False`` if cancelled mid-graph. + """ + for chunk in graph.stream(graph_input, config): + if on_node is not None: + on_node(chunk) + if cancel_event.is_set(): + return False + return True + + +def _fork_from_checkpoint( + graph: Any, + config: dict[str, Any], + target_checkpoint_id: str, + new_message: str, +) -> bool: + """Fork the graph from a previous checkpoint with a new message. + + Uses LangGraph's native state forking: ``update_state`` called with + an old checkpoint's config creates a new branch. The graph's head + pointer moves to the fork, discarding any state that was added after + the target checkpoint. + + After forking the graph is positioned after ``wait_for_user`` with + the new message injected, so the next step is ``process_input``. + + Returns ``True`` if the fork was created. + """ + # Load the target checkpoint to get its full config (includes checkpoint_ns) + target_config = { + "configurable": { + **config["configurable"], + "checkpoint_id": target_checkpoint_id, + } + } + target = graph.get_state(target_config) + if not target or not target.config: + return False + + # Fork: update_state at the old checkpoint creates a new branch + graph.update_state( + target.config, + values={"messages": [HumanMessage(content=new_message)]}, + as_node="wait_for_user", + ) + return True + + +def _build_turn_output(state: Any) -> dict[str, Any]: + """Extract turn output from graph state at an interrupt.""" + messages = state.values.get("messages", []) + ai_messages = [m for m in messages if isinstance(m, AIMessage)] + user_messages = [m for m in messages if isinstance(m, HumanMessage)] + last_reply = ai_messages[-1].content if ai_messages else "" + return {"reply": last_reply, "turn": len(user_messages)} + + +def _build_session_output(state: Any) -> dict[str, Any]: + """Build final output when the graph conversation is complete.""" + messages = state.values.get("messages", []) + user_count = len([m for m in messages if isinstance(m, HumanMessage)]) + return { + "finished": True, + "turn_count": user_count, + "total_messages": len(messages), + "summary": f"Session complete after {user_count} turns.", + } + + +async def _finalize_invocation( + ctx: TaskContext[dict], + thread_config: dict[str, Any], + invocation_id: str, +) -> dict[str, Any] | Any: + """Save results and suspend/return after a graph invoke completes.""" + state = await asyncio.to_thread(_graph.get_state, thread_config) + + new_cp_id = state.config["configurable"]["checkpoint_id"] + ctx.metadata.set("stable_checkpoint_id", new_cp_id) + ctx.metadata.set("last_applied_invocation_id", invocation_id) + + if state.next: + output = _build_turn_output(state) + invocation_store.save(invocation_id, {"status": "completed", "output": output}) + return await ctx.suspend(reason="awaiting_user_input", output=output) + + result = _build_session_output(state) + invocation_store.save(invocation_id, {"status": "completed", "output": result}) + return result + + +# --------------------------------------------------------------------------- +# Durable task — bridges LangGraph with HTTP lifecycle +# --------------------------------------------------------------------------- + + +@durable_task(name="langgraph_session", steerable=True) +async def langgraph_session(ctx: TaskContext[dict]) -> dict[str, Any]: + """Run one LangGraph conversation turn with steering support. + + Input schema: ``{"session_id": str, "message": str, "invocation_id": str}`` + """ + session_id: str = ctx.input["session_id"] + message: str = ctx.input["message"] + invocation_id: str = ctx.input["invocation_id"] + + invocation_store.save(invocation_id, {"status": "running"}) + await ctx.stream({"type": "lifecycle", "status": "running"}) + + thread_config: dict[str, Any] = {"configurable": {"thread_id": session_id}} + + if ctx.entry_mode == "recovered": + logger.warning("Recovered stale task for session %s", session_id) + + # ── Fork-on-steer: rollback to stable checkpoint ──────────────── + # If the previous invocation was cancelled mid-flight, the graph may + # have drifted past the stable checkpoint. Fork from the stable + # checkpoint with the new message so the graph processes it cleanly. + stable_cp = ctx.metadata.get("stable_checkpoint_id") + if stable_cp: + state = await asyncio.to_thread(_graph.get_state, thread_config) + if state and state.values.get("messages"): + current_cp = state.config["configurable"].get("checkpoint_id") + if current_cp and current_cp != stable_cp: + forked = await asyncio.to_thread( + _fork_from_checkpoint, + _graph, + thread_config, + stable_cp, + message, + ) + if forked: + logger.info( + "Forked session %s from stable checkpoint %s", + session_id, + stable_cp, + ) + completed = await asyncio.to_thread( + _invoke_cancellable, + _graph, + None, + thread_config, + ctx.cancel, + ) + + if not completed or ctx.cancel.is_set(): + invocation_store.save( + invocation_id, + {"status": "cancelled", "reason": "steered"}, + ) + return await ctx.suspend(reason="steered") + + return await _finalize_invocation(ctx, thread_config, invocation_id) + + # ── Phase 1: Pre-entry cancel ─────────────────────────────────── + if ctx.cancel.is_set(): + invocation_store.save( + invocation_id, {"status": "cancelled", "reason": "steered"} + ) + return await ctx.suspend(reason="steered") + + # ── Phase 2: Invoke graph with inter-node cancellation ────────── + state = await asyncio.to_thread(_graph.get_state, thread_config) + + if state.next: + graph_input = Command(resume=message) + else: + graph_input = { + "messages": [HumanMessage(content=message)], + "is_complete": False, + } + + loop = asyncio.get_event_loop() + + def _on_node(chunk: dict) -> None: + """Stream node progress events from the sync graph thread.""" + node_names = list(chunk.keys()) + for name in node_names: + if ctx._stream_handler is not None: # pylint: disable=protected-access + asyncio.run_coroutine_threadsafe( + ctx.stream({"type": "node_progress", "node": name}), + loop, + ) + invocation_store.save( + invocation_id, + { + "status": "streaming", + "last_node": node_names[-1] if node_names else None, + }, + ) + + completed = await asyncio.to_thread( + _invoke_cancellable, + _graph, + graph_input, + thread_config, + ctx.cancel, + _on_node, + ) + + # ── Phase 3: Post-completion cancel check ─────────────────────── + if not completed or ctx.cancel.is_set(): + invocation_store.save( + invocation_id, {"status": "cancelled", "reason": "steered"} + ) + return await ctx.suspend(reason="steered") + + # Normal completion + return await _finalize_invocation(ctx, thread_config, invocation_id) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/app.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/app.py new file mode 100644 index 000000000000..517de7c8f2c9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/app.py @@ -0,0 +1,184 @@ +"""HTTP host for the LangGraph durable agent with streaming and steering. + +Wires the LangGraph durable task (``agent.py``) to the invocations framework. +Per-invocation results are written by the durable task itself (inside the +crash-resilient execution boundary), not by a background collector. + +Streaming +~~~~~~~~~ + +Pass ``Accept: text/event-stream`` on POST to receive an SSE stream of node +progress events (``node_progress``) plus lifecycle events (``queued``, +``running``). Without the header you get the standard 202 JSON response for +async polling via GET. + +Steering is handled by the framework: the durable task is declared with +``steerable=True``, so calling ``start()`` on an in-progress task **queues** +the new input instead of raising ``TaskConflictError``. The running function +sees ``ctx.cancel`` set and short-circuits. The framework then drains the +queue and re-enters the function with the next input. + +Usage:: + + pip install -r requirements.txt + + python -m durable_langgraph.app + # — or — + python app.py + + # Turn 1 — async + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "I need help planning a trip to Tokyo"}' + # → 202 (x-agent-invocation-id: ) + + # Turn 1 — streaming + curl -N -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -H "Accept: text/event-stream" \\ + -d '{"message": "I need help planning a trip to Tokyo"}' + # → SSE stream: lifecycle:queued → lifecycle:running → node_progress → done + + # Poll that invocation (snapshot — always available) + curl "http://localhost:8088/invocations/" + # → {"invocation_id": "", "status": "completed", "output": {...}} + + # Steer — send a new invocation while a turn is still running. + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "Actually, let us go to Paris instead"}' + + # End session + curl -X POST "http://localhost:8088/invocations?agent_session_id=demo-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "done"}' +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import AsyncGenerator + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver.invocations import InvocationAgentServerHost + +from .agent import invocation_store, langgraph_session + +logger = logging.getLogger(__name__) + +app = InvocationAgentServerHost() + + +async def _sse_from_run( + run: object, invocation_id: str, *, initial_status: str = "queued" +) -> AsyncGenerator[bytes, None]: + """Convert a TaskRun's stream into SSE-formatted bytes.""" + from azure.ai.agentserver.core.durable import ( # pylint: disable=import-outside-toplevel + TaskCancelled, + TaskFailed, + TaskTerminated, + ) + + yield ( + f"data: {json.dumps({'type': 'lifecycle', 'status': initial_status, 'invocation_id': invocation_id})}\n\n" + ).encode() + + try: + async for chunk in run: # type: ignore[union-attr] + yield f"data: {json.dumps(chunk)}\n\n".encode() + + try: + result = await run.result() # type: ignore[union-attr] + done_data = {"type": "done", "invocation_id": invocation_id} + if ( + result is not None + and hasattr(result, "output") + and result.output is not None + ): + done_data["output"] = result.output + yield f"event: done\ndata: {json.dumps(done_data)}\n\n".encode() + except (TaskCancelled, TaskTerminated): + yield ( + f"event: superseded\n" + f"data: {json.dumps({'type': 'superseded', 'invocation_id': invocation_id})}\n\n" + ).encode() + except TaskFailed as exc: + error_data = { + "type": "error", + "invocation_id": invocation_id, + "error": str(exc), + } + yield f"event: error\ndata: {json.dumps(error_data)}\n\n".encode() + except Exception as exc: # pylint: disable=broad-except + error_data = { + "type": "error", + "invocation_id": invocation_id, + "error": str(exc), + } + yield f"event: error\ndata: {json.dumps(error_data)}\n\n".encode() + + +@app.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Start or steer a LangGraph session. + + If ``Accept: text/event-stream`` is set, returns an SSE stream of node + progress events. Otherwise returns ``202 Accepted`` for async polling. + """ + data = await request.json() + invocation_id: str = request.state.invocation_id + session_id: str = request.state.session_id + message: str = data.get("message", "") + task_id = f"session-{session_id}" + + task_input = { + "session_id": session_id, + "message": message, + "invocation_id": invocation_id, + } + + invocation_store.save(invocation_id, {"status": "queued"}) + + run = await langgraph_session.start(task_id=task_id, input=task_input) + + # SSE streaming mode — return live node progress + wants_stream = "text/event-stream" in request.headers.get("accept", "") + if wants_stream: + return StreamingResponse( + _sse_from_run(run, invocation_id), + media_type="text/event-stream", + headers={"X-Agent-Invocation-Id": invocation_id}, + ) + + # Standard async mode — return 202 with status from store + stored = invocation_store.load(invocation_id) + status = stored["status"] if stored else "queued" + + return JSONResponse( + {"invocation_id": invocation_id, "status": status}, + status_code=202, + ) + + +@app.get_invocation_handler +async def poll_invocation(request: Request) -> Response: + """Poll a specific invocation's snapshot. + + Returns the durable snapshot from the invocation store. During streaming + this includes ``last_node``; after completion it includes full output. + Use this as the recovery path after an SSE disconnect. + """ + invocation_id: str = request.state.invocation_id + + result = invocation_store.load(invocation_id) + if result is None: + return JSONResponse({"error": "Invocation not found"}, status_code=404) + + return JSONResponse({"invocation_id": invocation_id, **result}) + + +if __name__ == "__main__": + app.run() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/requirements.txt b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/requirements.txt new file mode 100644 index 000000000000..79260e068214 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/requirements.txt @@ -0,0 +1,4 @@ +azure-ai-agentserver-invocations +langgraph>=0.2 +langgraph-checkpoint-sqlite>=2.0 +langchain-core>=0.3 diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/store.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/store.py new file mode 100644 index 000000000000..1f456a19ea18 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_langgraph/store.py @@ -0,0 +1,59 @@ +"""File-based key→JSON store for powering the invocation API. + +This module provides a minimal persistence layer that the HTTP host uses to +store per-invocation results. It is **not** part of the durable task +framework — it is the developer's own persistence for powering the API +contract (``GET /invocations/{invocation_id}``). + +.. warning:: + + For demonstration only. In production, use a database (Redis, Cosmos DB, + PostgreSQL, etc.). +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from typing import Any + + +class FileStore: + """Minimal file-backed key→JSON store. + + Each entry is a single JSON file. Writes are atomic (temp + rename). + """ + + def __init__(self, base_dir: Path) -> None: + self._base = base_dir + self._base.mkdir(parents=True, exist_ok=True) + + def save(self, key: str, data: dict[str, Any]) -> None: + """Atomically write *data* as JSON — temp file + rename.""" + target = self._base / f"{key}.json" + fd, tmp_path = tempfile.mkstemp( + dir=str(self._base), suffix=".tmp", prefix=f"{key}_" + ) + try: + with open(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + Path(tmp_path).replace(target) + except BaseException: + Path(tmp_path).unlink(missing_ok=True) + raise + + def load(self, key: str) -> dict[str, Any] | None: + """Return the stored dict, or ``None`` if the key does not exist.""" + path = self._base / f"{key}.json" + if path.exists(): + return json.loads(path.read_text()) + return None + + def delete(self, key: str) -> bool: + """Remove the entry for *key*. Returns ``True`` if it existed.""" + path = self._base / f"{key}.json" + if path.exists(): + path.unlink() + return True + return False diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/__init__.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/agent.py new file mode 100644 index 000000000000..d54d0b4c76eb --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/agent.py @@ -0,0 +1,105 @@ +"""Durable multi-turn session agent. + +Defines the durable task that powers a sticky conversation session. Each +invocation runs this function from the top — ``ctx.entry_mode`` tells us +whether this is a fresh start, a resume, or a crash recovery. + +The agent keeps its own conversation state in a ``FileStore`` checkpoint +and writes per-invocation results to the invocation store — both inside +the durable execution boundary so they survive crashes. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from azure.ai.agentserver.core.durable import ( + TaskContext, + durable_task, +) + +from .store import FileStore + +logger = logging.getLogger(__name__) + +_DATA_DIR = Path.home() / ".durable-sessions" + +# Session checkpoint store — conversation state across turns +checkpoint_store = FileStore(_DATA_DIR / "checkpoints") + +# Invocation result store — written inside the durable task so it survives crashes +invocation_store = FileStore(_DATA_DIR / "invocations") + + +def _generate_reply(state: dict[str, Any]) -> str: + """Placeholder for an LLM call. Replace with your model of choice.""" + turn = state["turn_count"] + last_msg = state["history"][-1]["content"] if state["history"] else "" + if turn == 1: + return ( + f"Thanks for reaching out! You said: '{last_msg}'. " + "Could you share more details so I can help?" + ) + if turn == 2: + return ( + f"Great, noted: '{last_msg}'. Based on our conversation " + "so far, here are some initial thoughts. What else?" + ) + return ( + f"Turn {turn}: incorporating '{last_msg}' — " + f"I now have context from {turn} turns of conversation." + ) + + +@durable_task(name="session_workflow") +async def session_workflow(ctx: TaskContext[dict]) -> dict[str, Any]: + """Single durable function for the entire session. + + Each invocation runs this function from the top. + ``ctx.entry_mode`` tells us why we were entered. + + The invocation result is written to ``invocation_store`` **inside** the + durable boundary — if the process crashes, the task recovers and the + write happens on re-execution. + """ + session_id: str = ctx.input["session_id"] + message: str = ctx.input["message"] + invocation_id: str = ctx.input["invocation_id"] + + # Mark invocation as running — inside the durable boundary so it + # only exists if the task is actually executing. + invocation_store.save(invocation_id, {"status": "running"}) + + state = checkpoint_store.load(session_id) or {"history": [], "turn_count": 0} + + if ctx.entry_mode == "recovered": + logger.warning("Recovered stale task for session %s", session_id) + + # Handle explicit session end + if message.strip().lower() == "done": + summary = ( + f"Session complete after {state['turn_count']} turns. " + f"Total messages exchanged: {len(state['history'])}." + ) + checkpoint_store.delete(session_id) + result = {"reply": summary, "turn": state["turn_count"], "finished": True} + invocation_store.save(invocation_id, {"status": "completed", "output": result}) + return result + + # Process this turn + state["history"].append({"role": "user", "content": message}) + state["turn_count"] += 1 + + reply = _generate_reply(state) + state["history"].append({"role": "assistant", "content": reply}) + + checkpoint_store.save(session_id, state) + + # Persist invocation result BEFORE suspending (inside durable boundary) + output = {"reply": reply, "turn": state["turn_count"]} + invocation_store.save(invocation_id, {"status": "completed", "output": output}) + + # Suspend — the client will resume with the next turn + return await ctx.suspend(reason="awaiting_user_input", output=output) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/app.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/app.py new file mode 100644 index 000000000000..91e7daec9240 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/app.py @@ -0,0 +1,100 @@ +"""HTTP host for the durable multi-turn agent. + +Wires the durable task (``agent.py``) to the invocations framework. +Per-invocation results are written by the durable task itself (inside the +crash-resilient execution boundary), not by a background collector. + +Usage:: + + pip install azure-ai-agentserver-invocations + + python -m durable_multiturn.app + # — or — + python app.py + + # Turn 1 + curl -X POST "http://localhost:8088/invocations?agent_session_id=trip-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "I want to plan a vacation to Japan"}' + # → 202 (x-agent-invocation-id: ) + + # Poll that invocation + curl "http://localhost:8088/invocations/" + # → {"invocation_id": "", "status": "completed", "output": {...}} + + # Turn 2 + curl -X POST "http://localhost:8088/invocations?agent_session_id=trip-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "Budget is $5000, 2 weeks"}' + + # End session + curl -X POST "http://localhost:8088/invocations?agent_session_id=trip-001" \\ + -H "Content-Type: application/json" \\ + -d '{"message": "done"}' +""" + +from __future__ import annotations + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver.core.durable import TaskConflictError +from azure.ai.agentserver.invocations import InvocationAgentServerHost + +from .agent import invocation_store, session_workflow + +app = InvocationAgentServerHost() + + +@app.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Start or resume a durable session task. + + Each POST is one invocation. The durable task is an internal detail + — the caller only sees ``invocation_id`` (from platform headers). + + The task itself writes the invocation result to the store inside the + durable execution boundary — no background collector needed. + """ + data = await request.json() + invocation_id: str = request.state.invocation_id + session_id: str = request.state.session_id + message: str = data.get("message", "") + task_id = f"session-{session_id}" + + try: + await session_workflow.start( + task_id=task_id, + input={ + "session_id": session_id, + "message": message, + "invocation_id": invocation_id, + }, + ) + except TaskConflictError as e: + return JSONResponse({"error": str(e)}, status_code=409) + + return JSONResponse( + {"invocation_id": invocation_id, "status": "running"}, + status_code=202, + ) + + +@app.get_invocation_handler +async def poll_invocation(request: Request) -> Response: + """Poll a specific invocation's result. + + Reads from the file-based invocation store — works after restarts. + Returns the output of **this invocation only** — not the whole session. + """ + invocation_id: str = request.state.invocation_id + + result = invocation_store.load(invocation_id) + if result is None: + return JSONResponse({"error": "Invocation not found"}, status_code=404) + + return JSONResponse({"invocation_id": invocation_id, **result}) + + +if __name__ == "__main__": + app.run() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/requirements.txt b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/requirements.txt new file mode 100644 index 000000000000..bc5cf4644e14 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-invocations diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/store.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/store.py new file mode 100644 index 000000000000..003049988a81 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/durable_multiturn/store.py @@ -0,0 +1,57 @@ +"""File-based key→JSON store for powering the invocation API. + +This module provides a minimal persistence layer that the HTTP host uses to +store per-invocation results. It is **not** part of the durable task +framework — it is the developer's own persistence for powering the API +contract (``GET /invocations/{invocation_id}``). + +.. warning:: + + For demonstration only. In production, use a database (Redis, Cosmos DB, + PostgreSQL, etc.). +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from typing import Any + + +class FileStore: + """Minimal file-backed key→JSON store. + + Each entry is a single JSON file. Writes are atomic (temp + rename). + """ + + def __init__(self, base_dir: Path) -> None: + self._base = base_dir + self._base.mkdir(parents=True, exist_ok=True) + + def save(self, key: str, data: dict[str, Any]) -> None: + """Atomically write *data* as JSON — temp file + rename.""" + target = self._base / f"{key}.json" + fd, tmp_path = tempfile.mkstemp( + dir=str(self._base), suffix=".tmp", prefix=f"{key}_" + ) + try: + with open(fd, "w") as f: + json.dump(data, f, indent=2) + Path(tmp_path).replace(target) + except BaseException: + Path(tmp_path).unlink(missing_ok=True) + raise + + def load(self, key: str) -> dict[str, Any] | None: + """Return the stored dict, or ``None`` if the key does not exist.""" + path = self._base / f"{key}.json" + if path.exists(): + return json.loads(path.read_text()) + return None + + def delete(self, key: str) -> None: + """Remove the entry for *key* (no-op if missing).""" + path = self._base / f"{key}.json" + if path.exists(): + path.unlink() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/multiturn_invoke_agent/multiturn_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/multiturn_invoke_agent/multiturn_invoke_agent.py index 96fa857bf02c..ddef29b2864d 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/samples/multiturn_invoke_agent/multiturn_invoke_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/multiturn_invoke_agent/multiturn_invoke_agent.py @@ -32,12 +32,12 @@ -d '{"message": "Budget is $5000, prefer direct flights"}' # -> {"reply": "Here is a suggested itinerary ...", ...} """ + from starlette.requests import Request from starlette.responses import JSONResponse, Response from azure.ai.agentserver.invocations import InvocationAgentServerHost - app = InvocationAgentServerHost() # In-memory session store — keyed by session ID. @@ -91,11 +91,13 @@ async def handle_invoke(request: Request) -> Response: reply = _build_reply(history) history.append({"role": "assistant", "content": reply}) - return JSONResponse({ - "reply": reply, - "session_id": session_id, - "turn": len([m for m in history if m["role"] == "user"]), - }) + return JSONResponse( + { + "reply": reply, + "session_id": session_id, + "turn": len([m for m in history if m["role"] == "user"]), + } + ) if __name__ == "__main__": diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/simple_invoke_agent/simple_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/simple_invoke_agent/simple_invoke_agent.py index a2e7fdb32d3b..adb537cf5dce 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/samples/simple_invoke_agent/simple_invoke_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/simple_invoke_agent/simple_invoke_agent.py @@ -11,12 +11,12 @@ curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"name": "Alice"}' # -> {"greeting": "Hello, Alice!"} """ + from starlette.requests import Request from starlette.responses import JSONResponse, Response from azure.ai.agentserver.invocations import InvocationAgentServerHost - app = InvocationAgentServerHost() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/streaming_invoke_agent/streaming_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/streaming_invoke_agent/streaming_invoke_agent.py index a207a93cca0d..c5caf7b5a920 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/samples/streaming_invoke_agent/streaming_invoke_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/streaming_invoke_agent/streaming_invoke_agent.py @@ -18,6 +18,7 @@ # -> event: done # -> data: {"invocation_id": "..."} """ + import asyncio import json from collections.abc import AsyncGenerator # pylint: disable=import-error @@ -27,14 +28,32 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - app = InvocationAgentServerHost() # Simulated tokens — in production these would come from a model. _SIMULATED_TOKENS = [ - "class", " Calculator", ":", "\n", - " ", "def", " add", "(", "self", ",", " a", ",", " b", ")", ":", "\n", - " ", "return", " a", " +", " b", "\n", + "class", + " Calculator", + ":", + "\n", + " ", + "def", + " add", + "(", + "self", + ",", + " a", + ",", + " b", + ")", + ":", + "\n", + " ", + "return", + " a", + " +", + " b", + "\n", ] diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_decorator_pattern.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_decorator_pattern.py index 73307f2ba110..4bb7d141570a 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_decorator_pattern.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_decorator_pattern.py @@ -10,11 +10,11 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - # --------------------------------------------------------------------------- # invoke_handler stores function # --------------------------------------------------------------------------- + def test_invoke_handler_stores_function(): """@app.invoke_handler stores the function on the protocol object.""" app = InvocationAgentServerHost() @@ -30,6 +30,7 @@ async def handle(request: Request) -> Response: # invoke_handler returns original function # --------------------------------------------------------------------------- + def test_invoke_handler_returns_original_function(): """@app.invoke_handler returns the original function.""" app = InvocationAgentServerHost() @@ -45,6 +46,7 @@ async def handle(request: Request) -> Response: # get_invocation_handler stores function # --------------------------------------------------------------------------- + def test_get_invocation_handler_stores_function(): """@app.get_invocation_handler stores the function.""" app = InvocationAgentServerHost() @@ -60,6 +62,7 @@ async def get_handler(request: Request) -> Response: # cancel_invocation_handler stores function # --------------------------------------------------------------------------- + def test_cancel_invocation_handler_stores_function(): """@app.cancel_invocation_handler stores the function.""" app = InvocationAgentServerHost() @@ -75,6 +78,7 @@ async def cancel_handler(request: Request) -> Response: # shutdown_handler stores function # --------------------------------------------------------------------------- + def test_shutdown_handler_stores_function(): """@server.shutdown_handler stores the function on the server.""" app = InvocationAgentServerHost() @@ -90,6 +94,7 @@ async def on_shutdown(): # Full request flow # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_full_request_flow(): """Full lifecycle: invoke → get → cancel → get (404).""" @@ -107,7 +112,9 @@ async def get_handler(request: Request) -> Response: inv_id = request.path_params["invocation_id"] if inv_id in store: return Response(content=store[inv_id]) - return JSONResponse({"error": {"code": "not_found", "message": "Not found"}}, status_code=404) + return JSONResponse( + {"error": {"code": "not_found", "message": "Not found"}}, status_code=404 + ) @app.cancel_invocation_handler async def cancel_handler(request: Request) -> Response: @@ -115,7 +122,9 @@ async def cancel_handler(request: Request) -> Response: if inv_id in store: del store[inv_id] return JSONResponse({"status": "cancelled"}) - return JSONResponse({"error": {"code": "not_found", "message": "Not found"}}, status_code=404) + return JSONResponse( + {"error": {"code": "not_found", "message": "Not found"}}, status_code=404 + ) transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://testserver") as client: @@ -142,6 +151,7 @@ async def cancel_handler(request: Request) -> Response: # Missing optional handlers return 404 # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_missing_invoke_handler_returns_501(): """POST /invocations without registered handler returns 501.""" @@ -186,6 +196,7 @@ async def handle(request: Request) -> Response: # Optional handler defaults and overrides # --------------------------------------------------------------------------- + def test_optional_handlers_default_none(): """Get and cancel handlers default to None.""" app = InvocationAgentServerHost() @@ -208,6 +219,7 @@ async def get_handler(request: Request) -> Response: # Shutdown handler called during lifespan # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_shutdown_handler_called_during_lifespan(): """Shutdown handler is called when the app lifespan ends.""" @@ -235,6 +247,7 @@ async def on_shutdown(): # Config passthrough # --------------------------------------------------------------------------- + def test_graceful_shutdown_timeout_passthrough(): """graceful_shutdown_timeout is passed through to the base class.""" server = InvocationAgentServerHost(graceful_shutdown_timeout=15) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_edge_cases.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_edge_cases.py index 351418db7461..999f46310e07 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_edge_cases.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_edge_cases.py @@ -64,6 +64,7 @@ async def handle(request: Request) -> Response: # Method not allowed tests # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_invocations_returns_405(): """GET /invocations returns 405 Method Not Allowed.""" @@ -128,6 +129,7 @@ async def handle(request: Request) -> Response: # Response header tests # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_custom_invocation_id_overwritten(): """Handler-set x-agent-invocation-id is overwritten by the server.""" @@ -176,6 +178,7 @@ async def test_invocation_id_generated_when_empty(echo_client): # Payload edge cases # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_large_payload(): """Large payload (1MB) is handled correctly.""" @@ -210,6 +213,7 @@ async def test_binary_payload(echo_client): # Streaming edge cases # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_empty_streaming(): """Empty streaming response doesn't crash.""" @@ -243,6 +247,7 @@ async def generate(): # Invocation lifecycle # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_multiple_gets(async_storage_client): """Multiple GETs for the same invocation return the same result.""" @@ -283,6 +288,7 @@ async def test_invoke_cancel_get(async_storage_client): # Concurrency # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_concurrent_invocations_get_unique_ids(): """10 concurrent POSTs each get unique invocation IDs.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_get_cancel.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_get_cancel.py index 23c133fe3b9b..5cf42f8fe4c6 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_get_cancel.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_get_cancel.py @@ -10,11 +10,11 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - # --------------------------------------------------------------------------- # GET after invoke # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_after_invoke_returns_stored_result(async_storage_client): """GET /invocations/{id} after invoke returns the stored result.""" @@ -31,6 +31,7 @@ async def test_get_after_invoke_returns_stored_result(async_storage_client): # GET unknown ID # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_unknown_id_returns_404(async_storage_client): """GET /invocations/{unknown} returns 404.""" @@ -42,6 +43,7 @@ async def test_get_unknown_id_returns_404(async_storage_client): # Cancel after invoke # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_cancel_after_invoke_returns_cancelled(async_storage_client): """POST /invocations/{id}/cancel after invoke returns cancelled status.""" @@ -57,6 +59,7 @@ async def test_cancel_after_invoke_returns_cancelled(async_storage_client): # Cancel unknown ID # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_cancel_unknown_id_returns_404(async_storage_client): """POST /invocations/{unknown}/cancel returns 404.""" @@ -68,6 +71,7 @@ async def test_cancel_unknown_id_returns_404(async_storage_client): # GET after cancel # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_after_cancel_returns_404(async_storage_client): """GET after cancel returns 404 (data has been removed).""" @@ -83,6 +87,7 @@ async def test_get_after_cancel_returns_404(async_storage_client): # GET error returns 500 (inline InvocationAgentServerHost) # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_invocation_error_returns_500(): """GET handler raising an exception returns 500.""" @@ -107,6 +112,7 @@ async def get_handler(request: Request) -> Response: # Cancel error returns 500 (inline InvocationAgentServerHost) # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_cancel_invocation_error_returns_500(): """Cancel handler raising an exception returns 500.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_graceful_shutdown.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_graceful_shutdown.py index db35beceda0f..0cb5ed48cf0e 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_graceful_shutdown.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_graceful_shutdown.py @@ -13,11 +13,11 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _make_server_with_shutdown(**kwargs) -> tuple[InvocationAgentServerHost, list]: """Create InvocationAgentServerHost with a tracked shutdown handler.""" server = InvocationAgentServerHost(**kwargs) @@ -38,6 +38,7 @@ async def on_shutdown(): # Shutdown handler registration # --------------------------------------------------------------------------- + def test_shutdown_handler_registered(): """Shutdown handler is stored on the server.""" server, _ = _make_server_with_shutdown() @@ -59,6 +60,7 @@ async def handle(request: Request) -> Response: # ASGI lifespan helper # --------------------------------------------------------------------------- + async def _drive_lifespan(app): """Drive a full ASGI lifespan startup+shutdown cycle.""" scope = {"type": "lifespan"} @@ -84,6 +86,7 @@ async def send(message): # Shutdown handler called during lifespan # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_shutdown_handler_called_on_lifespan_exit(): """Shutdown handler runs when the ASGI lifespan exits.""" @@ -99,6 +102,7 @@ async def test_shutdown_handler_called_on_lifespan_exit(): # Shutdown handler timeout # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_shutdown_handler_timeout(caplog): """Shutdown handler that exceeds timeout is warned about.""" @@ -120,13 +124,17 @@ async def on_shutdown(): # Shutdown should have been interrupted assert "completed" not in calls # Logger should have warned about timeout - assert any("did not complete" in r.message.lower() or "timeout" in r.message.lower() for r in caplog.records) + assert any( + "did not complete" in r.message.lower() or "timeout" in r.message.lower() + for r in caplog.records + ) # --------------------------------------------------------------------------- # Shutdown handler exception # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_shutdown_handler_exception(caplog): """Shutdown handler that raises is caught and logged.""" @@ -144,13 +152,17 @@ async def on_shutdown(): await _drive_lifespan(app) # Should have logged the exception - assert any("on_shutdown" in r.message.lower() or "error" in r.message.lower() for r in caplog.records) + assert any( + "on_shutdown" in r.message.lower() or "error" in r.message.lower() + for r in caplog.records + ) # --------------------------------------------------------------------------- # Graceful shutdown timeout config # --------------------------------------------------------------------------- + def test_default_graceful_shutdown_timeout(): """Default graceful shutdown timeout is 30 seconds.""" app = InvocationAgentServerHost() @@ -173,6 +185,7 @@ def test_zero_graceful_shutdown_timeout(): # Health endpoint accessible during normal operation # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_health_endpoint_during_operation(): """GET /readiness returns 200 during normal operation.""" @@ -188,6 +201,7 @@ async def test_health_endpoint_during_operation(): # No shutdown handler is no-op # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_no_shutdown_handler_is_noop(): """Without a shutdown handler, lifespan exit succeeds silently.""" @@ -208,6 +222,7 @@ async def handle(request: Request) -> Response: # Multiple requests before shutdown # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_multiple_requests_before_shutdown(): """Multiple requests can be served, then shutdown handler runs.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_invoke.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_invoke.py index 5de15efd63cc..198cbcd76711 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_invoke.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_invoke.py @@ -12,6 +12,7 @@ # Echo body # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_invoke_echo_body(echo_client): """POST /invocations echoes the request body.""" @@ -24,6 +25,7 @@ async def test_invoke_echo_body(echo_client): # Headers # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_invoke_returns_invocation_id_header(echo_client): """Response includes x-agent-invocation-id header.""" @@ -68,6 +70,7 @@ async def test_invoke_accepts_custom_invocation_id(echo_client): # Streaming # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_streaming_returns_chunks(streaming_client): """Streaming handler returns 3 JSON chunks.""" @@ -91,6 +94,7 @@ async def test_streaming_has_invocation_id_header(streaming_client): # Empty body # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_invoke_empty_body(echo_client): """Empty body doesn't crash the server.""" @@ -103,6 +107,7 @@ async def test_invoke_empty_body(echo_client): # Error handling # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_invoke_error_returns_500(failing_client): """Handler exception returns 500 with generic message.""" @@ -124,6 +129,7 @@ async def test_invoke_error_has_invocation_id(failing_client): # Error handling # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_error_hides_details_by_default(failing_client): """Exception message is hidden in error responses.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_multimodal_protocol.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_multimodal_protocol.py index 818eb20c491e..ee866da198fb 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_multimodal_protocol.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_multimodal_protocol.py @@ -12,11 +12,11 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - # --------------------------------------------------------------------------- # Helper: content-type echo agent # --------------------------------------------------------------------------- + def _make_content_type_echo_agent() -> InvocationAgentServerHost: """Agent that echoes body and returns the content-type it received.""" app = InvocationAgentServerHost() @@ -66,6 +66,7 @@ async def generate(): # Various content types # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_png_content_type(): """PNG content type is accepted and echoed.""" @@ -166,6 +167,7 @@ async def test_text_plain_content_type(): # Custom HTTP status codes # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_custom_status_200(): """Handler returning 200.""" @@ -200,6 +202,7 @@ async def test_custom_status_202(): # Query strings # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_query_string_passed_to_handler(): """Query string params are accessible in the handler.""" @@ -221,6 +224,7 @@ async def handle(request: Request) -> Response: # SSE streaming # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_sse_streaming(): """SSE-formatted streaming response works.""" @@ -238,6 +242,7 @@ async def test_sse_streaming(): # Large binary payloads # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_large_binary_payload(): """Large binary payload (512KB) is handled correctly.""" @@ -258,6 +263,7 @@ async def test_large_binary_payload(): # Health endpoint (updated from /healthy to /readiness) # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_health_endpoint_returns_200(): """GET /readiness returns 200 with healthy status.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_id.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_id.py index 934433bd0333..4b087e1e958b 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_id.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_id.py @@ -19,6 +19,7 @@ # Header presence — success responses # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_invoke_returns_request_id_header(echo_client): """POST /invocations success response includes x-request-id.""" @@ -61,6 +62,7 @@ async def test_readiness_returns_request_id(echo_client): # Error responses — header present, but NO body enrichment # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_error_response_has_request_id_header(failing_client): """500 error response includes x-request-id header.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_limits.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_limits.py index 24d71ed51e8f..95b24d827638 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_limits.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_request_limits.py @@ -10,11 +10,11 @@ from azure.ai.agentserver.invocations import InvocationAgentServerHost - # --------------------------------------------------------------------------- # InvocationAgentServerHost no longer accepts request_timeout # --------------------------------------------------------------------------- + def test_no_request_timeout_parameter(): """InvocationAgentServerHost no longer accepts request_timeout.""" with pytest.raises(TypeError): @@ -25,6 +25,7 @@ def test_no_request_timeout_parameter(): # Slow invoke completes without timeout # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_slow_invoke_completes(): """Without timeout, handler runs to completion.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_server_routes.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_server_routes.py index 8bafb6fb9608..80d560c5b965 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_server_routes.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_server_routes.py @@ -18,6 +18,7 @@ # POST /invocations returns 200 # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_post_invocations_returns_200(echo_client): """POST /invocations returns 200 OK.""" @@ -29,6 +30,7 @@ async def test_post_invocations_returns_200(echo_client): # POST /invocations returns invocation-id header (UUID) # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_post_invocations_returns_uuid_invocation_id(echo_client): """POST /invocations returns a valid UUID in x-agent-invocation-id.""" @@ -42,6 +44,7 @@ async def test_post_invocations_returns_uuid_invocation_id(echo_client): # GET openapi spec returns 404 when not set # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_openapi_spec_returns_404_when_not_set(no_spec_client): """GET /invocations/docs/openapi.json returns 404 when no spec registered.""" @@ -53,6 +56,7 @@ async def test_get_openapi_spec_returns_404_when_not_set(no_spec_client): # GET openapi spec returns spec when registered # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_openapi_spec_returns_spec_when_registered(): """GET /invocations/docs/openapi.json returns the spec when registered.""" @@ -73,6 +77,7 @@ async def handle(request: Request) -> Response: # GET /invocations/{id} returns 404 default # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_invocation_returns_404_default(echo_client): """GET /invocations/{id} returns 404 when no get handler registered.""" @@ -84,6 +89,7 @@ async def test_get_invocation_returns_404_default(echo_client): # POST /invocations/{id}/cancel returns 404 default # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_cancel_invocation_returns_404_default(echo_client): """POST /invocations/{id}/cancel returns 404 when no cancel handler.""" @@ -95,6 +101,7 @@ async def test_cancel_invocation_returns_404_default(echo_client): # Unknown route returns 404 # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_unknown_route_returns_404(echo_client): """Unknown route returns 404.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_session_id.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_session_id.py index 6398f2f8d327..7a8f54751859 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_session_id.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_session_id.py @@ -20,6 +20,7 @@ # Constants # --------------------------------------------------------------------------- + def test_session_id_header_constant(): """SESSION_ID_HEADER constant is correct.""" assert InvocationConstants.SESSION_ID_HEADER == "x-agent-session-id" @@ -29,6 +30,7 @@ def test_session_id_header_constant(): # POST /invocations response has x-agent-session-id header # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_post_invocations_has_session_id_header(echo_client): """POST /invocations response includes x-agent-session-id header.""" @@ -42,6 +44,7 @@ async def test_post_invocations_has_session_id_header(echo_client): # POST /invocations with query param uses that value # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_post_invocations_with_query_param(): """POST /invocations with agent_session_id query param uses that value.""" @@ -64,6 +67,7 @@ async def handle(request: Request) -> Response: # POST /invocations with env var # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_post_invocations_uses_env_var(): """POST /invocations uses FOUNDRY_AGENT_SESSION_ID env var when no query param.""" @@ -75,7 +79,9 @@ async def handle(request: Request) -> Response: return Response(content=b"ok") transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://testserver") as client: + async with AsyncClient( + transport=transport, base_url="http://testserver" + ) as client: resp = await client.post("/invocations", content=b"test") assert resp.headers["x-agent-session-id"] == "env-session" @@ -84,6 +90,7 @@ async def handle(request: Request) -> Response: # GET /invocations/{id} does NOT have x-agent-session-id header # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_invocation_no_session_id_header(async_storage_client): """GET /invocations/{id} does NOT include x-agent-session-id.""" @@ -99,6 +106,7 @@ async def test_get_invocation_no_session_id_header(async_storage_client): # POST /invocations/{id}/cancel does NOT have x-agent-session-id header # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_cancel_invocation_no_session_id_header(async_storage_client): """POST /invocations/{id}/cancel does NOT include x-agent-session-id.""" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_tracing_e2e.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_tracing_e2e.py index 11f0b0f9f9b2..584884d92eb2 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_tracing_e2e.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_tracing_e2e.py @@ -36,7 +36,9 @@ def _flush_provider(): provider.force_flush() -def _poll_appinsights(logs_client, resource_id, query, *, timeout=_APPINSIGHTS_POLL_TIMEOUT): +def _poll_appinsights( + logs_client, resource_id, query, *, timeout=_APPINSIGHTS_POLL_TIMEOUT +): """Poll Application Insights until the KQL query returns >= 1 row or timeout.""" from azure.core.exceptions import ServiceRequestError @@ -111,6 +113,7 @@ def _warmup_appinsights(): # E2E test # --------------------------------------------------------------------------- + class TestInvocationTracingE2E: """Verify that user-created spans inside InvocationAgentServerHost handlers land in App Insights.""" @@ -143,7 +146,9 @@ async def handle(request: Request) -> Response: return Response(content=body, media_type="application/octet-stream") transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://testserver") as client: + async with AsyncClient( + transport=transport, base_url="http://testserver" + ) as client: resp = await client.post("/invocations", content=b"hello e2e") assert resp.status_code == 200 diff --git a/sdk/agentserver/azure-ai-agentserver-responses/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-responses/pyproject.toml index 2e51d7728bfd..9091ab8b4724 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-responses/pyproject.toml @@ -69,3 +69,5 @@ azure-sdk-tools = { path = "../../../eng/tools/azure-sdk-tools" } [tool.azure-sdk-build] verifytypes = false latestdependency = false +# azure-ai-agentserver-core>=2.0.0b4 is not yet on PyPI +mindependency = false