diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 1d608365986..4079e1a05ad 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -9,7 +9,7 @@ """ import importlib.metadata -from typing import Final +from typing import TYPE_CHECKING, Any, Final try: _version = importlib.metadata.version(__name__) @@ -264,6 +264,7 @@ ) from ._workflows._agent_utils import resolve_agent_id from ._workflows._checkpoint import ( + CheckpointID, CheckpointStorage, FileCheckpointStorage, InMemoryCheckpointStorage, @@ -307,7 +308,6 @@ workflow, ) from ._workflows._request_info_mixin import response_handler -from ._workflows._runner import Runner from ._workflows._runner_context import ( InProcRunnerContext, RunnerContext, @@ -405,6 +405,7 @@ "ChatResponse", "ChatResponseUpdate", "CheckResult", + "CheckpointID", "CheckpointStorage", "ClassSkill", "CompactionProvider", @@ -618,3 +619,20 @@ "validate_workflow_graph", "workflow", ] + +if TYPE_CHECKING: + from ._workflows._runner import Runner + + +def __getattr__(name: str) -> Any: + """Lazily resolve deprecated public names, emitting a ``DeprecationWarning``. + + ``Runner`` remains importable from ``agent_framework`` for backward + compatibility but is deprecated and slated for removal from the public API. + """ + if name == "Runner": + from ._workflows._runner import Runner, warn_runner_deprecated + + warn_runner_deprecated() + return Runner + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 51a3312e2ba..039684ca48c 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -3,6 +3,7 @@ import asyncio import contextlib import logging +import warnings from collections import defaultdict from collections.abc import AsyncGenerator, Sequence from typing import Any @@ -10,7 +11,6 @@ from ..exceptions import ( WorkflowCheckpointException, WorkflowConvergenceException, - WorkflowRunnerException, ) from ._checkpoint import CheckpointID, CheckpointStorage, WorkflowCheckpoint from ._const import EXECUTOR_STATE_KEY @@ -27,6 +27,21 @@ logger = logging.getLogger(__name__) +def warn_runner_deprecated() -> None: + """Emit a deprecation warning when ``Runner`` is accessed from the public API. + + ``Runner`` remains importable from ``agent_framework`` for backward + compatibility, but it is intended for internal use only and will be removed + from the public API in a future version. + """ + warnings.warn( + "`Runner` is deprecated and will be removed from the public API in a future version. " + "It is intended for internal use only.", + DeprecationWarning, + stacklevel=3, + ) + + class Runner: """A class to run a workflow in Pregel supersteps.""" @@ -63,25 +78,34 @@ def __init__( self._iteration = 0 self._max_iterations = max_iterations self._state = state - self._running = False - self._resumed_from_checkpoint = False # Track whether we resumed + + # Checkpointing related attributes + self._resumed_from_checkpoint = False + self._previous_checkpoint_id: CheckpointID | None = None @property def context(self) -> RunnerContext: - """Get the workflow context.""" + """Get the runner context for message, event, and checkpoint handling.""" return self._ctx + @property + def state(self) -> State: + """Get the shared state for the workflow.""" + return self._state + def reset_iteration_count(self) -> None: - """Reset the iteration count to zero.""" + """Reset the iteration count to zero. + + This is useful when the workflow resumes from a new set of messages. + + Note: + When a workflow is resumed from a response (for a request_info_event) + or a checkpoint, the iteration count is normally NOT reset. + """ self._iteration = 0 async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: """Run the workflow until no more messages are sent.""" - if self._running: - raise WorkflowRunnerException("Runner is already running.") - - self._running = True - previous_checkpoint_id: CheckpointID | None = None try: # Emit any events already produced prior to entering loop if await self._ctx.has_events(): @@ -89,12 +113,12 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: for event in await self._ctx.drain_events(): yield event - # Create the first checkpoint. Checkpoints are usually considered to be created at the end of an iteration, - # we can think of the first checkpoint as being created at the end of a "superstep 0" which captures the - # states after which the start executor has run. Note that we execute the start executor outside of the - # main iteration loop. - if await self._ctx.has_messages() and not self._resumed_from_checkpoint: - previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id) + # Create a checkpoint before a run starts. Checkpoints are usually considered to be created at the + # end of an iteration, we can think of this checkpoint as being created at the end of "superstep 0" + # which captures the states after which the start executor has run. Note that we execute the start + # executor outside of the main iteration loop. + if await self._ctx.has_messages() and self._iteration == 0 and not self._resumed_from_checkpoint: + await self.create_checkpoint_if_enabled() while self._iteration < self._max_iterations: logger.info(f"Starting superstep {self._iteration + 1}") @@ -141,7 +165,7 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: self._state.commit() # Create checkpoint after each superstep iteration - previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id) + await self.create_checkpoint_if_enabled() yield WorkflowEvent.superstep_completed(iteration=self._iteration) @@ -149,13 +173,15 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: if not await self._ctx.has_messages(): break + logger.info(f"Workflow completed after {self._iteration} supersteps") + if self._iteration >= self._max_iterations and await self._ctx.has_messages(): raise WorkflowConvergenceException(f"Runner did not converge after {self._max_iterations} iterations.") - - logger.info(f"Workflow completed after {self._iteration} supersteps") - self._resumed_from_checkpoint = False # Reset resume flag for next run finally: - self._running = False + # Reset the resume flag so stale resume state never leaks into the next run on this + # instance - even if convergence raised before completing (e.g. an executor failure + # during a resumed run). + self._resumed_from_checkpoint = False async def _run_iteration(self) -> None: """Run a single iteration of the workflow. @@ -209,40 +235,55 @@ async def _deliver_messages_for_edge_runner(edge_runner: EdgeRunner) -> None: ] await asyncio.gather(*tasks) - async def _create_checkpoint_if_enabled(self, previous_checkpoint_id: CheckpointID | None) -> CheckpointID | None: + async def _prepare_checkpoint_state(self) -> None: + """Persist executor snapshots into committed shared state. + + This is used by checkpoint capture paths that need a complete, restorable + state payload without necessarily writing to a checkpoint storage backend. + """ + await self._save_executor_states() + self._state.commit() + + async def create_checkpoint_if_enabled(self) -> None: """Create a checkpoint if checkpointing is enabled and attach a label and metadata.""" if not self._ctx.has_checkpointing(): - return None + return try: - # Save executor states into the shared state before creating the checkpoint, - # so that they are included in the checkpoint payload. - await self._save_executor_states() - # `on_checkpoint_save()` writes via State.set(), which stages values in the - # pending buffer. Checkpoints serialize committed state only, so commit here - # to ensure executor snapshots are captured in this checkpoint. - self._state.commit() + # Save executor states into committed state before creating the checkpoint. + await self._prepare_checkpoint_state() checkpoint_id = await self._ctx.create_checkpoint( self._workflow_name, self._graph_signature_hash, self._state, - previous_checkpoint_id, + self._previous_checkpoint_id, self._iteration, ) - logger.info(f"Created checkpoint: {checkpoint_id}") - return checkpoint_id + logger.info( + "Created checkpoint: %s with parent checkpoint at iteration %d: %s", + checkpoint_id, + self._iteration, + self._previous_checkpoint_id, + ) + self._previous_checkpoint_id = checkpoint_id except Exception as e: - logger.warning(f"Failed to create checkpoint: {e}") - return None + logger.warning( + "Failed to create checkpoint at iteration %d: %s. " + "Note that this does not fail the workflow run. " + "The next successfully-created checkpoint will be parented to the last successful checkpoint: %s", + self._iteration, + e, + self._previous_checkpoint_id, + ) async def restore_from_checkpoint( self, checkpoint_id: CheckpointID, checkpoint_storage: CheckpointStorage | None = None, ) -> None: - """Restore workflow state from a checkpoint. + """Restore the runner from a checkpoint. Args: checkpoint_id: The ID of the checkpoint to restore from @@ -290,7 +331,7 @@ async def restore_from_checkpoint( # Apply the checkpoint to the context await self._ctx.apply_checkpoint(checkpoint) # Mark the runner as resumed - self._mark_resumed(checkpoint.iteration_count) + self._mark_resumed(checkpoint) logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}") except WorkflowCheckpointException: @@ -356,13 +397,14 @@ def _parse_edge_runners(self, edge_runners: list[EdgeRunner]) -> dict[str, list[ return parsed - def _mark_resumed(self, iteration: int) -> None: + def _mark_resumed(self, checkpoint: WorkflowCheckpoint) -> None: """Mark the runner as having resumed from a checkpoint. Optionally set the current iteration and max iterations. """ self._resumed_from_checkpoint = True - self._iteration = iteration + self._iteration = checkpoint.iteration_count + self._previous_checkpoint_id = checkpoint.checkpoint_id async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None: """Store executor state in state under a reserved key. diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index c4840bb0455..fca1b06b6da 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -11,12 +11,14 @@ import types import uuid import warnings +import weakref from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal, overload from .._sessions import ContextProvider from .._types import ResponseStream +from ..exceptions import WorkflowException from ..observability import OtelAttr, capture_exception, create_workflow_span from ._checkpoint import CheckpointStorage from ._const import DEFAULT_MAX_ITERATIONS, GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY @@ -346,25 +348,29 @@ def __init__( # Store non-serializable runtime objects as private attributes self._runner_context = runner_context self._runner_context.set_yield_output_classifier(self._output_designation.classify) - self._state = State() self._runner: Runner = Runner( self.edge_groups, self.executors, - self._state, + State(), runner_context, self.name, self.graph_signature_hash, max_iterations=max_iterations, ) - # Flag to prevent concurrent workflow executions - self._is_running = False - # Current run-level status of this workflow instance. Updated in lockstep with # the status events emitted from `_run_workflow_with_tracing`. Defaults to IDLE # for a freshly built workflow that has not yet been run. self._status: WorkflowRunState = WorkflowRunState.IDLE + # Weak reference to the in-flight run's ``ResponseStream``. Used as the single + # concurrency lock: if the previous stream is still alive, ``run()`` rejects a + # new run synchronously (before any await). When the stream is fully consumed + # ``_run_core``'s finally clears this; if the caller drops the stream without + # ever iterating, the weakref dereferences to ``None`` once Python collects it, + # so a subsequent ``run()`` is allowed. + self._active_run: weakref.ref[ResponseStream[WorkflowEvent, WorkflowRunResult]] | None = None + @property def status(self) -> WorkflowRunState: """Return the current run-level status of this workflow instance. @@ -376,16 +382,6 @@ def status(self) -> WorkflowRunState: """ return self._status - def _ensure_not_running(self) -> None: - """Ensure the workflow is not already running.""" - if self._is_running: - raise RuntimeError("Workflow is already running. Concurrent executions are not allowed.") - self._is_running = True - - def _reset_running_flag(self) -> None: - """Reset the running flag.""" - self._is_running = False - def to_dict(self) -> dict[str, Any]: """Serialize the workflow definition into a JSON-ready dictionary.""" data: dict[str, Any] = { @@ -535,13 +531,12 @@ async def _run_workflow_with_tracing( yield in_progress # noqa: RUF070 # Per-run reset for fresh-message runs only. We deliberately - # do NOT clear shared workflow state (`_state.clear()`) or the - # runner context's in-flight messages (`reset_for_new_run()`) - # here - state and pending work persist across `run()` calls - # so that a `WorkflowAgent` can deliver multi-turn input on - # the same instance and have prior turns' context survive. - # Iteration counting and per-run kwargs ARE per-run though, - # so they're reset here. + # do NOT clear shared workflow state or the runner context's + # in-flight messages here - state and pending work persist + # across `run()` calls so that a `WorkflowAgent` can deliver + # multi-turn input on the same instance and have prior turns' + # context survive. Iteration counting and per-run kwargs ARE + # per-run though, so they're reset here. if not is_continuation: self._runner.reset_iteration_count() @@ -564,14 +559,13 @@ async def _run_workflow_with_tracing( combined_kwargs["client_kwargs"] = self._resolve_invocation_kwargs( client_kwargs, "client_kwargs" ) - self._state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs) + self._runner.state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs) elif not is_continuation: - self._state.set(WORKFLOW_RUN_KWARGS_KEY, {}) - self._state.commit() # Commit immediately so kwargs are available + self._runner.state.set(WORKFLOW_RUN_KWARGS_KEY, {}) + self._runner.state.commit() # Commit immediately so kwargs are available - # Set streaming mode (always set explicitly per run since - # reset_for_new_run() no longer runs to clear it). - self._runner_context.set_streaming(streaming) + # Explicitly set streaming mode per run + self._runner.context.set_streaming(streaming) # Execute initial setup if provided if initial_executor_fn: @@ -665,7 +659,7 @@ async def _execute_with_message_or_checkpoint( await executor.execute( message, [self.__class__.__name__], - self._state, + self._runner.state, self._runner.context, trace_contexts=None, source_span_ids=None, @@ -745,9 +739,28 @@ def run( Raises: ValueError: If parameter combination is invalid. """ - # Validate parameters and set running flag eagerly (before any async work) + # Validate parameters first so misuse fails before we touch any run state. self._validate_run_params(message, responses, checkpoint_id) - self._ensure_not_running() + + # Concurrency check: reject a second run synchronously - before constructing + # the ResponseStream or yielding control to the event loop - so a concurrent + # ``run`` call can't slip past the guard while the first call is suspended + # inside its async generator. The ``ResponseStream`` returned below is the + # lock: as long as the caller holds a reference to it, ``self._active_run()`` + # resolves to a live object and a new ``run`` is rejected. When the stream is + # fully consumed, ``_run_core``'s finally clears the attribute. When the + # caller drops the stream without iterating, garbage collection invalidates + # the weakref, so a subsequent ``run`` is permitted. + if self._is_run_active(): + raise WorkflowException( + "Workflow is already running; concurrent runs are not allowed on the same instance." + ) + + # No run is active, so any runtime checkpoint storage override still set on the + # context is stale - left over from a prior run whose stream was dropped before + # its async-generator finalizer ran. Clear it so this run starts clean and does + # not silently inherit the prior run's runtime checkpoint storage. + self._runner.context.clear_runtime_checkpoint_storage() response_stream = ResponseStream[WorkflowEvent, WorkflowRunResult]( self._run_core( @@ -760,10 +773,8 @@ def run( client_kwargs=client_kwargs, ), finalizer=functools.partial(self._finalize_events, include_status_events=include_status_events), - cleanup_hooks=[ - functools.partial(self._run_cleanup, checkpoint_storage), - ], ) + self._active_run = weakref.ref(response_stream) if stream: return response_stream @@ -785,55 +796,79 @@ async def _run_core( Yields: WorkflowEvent: The events generated during the workflow execution. """ - # Enable runtime checkpointing if storage provided + # Capture the weakref instance ``run()`` installed for *this* run. We + # compare by object identity in the finally so a stale finalizer (e.g. + # the caller dropped this stream after partial iteration, then started + # a new run before async-gen finalization throws ``GeneratorExit`` into + # us) does not clobber a successor run's freshly installed weakref. + # ``run()`` runs synchronously and assigns ``self._active_run`` before + # this generator's body is first iterated, so by the time we read it + # here it already points at our own ``ResponseStream``. + my_active_run = self._active_run + + # Enable runtime checkpointing if storage provided. if checkpoint_storage is not None: self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage) - # Async validation: a fresh-message run is only allowed when the - # runner context has fully drained from any prior run. If it still - # has in-flight executor messages, the prior run didn't complete - - # the caller must either resume from a checkpoint or wait for the - # prior run to drain. (Pending request_info events are intentionally - # NOT blocked here: a follow-up run with message=... is the normal - # way to deliver a response to those pending requests, e.g. via - # WorkflowAgent._process_pending_requests.) - # NOTE: _validate_run_params already enforces that ``message`` is - # mutually exclusive with both ``checkpoint_id`` and ``responses``, - # so we don't need to re-check those here. - if message is not None and await self._runner.context.has_messages(): - raise RuntimeError( - "Cannot start a new run with 'message' while in-flight executor " - "messages remain from a prior run. Resume from a checkpoint " - "(checkpoint_id=...) or wait for the prior run to complete. " - "Workflows that need to recover from a mid-run failure must use " - "checkpointing; there is no in-process recovery path." - ) + try: + # Async validation: a fresh-message run is only allowed when the + # runner context has fully drained from any prior run. If it still + # has in-flight executor messages, the prior run didn't complete - + # the caller must either resume from a checkpoint or wait for the + # prior run to drain. (Pending request_info events are intentionally + # NOT blocked here: a follow-up run with message=... is the normal + # way to deliver a response to those pending requests, e.g. via + # WorkflowAgent._process_pending_requests.) + # NOTE: _validate_run_params already enforces that ``message`` is + # mutually exclusive with both ``checkpoint_id`` and ``responses``, + # so we don't need to re-check those here. + if message is not None and await self._runner.context.has_messages(): + raise RuntimeError( + "Cannot start a new run with 'message' while in-flight executor " + "messages remain from a prior run. Resume from a checkpoint " + "(checkpoint_id=...) or wait for the prior run to complete. " + "Workflows that need to recover from a mid-run failure must use " + "checkpointing; there is no in-process recovery path." + ) - initial_executor_fn = self._resolve_execution_mode(message, responses, checkpoint_id, checkpoint_storage) - - async for event in self._run_workflow_with_tracing( - initial_executor_fn=initial_executor_fn, - is_continuation=(message is None), - streaming=streaming, - function_invocation_kwargs=function_invocation_kwargs, - client_kwargs=client_kwargs, - ): - if event.type == "request_info" and event.request_id in (responses or {}): - # Don't yield request_info events for which we have responses to send - - # these are considered "handled". This prevents the caller from seeing - # events for requests they are already responding to. - # This usually happens when responses are provided with a checkpoint - # (restore then send), because the request_info events are stored in the - # checkpoint and would be emitted on restoration by the runner regardless - # of if a response is provided or not. - continue - yield event + initial_executor_fn = self._resolve_execution_mode(message, responses, checkpoint_id, checkpoint_storage) - async def _run_cleanup(self, checkpoint_storage: CheckpointStorage | None) -> None: - """Cleanup hook called after stream consumption.""" - if checkpoint_storage is not None: - self._runner.context.clear_runtime_checkpoint_storage() - self._reset_running_flag() + async for event in self._run_workflow_with_tracing( + initial_executor_fn=initial_executor_fn, + is_continuation=(message is None), + streaming=streaming, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + ): + if event.type == "request_info" and event.request_id in (responses or {}): + # Don't yield request_info events for which we have responses to send - + # these are considered "handled". This prevents the caller from seeing + # events for requests they are already responding to. + # This usually happens when responses are provided with a checkpoint + # (restore then send), because the request_info events are stored in the + # checkpoint and would be emitted on restoration by the runner regardless + # of if a response is provided or not. + continue + yield event + finally: + # Whether this run is still the active one (no successor ``run()`` has + # installed a new weakref since we started). Captured once because the + # active-run clear below mutates ``self._active_run``. Used to scope both + # the run-lock release and the runtime-storage clear so a dropped run's + # deferred finalizer cannot clobber a successor run's state. + owns_run = self._active_run is my_active_run + if owns_run: + # Clear the active-run weakref so a subsequent ``run()`` is allowed. + # If the caller dropped this stream after partial iteration and a new + # ``run()`` already installed its own weakref before our async-gen + # finalizer ran, ``self._active_run`` points at the successor and we + # leave it untouched to preserve the successor's concurrency guard. + self._active_run = None + # Same ownership scoping applies to the runtime checkpoint storage: + # only clear it when this run still owns it, so a dropped run's + # deferred finalizer can't clear a successor's storage. + if checkpoint_storage is not None: + self._runner.context.clear_runtime_checkpoint_storage() @staticmethod def _finalize_events( @@ -935,7 +970,7 @@ async def _restore_and_send_responses( async def _send_responses_internal(self, responses: Mapping[str, Any]) -> None: """Internal method to validate and send responses to the executors.""" - pending_requests = await self._runner_context.get_pending_request_info_events() + pending_requests = await self._runner.context.get_pending_request_info_events() if not pending_requests: raise RuntimeError("No pending requests found in workflow context.") @@ -955,7 +990,7 @@ async def _send_responses_internal(self, responses: Mapping[str, Any]) -> None: coerced_responses[request_id] = response await asyncio.gather(*[ - self._runner_context.send_request_info_response(request_id, response) + self._runner.context.send_request_info_response(request_id, response) for request_id, response in coerced_responses.items() ]) @@ -1151,3 +1186,12 @@ def as_agent( context_providers=context_providers, **kwargs, ) + + def _is_run_active(self) -> bool: + """Check if a workflow run is currently active. + + Returns: + True if a run is active, False otherwise. + """ + existing_stream = self._active_run() if self._active_run is not None else None + return existing_stream is not None diff --git a/python/packages/core/tests/workflow/test_checkpoint.py b/python/packages/core/tests/workflow/test_checkpoint.py index 29fbc1554b6..e7aa4f0b3a4 100644 --- a/python/packages/core/tests/workflow/test_checkpoint.py +++ b/python/packages/core/tests/workflow/test_checkpoint.py @@ -336,6 +336,97 @@ async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None: ) +async def test_workflow_checkpoint_ancestry_preserved_after_resume(): + """Resuming from a checkpoint must preserve ancestry: future checkpoints chain back to the resumed one.""" + from typing_extensions import Never + + from agent_framework import WorkflowBuilder, WorkflowContext, handler + from agent_framework._workflows._executor import Executor + + class StartExecutor(Executor): + @handler + async def run(self, message: str, ctx: WorkflowContext[str]) -> None: + await ctx.send_message(message, target_id="middle") + + class MiddleExecutor(Executor): + @handler + async def process(self, message: str, ctx: WorkflowContext[str]) -> None: + await ctx.send_message(message + "-processed", target_id="finish") + + class FinishExecutor(Executor): + @handler + async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] + await ctx.yield_output(message + "-done") + + storage = InMemoryCheckpointStorage() + + def _build_workflow() -> Any: + start = StartExecutor(id="start") + middle = MiddleExecutor(id="middle") + finish = FinishExecutor(id="finish") + return ( + WorkflowBuilder( + name="resume-ancestry-test", + max_iterations=10, + start_executor=start, + checkpoint_storage=storage, + ) + .add_edge(start, middle) + .add_edge(middle, finish) + .build() + ) + + # First run: produce an initial chain of checkpoints + workflow = _build_workflow() + workflow_name = workflow.name + _ = [event async for event in workflow.run("hello", stream=True)] + + initial_checkpoints = sorted(await storage.list_checkpoints(workflow_name=workflow_name), key=lambda c: c.timestamp) + assert len(initial_checkpoints) >= 3, ( + f"Need at least 3 initial checkpoints to pick a middle one, got {len(initial_checkpoints)}" + ) + initial_ids = {cp.checkpoint_id for cp in initial_checkpoints} + + # Pick an intermediate checkpoint to resume from (not the first, not the last) + resume_from = initial_checkpoints[len(initial_checkpoints) // 2] + + # Resume on a fresh workflow instance (same graph signature) and run to completion + resumed_workflow = _build_workflow() + assert resumed_workflow.name == workflow_name + _ = [event async for event in resumed_workflow.run(checkpoint_id=resume_from.checkpoint_id, stream=True)] + + # Inspect new checkpoints created after resuming + all_checkpoints = sorted(await storage.list_checkpoints(workflow_name=workflow_name), key=lambda c: c.timestamp) + new_checkpoints = [cp for cp in all_checkpoints if cp.checkpoint_id not in initial_ids] + assert new_checkpoints, "Resuming from an intermediate checkpoint should produce new checkpoints" + + # The very first checkpoint created after resuming must chain back to the resumed checkpoint + assert new_checkpoints[0].previous_checkpoint_id == resume_from.checkpoint_id, ( + "First post-resume checkpoint must chain to the checkpoint that was resumed from; " + f"got previous_checkpoint_id={new_checkpoints[0].previous_checkpoint_id!r}, " + f"expected {resume_from.checkpoint_id!r}" + ) + + # Subsequent post-resume checkpoints must continue chaining + for i in range(1, len(new_checkpoints)): + assert new_checkpoints[i].previous_checkpoint_id == new_checkpoints[i - 1].checkpoint_id, ( + f"Post-resume checkpoint {i} should chain to checkpoint {i - 1}" + ) + + # Walking the chain backwards from the most recent checkpoint must reach the original root + # without breaks (i.e. the full ancestry across the resume boundary is intact). + checkpoints_by_id = {cp.checkpoint_id: cp for cp in all_checkpoints} + chain: list[str] = [] + cursor: str | None = new_checkpoints[-1].checkpoint_id + while cursor is not None: + chain.append(cursor) + cursor = checkpoints_by_id[cursor].previous_checkpoint_id + # Chain must include the resumed-from checkpoint and terminate at the original root + assert resume_from.checkpoint_id in chain + assert chain[-1] == initial_checkpoints[0].checkpoint_id + assert checkpoints_by_id[chain[-1]].previous_checkpoint_id is None + + async def test_memory_checkpoint_storage_roundtrip_json_native_types(): """Test that JSON-native types (str, int, float, bool, None) roundtrip correctly.""" storage = InMemoryCheckpointStorage() diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index 4b10f153db4..9c88febf8c4 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -17,7 +17,6 @@ WorkflowContext, WorkflowConvergenceException, WorkflowEvent, - WorkflowRunnerException, WorkflowRunState, handler, ) @@ -305,40 +304,62 @@ async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, i assert probe_target.call_count == 1 -async def test_runner_already_running(): - """Test that running the runner while it is already running raises an error.""" +async def test_runner_run_until_convergence_runs_sequentially(): + """run_until_convergence can be invoked back-to-back on the same Runner. + + The Runner itself does not enforce concurrency; that responsibility lives on + :class:`Workflow`. This test simply confirms the Runner is reusable across + sequential runs. + """ + runner = _make_runner() + async for _ in runner.run_until_convergence(): + pass + async for _ in runner.run_until_convergence(): + pass + + +def _make_runner() -> Runner: + """Build a minimal runner for runner-level tests.""" + return Runner( + [], + {}, + State(), + InProcRunnerContext(), + "test_name", + graph_signature_hash="test_hash", + ) + + +async def test_runner_accepts_new_run_after_previous_failure(): + """A failed run must not leave the Runner unable to start a new run. + + After the first run raises, ``run_until_convergence()`` must be callable + again and not surface any lifecycle-related rejection. + """ executor_a = MockExecutor(id="executor_a") executor_b = MockExecutor(id="executor_b") - - # Create a loop edges = [ SingleEdgeGroup(executor_a.id, executor_b.id), SingleEdgeGroup(executor_b.id, executor_a.id), ] - - executors: dict[str, Executor] = { - executor_a.id: executor_a, - executor_b.id: executor_b, - } + executors: dict[str, Executor] = {executor_a.id: executor_a, executor_b.id: executor_b} state = State() ctx = InProcRunnerContext() + runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash", max_iterations=2) - runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash") + await executor_a.execute(MockMessage(data=0), ["START"], state, ctx) - await executor_a.execute( - MockMessage(data=0), - ["START"], # source_executor_ids - state, # state - ctx, # runner_context - ) - - with pytest.raises(WorkflowRunnerException, match="Runner is already running."): - - async def _run(): - async for _ in runner.run_until_convergence(): - pass + with pytest.raises(WorkflowConvergenceException): + async for _ in runner.run_until_convergence(): + pass - await asyncio.gather(_run(), _run()) + # A second run on the same Runner must not be blocked by stale lifecycle + # state from the failed run. + try: + async for _ in runner.run_until_convergence(): + pass + except Exception as exc: + assert "Runner is already running" not in str(exc), "Runner stayed locked after a failed run" async def test_runner_emits_runner_completion_for_agent_response_without_targets(): @@ -862,7 +883,13 @@ async def test_runner_checkpoint_with_resumed_flag(): state = State() runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash") - runner._mark_resumed(5) # pyright: ignore[reportPrivateUsage] + resumed_checkpoint = WorkflowCheckpoint( + checkpoint_id="resumed-cp", + workflow_name="test_name", + graph_signature_hash="test_hash", + iteration_count=5, + ) + runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage] # Add a message to trigger the checkpoint creation path await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id="START")) @@ -882,6 +909,86 @@ async def test_runner_checkpoint_with_resumed_flag(): assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage] +async def test_runner_mark_resumed_sets_previous_checkpoint_id(): + """_mark_resumed must populate _previous_checkpoint_id so future checkpoints chain back to the resume point.""" + runner = Runner( + [], + {}, + State(), + InProcRunnerContext(), + "test_name", + graph_signature_hash="test_hash", + ) + + # Pre-condition: nothing to chain back to + assert runner._previous_checkpoint_id is None # pyright: ignore[reportPrivateUsage] + + resumed_checkpoint = WorkflowCheckpoint( + checkpoint_id="resumed-cp-id", + workflow_name="test_name", + graph_signature_hash="test_hash", + iteration_count=3, + ) + runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage] + + assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage] + assert runner._iteration == 3 # pyright: ignore[reportPrivateUsage] + assert runner._previous_checkpoint_id == "resumed-cp-id" # pyright: ignore[reportPrivateUsage] + + +async def test_runner_post_resume_checkpoint_chains_to_resumed_checkpoint(): + """After resuming, the next checkpoint created must reference the resumed checkpoint as its parent.""" + storage = InMemoryCheckpointStorage() + ctx = CheckpointingContext(storage) + executor_a = MockExecutor(id="executor_a") + executor_b = MockExecutor(id="executor_b") + + edges = [ + SingleEdgeGroup(executor_a.id, executor_b.id), + SingleEdgeGroup(executor_b.id, executor_a.id), + ] + + executors: dict[str, Executor] = { + executor_a.id: executor_a, + executor_b.id: executor_b, + } + state = State() + + runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash") + + # Simulate having resumed from a prior checkpoint + resumed_checkpoint = WorkflowCheckpoint( + checkpoint_id="parent-checkpoint-id", + workflow_name="test_name", + graph_signature_hash="test_hash", + iteration_count=1, + ) + runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage] + + # Seed a message so the runner has work to do (and creates checkpoints at superstep boundaries) + await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id=executor_a.id)) + + async for _ in runner.run_until_convergence(): + pass + + # Find the first checkpoint created after the resume point (across all workflows tracked by storage) + new_checkpoints = sorted( + await storage.list_checkpoints(workflow_name="test_name"), + key=lambda c: c.timestamp, + ) + assert new_checkpoints, "Resuming and running should produce at least one new checkpoint" + + # The first new checkpoint must chain to the resumed-from checkpoint, not to None + assert new_checkpoints[0].previous_checkpoint_id == "parent-checkpoint-id", ( + "First post-resume checkpoint must chain to the resumed checkpoint id; " + f"got {new_checkpoints[0].previous_checkpoint_id!r}" + ) + + # Subsequent post-resume checkpoints continue the chain + for i in range(1, len(new_checkpoints)): + assert new_checkpoints[i].previous_checkpoint_id == new_checkpoints[i - 1].checkpoint_id + + class ExecutorThatFailsWithEvents(Executor): """An executor that emits events and then raises an exception after receiving messages.""" @@ -951,6 +1058,172 @@ async def test_runner_drains_events_on_iteration_exception(): assert len(output_events) >= 1 +async def test_runner_resumed_flag_reset_after_failed_resumed_run(): + """A failed *resumed* run must not leak the resume flag into the next run. + + The resume flag suppresses the initial "superstep 0" (entry) checkpoint when resuming from an + iteration-0 checkpoint (which already exists and must not be recreated). It used to be cleared + only on the success path, so an executor failure during a resumed run left it ``True`` and the + next fresh run wrongly skipped its entry checkpoint. The flag is now cleared in a ``finally`` so + this holds even when convergence raises. + + This also verifies checkpoint creation on the re-run: the resumed (failed) run creates no entry + checkpoint, while the subsequent fresh run does. + """ + storage = InMemoryCheckpointStorage() + ctx = CheckpointingContext(storage) + executor_a = PassthroughExecutor(id="executor_a") + executor_b = ExecutorThatFailsWithEvents(id="executor_b", runner_ctx=ctx, fail_on_iteration=1) + + edges = [SingleEdgeGroup(executor_a.id, executor_b.id)] + executors: dict[str, Executor] = {executor_a.id: executor_a, executor_b.id: executor_b} + state = State() + + runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash") + + # Simulate a resumed run; this marks the runner as resumed so the next run skips + # the superstep-0 checkpoint. + resumed_checkpoint = WorkflowCheckpoint( + checkpoint_id="resumed-cp", + workflow_name="test_name", + graph_signature_hash="test_hash", + iteration_count=0, + ) + runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage] + assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage] + + # Run the resumed turn; executor_b fails mid-iteration before any superstep + # checkpoint is created. + await executor_a.execute(MockMessage(data=0), ["START"], state, ctx) + with pytest.raises(RuntimeError, match="Executor failed with pending events"): + async for _ in runner.run_until_convergence(): + pass + + # The fix: the resume flag is cleared even though the run raised. + assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage] + # The resumed (failed) run created no superstep-0 checkpoint (it was skipped). + assert await storage.list_checkpoints(workflow_name="test_name") == [] + + # Re-run as a fresh turn: with the flag correctly reset, the runner now creates + # the initial superstep-0 checkpoint (iteration_count == 0) before failing again. + runner.reset_iteration_count() + await executor_a.execute(MockMessage(data=0), ["START"], state, ctx) + with pytest.raises(RuntimeError, match="Executor failed with pending events"): + async for _ in runner.run_until_convergence(): + pass + + checkpoints = await storage.list_checkpoints(workflow_name="test_name") + assert any(cp.iteration_count == 0 for cp in checkpoints), ( + "Fresh run after a failed resumed run must create the superstep-0 checkpoint; " + "a leaked resume flag would have skipped it" + ) + + +async def test_runner_creates_entry_checkpoint_at_iteration_zero(): + """A fresh run creates the entry (superstep-0) checkpoint at iteration 0 with no parent. + + This is the baseline the lineage-consistency guard must preserve: when starting from iteration 0 + with messages queued and not resumed, the entry checkpoint is created and begins a new lineage + (``previous_checkpoint_id is None``). + """ + storage = InMemoryCheckpointStorage() + ctx = CheckpointingContext(storage) + # Terminal executor with no outgoing edges: the runner runs one superstep and converges. + source = MockExecutor(id="source") + state = State() + runner = Runner([], {source.id: source}, state, ctx, "test_name", graph_signature_hash="test_hash") + + assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage] + assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage] + + await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id=source.id)) + + async for _ in runner.run_until_convergence(): + pass + + checkpoints = await storage.list_checkpoints(workflow_name="test_name") + entry_checkpoints = [cp for cp in checkpoints if cp.iteration_count == 0] + assert len(entry_checkpoints) == 1, "A fresh run must create exactly one entry checkpoint at iteration 0" + assert entry_checkpoints[0].previous_checkpoint_id is None, ( + "The entry checkpoint of a fresh run must begin a new lineage with no parent" + ) + + +async def test_runner_skips_entry_checkpoint_when_iteration_nonzero(): + """The entry (superstep-0) checkpoint must only be created at iteration 0 to keep lineage consistent. + + A re-run that did not reset the iteration count (and is not marked as resumed) must not write an + entry checkpoint carrying a non-zero ``iteration_count`` - doing so would place two checkpoints at + the same iteration in the lineage. The ``_iteration == 0`` guard suppresses the entry checkpoint in + this case while still allowing the normal per-superstep checkpoints to be created. + """ + storage = InMemoryCheckpointStorage() + ctx = CheckpointingContext(storage) + # Terminal executor with no outgoing edges: the runner runs one superstep and converges. + source = MockExecutor(id="source") + state = State() + runner = Runner([], {source.id: source}, state, ctx, "test_name", graph_signature_hash="test_hash") + + # Simulate a re-run that kept its iteration count and is not marked as resumed. + runner._iteration = 5 # pyright: ignore[reportPrivateUsage] + assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage] + + await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id=source.id)) + + async for _ in runner.run_until_convergence(): + pass + + checkpoints = await storage.list_checkpoints(workflow_name="test_name") + # No entry checkpoint at the pre-existing iteration count may be created. + assert all(cp.iteration_count != 5 for cp in checkpoints), ( + "Entry checkpoint must not be created at a non-zero iteration; lineage would have a duplicate iteration" + ) + # The normal post-superstep checkpoint is still created (iteration advanced to 6). + assert any(cp.iteration_count == 6 for cp in checkpoints) + + +async def test_runner_resumed_from_iteration_zero_skips_entry_checkpoint(): + """Resuming from an iteration-0 checkpoint must not recreate the entry checkpoint. + + Here ``_iteration == 0`` is true, so the iteration guard alone would not suppress the entry + checkpoint; the resume flag is what prevents recreating the checkpoint that already exists at + iteration 0. + """ + storage = InMemoryCheckpointStorage() + ctx = CheckpointingContext(storage) + source = MockExecutor(id="source") + state = State() + runner = Runner([], {source.id: source}, state, ctx, "test_name", graph_signature_hash="test_hash") + + # Resume from an iteration-0 checkpoint: iteration stays 0 but the run is marked as resumed. + resumed_checkpoint = WorkflowCheckpoint( + checkpoint_id="entry-cp", + workflow_name="test_name", + graph_signature_hash="test_hash", + iteration_count=0, + ) + runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage] + assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage] + assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage] + + await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id=source.id)) + + async for _ in runner.run_until_convergence(): + pass + + # The pre-loop entry checkpoint is skipped; only the post-superstep checkpoint (iteration 1) is created, + # and it chains back to the resumed entry checkpoint. + checkpoints = sorted( + await storage.list_checkpoints(workflow_name="test_name"), + key=lambda c: c.timestamp, + ) + assert all(cp.checkpoint_id != "entry-cp" for cp in checkpoints), "Resumed entry checkpoint must not be recreated" + assert checkpoints, "The resumed run must still create its post-superstep checkpoint" + assert checkpoints[0].previous_checkpoint_id == "entry-cp", ( + "The first post-resume checkpoint must chain back to the resumed entry checkpoint" + ) + + class SlowEventEmittingExecutor(Executor): """An executor that emits events with delays to test straggler event draining.""" diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index f9791350a12..260613c74a1 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import gc import tempfile from collections.abc import AsyncIterable, Awaitable, Sequence from dataclasses import dataclass, field @@ -19,6 +20,7 @@ Content, Executor, FileCheckpointStorage, + InProcRunnerContext, Message, ResponseStream, WorkflowBuilder, @@ -26,6 +28,7 @@ WorkflowContext, WorkflowConvergenceException, WorkflowEvent, + WorkflowException, WorkflowMessage, WorkflowRunState, handler, @@ -759,8 +762,7 @@ async def run_workflow(): # Try to start a second concurrent execution - this should fail with pytest.raises( - RuntimeError, - match="Workflow is already running. Concurrent executions are not allowed.", + WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance." ): await workflow.run(NumberMessage(data=0)) @@ -795,8 +797,7 @@ async def consume_stream_slowly(): # Try to start a second concurrent execution - this should fail with pytest.raises( - RuntimeError, - match="Workflow is already running. Concurrent executions are not allowed.", + WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance." ): await workflow.run(NumberMessage(data=0)) @@ -828,14 +829,12 @@ async def consume_stream(): # Try different execution methods - all should fail with pytest.raises( - RuntimeError, - match="Workflow is already running. Concurrent executions are not allowed.", + WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance." ): await workflow.run(NumberMessage(data=0)) with pytest.raises( - RuntimeError, - match="Workflow is already running. Concurrent executions are not allowed.", + WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance." ): async for _ in workflow.run(NumberMessage(data=0), stream=True): break @@ -848,6 +847,238 @@ async def consume_stream(): assert result.get_final_state() == WorkflowRunState.IDLE +async def test_workflow_sequential_runs_after_completion() -> None: + """A completed run must release the runner so the next ``run`` succeeds. + + This is the happy-path counterpart to the concurrent-run guard tests: + those tests verify that a *concurrent* run is rejected, but they do not + verify that the lock is actually released afterwards. This test + exercises that release path explicitly across the three call shapes + (non-streaming, streaming-iterated, streaming-via-get_final_response) + and across multiple consecutive turns to catch lock leaks. + """ + executor = IncrementExecutor(id="seq_executor", limit=3, increment=1) + workflow = WorkflowBuilder(start_executor=executor).build() + + # Non-streaming -> non-streaming + r1 = await workflow.run(NumberMessage(data=0)) + assert r1.get_final_state() == WorkflowRunState.IDLE + + r2 = await workflow.run(NumberMessage(data=0)) + assert r2.get_final_state() == WorkflowRunState.IDLE + + # Non-streaming -> streaming-iterated + stream_events: list[WorkflowEvent] = [] + async for event in workflow.run(NumberMessage(data=0), stream=True): + stream_events.append(event) + assert any(e.type == "status" and e.state == WorkflowRunState.IDLE for e in stream_events) + + # Streaming -> streaming via get_final_response (no manual iteration) + r3 = await workflow.run(NumberMessage(data=0), stream=True).get_final_response() + assert r3.get_final_state() == WorkflowRunState.IDLE + + # Streaming -> non-streaming (back to the start) + r4 = await workflow.run(NumberMessage(data=0)) + assert r4.get_final_state() == WorkflowRunState.IDLE + + +async def test_workflow_unconsumed_stream_releases_run_lock() -> None: + """An unconsumed stream must not leak the run lock. + + ``Workflow.run`` reserves the runner *synchronously* so that concurrent + callers are rejected immediately. The reservation is normally released + by ``_run_core``'s ``finally`` once the stream is iterated. If the + caller never iterates the stream, a GC-time finalizer must release the + reservation instead - otherwise every subsequent ``Workflow.run`` call + on this instance would fail with the concurrent-run error. + """ + executor = IncrementExecutor(id="unconsumed_stream_exec", limit=3, increment=1) + workflow = WorkflowBuilder(start_executor=executor).build() + + # Build a stream and immediately drop it without iterating. + stream = workflow.run(NumberMessage(data=0), stream=True) + assert stream is not None # silence unused-variable warnings; stream is GC'd below + del stream + gc.collect() + # Yield to the event loop so any scheduled finalizer work can run. + await asyncio.sleep(0) + + # The runner should be back to IDLE; a fresh run must succeed. + result = await workflow.run(NumberMessage(data=0)) + assert result.get_final_state() == WorkflowRunState.IDLE + + +async def test_workflow_unawaited_run_coroutine_releases_run_lock() -> None: + """An un-awaited non-streaming ``run()`` coroutine must also not leak the lock. + + ``Workflow.run`` (non-streaming) returns a coroutine produced by + ``ResponseStream.get_final_response``. The underlying ResponseStream is + held alive by that coroutine, so dropping the coroutine without + awaiting it must still release the reservation via the same GC-time + fallback used for unconsumed streams. + """ + executor = IncrementExecutor(id="unawaited_run_exec", limit=3, increment=1) + workflow = WorkflowBuilder(start_executor=executor).build() + + coro = workflow.run(NumberMessage(data=0)) + # Closing suppresses the "coroutine was never awaited" warning. We cast to + # ``Any`` because the typed return is ``Awaitable[...]``; in practice it is + # a coroutine that exposes ``close``. + cast(Any, coro).close() + del coro + gc.collect() + await asyncio.sleep(0) + + result = await workflow.run(NumberMessage(data=0)) + assert result.get_final_state() == WorkflowRunState.IDLE + + +async def test_workflow_partial_stream_does_not_clobber_successor_active_run() -> None: + """A stale ``_run_core`` finalizer must not clear a successor's run lock. + + Repro for the GC-finalizer race the user reported: + + 1. Start stream A and consume one event so its body is suspended at a + ``yield``. Its ``finally`` is now armed and will run when the + generator is closed. + 2. Drop stream A and ``gc.collect``. The ``_active_run`` weakref's + referent is gone, so a subsequent ``run()`` will pass the + concurrency guard - but stream A's async-gen finalizer hasn't + actually executed yet (``aclose`` is scheduled on the loop). + 3. Synchronously start stream B; ``run()`` installs a fresh weakref + in ``_active_run``. + 4. Yield to the loop so stream A's stale ``finally`` runs. Without + the identity check it unconditionally writes + ``self._active_run = None``, silently disabling the concurrency + guard for stream B. + """ + executor = IncrementExecutor(id="stale_finalizer_exec", limit=100, increment=1) + workflow = WorkflowBuilder(start_executor=executor).build() + + # Step 1: drive stream A's body until it's suspended at its first yield. + stream_a = workflow.run(NumberMessage(data=0), stream=True) + aiter_a = stream_a.__aiter__() + await aiter_a.__anext__() + + # Step 2: drop stream A; GC invalidates the weakref and schedules + # async-gen close, but does not run the close inline. + del stream_a + del aiter_a + gc.collect() + + # Step 3: synchronously start stream B *before* yielding to the loop, + # so the stale ``aclose`` for stream A hasn't fired yet. + stream_b = workflow.run(NumberMessage(data=0), stream=True) + ref_b = workflow._active_run # type: ignore[attr-defined] + assert ref_b is not None and ref_b() is stream_b + + # Step 4: yield enough times for stream A's scheduled aclose to drive + # its body through ``GeneratorExit`` and into its ``finally``. + for _ in range(5): + await asyncio.sleep(0) + + # With the fix, stream B's reservation is still in place. Without it, + # ``_active_run`` was clobbered to ``None`` and a concurrent run would + # be (incorrectly) accepted. + assert workflow._active_run is ref_b # type: ignore[attr-defined] + with pytest.raises( + WorkflowException, + match="Workflow is already running; concurrent runs are not allowed on the same instance.", + ): + await workflow.run(NumberMessage(data=0)) + + # Tear down stream B without iterating it (its body never started, so + # closing it is a no-op for workflow state). + del stream_b + del ref_b + gc.collect() + await asyncio.sleep(0) + + +async def test_workflow_stale_runtime_checkpoint_storage_not_inherited() -> None: + """A new run must not inherit a prior run's leftover runtime checkpoint storage. + + If a run that set a runtime ``checkpoint_storage`` override is dropped before + its async-generator finalizer clears it, the override can linger on the + ``RunnerContext`` while ``_is_run_active()`` already reports False. ``run()`` + defensively clears that stale override so a subsequent run that does not pass + its own ``checkpoint_storage`` does not silently checkpoint into it. + """ + with tempfile.TemporaryDirectory() as temp_dir: + leftover_storage = FileCheckpointStorage(temp_dir) + executor = IncrementExecutor(id="stale_storage_exec", limit=3, increment=1) + workflow = WorkflowBuilder(start_executor=executor).build() + + assert isinstance(workflow._runner.context, InProcRunnerContext) # pyright: ignore[reportPrivateUsage] + + # Simulate a leftover runtime override from a dropped prior run. + workflow._runner.context.set_runtime_checkpoint_storage(leftover_storage) # pyright: ignore[reportPrivateUsage] + + # A fresh run without its own checkpoint_storage must not use the leftover. + result = await workflow.run(NumberMessage(data=0)) + assert result.get_final_state() == WorkflowRunState.IDLE + + checkpoints = await leftover_storage.list_checkpoints(workflow_name=workflow.name) + assert checkpoints == [], "Stale runtime checkpoint storage must not be inherited by a new run" + assert workflow._runner.context._runtime_checkpoint_storage is None # pyright: ignore[reportPrivateUsage] + + +async def test_workflow_partial_stream_does_not_clobber_successor_runtime_storage() -> None: + """A stale ``_run_core`` finalizer must not clear a successor's runtime storage. + + Same GC-finalizer race as + ``test_workflow_partial_stream_does_not_clobber_successor_active_run`` but for the + runtime checkpoint storage override: the dropped run's deferred ``finally`` must + only clear the override if it still owns it, otherwise it wipes the successor + run's storage. + """ + with ( + tempfile.TemporaryDirectory() as temp_dir_a, + tempfile.TemporaryDirectory() as temp_dir_b, + ): + storage_a = FileCheckpointStorage(temp_dir_a) + storage_b = FileCheckpointStorage(temp_dir_b) + executor = IncrementExecutor(id="storage_finalizer_exec", limit=100, increment=1) + workflow = WorkflowBuilder(start_executor=executor).build() + context = workflow._runner.context # pyright: ignore[reportPrivateUsage] + + assert isinstance(context, InProcRunnerContext) + + # Step 1: drive stream A's body to its first yield so it set storage_a. + stream_a = workflow.run(NumberMessage(data=0), checkpoint_storage=storage_a, stream=True) + aiter_a = stream_a.__aiter__() + await aiter_a.__anext__() + assert context._runtime_checkpoint_storage is storage_a # pyright: ignore[reportPrivateUsage] + + # Step 2: drop stream A; the weakref dies and async-gen close is scheduled + # but not run inline. + del stream_a + del aiter_a + gc.collect() + + # Step 3: synchronously start stream B with its own storage and drive it to + # its first yield so it set storage_b and took ownership of the override. + stream_b = workflow.run(NumberMessage(data=0), checkpoint_storage=storage_b, stream=True) + aiter_b = stream_b.__aiter__() + await aiter_b.__anext__() + assert context._runtime_checkpoint_storage is storage_b # pyright: ignore[reportPrivateUsage] + + # Step 4: yield enough for stream A's scheduled aclose to drive its body + # through ``GeneratorExit`` and into its ``finally``. + for _ in range(5): + await asyncio.sleep(0) + + # With the ownership guard, stream B's override survives. Without it, A's + # stale finalizer would have cleared it. + assert context._runtime_checkpoint_storage is storage_b # pyright: ignore[reportPrivateUsage] + + # Tear down stream B. + del stream_b + del aiter_b + gc.collect() + await asyncio.sleep(0) + + class _StreamingTestAgent(BaseAgent): """Test agent that supports both streaming and non-streaming modes.""" diff --git a/python/packages/declarative/tests/test_http_request_executor.py b/python/packages/declarative/tests/test_http_request_executor.py index 4030cf42946..f07f19d27a3 100644 --- a/python/packages/declarative/tests/test_http_request_executor.py +++ b/python/packages/declarative/tests/test_http_request_executor.py @@ -90,7 +90,7 @@ async def _run(yaml_def: dict[str, Any], handler: HttpRequestHandler) -> Any: def _state(workflow: Any, events: Any) -> dict[str, Any]: """Read declarative state out of the workflow after run completes.""" - return workflow._state.get(DECLARATIVE_STATE_KEY) or {} + return workflow._runner.state.get(DECLARATIVE_STATE_KEY) or {} # Helper used by parametrised path tests @@ -151,7 +151,7 @@ async def test_get_parses_json_object(self) -> None: workflow = factory.create_workflow_from_definition(_yaml(_action(method="GET", response="Local.Result"))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == {"key": "value", "number": 42} assert handler.last_info is not None assert handler.last_info.method == "GET" @@ -164,7 +164,7 @@ async def test_get_parses_plain_string(self) -> None: workflow = factory.create_workflow_from_definition(_yaml(_action(response="Local.Result"))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == "not-json content" @pytest.mark.asyncio @@ -174,7 +174,7 @@ async def test_get_empty_body_yields_none(self) -> None: workflow = factory.create_workflow_from_definition(_yaml(_action(response="Local.Result"))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] is None @pytest.mark.asyncio @@ -184,7 +184,7 @@ async def test_response_object_form_path(self) -> None: workflow = factory.create_workflow_from_definition(_yaml(_action(response={"path": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == {"x": 1} @pytest.mark.asyncio @@ -517,7 +517,7 @@ async def test_response_headers_folded_with_commas(self) -> None: factory = WorkflowFactory(http_request_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(response_headers="Local.H"))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) h = decl["Local"]["H"] assert h["Content-Type"] == "application/json" assert h["Set-Cookie"] == "a=1,b=2" @@ -528,7 +528,7 @@ async def test_response_headers_empty_assigned_none(self) -> None: factory = WorkflowFactory(http_request_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(response_headers="Local.H"))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["H"] is None @pytest.mark.asyncio @@ -538,7 +538,7 @@ async def test_non_2xx_still_publishes_headers(self) -> None: workflow = factory.create_workflow_from_definition(_yaml(_action(response_headers="Local.H"))) with pytest.raises(DeclarativeActionError): await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["H"] == {"X-Trace": "abc"} @@ -559,7 +559,7 @@ async def test_conversation_id_appends_message(self) -> None: ) ) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) conv = decl["System"]["conversations"].get("conv-test-1") assert conv is not None assert len(conv["messages"]) == 1 @@ -570,7 +570,7 @@ async def test_empty_conversation_id_does_not_append(self) -> None: factory = WorkflowFactory(http_request_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(response="Local.Result", conversation_id=""))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) # Auto-init creates an entry for the System.ConversationId conversation, # but it should NOT have HTTP-appended messages from us. for _cid, conv in decl["System"]["conversations"].items(): @@ -582,7 +582,7 @@ async def test_empty_body_skips_conversation_append(self) -> None: factory = WorkflowFactory(http_request_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(conversation_id="conv-test-1"))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) # No conversation entry should have been created either. assert "conv-test-1" not in decl["System"]["conversations"] diff --git a/python/packages/declarative/tests/test_http_request_yaml_integration.py b/python/packages/declarative/tests/test_http_request_yaml_integration.py index 7d5ac96d476..68fb22cb050 100644 --- a/python/packages/declarative/tests/test_http_request_yaml_integration.py +++ b/python/packages/declarative/tests/test_http_request_yaml_integration.py @@ -73,7 +73,7 @@ async def test_http_request_yaml_roundtrip() -> None: workflow = factory.create_workflow_from_yaml_path(FIXTURE_PATH) await workflow.run({}) - decl: dict[str, Any] = workflow._state.get(DECLARATIVE_STATE_KEY) or {} + decl: dict[str, Any] = workflow._runner.state.get(DECLARATIVE_STATE_KEY) or {} local: dict[str, Any] = decl.get("Local") or {} assert local.get("RepoOwner") == "dotnet" diff --git a/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py index 549cdd30a70..4a6bde1bbbd 100644 --- a/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py +++ b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py @@ -244,7 +244,7 @@ async def test_output_result_parses_json_text(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == [{"k": "v", "n": 1}] @pytest.mark.asyncio @@ -253,7 +253,7 @@ async def test_output_result_falls_back_to_raw_text(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == ["plain text not json"] @pytest.mark.asyncio @@ -262,7 +262,7 @@ async def test_output_messages_writes_single_tool_role_message(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"messages": "Local.Messages"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) msg = decl["Local"]["Messages"] # Single Tool-role message containing both contents (parity with .NET). assert isinstance(msg, Message) @@ -276,7 +276,7 @@ async def test_uri_content_serialised_as_uri_string(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == ["https://example.com/file.txt"] @pytest.mark.asyncio @@ -285,7 +285,7 @@ async def test_output_path_object_form(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": {"path": "Local.Result"}}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == ["ok"] @@ -306,7 +306,7 @@ async def test_conversation_id_appends_assistant_message(self) -> None: ) ) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) conv = decl["System"]["conversations"]["conv-42"] msgs = conv["messages"] if isinstance(conv, dict) else conv.messages assert len(msgs) == 1 @@ -328,7 +328,7 @@ async def test_empty_conversation_id_does_not_append(self) -> None: ) ) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) # Empty conversation id must not produce a `""` entry under System.conversations. conversations = decl.get("System", {}).get("conversations", {}) assert "" not in conversations @@ -529,7 +529,7 @@ async def test_handler_returns_error_result_assigns_error_string(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == "Error: server down" @pytest.mark.asyncio @@ -538,7 +538,7 @@ async def test_tool_execution_exception_becomes_error_result(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == "Error: invalid arguments" @pytest.mark.asyncio @@ -547,7 +547,7 @@ async def test_httpx_error_becomes_error_result(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) result = decl["Local"]["Result"] assert isinstance(result, str) assert result.startswith("Error:") diff --git a/python/packages/declarative/tests/test_workflow_factory.py b/python/packages/declarative/tests/test_workflow_factory.py index ba54e50e8ea..acf677f6c84 100644 --- a/python/packages/declarative/tests/test_workflow_factory.py +++ b/python/packages/declarative/tests/test_workflow_factory.py @@ -291,11 +291,11 @@ async def test_as_agent_continuation_preserves_prior_state(self): # Stamp a marker into the declarative state between turns. The # continuation branch must preserve it; a state-clearing run would # wipe ``DECLARATIVE_STATE_KEY`` and force re-initialization. - state_data = workflow._state.get(DECLARATIVE_STATE_KEY) + state_data = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert isinstance(state_data, dict), "Expected declarative state to be initialized after turn 1" state_data["Local"] = {"persisted_marker": "kept-from-turn-1"} - workflow._state.set(DECLARATIVE_STATE_KEY, state_data) - workflow._state.commit() + workflow._runner.state.set(DECLARATIVE_STATE_KEY, state_data) + workflow._runner.state.commit() second = await agent.run("turn-2-msg") assert second.text == "turn-2-msg", ( @@ -305,7 +305,7 @@ async def test_as_agent_continuation_preserves_prior_state(self): # The continuation branch in ``_ensure_state_initialized`` must: # 1. preserve the cross-turn marker we stamped above # 2. refresh Inputs.input and System.LastMessage* to the new turn - post_state = workflow._state.get(DECLARATIVE_STATE_KEY) + post_state = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert isinstance(post_state, dict), "declarative state vanished between turns" local = post_state.get("Local", {}) assert local.get("persisted_marker") == "kept-from-turn-1", (