feat: Add parallel node groups with join strategies#4584
feat: Add parallel node groups with join strategies#4584drahnreb wants to merge 3 commits intogoogle:mainfrom
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Summary of ChangesHello @drahnreb, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the GraphAgent's capabilities by introducing parallel execution of nodes and formalizing several advanced workflow patterns. It addresses the limitation of purely sequential execution, allowing for more complex and efficient agentic designs such as map-reduce, competitive search, and hierarchical decision-making. The changes provide developers with robust tools for building highly flexible and scalable multi-agent systems. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces parallel node execution to the GraphAgent framework, supporting various join strategies (WAIT_ALL, WAIT_ANY, WAIT_N) and error policies. It also adds first-class pattern APIs like DynamicNode, NestedGraphNode, and DynamicParallelGroup to simplify common agentic workflows. The implementation includes comprehensive telemetry integration and a robust set of tests and samples. My feedback focuses on ensuring deterministic state merging in parallel groups and fixing a potential issue with parallel group re-execution in cyclic graphs.
| f"Executing parallel group '{group_id}' with nodes:" | ||
| f" {parallel_group.nodes}" | ||
| ) | ||
|
|
There was a problem hiding this comment.
The executed_parallel_groups check prevents a parallel group from re-executing if the graph loops back to it. Since executed_parallel_groups is persisted in agent_state, it will remain populated across iterations. This check should be iteration-scoped to allow parallel groups to run again in cyclic workflows.
# Check if current node is part of a parallel group
parallel_group_info = self._find_parallel_group(current_node_name)
if parallel_group_info:
group_id, parallel_group = parallel_group_info
# Track group execution per iteration to support loops
group_iteration_key = f"{group_id}_{iteration}"
if group_iteration_key in executed_parallel_groups:There was a problem hiding this comment.
Fixed: executed_parallel_groups.clear() on cycle revisit.
| branch_state = result["state"] | ||
|
|
||
| # Merge data keys with conflict detection | ||
| for key, value in branch_state.data.items(): |
There was a problem hiding this comment.
The current state merging logic iterates over results.items(), which is ordered by completion time. This makes the "last write wins" behavior non-deterministic if multiple parallel branches modify the same state key. To ensure deterministic behavior, state should be merged in the order defined in group.nodes.
for node_name in group.nodes:
if node_name not in results:
continue
result = results[node_name]
branch_state = result["state"]There was a problem hiding this comment.
Fixed: Merge order now uses group.nodes (definition order) instead of completion order.
5f65d99 to
b42592f
Compare
Addressing review feedbackForce-pushed with the following fixes: Critical:
High:
Medium:
|
0e56742 to
329f7e6
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces parallel node execution to the GraphAgent with support for various join strategies (WAIT_ALL, WAIT_ANY, WAIT_N) and error policies. The implementation follows the ParallelAgent pattern, ensuring state isolation via deep copies and providing deterministic state merging with conflict detection. The integration into the GraphAgent run loop correctly handles cyclic workflows by resetting parallel group tracking. I have identified a few critical issues regarding observability consistency and potential race conditions in shared session state that should be addressed to ensure the robustness of the parallel execution model.
|
|
||
| # Create generator for each node | ||
| node = nodes[node_name] | ||
| node_generators[node_name] = execute_node_fn(node, branch_state, ctx) |
There was a problem hiding this comment.
While branch_state (the GraphState object) is isolated via deepcopy, the InvocationContext (ctx) is shared across all parallel branches. Since ADK agents frequently interact with ctx.session.state, concurrent writes to the underlying session state by multiple branches will lead to race conditions and non-deterministic behavior. The context or at least the session state should be isolated per branch.
There was a problem hiding this comment.
Fixed: Added branch_ctx = ctx.model_copy() per parallel branch.
There was a problem hiding this comment.
Fixed — implemented diff-based state merge. Now captures original_data = deepcopy(state.data) before merge and skips keys unchanged from original, preventing stale branch copies from overwriting other branches' modifications. Added test_parallel_unchanged_keys_not_overwritten test.
| # Invoke after_node_callback (custom observability) | ||
| if self.after_node_callback: | ||
| event = await self._execute_callback( | ||
| self.after_node_callback, "after_node", current_node, | ||
| current_node_name, state, iteration, ctx, agent_state, | ||
| effective_config, output=output) | ||
| if event: | ||
| yield event | ||
|
|
There was a problem hiding this comment.
There was a problem hiding this comment.
Fixed: after_node_callback now fires before continue for parallel trigger nodes.
| if self.before_node_callback: | ||
| event = await self._execute_callback( | ||
| self.before_node_callback, "before_node", current_node, | ||
| current_node_name, state, iteration, ctx, agent_state, | ||
| effective_config) | ||
| if event: | ||
| yield event | ||
|
|
There was a problem hiding this comment.
The before_node_callback is invoked before checking if a node should be skipped as part of an already executed parallel group. This results in the callback firing for nodes that aren't actually being executed individually, which is inconsistent with the after_node_callback behavior (which is skipped).
There was a problem hiding this comment.
Fixed: before_node_callback moved after parallel group skip check.
|
|
||
| # Handle collected errors | ||
| if errors and group.error_policy == ErrorPolicy.COLLECT: | ||
| error_msg = f"Errors in parallel execution: {errors}" |
There was a problem hiding this comment.
Raising a generic Exception is discouraged as it makes it difficult for callers to catch and handle specific error conditions. Please use a more specific exception type like RuntimeError or a custom ParallelExecutionError.
| error_msg = f"Errors in parallel execution: {errors}" | |
| raise RuntimeError(error_msg) |
There was a problem hiding this comment.
Fixed: Changed to raise RuntimeError(error_msg).
053e9e5 to
fd2ba4d
Compare
|
@gemini-code-assist please re-review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-designed feature: parallel execution capabilities for GraphAgent. The implementation is comprehensive, covering various join strategies, error policies, and advanced patterns like dynamic and nested nodes. The code is well-structured with a clear separation of concerns. The extensive set of examples and tests is particularly commendable. My review focuses on minor improvements in the example files to enhance robustness, clarity, and consistency, as the core logic appears very solid.
| def is_valid_json(state) -> bool: | ||
| """Check if JSON is valid from structured output.""" | ||
| # Convention: Access via agent.name (auto-defaulted as output_key) | ||
| result = state.data.get("validator", {}) | ||
| return result.get("valid", False) is True |
There was a problem hiding this comment.
For improved robustness and type safety, it's better to use state.get_parsed() to safely access structured data from the agent's output. This method handles cases where the output might not be a valid dictionary, preventing potential AttributeError exceptions.
To use this, you'll need to add from google.adk.agents.graph import GraphState to your imports and add the type hint to the function signature.
| def is_valid_json(state) -> bool: | |
| """Check if JSON is valid from structured output.""" | |
| # Convention: Access via agent.name (auto-defaulted as output_key) | |
| result = state.data.get("validator", {}) | |
| return result.get("valid", False) is True | |
| def is_valid_json(state: GraphState) -> bool: | |
| """Check if JSON is valid from structured output.""" | |
| # Use get_parsed for safe parsing of structured output. | |
| result = state.get_parsed("validator", ValidationResult) | |
| return result.valid if result else False |
There was a problem hiding this comment.
Fixed — now uses state.get_parsed("validator", ValidationResult) for type-safe access.
There was a problem hiding this comment.
Fixed — narrowed to except ValidationError: with explicit import.
| def _is_approved(state: GraphState) -> bool: | ||
| """Check if critic approved using structured output.""" | ||
| review = state.get_parsed("critic", ReviewResult) | ||
| return review.decision.lower() == "approve" if review else False |
There was a problem hiding this comment.
Fixed — removed unused _is_approved function.
| def _is_complete(state: GraphState) -> bool: | ||
| """Check if ReAct loop is complete using structured output.""" | ||
| obs = state.get_parsed("observer", ObservationResult) | ||
| return obs.status.lower() == "complete" if obs else False |
There was a problem hiding this comment.
Fixed — removed unused _is_complete function.
| if not final_counter: | ||
| graph_data_raw = session.state.get("graph_data") | ||
| if graph_data_raw: | ||
| try: | ||
| data = ( | ||
| json.loads(graph_data_raw) | ||
| if isinstance(graph_data_raw, str) | ||
| else graph_data_raw | ||
| ) | ||
| final_counter = data.get("counter", 0) | ||
| except (json.JSONDecodeError, TypeError): | ||
| final_counter = 0 |
There was a problem hiding this comment.
The check if not final_counter: is a bit brittle because it would incorrectly trigger if the loop could validly finish with counter=0. A more robust approach is to check if the key is missing from the session state before attempting to re-parse from graph_data.
| if not final_counter: | |
| graph_data_raw = session.state.get("graph_data") | |
| if graph_data_raw: | |
| try: | |
| data = ( | |
| json.loads(graph_data_raw) | |
| if isinstance(graph_data_raw, str) | |
| else graph_data_raw | |
| ) | |
| final_counter = data.get("counter", 0) | |
| except (json.JSONDecodeError, TypeError): | |
| final_counter = 0 | |
| final_counter = session.state.get("counter") | |
| if final_counter is None: | |
| graph_data_raw = session.state.get("graph_data") | |
| if graph_data_raw: | |
| try: | |
| data = ( | |
| json.loads(graph_data_raw) | |
| if isinstance(graph_data_raw, str) | |
| else graph_data_raw | |
| ) | |
| final_counter = data.get("counter", 0) | |
| except (json.JSONDecodeError, TypeError): | |
| final_counter = 0 | |
| final_counter = final_counter or 0 |
There was a problem hiding this comment.
Fixed — changed to if final_counter is None: to correctly handle counter value of 0.
| """ | ||
|
|
||
| import asyncio | ||
|
|
|
|
||
| async def main(): | ||
| print("\n" + "=" * 60) | ||
| print("Example 3: Enhanced Routing") |
There was a problem hiding this comment.
Fixed — corrected 'Example 3' to 'Example 15'.
Add GraphAgent for building directed-graph workflows with conditional routing, cyclic execution, state management with reducers, typed events, streaming, callbacks, rewind, resumability, telemetry with OpenTelemetry tracing, evaluation metrics, and CLI graph visualization for GraphAgent topologies. Includes samples and design documentation.
Add DynamicNode (runtime agent selection), NestedGraphNode (hierarchical workflow composition), and DynamicParallelGroup (variable-count concurrent execution). Extends CLI visualization with pattern-aware rendering (diamond, parallelogram, sub-cluster shapes). Includes pattern samples, node type reference, and design documentation.
fd2ba4d to
20d6a1f
Compare
Please ensure you have read the contribution guide before creating a pull request.
Link to Issue or Description of Change
1. Link to an existing issue (if applicable):
2. Or, if no issue exists, describe the change:
Problem:
GraphAgent executes nodes sequentially by default. Parallel execution is needed for patterns like map-reduce, competitive search, and consensus-based decisions where multiple agents run concurrently and results are merged.
Solution:
Add
ParallelNodeGroupfor concurrent node execution within GraphAgent with WAIT_ALL, WAIT_ANY, and WAIT_N join strategies and configurable error policies (FAIL_FAST, CONTINUE, COLLECT). Integrates parallel group execution into GraphAgent's run loop with automatic telemetry spans per parallel group.What's included:
src/google/adk/agents/graph/parallel.pygraph_agent.pywith parallel group execution,add_parallel_group,_find_parallel_groupgraph/__init__.pywith parallel exportstest_graph_parallel.py,test_parallel_execution.py,test_parallel_state_merge.py,test_parallel_task_lookup.py,test_parallel_unit.pygraph_agent_multi_agent, examples 09-11/14)Part 3 of 5 — see tracking issue #4581. Stacked on #4583.
Testing Plan
Unit Tests:
Manual End-to-End (E2E) Tests:
4 parallel sample agents import and instantiate successfully.
Checklist
Additional context
Part 3 of 5. Depends on #4582 (Core GraphAgent) and #4583 (Graph patterns).