Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/packages/core/agent_framework/_workflows/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
self._state = state
self._running = False
self._resumed_from_checkpoint = False # Track whether we resumed
self._resume_parent_checkpoint_id: CheckpointID | None = None

@property
def context(self) -> RunnerContext:
Expand All @@ -81,7 +82,9 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
raise WorkflowRunnerException("Runner is already running.")

self._running = True
previous_checkpoint_id: CheckpointID | None = None
previous_checkpoint_id: CheckpointID | None = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition is logically equivalent to just self._resume_parent_checkpoint_id because the two fields are always set and cleared together — _resume_parent_checkpoint_id is non-None iff _resumed_from_checkpoint is True. The guard adds noise without adding safety. Simplifying to the single expression also makes it obvious that the ID is the only thing that matters here.

Suggested change
previous_checkpoint_id: CheckpointID | None = (
previous_checkpoint_id: CheckpointID | None = self._resume_parent_checkpoint_id

self._resume_parent_checkpoint_id if self._resumed_from_checkpoint else None
)
try:
# Emit any events already produced prior to entering loop
if await self._ctx.has_events():
Expand Down Expand Up @@ -154,6 +157,7 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:

logger.info(f"Workflow completed after {self._iteration} supersteps")
self._resumed_from_checkpoint = False # Reset resume flag for next run
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both _resumed_from_checkpoint and _resume_parent_checkpoint_id are only reset on the success path. If run_until_convergence raises (e.g., WorkflowConvergenceException), these values remain stale, potentially causing incorrect checkpoint chaining on a subsequent reuse of the runner. Consider moving these resets into the finally block.

self._resume_parent_checkpoint_id = None
Comment on lines 157 to +160
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both _resumed_from_checkpoint and _resume_parent_checkpoint_id are reset here inside the try block. If an earlier iteration raises, these values remain set, and a subsequent run_until_convergence call on the same runner would incorrectly behave as if it were resuming. Moving these resets into the finally block would be more robust.

Suggested change
logger.info(f"Workflow completed after {self._iteration} supersteps")
self._resumed_from_checkpoint = False # Reset resume flag for next run
self._resume_parent_checkpoint_id = None
logger.info(f"Workflow completed after {self._iteration} supersteps")
finally:
self._resumed_from_checkpoint = False # Reset resume flag for next run
self._resume_parent_checkpoint_id = None
self._running = False

finally:
self._running = False

Expand Down Expand Up @@ -285,7 +289,7 @@ async def restore_from_checkpoint(
# Apply the checkpoint to the context
await self._ctx.apply_checkpoint(checkpoint)
# Mark the runner as resumed
self._mark_resumed(checkpoint.iteration_count)
self._mark_resumed(checkpoint.iteration_count, checkpoint.checkpoint_id)

logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}")
except WorkflowCheckpointException:
Expand Down Expand Up @@ -351,13 +355,14 @@ def _parse_edge_runners(self, edge_runners: list[EdgeRunner]) -> dict[str, list[

return parsed

def _mark_resumed(self, iteration: int) -> None:
def _mark_resumed(self, iteration: int, checkpoint_id: CheckpointID) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring is now slightly inaccurate—checkpoint_id is required, not optional. Consider updating to reflect the new parameter.

Suggested change
def _mark_resumed(self, iteration: int, checkpoint_id: CheckpointID) -> None:
def _mark_resumed(self, iteration: int, checkpoint_id: CheckpointID) -> None:
"""Mark the runner as having resumed from a checkpoint.
Sets the current iteration and records the checkpoint ID for chaining.
"""

"""Mark the runner as having resumed from a checkpoint.

Optionally set the current iteration and max iterations.
"""
self._resumed_from_checkpoint = True
self._iteration = iteration
self._resume_parent_checkpoint_id = checkpoint_id

async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None:
"""Store executor state in state under a reserved key.
Expand Down
88 changes: 88 additions & 0 deletions python/packages/core/tests/workflow/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,94 @@ async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None:
)


async def test_resumed_workflow_keeps_previous_checkpoint_id_chain():
"""New checkpoints created after resume should chain to the restored checkpoint."""
from typing_extensions import Never

from agent_framework import WorkflowBuilder, WorkflowContext, handler
from agent_framework._workflows._executor import Executor

class StartExecutor(Executor):
@handler
async def run(self, message: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(message, target_id="middle")

class MiddleExecutor(Executor):
@handler
async def process(self, message: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(message + "-processed", target_id="finish")

class FinishExecutor(Executor):
@handler
async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None:
await ctx.yield_output(message + "-done")

storage = InMemoryCheckpointStorage()

start = StartExecutor(id="start")
middle = MiddleExecutor(id="middle")
finish = FinishExecutor(id="finish")

workflow_name = "resume-chain-workflow"

workflow = (
WorkflowBuilder(
max_iterations=10,
name=workflow_name,
start_executor=start,
checkpoint_storage=storage,
)
.add_edge(start, middle)
.add_edge(middle, finish)
.build()
)

_ = [event async for event in workflow.run("hello", stream=True)]

initial_checkpoints = sorted(
await storage.list_checkpoints(workflow_name=workflow.name),
key=lambda c: c.timestamp,
)
assert len(initial_checkpoints) >= 2, (
f"Expected at least 2 checkpoints before resume, got {len(initial_checkpoints)}"
)

restore_checkpoint = initial_checkpoints[1]
initial_ids = {checkpoint.checkpoint_id for checkpoint in initial_checkpoints}

resumed_start = StartExecutor(id="start")
resumed_middle = MiddleExecutor(id="middle")
resumed_finish = FinishExecutor(id="finish")

resumed_workflow = (
WorkflowBuilder(
max_iterations=10,
name=workflow_name,
start_executor=resumed_start,
checkpoint_storage=storage,
)
.add_edge(resumed_start, resumed_middle)
.add_edge(resumed_middle, resumed_finish)
.build()
)

_ = [event async for event in resumed_workflow.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True)]

all_checkpoints = sorted(
await storage.list_checkpoints(workflow_name=workflow.name),
key=lambda c: c.timestamp,
)
resumed_checkpoints = [
checkpoint for checkpoint in all_checkpoints if checkpoint.checkpoint_id not in initial_ids
]

assert resumed_checkpoints, "Expected at least one new checkpoint after resume"
assert resumed_checkpoints[0].previous_checkpoint_id == restore_checkpoint.checkpoint_id

for i in range(1, len(resumed_checkpoints)):
assert resumed_checkpoints[i].previous_checkpoint_id == resumed_checkpoints[i - 1].checkpoint_id

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider extending this test (or adding a new one) to also verify the reset behavior: after the resumed run completes, do a fresh workflow.run("hello", stream=True) and assert the first new checkpoint has previous_checkpoint_id is None. This covers the reset of _resume_parent_checkpoint_id on line 160 of _runner.py.


async def test_memory_checkpoint_storage_roundtrip_json_native_types():
"""Test that JSON-native types (str, int, float, bool, None) roundtrip correctly."""
storage = InMemoryCheckpointStorage()
Expand Down
Loading