feat: Add CheckpointService for agent state persistence#4586
feat: Add CheckpointService for agent state persistence#4586drahnreb wants to merge 5 commits intogoogle:mainfrom
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Summary of ChangesHello @drahnreb, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the robustness and manageability of long-running agent workflows by introducing a dedicated Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This is an impressive and comprehensive pull request that introduces the CheckpointService and a powerful GraphAgent with extensive features like conditional routing, parallel execution, interrupts, and checkpointing. The inclusion of detailed design documents, architectural patterns, and a wide array of examples is highly commendable and will be of great value to users. The code is well-structured and shows careful consideration for edge cases, security, and performance.
My review includes a few suggestions for improving maintainability by refactoring some complex methods and correcting minor inconsistencies in documentation and scripts. One notable change is the removal of Context from the public API in google.adk.__init__, which is a breaking change that should probably be highlighted in the PR description for visibility.
| "01_basic" | ||
| "02_conditional_routing" | ||
| "03_cyclic_execution" | ||
| "03_enhanced_routing" |
There was a problem hiding this comment.
There was a problem hiding this comment.
Fixed. The run_all_examples.sh numbering has been corrected to match the actual example directory names (01_basic through 15_enhanced_routing).
| "user_content": types.Content( | ||
| role="user", | ||
| parts=[types.Part(text=task_data)] | ||
| ) |
There was a problem hiding this comment.
There was a problem hiding this comment.
Acknowledged: types is imported via google.genai.types in all sample files.
| # GraphAgent Design Document | ||
|
|
||
| **Author**: ADK Team | ||
| **Date**: 2026-01-25 |
| # InterruptService Architecture | ||
|
|
||
| **Author**: ADK Team | ||
| **Date**: 2026-02-04 |
| | 12_parallel_checkpointing | ✅ | - | ✅ | - | - | - | Parallel | | ||
| | 13_parallel_interrupts | ✅ | - | - | ✅ | - | - | Parallel | | ||
| | 14_parallel_rewind | ✅ | ✅ | - | - | - | - | Parallel | | ||
| | 15_enhanced_routing | - | - | - | - | - | - | Advanced | |
There was a problem hiding this comment.
Fixed: numbering corrected in run_all_examples.sh.
| from .graph import END | ||
| from .graph import GraphAgent | ||
| from .graph import GraphNode | ||
| from .graph import GraphState | ||
| from .graph import START |
There was a problem hiding this comment.
It's great to see the new GraphAgent components being exported. However, I noticed that Context is no longer exported from this package's __all__ list (and also from the top-level google.adk package). This is a significant breaking change for users who might be importing Context from here. It would be beneficial to mention this change in the pull request description to ensure users are aware of the API update.
There was a problem hiding this comment.
Context is imported and exported in init.py at lines 16 and 36.
| async def _run_async_impl( | ||
| self, ctx: InvocationContext | ||
| ) -> AsyncGenerator[Event, None]: | ||
| """Core graph execution logic. | ||
|
|
||
| Executes nodes in graph order, following conditional edges, | ||
| supporting loops and human-in-the-loop interrupts. | ||
|
|
||
| Args: | ||
| ctx: Invocation context | ||
|
|
||
| Yields: | ||
| Events from graph execution | ||
|
|
||
| Raises: | ||
| ValueError: If start node not set or graph structure invalid | ||
| """ | ||
| if not self.start_node: | ||
| raise ValueError("Start node not set. Call set_start() first.") | ||
|
|
||
| # Register session with InterruptService if enabled | ||
| if self.interrupt_service: | ||
| self.interrupt_service.register_session(ctx.session.id) | ||
|
|
||
| # Get effective telemetry config for nested graph inheritance | ||
| effective_config = self._get_effective_telemetry_config(ctx) | ||
|
|
||
| with tracer.start_as_current_span( | ||
| f"graph_agent_execution {self.name}" | ||
| ) as span: | ||
| span.set_attribute("graph_agent.name", self.name) | ||
| span.set_attribute("graph_agent.start_node", self.start_node) | ||
| span.set_attribute("graph_agent.max_iterations", self.max_iterations) | ||
| try: | ||
| # Load execution tracking state (BaseAgentState pattern) | ||
| agent_state = ( | ||
| self._load_agent_state(ctx, GraphAgentState) or GraphAgentState() | ||
| ) | ||
|
|
||
| # Store telemetry config for nested graph inheritance | ||
| if effective_config: | ||
| agent_state.telemetry_config_dict = effective_config.model_dump() | ||
|
|
||
| # Initialize domain data from session state or user input. | ||
| # Exclude graph-internal keys to prevent circular references | ||
| # (state.data["graph_data"] → state.data) and keep domain data clean. | ||
| domain_data = { | ||
| k: v | ||
| for k, v in ctx.session.state.items() | ||
| if not k.startswith("_") and k not in _GRAPH_INTERNAL_KEYS | ||
| } | ||
| if domain_data: | ||
| state = GraphState(data=domain_data) | ||
| else: | ||
| # Extract text from Content object | ||
| user_text = "" | ||
| if ( | ||
| hasattr(ctx, "user_content") | ||
| and ctx.user_content | ||
| and ctx.user_content.parts | ||
| ): | ||
| user_text = ( | ||
| ctx.user_content.parts[0].text | ||
| if ctx.user_content.parts[0].text | ||
| else "" | ||
| ) | ||
| state = GraphState(data={"input": user_text}) | ||
|
|
||
| # Track which parallel groups have been executed | ||
| executed_parallel_groups = set(agent_state.executed_parallel_groups) | ||
|
|
||
| # ADK resumability: resume from saved node or start fresh. | ||
| # | ||
| # Design note: SequentialAgent ONLY emits state events when | ||
| # ctx.is_resumable is True, because its state events serve only | ||
| # resumability. GraphAgent's state events serve multiple consumers | ||
| # (rewind, interrupts, telemetry) that are orthogonal to | ||
| # resumability. Therefore: | ||
| # - Per-iteration state events: always emitted (multi-consumer) | ||
| # - Resume skip: first iteration skipped when resuming (already | ||
| # persisted before pause, avoids duplicate) | ||
| # - end_of_agent: guarded by is_resumable (purely a resumability | ||
| # lifecycle signal, has no other consumers) | ||
| # - Interrupt/cancellation state saves: always emitted (they | ||
| # serve interrupt functionality, not just resumability) | ||
| current_node_name, iteration, resuming = self._get_resume_state( | ||
| agent_state | ||
| ) | ||
| pause_invocation = False | ||
|
|
||
| while current_node_name and iteration < self.max_iterations: | ||
| iteration += 1 | ||
| current_node = self.nodes[current_node_name] | ||
|
|
||
| # Check for immediate cancellation (ESC-like interrupt) | ||
| # Allows user to abort execution at any time, not just at pause points | ||
| if self.interrupt_service and not self.interrupt_service.is_active( | ||
| ctx.session.id | ||
| ): | ||
| logger.info( | ||
| "GraphAgent execution cancelled (immediate interrupt) for" | ||
| f" session {ctx.session.id}" | ||
| ) | ||
| # Save partial state before cancelling (enables resume/restart) | ||
| ctx.set_agent_state(self.name, agent_state=agent_state) | ||
| yield self._create_agent_state_event(ctx) | ||
| yield Event( | ||
| author=self.name, | ||
| content=types.Content( | ||
| parts=[types.Part(text="⚠️ Execution cancelled by user")] | ||
| ), | ||
| actions=EventActions( | ||
| escalate=False, | ||
| state_delta={ | ||
| "graph_cancelled": True, | ||
| "graph_cancelled_at_node": current_node_name, | ||
| "graph_iteration": iteration, | ||
| "graph_data": state.data, | ||
| "graph_path": list(agent_state.path), | ||
| "graph_can_resume": True, | ||
| }, | ||
| ), | ||
| ) | ||
| break # Exit immediately but state is saved | ||
|
|
||
| # Track execution path in agent_state | ||
| agent_state.path.append(current_node_name) | ||
| agent_state.iteration = iteration | ||
| agent_state.current_node = current_node_name | ||
| agent_state.node_invocations.setdefault(current_node_name, []).append( | ||
| ctx.invocation_id | ||
| ) | ||
|
|
||
| # ADK resumability: reset sub-agent states on cycle revisit | ||
| # (mirrors LoopAgent pattern at loop_agent.py:114) | ||
| if ( | ||
| agent_state.path.count(current_node_name) > 1 | ||
| and current_node.agent | ||
| ): | ||
| ctx.reset_sub_agent_states(current_node.agent.name) | ||
|
|
||
| # Track agent path for nested graph support | ||
| if self.name not in agent_state.agent_path: | ||
| agent_state.agent_path.append(self.name) | ||
|
|
||
| # Persist execution tracking via agent_state event. | ||
| # These events are consumed by rewind, interrupts, and telemetry | ||
| # (not just resumability), so they're always emitted. | ||
| # Skip only on first iteration when resuming (already persisted). | ||
| if not resuming: | ||
| ctx.set_agent_state(self.name, agent_state=agent_state) | ||
| yield self._create_agent_state_event(ctx) | ||
| else: | ||
| resuming = False # Only skip first iteration after resume | ||
|
|
||
| # Invoke before_node_callback (custom observability) | ||
| if self.before_node_callback: | ||
| from .callbacks import NodeCallbackContext | ||
|
|
||
| callback_ctx = NodeCallbackContext( | ||
| node=current_node, | ||
| state=state, | ||
| iteration=iteration, | ||
| invocation_context=ctx, | ||
| metadata={ | ||
| "agent_path": list(agent_state.agent_path), | ||
| "path": list(agent_state.path), | ||
| }, | ||
| ) | ||
|
|
||
| # Execute callback with telemetry | ||
| callback_start_time = time.time() | ||
| with graph_tracing.tracer.start_as_current_span( | ||
| "graph_callback before_node" | ||
| ) as cb_span: | ||
| # Add attributes with additional_attributes support | ||
| attrs = self._get_telemetry_attributes( | ||
| { | ||
| graph_tracing.GRAPH_CALLBACK_TYPE: "before_node", | ||
| graph_tracing.GRAPH_AGENT_NAME: self.name, | ||
| graph_tracing.GRAPH_NODE_NAME: current_node_name, | ||
| }, | ||
| effective_config=effective_config, | ||
| ) | ||
| for key, value in attrs.items(): | ||
| cb_span.set_attribute(key, value) | ||
|
|
||
| try: | ||
| event = await self.before_node_callback(callback_ctx) | ||
| if event: | ||
| yield event | ||
|
|
||
| # Record success (check sampling) | ||
| cb_span.set_attribute("graph.callback.success", True) | ||
| if self._should_sample(effective_config=effective_config): | ||
| callback_latency_ms = ( | ||
| time.time() - callback_start_time | ||
| ) * 1000 | ||
| graph_tracing.record_callback_execution( | ||
| callback_type="before_node", | ||
| agent_name=self.name, | ||
| latency_ms=callback_latency_ms, | ||
| success=True, | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| # Record failure (check sampling) | ||
| cb_span.set_attribute("graph.callback.success", False) | ||
| cb_span.set_attribute("graph.callback.error", str(e)) | ||
| if self._should_sample(effective_config=effective_config): | ||
| callback_latency_ms = ( | ||
| time.time() - callback_start_time | ||
| ) * 1000 | ||
| graph_tracing.record_callback_execution( | ||
| callback_type="before_node", | ||
| agent_name=self.name, | ||
| latency_ms=callback_latency_ms, | ||
| success=False, | ||
| ) | ||
| logger.error( | ||
| "before_node_callback failed for node" | ||
| f" '{current_node_name}': {e}", | ||
| exc_info=True, | ||
| ) | ||
| # Continue execution despite callback error | ||
|
|
||
| # Handle BEFORE-node interrupt (validation timing) | ||
| if ( | ||
| self._should_interrupt_before(current_node_name) | ||
| and self.interrupt_service | ||
| ): | ||
| _b_events, _b_ctrl = await self._handle_before_node_interrupt( | ||
| current_node_name, current_node, state, ctx, agent_state | ||
| ) | ||
| for _e in _b_events: | ||
| yield _e | ||
| # Persist agent_state after interrupt handler may have mutated it | ||
| ctx.set_agent_state(self.name, agent_state=agent_state) | ||
| yield self._create_agent_state_event(ctx) | ||
| if _b_ctrl == "break": | ||
| break | ||
| elif _b_ctrl is not None: | ||
| if isinstance(_b_ctrl, tuple): | ||
| current_node_name = _b_ctrl[1] | ||
| continue | ||
|
|
||
| # Check if current node is part of a parallel group | ||
| parallel_group_info = self._find_parallel_group(current_node_name) | ||
| if parallel_group_info: | ||
| group_id, parallel_group = parallel_group_info | ||
|
|
||
| # Check if this group has already been executed | ||
| if group_id in executed_parallel_groups: | ||
| # Group already executed, skip this node | ||
| logger.info( | ||
| f"Skipping node '{current_node_name}' - already executed as" | ||
| f" part of parallel group '{group_id}'" | ||
| ) | ||
| # Route to next node from this node's edges | ||
| next_node_name = self._get_next_node_with_telemetry( | ||
| current_node, state | ||
| ) | ||
| if next_node_name is None: | ||
| if current_node_name in self.end_nodes: | ||
| break | ||
| else: | ||
| raise ValueError( | ||
| f"Node {current_node_name} has no outgoing edges and is" | ||
| " not an end node" | ||
| ) | ||
| current_node_name = next_node_name | ||
| continue | ||
|
|
||
| # Execute entire parallel group | ||
| logger.info( | ||
| f"Executing parallel group '{group_id}' with nodes:" | ||
| f" {parallel_group.nodes}" | ||
| ) | ||
|
|
||
| # Execute parallel group with telemetry | ||
| parallel_start_time = time.time() | ||
| with graph_tracing.tracer.start_as_current_span( | ||
| f"parallel_group {group_id}" | ||
| ) as pg_span: | ||
| # Add attributes with additional_attributes support | ||
| attrs = self._get_telemetry_attributes( | ||
| { | ||
| graph_tracing.GRAPH_PARALLEL_NODE_COUNT: len( | ||
| parallel_group.nodes | ||
| ), | ||
| graph_tracing.GRAPH_PARALLEL_STRATEGY: ( | ||
| parallel_group.join_strategy.value | ||
| ), | ||
| graph_tracing.GRAPH_PARALLEL_WAIT_N: ( | ||
| parallel_group.wait_n | ||
| ), | ||
| graph_tracing.GRAPH_AGENT_NAME: self.name, | ||
| }, | ||
| effective_config=effective_config, | ||
| ) | ||
| for key, value in attrs.items(): | ||
| pg_span.set_attribute(key, value) | ||
|
|
||
| # Collect all events from parallel execution | ||
| completed_count = 0 | ||
| async for event in execute_parallel_group( | ||
| parallel_group, | ||
| self.nodes, | ||
| state, | ||
| ctx, | ||
| self._execute_node, | ||
| ): | ||
| yield event | ||
| # Count completions (rough estimate based on events) | ||
| if event.author != self.name: | ||
| completed_count = min( | ||
| completed_count + 1, len(parallel_group.nodes) | ||
| ) | ||
|
|
||
| # Record parallel group metrics (check sampling) | ||
| pg_span.set_attribute( | ||
| "graph.parallel.completed_count", completed_count | ||
| ) | ||
| if self._should_sample(effective_config=effective_config): | ||
| parallel_latency_ms = (time.time() - parallel_start_time) * 1000 | ||
| graph_tracing.record_parallel_group_execution( | ||
| agent_name=self.name, | ||
| node_count=len(parallel_group.nodes), | ||
| strategy=parallel_group.join_strategy.value, | ||
| latency_ms=parallel_latency_ms, | ||
| completed_count=completed_count, | ||
| ) | ||
|
|
||
| # Mark group as executed | ||
| executed_parallel_groups.add(group_id) | ||
| agent_state.executed_parallel_groups = list( | ||
| executed_parallel_groups | ||
| ) | ||
|
|
||
| # After parallel group completes, determine next node | ||
| # Use the current node's edges to determine routing | ||
| # (all nodes in group should have same outgoing edges) | ||
| next_node_name = self._get_next_node_with_telemetry( | ||
| current_node, state | ||
| ) | ||
|
|
||
| if next_node_name is None: | ||
| # No more nodes after parallel group | ||
| if current_node_name in self.end_nodes: | ||
| break | ||
| else: | ||
| raise ValueError( | ||
| f"Parallel group '{group_id}' has no outgoing edges and" | ||
| f" node '{current_node_name}' is not an end node" | ||
| ) | ||
|
|
||
| current_node_name = next_node_name | ||
| continue # Skip individual node execution, continue to next iteration | ||
|
|
||
| # Execute node with immediate cancellation support | ||
| # Check cancellation while streaming events from node execution | ||
| output_holder: Dict[str, Any] = {"output": ""} | ||
| try: | ||
| async for event in self._execute_node( | ||
| current_node, | ||
| state, | ||
| ctx, | ||
| effective_config, | ||
| output_holder=output_holder, | ||
| iteration=iteration, | ||
| ): | ||
| # Check for immediate cancellation DURING node execution | ||
| if ( | ||
| self.interrupt_service | ||
| and not self.interrupt_service.is_active(ctx.session.id) | ||
| ): | ||
| logger.info( | ||
| "GraphAgent execution cancelled (immediate interrupt" | ||
| f" during node '{current_node_name}') for session" | ||
| f" {ctx.session.id}" | ||
| ) | ||
| ctx.set_agent_state(self.name, agent_state=agent_state) | ||
| yield self._create_agent_state_event(ctx) | ||
| yield Event( | ||
| author=self.name, | ||
| content=types.Content( | ||
| parts=[ | ||
| types.Part( | ||
| text=( | ||
| "⚠️ Execution cancelled during node" | ||
| f" '{current_node_name}'" | ||
| ) | ||
| ) | ||
| ] | ||
| ), | ||
| actions=EventActions( | ||
| escalate=False, | ||
| state_delta={ | ||
| "graph_cancelled": True, | ||
| "graph_cancelled_at_node": current_node_name, | ||
| "graph_data": state.data, | ||
| "graph_partial_output": output_holder["output"], | ||
| "graph_can_resume": True, | ||
| }, | ||
| ), | ||
| ) | ||
| return | ||
| yield event | ||
| except asyncio.CancelledError: | ||
| # Task cancelled externally (e.g., timeout, user abort) | ||
| logger.info( | ||
| f"GraphAgent task cancelled during node '{current_node_name}'" | ||
| f" for session {ctx.session.id}" | ||
| ) | ||
| ctx.set_agent_state(self.name, agent_state=agent_state) | ||
| yield self._create_agent_state_event(ctx) | ||
| yield Event( | ||
| author=self.name, | ||
| content=types.Content( | ||
| parts=[ | ||
| types.Part( | ||
| text=( | ||
| "⚠️ Task cancelled during node" | ||
| f" '{current_node_name}'" | ||
| ) | ||
| ) | ||
| ] | ||
| ), | ||
| actions=EventActions( | ||
| escalate=False, | ||
| state_delta={ | ||
| "graph_task_cancelled": True, | ||
| "graph_cancelled_at_node": current_node_name, | ||
| "graph_data": state.data, | ||
| "graph_partial_output": output_holder["output"], | ||
| "graph_can_resume": True, | ||
| }, | ||
| ), | ||
| ) | ||
| raise | ||
|
|
||
| # ADK resumability: check if node execution was paused | ||
| if output_holder.get("pause"): | ||
| pause_invocation = True | ||
| return | ||
|
|
||
| # Sync session state into GraphState.data FIRST so that | ||
| # output_mapper receives the most up-to-date values and can | ||
| # override them. Agents write routing signals via state_delta | ||
| # (the ADK-standard pattern); this sync makes those values | ||
| # visible to edge condition lambdas (which receive GraphState) | ||
| # without requiring an explicit output_mapper. | ||
| # Internal keys (prefix '_') are excluded. | ||
| for _sk, _sv in ctx.session.state.items(): | ||
| if not _sk.startswith("_") and _sk not in _GRAPH_INTERNAL_KEYS: | ||
| state.data[_sk] = _sv | ||
|
|
||
| # Update state with node output (output_mapper runs AFTER | ||
| # session sync, so it can override synced values when needed) | ||
| output = output_holder["output"] | ||
| if output: | ||
| # Track state before reduction for telemetry | ||
| had_previous_value = current_node.name in state.data | ||
| reducer_start = time.time() | ||
|
|
||
| # Apply output mapper with reducer | ||
| prev_state = state | ||
| state = current_node.output_mapper(output, state) | ||
| if state is None: | ||
| # Custom output_mapper mutated in-place but forgot to return | ||
| state = prev_state | ||
|
|
||
| # Record state reducer telemetry (check sampling) | ||
| reducer_latency_ms = (time.time() - reducer_start) * 1000 | ||
| if self._should_sample(effective_config=effective_config): | ||
| graph_tracing.record_state_reducer( | ||
| node_name=current_node.name, | ||
| reducer_type=current_node.reducer.name, | ||
| state_key=current_node.name, | ||
| agent_name=self.name, | ||
| latency_ms=reducer_latency_ms, | ||
| had_previous_value=had_previous_value, | ||
| ) | ||
|
|
||
| # Record output mapper telemetry | ||
| is_default_mapper = ( | ||
| current_node.output_mapper.__name__ | ||
| == "_default_output_mapper" | ||
| ) | ||
| graph_tracing.record_mapper( | ||
| node_name=current_node.name, | ||
| mapper_type="output", | ||
| agent_name=self.name, | ||
| latency_ms=reducer_latency_ms, | ||
| is_default=is_default_mapper, | ||
| ) | ||
|
|
||
| # Invoke after_node_callback (custom observability) | ||
| if self.after_node_callback: | ||
| from .callbacks import NodeCallbackContext | ||
|
|
||
| callback_ctx = NodeCallbackContext( | ||
| node=current_node, | ||
| state=state, | ||
| iteration=iteration, | ||
| invocation_context=ctx, | ||
| metadata={ | ||
| "output": output, | ||
| "agent_path": list(agent_state.agent_path), | ||
| "path": list(agent_state.path), | ||
| }, | ||
| ) | ||
|
|
||
| # Execute callback with telemetry | ||
| callback_start_time = time.time() | ||
| with graph_tracing.tracer.start_as_current_span( | ||
| "graph_callback after_node" | ||
| ) as cb_span: | ||
| # Add attributes with additional_attributes support | ||
| attrs = self._get_telemetry_attributes( | ||
| { | ||
| graph_tracing.GRAPH_CALLBACK_TYPE: "after_node", | ||
| graph_tracing.GRAPH_AGENT_NAME: self.name, | ||
| graph_tracing.GRAPH_NODE_NAME: current_node_name, | ||
| }, | ||
| effective_config=effective_config, | ||
| ) | ||
| for key, value in attrs.items(): | ||
| cb_span.set_attribute(key, value) | ||
|
|
||
| try: | ||
| event = await self.after_node_callback(callback_ctx) | ||
| if event: | ||
| yield event | ||
|
|
||
| # Record success (check sampling) | ||
| cb_span.set_attribute("graph.callback.success", True) | ||
| if self._should_sample(effective_config=effective_config): | ||
| callback_latency_ms = ( | ||
| time.time() - callback_start_time | ||
| ) * 1000 | ||
| graph_tracing.record_callback_execution( | ||
| callback_type="after_node", | ||
| agent_name=self.name, | ||
| latency_ms=callback_latency_ms, | ||
| success=True, | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| # Record failure (check sampling) | ||
| cb_span.set_attribute("graph.callback.success", False) | ||
| cb_span.set_attribute("graph.callback.error", str(e)) | ||
| if self._should_sample(effective_config=effective_config): | ||
| callback_latency_ms = ( | ||
| time.time() - callback_start_time | ||
| ) * 1000 | ||
| graph_tracing.record_callback_execution( | ||
| callback_type="after_node", | ||
| agent_name=self.name, | ||
| latency_ms=callback_latency_ms, | ||
| success=False, | ||
| ) | ||
| logger.error( | ||
| "after_node_callback failed for node" | ||
| f" '{current_node_name}': {e}", | ||
| exc_info=True, | ||
| ) | ||
| # Continue execution despite callback error | ||
|
|
||
| # Emit graph metadata event for evaluation framework | ||
| # This will be captured in Invocation.intermediate_data by EvaluationGenerator | ||
| # Set partial=True so is_final_response() returns False (making it an intermediate event) | ||
| graph_metadata = { | ||
| "graph_node": current_node_name, | ||
| "graph_iteration": iteration, | ||
| "graph_path": list(agent_state.path), | ||
| "node_invocations": { | ||
| name: len(invocs) | ||
| for name, invocs in agent_state.node_invocations.items() | ||
| }, | ||
| "graph_state": dict(state.data), | ||
| } | ||
| yield Event( | ||
| author=f"{self.name}#metadata", | ||
| content=types.Content( | ||
| parts=[types.Part(text=f"[GraphMetadata] {graph_metadata}")] | ||
| ), | ||
| partial=True, # Mark as intermediate event | ||
| ) | ||
|
|
||
| # Handle AFTER-node interrupt (retrospective feedback timing) | ||
| # This enables retrospective feedback: observe past, steer future | ||
| if ( | ||
| self._should_interrupt_after(current_node_name) | ||
| and self.interrupt_service | ||
| ): | ||
| _a_events, _a_ctrl = await self._handle_after_node_interrupt( | ||
| current_node_name, state, ctx, agent_state | ||
| ) | ||
| for _e in _a_events: | ||
| yield _e | ||
| # Persist agent_state after interrupt handler may have mutated it | ||
| ctx.set_agent_state(self.name, agent_state=agent_state) | ||
| yield self._create_agent_state_event(ctx) | ||
| if _a_ctrl == "break": | ||
| break | ||
| elif _a_ctrl is not None: | ||
| if isinstance(_a_ctrl, tuple): | ||
| current_node_name = _a_ctrl[1] | ||
| continue | ||
|
|
||
| # Checkpointing - yield event with state_delta to persist checkpoint | ||
| # Note: For full checkpoint/resume functionality, use CheckpointCallback | ||
| if self.checkpointing: | ||
| ctx.set_agent_state(self.name, agent_state=agent_state) | ||
| yield self._create_agent_state_event(ctx) | ||
| yield Event( | ||
| author=self.name, | ||
| content=types.Content( | ||
| parts=[types.Part(text=f"Checkpoint: {current_node_name}")] | ||
| ), | ||
| actions=EventActions( | ||
| state_delta={ | ||
| "graph_data": state.data, | ||
| "graph_checkpoint": { | ||
| "node": current_node_name, | ||
| "iteration": iteration, | ||
| }, | ||
| } | ||
| ), | ||
| ) | ||
|
|
||
| # Inject transient execution data for edge conditions | ||
| state.data["_graph_iteration"] = agent_state.iteration | ||
| state.data["_graph_path"] = list(agent_state.path) | ||
| state.data["_conditions"] = dict(agent_state.conditions) | ||
|
|
||
| # Get next node via conditional routing | ||
| next_node_name = self._get_next_node_with_telemetry( | ||
| current_node, state | ||
| ) | ||
|
|
||
| # Clean up transient keys | ||
| for _tk in ("_graph_iteration", "_graph_path", "_conditions"): | ||
| state.data.pop(_tk, None) | ||
| if next_node_name is None: | ||
| # No more edges - check if we're at an end node | ||
| if current_node_name in self.end_nodes: | ||
| break | ||
| else: | ||
| # Not at an end node and no edges - error | ||
| raise ValueError( | ||
| f"Node {current_node_name} has no outgoing edges and is not" | ||
| " an end node" | ||
| ) | ||
|
|
||
| current_node_name = next_node_name | ||
|
|
||
| # Record iteration metrics (check sampling) | ||
| if self._should_sample(effective_config=effective_config): | ||
| graph_tracing.record_graph_iteration( | ||
| agent_name=self.name, | ||
| iteration=iteration, | ||
| path_length=len(agent_state.path), | ||
| ) | ||
|
|
||
| # ADK resumability: skip final response + end_of_agent when paused | ||
| if not pause_invocation: | ||
| # Final response - yield event with graph metadata | ||
| # Include last node's output ONLY if: | ||
| # 1. explicit final_output is set, OR | ||
| # 2. last node was a function (doesn't yield events, so we need to show output) | ||
| # Don't include output for agent nodes (they already yielded their output) | ||
| final_output = state.data.get("final_output", "") | ||
|
|
||
| # If no explicit final_output, check if last node was a function | ||
| if not final_output and current_node_name: | ||
| last_node = self.nodes.get(current_node_name) | ||
| if last_node and last_node.function: | ||
| # Function node - include its output | ||
| final_output = state.data.get(current_node_name, "") | ||
|
|
||
| response_text = f"{final_output}" | ||
|
|
||
| yield Event( | ||
| author=self.name, | ||
| content=types.Content(parts=[types.Part(text=response_text)]), | ||
| actions=EventActions( | ||
| state_delta={ | ||
| "graph_data": state.data, | ||
| "graph_iterations": iteration, | ||
| "graph_path": list(agent_state.path), | ||
| } | ||
| ), | ||
| ) | ||
| # end_of_agent is guarded by is_resumable because it is purely a | ||
| # resumability lifecycle signal (tells the runner "this agent is | ||
| # done, don't re-run it on resume"). Unlike per-iteration state | ||
| # events which serve rewind/interrupts/telemetry, end_of_agent | ||
| # has no other consumers. | ||
| if ctx.is_resumable: | ||
| ctx.set_agent_state(self.name, end_of_agent=True) | ||
| yield self._create_agent_state_event(ctx) | ||
|
|
||
| finally: | ||
| # Unregister session from InterruptService and finalize tracing | ||
| if self.interrupt_service: | ||
| self.interrupt_service.unregister_session(ctx.session.id) | ||
| span.set_attribute("graph_agent.completed", True) | ||
|
|
||
| # Interrupt methods inherited from GraphInterruptMixin |
There was a problem hiding this comment.
The _run_async_impl method is quite long and complex, handling various aspects like state initialization, iteration, interrupts, callbacks, checkpointing, and parallel execution. For better readability and maintainability, consider refactoring this method by extracting some of its logic into smaller, more focused helper methods. For instance, the callback invocation logic, which is similar for before_node_callback and after_node_callback, could be a good candidate for a helper function.
There was a problem hiding this comment.
Fixed. Extracted 4 helper methods from _run_async_impl: _execute_callback, _sync_state_and_reduce, _build_cancellation_events, and _execute_parallel_phase. Method reduced from ~541 to ~395 lines. The helpers encapsulate distinct concerns: callback execution, state synchronization, cancellation event building, and parallel phase orchestration.
| async def _handle_before_node_interrupt( | ||
| self, | ||
| current_node_name: str, | ||
| current_node: GraphNode, | ||
| state: GraphState, | ||
| ctx: InvocationContext, | ||
| agent_state: GraphAgentState, | ||
| ) -> Tuple[List[Event], str | Tuple[str, str] | None]: | ||
| """Handle a BEFORE-node interrupt and return events + routing control. | ||
|
|
||
| Args: | ||
| current_node_name: Name of the node about to execute. | ||
| current_node: The GraphNode about to execute (needed for "skip"). | ||
| state: Current graph state. | ||
| ctx: Invocation context. | ||
| agent_state: Execution tracking state. | ||
|
|
||
| Returns: | ||
| Tuple of (events_to_yield, control) where control is: | ||
| - None: proceed to normal node execution. | ||
| - "rerun": re-run current node (continue the loop). | ||
| - "break": exit the main loop immediately. | ||
| - ("go_back", target_node): jump to target_node. | ||
| - ("skip", next_node | None): skip node, route to next_node. | ||
| """ | ||
| assert self.interrupt_service is not None | ||
| interrupt_message = await self._check_interrupt_with_telemetry( | ||
| ctx.session.id, "before" | ||
| ) | ||
| if not interrupt_message: | ||
| return [], None | ||
|
|
||
| action_result = await self._process_interrupt_message( | ||
| interrupt_message, state, current_node_name, ctx, agent_state | ||
| ) | ||
|
|
||
| should_escalate = ( | ||
| action_result == "pause" | ||
| if isinstance(action_result, str) | ||
| else (isinstance(action_result, tuple) and action_result[0] == "pause") | ||
| ) | ||
|
|
||
| event = Event( | ||
| author=self.name, | ||
| content=types.Content( | ||
| parts=[ | ||
| types.Part( | ||
| text=( | ||
| "\U0001f6d1 INTERRUPT (BEFORE):" | ||
| f" {interrupt_message.text}" | ||
| ) | ||
| ) | ||
| ] | ||
| ), | ||
| actions=EventActions( | ||
| escalate=should_escalate, | ||
| state_delta={ | ||
| "interrupt_message": interrupt_message.text, | ||
| "interrupt_timing": "before", | ||
| "interrupt_node": current_node_name, | ||
| }, | ||
| ), | ||
| ) | ||
|
|
||
| if isinstance(action_result, tuple): | ||
| action, target_node = action_result | ||
| if action == "go_back": | ||
| return [event], ("go_back", target_node) | ||
| elif action_result == "rerun": | ||
| return [event], "rerun" | ||
| elif action_result == "skip": | ||
| next_node_name = self._get_next_node_with_telemetry(current_node, state) # type: ignore[attr-defined] | ||
| return ( | ||
| [event], | ||
| ("skip", next_node_name) if next_node_name else "break", | ||
| ) | ||
| elif action_result == "pause": | ||
| try: | ||
| resumed = await self.interrupt_service.wait_if_paused(ctx.session.id) | ||
| if not resumed: | ||
| return [event], "break" | ||
| except TimeoutError: | ||
| return [event], "break" | ||
|
|
||
| return [event], None | ||
|
|
||
| async def _handle_after_node_interrupt( | ||
| self, | ||
| current_node_name: str, | ||
| state: GraphState, | ||
| ctx: InvocationContext, | ||
| agent_state: GraphAgentState, | ||
| ) -> Tuple[List[Event], str | Tuple[str, str] | None]: | ||
| """Handle an AFTER-node interrupt and return events + routing control. | ||
|
|
||
| Args: | ||
| current_node_name: Name of the node that just executed. | ||
| state: Current graph state (includes the node's output). | ||
| ctx: Invocation context. | ||
| agent_state: Execution tracking state. | ||
|
|
||
| Returns: | ||
| Tuple of (events_to_yield, control) where control is: | ||
| - None: accept results and proceed to next node. | ||
| - "rerun": re-run current node. | ||
| - "break": exit the main loop. | ||
| - ("go_back", target_node): jump to target_node. | ||
| """ | ||
| assert self.interrupt_service is not None | ||
| interrupt_message = await self._check_interrupt_with_telemetry( | ||
| ctx.session.id, "after" | ||
| ) | ||
| if not interrupt_message: | ||
| return [], None | ||
|
|
||
| action_result = await self._process_interrupt_message( | ||
| interrupt_message, state, current_node_name, ctx, agent_state | ||
| ) | ||
|
|
||
| should_escalate = ( | ||
| action_result == "pause" | ||
| if isinstance(action_result, str) | ||
| else (isinstance(action_result, tuple) and action_result[0] == "pause") | ||
| ) | ||
|
|
||
| state_delta_dict: Dict[str, Any] = { | ||
| "interrupt_message": interrupt_message.text, | ||
| "interrupt_timing": "after", | ||
| "interrupt_metadata": interrupt_message.metadata, | ||
| "interrupt_action": interrupt_message.action, | ||
| "interrupt_node": current_node_name, | ||
| } | ||
|
|
||
| event = Event( | ||
| author=self.name, | ||
| content=types.Content( | ||
| parts=[ | ||
| types.Part( | ||
| text=( | ||
| "\U0001f6d1 INTERRUPT (AFTER):" | ||
| f" {interrupt_message.text}" | ||
| ) | ||
| ) | ||
| ] | ||
| ), | ||
| actions=EventActions( | ||
| escalate=should_escalate, state_delta=state_delta_dict | ||
| ), | ||
| ) | ||
|
|
||
| if isinstance(action_result, tuple): | ||
| action, target_node = action_result | ||
| if action == "go_back": | ||
| return [event], ("go_back", target_node) | ||
| elif action_result == "rerun": | ||
| return [event], "rerun" | ||
| elif action_result == "pause": | ||
| try: | ||
| resumed = await self.interrupt_service.wait_if_paused(ctx.session.id) | ||
| if not resumed: | ||
| logger.info( | ||
| "GraphAgent execution cancelled for session %s", | ||
| ctx.session.id, | ||
| ) | ||
| return [event], "break" | ||
| except TimeoutError as e: | ||
| logger.warning("Interrupt wait timeout: %s", e) | ||
| return [event], "break" | ||
|
|
||
| return [event], None | ||
|
|
There was a problem hiding this comment.
The methods _handle_before_node_interrupt and _handle_after_node_interrupt share a significant amount of logic. They both check for an interrupt message, process it, create an event, and then handle the action result. To reduce code duplication and improve maintainability, you could refactor this into a single helper method that takes the interrupt mode ('before' or 'after') as a parameter and handles the common logic, while keeping the mode-specific differences (like the skip action) inside conditional blocks.
There was a problem hiding this comment.
Already refactored: both methods delegate to _check_interrupt_with_telemetry, _process_interrupt_message, and _dispatch_interrupt_action. Remaining differences are intentional (BEFORE needs current_node for skip, AFTER includes metadata fields).
3282d2e to
5c2bdb5
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive suite of new examples and documentation for the GraphAgent framework, showcasing advanced features such as checkpointing, LLM-based interrupt reasoning, callback-based observability, flexible interrupt timings, immediate cancellation, dynamic task queues, agent-driven topology, Human-In-The-Loop (HITL) workflows, multi-agent coordination, and various parallel execution strategies (WAIT_ALL, WAIT_ANY, WAIT_N) with rewind integration. The changes include new READMEs, Python agent scripts, and YAML configurations for each example, demonstrating how to build complex, adaptive, and fault-tolerant agentic workflows. A review comment highlights a NameError in the scenario_5_all_interrupt_timings function within contributing/samples/graph_agent_advanced/agent.py, noting that paper_writer_agent and peer_reviewer_agent are not properly scoped and suggesting they be re-instantiated within the scenario function to resolve the issue.
| graph_before.add_node("write", agent=paper_writer_agent) | ||
| graph_before.add_node("peer_review", agent=peer_reviewer_agent) |
There was a problem hiding this comment.
The agents paper_writer_agent and peer_reviewer_agent are not defined within the scope of the scenario_5_all_interrupt_timings function. They are created inside build_research_workflow but are not accessible here, which will lead to a NameError when this scenario is executed.
To fix this, you should create new instances of these agents within the scenario, similar to how other scenarios are structured.
| graph_before.add_node("write", agent=paper_writer_agent) | |
| graph_before.add_node("peer_review", agent=peer_reviewer_agent) | |
| # Create fresh agents for this scenario to avoid parent conflicts. | |
| (_, _, _, paper_writer_agent, peer_reviewer_agent) = _create_research_agents() | |
| # Add simple nodes | |
| graph_before.add_node("write", agent=paper_writer_agent) | |
| graph_before.add_node("peer_review", agent=peer_reviewer_agent) |
There was a problem hiding this comment.
Fixed: scenario_5 now calls _create_research_agents() to get its own scoped references.
aa73841 to
7eee735
Compare
Addressing review feedbackForce-pushed with the following fixes: High:
Medium:
|
|
@gemini-code-assist please re-review |
There was a problem hiding this comment.
Code Review
This is an impressive pull request that introduces a powerful GraphAgent for complex workflow orchestration and a CheckpointService for state persistence. The implementation is robust, with thoughtful handling of concurrency, state management, and error conditions. The extensive set of new samples and documentation is particularly valuable for understanding and adopting these new features. My review includes a minor code cleanup suggestion and a documentation enhancement to make the new codebase even more discoverable.
| - `src/google/adk/agents/graph/graph_agent.py` - Core orchestration | ||
| - `src/google/adk/agents/graph/interrupt_reasoner.py` - LLM reasoning | ||
| - `src/google/adk/agents/graph/callbacks.py` - Callback infrastructure | ||
| - `src/google/adk/agents/graph/interrupt_service.py` - Interrupt management |
There was a problem hiding this comment.
There was a problem hiding this comment.
Fixed — added src/google/adk/checkpoints/checkpoint_service.py - Checkpoint management to the list.
| review = state.data.get("peer_review", {}) | ||
| # LLM agents store output as JSON string; parse if needed | ||
| if isinstance(review, str): | ||
| import json |
There was a problem hiding this comment.
Fixed — removed duplicate import json (already imported at line 16).
Add GraphAgent for building directed-graph workflows with conditional routing, cyclic execution, state management with reducers, typed events, streaming, callbacks, rewind, resumability, telemetry with OpenTelemetry tracing, evaluation metrics, and CLI graph visualization for GraphAgent topologies. Includes samples and design documentation.
Add DynamicNode (runtime agent selection), NestedGraphNode (hierarchical workflow composition), and DynamicParallelGroup (variable-count concurrent execution). Extends CLI visualization with pattern-aware rendering (diamond, parallelogram, sub-cluster shapes). Includes pattern samples, node type reference, and design documentation.
Add stateless CheckpointService using SessionService and ArtifactService primitives with delta compression, concurrent session locking, and configurable retention. Includes CheckpointableMixin for any agent type, GraphCheckpointCallback for node-level checkpointing, and comprehensive error model with telemetry. Includes checkpoint samples.
7eee735 to
1c03f63
Compare
|
@gemini-code-assist please re-review |
There was a problem hiding this comment.
Code Review
This pull request introduces the CheckpointService and associated components to enable state persistence, delta compression, and human-in-the-loop interrupts for GraphAgent workflows. The implementation is comprehensive, including a stateless service, node-level callbacks, and extensive telemetry. My review identified a few areas for improvement, primarily concerning the use of internal asyncio.Queue attributes, performance overhead in state size validation, and potential edge cases in state delta calculation. Addressing these will enhance the robustness and performance of the persistence layer.
| "data": state.data, | ||
| } | ||
| try: | ||
| result = eval(code, {"__builtins__": safe_builtins}, namespace) # noqa: S307 |
| # asyncio.Queue stores items in _queue (a collections.deque). | ||
| # This avoids drain/requeue and eliminates QueueFull risk. | ||
| queue = self._message_queues[session_id] | ||
| messages = list(queue._queue) # type: ignore[attr-defined] |
There was a problem hiding this comment.
| state_size = len( | ||
| json.dumps(state_snapshot, cls=PydanticJSONEncoder).encode("utf-8") | ||
| ) |
There was a problem hiding this comment.
| delta = {} | ||
| for _k, _v in state.data.items(): | ||
| if ( | ||
| not _k.startswith("_") | ||
| and _k not in _GRAPH_INTERNAL_KEYS | ||
| and ctx.session.state.get(_k) != _v | ||
| ): | ||
| delta[_k] = _v |
There was a problem hiding this comment.
The delta calculation for state_delta events relies on a direct equality check (ctx.session.state.get(_k) != _v). This may be inefficient for large nested dictionaries or lists, and might not correctly detect changes in mutable objects if they are modified in-place. Consider using a more robust deep comparison or ensuring that state updates always use new object instances.
Please ensure you have read the contribution guide before creating a pull request.
Link to Issue or Description of Change
1. Link to an existing issue (if applicable):
2. Or, if no issue exists, describe the change:
Problem:
Long-running agent workflows need state snapshots for recovery, debugging, and audit trails. There is no built-in checkpoint mechanism that composes existing ADK session/artifact services.
Solution:
Add stateless
CheckpointServiceusingSessionServiceandArtifactServiceprimitives with delta compression, concurrent session locking, and configurable retention. IncludesCheckpointableMixinfor any agent type,GraphCheckpointCallbackfor node-level checkpointing, and comprehensive error model (CheckpointNotFoundError,CheckpointCorruptedError,DeltaChainBrokenError) with telemetry.What's included:
src/google/adk/checkpoints/—checkpoint_service.py,models.py,mixins.py,callback.py,utils.py,__init__.pysrc/google/adk/agents/graph/checkpoint_callback.py— GraphCheckpointCallbackgraph/__init__.pywith checkpoint exportstest_graph_agent.pywith final test additionstest_interrupt_integration.pywith checkpoint+interrupt integration teststest_checkpoint_service.py,test_checkpoint_coverage.py,test_checkpoint_delta_chain.py,test_checkpoint_locks.py,test_checkpoint_mixin.py,test_checkpoint_utils.py,test_callback.pygraph_agent_advanced,graph_agent_agent_driven_checkpoint,graph_agent_agent_driven_topology,graph_agent_dynamic_topology,graph_agent_hitl,graph_agent_parallel_features,graph_agent_todo_queue, examples 04/12)Part 5 of 5 — see tracking issue #4581. Stacked on #4585.
Testing Plan
Unit Tests:
Manual End-to-End (E2E) Tests:
7 checkpoint sample agents import and instantiate successfully.
Checklist
Additional context
Part 5 of 5 (final). Depends on all prior PRs: #4582, #4583, #4584, #4585. Core
CheckpointServiceis agent-agnostic; onlyGraphCheckpointCallbackdepends on GraphAgent.Total across all 5 PRs: ~727 tests, 26 samples, 6 design docs.