Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions python/packages/core/agent_framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

import importlib.metadata
from typing import Final
from typing import TYPE_CHECKING, Any, Final

try:
_version = importlib.metadata.version(__name__)
Expand Down Expand Up @@ -264,6 +264,7 @@
)
from ._workflows._agent_utils import resolve_agent_id
from ._workflows._checkpoint import (
CheckpointID,
CheckpointStorage,
FileCheckpointStorage,
InMemoryCheckpointStorage,
Expand Down Expand Up @@ -307,7 +308,6 @@
workflow,
)
from ._workflows._request_info_mixin import response_handler
from ._workflows._runner import Runner
from ._workflows._runner_context import (
InProcRunnerContext,
RunnerContext,
Expand Down Expand Up @@ -405,6 +405,7 @@
"ChatResponse",
"ChatResponseUpdate",
"CheckResult",
"CheckpointID",
"CheckpointStorage",
"ClassSkill",
"CompactionProvider",
Expand Down Expand Up @@ -618,3 +619,20 @@
"validate_workflow_graph",
"workflow",
]

if TYPE_CHECKING:
from ._workflows._runner import Runner


def __getattr__(name: str) -> Any:
"""Lazily resolve deprecated public names, emitting a ``DeprecationWarning``.

``Runner`` remains importable from ``agent_framework`` for backward
compatibility but is deprecated and slated for removal from the public API.
"""
if name == "Runner":
from ._workflows._runner import Runner, warn_runner_deprecated

warn_runner_deprecated()
return Runner
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
120 changes: 81 additions & 39 deletions python/packages/core/agent_framework/_workflows/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import asyncio
import contextlib
import logging
import warnings
from collections import defaultdict
from collections.abc import AsyncGenerator, Sequence
from typing import Any

from ..exceptions import (
WorkflowCheckpointException,
WorkflowConvergenceException,
WorkflowRunnerException,
)
from ._checkpoint import CheckpointID, CheckpointStorage, WorkflowCheckpoint
from ._const import EXECUTOR_STATE_KEY
Expand All @@ -27,6 +27,21 @@
logger = logging.getLogger(__name__)


def warn_runner_deprecated() -> None:
"""Emit a deprecation warning when ``Runner`` is accessed from the public API.

``Runner`` remains importable from ``agent_framework`` for backward
compatibility, but it is intended for internal use only and will be removed
from the public API in a future version.
"""
warnings.warn(
"`Runner` is deprecated and will be removed from the public API in a future version. "
"It is intended for internal use only.",
DeprecationWarning,
stacklevel=3,
)


class Runner:
"""A class to run a workflow in Pregel supersteps."""

Expand Down Expand Up @@ -63,38 +78,47 @@ def __init__(
self._iteration = 0
self._max_iterations = max_iterations
self._state = state
self._running = False
self._resumed_from_checkpoint = False # Track whether we resumed

# Checkpointing related attributes
self._resumed_from_checkpoint = False
self._previous_checkpoint_id: CheckpointID | None = None

@property
def context(self) -> RunnerContext:
"""Get the workflow context."""
"""Get the runner context for message, event, and checkpoint handling."""
return self._ctx

@property
def state(self) -> State:
"""Get the shared state for the workflow."""
return self._state

def reset_iteration_count(self) -> None:
"""Reset the iteration count to zero."""
"""Reset the iteration count to zero.

This is useful when the workflow resumes from a new set of messages.

Note:
When a workflow is resumed from a response (for a request_info_event)
or a checkpoint, the iteration count is normally NOT reset.
"""
self._iteration = 0
Comment thread
TaoChenOSU marked this conversation as resolved.

async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
"""Run the workflow until no more messages are sent."""
if self._running:
raise WorkflowRunnerException("Runner is already running.")

self._running = True
previous_checkpoint_id: CheckpointID | None = None
try:
# Emit any events already produced prior to entering loop
if await self._ctx.has_events():
logger.info("Yielding pre-loop events")
for event in await self._ctx.drain_events():
yield event

# Create the first checkpoint. Checkpoints are usually considered to be created at the end of an iteration,
# we can think of the first checkpoint as being created at the end of a "superstep 0" which captures the
# states after which the start executor has run. Note that we execute the start executor outside of the
# main iteration loop.
if await self._ctx.has_messages() and not self._resumed_from_checkpoint:
previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id)
# Create a checkpoint before a run starts. Checkpoints are usually considered to be created at the
# end of an iteration, we can think of this checkpoint as being created at the end of "superstep 0"
# which captures the states after which the start executor has run. Note that we execute the start
# executor outside of the main iteration loop.
if await self._ctx.has_messages() and self._iteration == 0 and not self._resumed_from_checkpoint:
await self.create_checkpoint_if_enabled()

while self._iteration < self._max_iterations:
logger.info(f"Starting superstep {self._iteration + 1}")
Expand Down Expand Up @@ -141,21 +165,23 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
self._state.commit()

# Create checkpoint after each superstep iteration
previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id)
await self.create_checkpoint_if_enabled()

yield WorkflowEvent.superstep_completed(iteration=self._iteration)

# Check for convergence: no more messages to process
if not await self._ctx.has_messages():
break

logger.info(f"Workflow completed after {self._iteration} supersteps")

if self._iteration >= self._max_iterations and await self._ctx.has_messages():
raise WorkflowConvergenceException(f"Runner did not converge after {self._max_iterations} iterations.")

logger.info(f"Workflow completed after {self._iteration} supersteps")
self._resumed_from_checkpoint = False # Reset resume flag for next run
finally:
self._running = False
# Reset the resume flag so stale resume state never leaks into the next run on this
# instance - even if convergence raised before completing (e.g. an executor failure
# during a resumed run).
self._resumed_from_checkpoint = False

async def _run_iteration(self) -> None:
"""Run a single iteration of the workflow.
Expand Down Expand Up @@ -209,40 +235,55 @@ async def _deliver_messages_for_edge_runner(edge_runner: EdgeRunner) -> None:
]
await asyncio.gather(*tasks)

async def _create_checkpoint_if_enabled(self, previous_checkpoint_id: CheckpointID | None) -> CheckpointID | None:
async def _prepare_checkpoint_state(self) -> None:
"""Persist executor snapshots into committed shared state.

This is used by checkpoint capture paths that need a complete, restorable
state payload without necessarily writing to a checkpoint storage backend.
"""
await self._save_executor_states()
self._state.commit()

async def create_checkpoint_if_enabled(self) -> None:
"""Create a checkpoint if checkpointing is enabled and attach a label and metadata."""
if not self._ctx.has_checkpointing():
return None
return

try:
# Save executor states into the shared state before creating the checkpoint,
# so that they are included in the checkpoint payload.
await self._save_executor_states()
# `on_checkpoint_save()` writes via State.set(), which stages values in the
# pending buffer. Checkpoints serialize committed state only, so commit here
# to ensure executor snapshots are captured in this checkpoint.
self._state.commit()
# Save executor states into committed state before creating the checkpoint.
await self._prepare_checkpoint_state()

checkpoint_id = await self._ctx.create_checkpoint(
self._workflow_name,
self._graph_signature_hash,
self._state,
previous_checkpoint_id,
self._previous_checkpoint_id,
self._iteration,
)

logger.info(f"Created checkpoint: {checkpoint_id}")
return checkpoint_id
logger.info(
"Created checkpoint: %s with parent checkpoint at iteration %d: %s",
checkpoint_id,
self._iteration,
self._previous_checkpoint_id,
)
self._previous_checkpoint_id = checkpoint_id
except Exception as e:
logger.warning(f"Failed to create checkpoint: {e}")
return None
logger.warning(
"Failed to create checkpoint at iteration %d: %s. "
"Note that this does not fail the workflow run. "
"The next successfully-created checkpoint will be parented to the last successful checkpoint: %s",
self._iteration,
e,
self._previous_checkpoint_id,
)

async def restore_from_checkpoint(
self,
checkpoint_id: CheckpointID,
checkpoint_storage: CheckpointStorage | None = None,
) -> None:
"""Restore workflow state from a checkpoint.
"""Restore the runner from a checkpoint.

Args:
checkpoint_id: The ID of the checkpoint to restore from
Expand Down Expand Up @@ -290,7 +331,7 @@ async def restore_from_checkpoint(
# Apply the checkpoint to the context
await self._ctx.apply_checkpoint(checkpoint)
# Mark the runner as resumed
self._mark_resumed(checkpoint.iteration_count)
self._mark_resumed(checkpoint)

logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}")
except WorkflowCheckpointException:
Expand Down Expand Up @@ -356,13 +397,14 @@ def _parse_edge_runners(self, edge_runners: list[EdgeRunner]) -> dict[str, list[

return parsed

def _mark_resumed(self, iteration: int) -> None:
def _mark_resumed(self, checkpoint: WorkflowCheckpoint) -> None:
"""Mark the runner as having resumed from a checkpoint.

Optionally set the current iteration and max iterations.
"""
self._resumed_from_checkpoint = True
self._iteration = iteration
self._iteration = checkpoint.iteration_count
self._previous_checkpoint_id = checkpoint.checkpoint_id

async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None:
"""Store executor state in state under a reserved key.
Expand Down
Loading
Loading