-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Python: Preserve checkpoint ancestry when workflows resume #4591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -65,6 +65,7 @@ def __init__( | |||||||||||||||||||
| self._state = state | ||||||||||||||||||||
| self._running = False | ||||||||||||||||||||
| self._resumed_from_checkpoint = False # Track whether we resumed | ||||||||||||||||||||
| self._resume_parent_checkpoint_id: CheckpointID | None = None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| @property | ||||||||||||||||||||
| def context(self) -> RunnerContext: | ||||||||||||||||||||
|
|
@@ -81,7 +82,9 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: | |||||||||||||||||||
| raise WorkflowRunnerException("Runner is already running.") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| self._running = True | ||||||||||||||||||||
| previous_checkpoint_id: CheckpointID | None = None | ||||||||||||||||||||
| previous_checkpoint_id: CheckpointID | None = ( | ||||||||||||||||||||
| self._resume_parent_checkpoint_id if self._resumed_from_checkpoint else None | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| try: | ||||||||||||||||||||
| # Emit any events already produced prior to entering loop | ||||||||||||||||||||
| if await self._ctx.has_events(): | ||||||||||||||||||||
|
|
@@ -154,6 +157,7 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: | |||||||||||||||||||
|
|
||||||||||||||||||||
| logger.info(f"Workflow completed after {self._iteration} supersteps") | ||||||||||||||||||||
| self._resumed_from_checkpoint = False # Reset resume flag for next run | ||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both |
||||||||||||||||||||
| self._resume_parent_checkpoint_id = None | ||||||||||||||||||||
|
Comment on lines
157
to
+160
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both
Suggested change
|
||||||||||||||||||||
| finally: | ||||||||||||||||||||
| self._running = False | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -285,7 +289,7 @@ async def restore_from_checkpoint( | |||||||||||||||||||
| # Apply the checkpoint to the context | ||||||||||||||||||||
| await self._ctx.apply_checkpoint(checkpoint) | ||||||||||||||||||||
| # Mark the runner as resumed | ||||||||||||||||||||
| self._mark_resumed(checkpoint.iteration_count) | ||||||||||||||||||||
| self._mark_resumed(checkpoint.iteration_count, checkpoint.checkpoint_id) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}") | ||||||||||||||||||||
| except WorkflowCheckpointException: | ||||||||||||||||||||
|
|
@@ -351,13 +355,14 @@ def _parse_edge_runners(self, edge_runners: list[EdgeRunner]) -> dict[str, list[ | |||||||||||||||||||
|
|
||||||||||||||||||||
| return parsed | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _mark_resumed(self, iteration: int) -> None: | ||||||||||||||||||||
| def _mark_resumed(self, iteration: int, checkpoint_id: CheckpointID) -> None: | ||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring is now slightly inaccurate—
Suggested change
|
||||||||||||||||||||
| """Mark the runner as having resumed from a checkpoint. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Optionally set the current iteration and max iterations. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| self._resumed_from_checkpoint = True | ||||||||||||||||||||
| self._iteration = iteration | ||||||||||||||||||||
| self._resume_parent_checkpoint_id = checkpoint_id | ||||||||||||||||||||
|
|
||||||||||||||||||||
| async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None: | ||||||||||||||||||||
| """Store executor state in state under a reserved key. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -336,6 +336,94 @@ async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None: | |
| ) | ||
|
|
||
|
|
||
| async def test_resumed_workflow_keeps_previous_checkpoint_id_chain(): | ||
| """New checkpoints created after resume should chain to the restored checkpoint.""" | ||
| from typing_extensions import Never | ||
|
|
||
| from agent_framework import WorkflowBuilder, WorkflowContext, handler | ||
| from agent_framework._workflows._executor import Executor | ||
|
|
||
| class StartExecutor(Executor): | ||
| @handler | ||
| async def run(self, message: str, ctx: WorkflowContext[str]) -> None: | ||
| await ctx.send_message(message, target_id="middle") | ||
|
|
||
| class MiddleExecutor(Executor): | ||
| @handler | ||
| async def process(self, message: str, ctx: WorkflowContext[str]) -> None: | ||
| await ctx.send_message(message + "-processed", target_id="finish") | ||
|
|
||
| class FinishExecutor(Executor): | ||
| @handler | ||
| async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None: | ||
| await ctx.yield_output(message + "-done") | ||
|
|
||
| storage = InMemoryCheckpointStorage() | ||
|
|
||
| start = StartExecutor(id="start") | ||
| middle = MiddleExecutor(id="middle") | ||
| finish = FinishExecutor(id="finish") | ||
|
|
||
| workflow_name = "resume-chain-workflow" | ||
|
|
||
| workflow = ( | ||
| WorkflowBuilder( | ||
| max_iterations=10, | ||
| name=workflow_name, | ||
| start_executor=start, | ||
| checkpoint_storage=storage, | ||
| ) | ||
| .add_edge(start, middle) | ||
| .add_edge(middle, finish) | ||
| .build() | ||
| ) | ||
|
|
||
| _ = [event async for event in workflow.run("hello", stream=True)] | ||
|
|
||
| initial_checkpoints = sorted( | ||
| await storage.list_checkpoints(workflow_name=workflow.name), | ||
| key=lambda c: c.timestamp, | ||
| ) | ||
| assert len(initial_checkpoints) >= 2, ( | ||
| f"Expected at least 2 checkpoints before resume, got {len(initial_checkpoints)}" | ||
| ) | ||
|
|
||
| restore_checkpoint = initial_checkpoints[1] | ||
| initial_ids = {checkpoint.checkpoint_id for checkpoint in initial_checkpoints} | ||
|
|
||
| resumed_start = StartExecutor(id="start") | ||
| resumed_middle = MiddleExecutor(id="middle") | ||
| resumed_finish = FinishExecutor(id="finish") | ||
|
|
||
| resumed_workflow = ( | ||
| WorkflowBuilder( | ||
| max_iterations=10, | ||
| name=workflow_name, | ||
| start_executor=resumed_start, | ||
| checkpoint_storage=storage, | ||
| ) | ||
| .add_edge(resumed_start, resumed_middle) | ||
| .add_edge(resumed_middle, resumed_finish) | ||
| .build() | ||
| ) | ||
|
|
||
| _ = [event async for event in resumed_workflow.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True)] | ||
|
|
||
| all_checkpoints = sorted( | ||
| await storage.list_checkpoints(workflow_name=workflow.name), | ||
| key=lambda c: c.timestamp, | ||
| ) | ||
| resumed_checkpoints = [ | ||
| checkpoint for checkpoint in all_checkpoints if checkpoint.checkpoint_id not in initial_ids | ||
| ] | ||
|
|
||
| assert resumed_checkpoints, "Expected at least one new checkpoint after resume" | ||
| assert resumed_checkpoints[0].previous_checkpoint_id == restore_checkpoint.checkpoint_id | ||
|
|
||
| for i in range(1, len(resumed_checkpoints)): | ||
| assert resumed_checkpoints[i].previous_checkpoint_id == resumed_checkpoints[i - 1].checkpoint_id | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider extending this test (or adding a new one) to also verify the reset behavior: after the resumed run completes, do a fresh |
||
|
|
||
| async def test_memory_checkpoint_storage_roundtrip_json_native_types(): | ||
| """Test that JSON-native types (str, int, float, bool, None) roundtrip correctly.""" | ||
| storage = InMemoryCheckpointStorage() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition is logically equivalent to just
self._resume_parent_checkpoint_idbecause the two fields are always set and cleared together —_resume_parent_checkpoint_idis non-None iff_resumed_from_checkpointis True. The guard adds noise without adding safety. Simplifying to the single expression also makes it obvious that the ID is the only thing that matters here.