feat: Add GraphAgent for directed-graph workflow orchestration#4582
feat: Add GraphAgent for directed-graph workflow orchestration#4582drahnreb wants to merge 1 commit intogoogle:mainfrom
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
|
Response from ADK Triaging Agent Hello @drahnreb, thank you for your contribution! Before we can merge this pull request, you'll need to sign a Contributor License Agreement (CLA). You can do so by following the instructions at https://cla.developers.google.com/. For more information, please see the contribution guidelines. Thanks! |
Summary of ChangesHello @drahnreb, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces the foundational Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces GraphAgent for directed-graph workflow orchestration, significantly enhancing ADK's capabilities by enabling conditional routing, cyclic execution, and advanced state management. The changes include core agent implementation, configuration schemas, telemetry integration, and a comprehensive set of examples and design documentation. The new GraphAgent addresses limitations of existing sequential, parallel, and loop agents by allowing dynamic decision-making based on runtime state. The design document is well-structured and clearly explains the motivation, use cases, architecture, and capabilities of GraphAgent. The examples cover various features, including basic workflows, conditional routing, cyclic execution, callbacks, and rewind integration, demonstrating the flexibility and power of the new agent. The addition of telemetry and evaluation metrics further strengthens the observability and testability of graph-based workflows. Overall, this is a substantial and well-thought-out addition to the ADK framework.
| new_state.data[self.name] = [] | ||
| new_state.data[self.name].append(output) | ||
| elif self.reducer == StateReducer.SUM: | ||
| new_state.data[self.name] = new_state.data.get(self.name, 0) + output |
There was a problem hiding this comment.
The new_state.data.get(self.name, 0) + output operation for StateReducer.SUM assumes output is a numeric type. However, agent outputs are typically strings. If output is a string, this will result in a TypeError. Consider adding a type conversion (e.g., int(output) or float(output)) or explicitly documenting that output_mapper should handle type conversion before reaching the reducer.
There was a problem hiding this comment.
Fixed. SUM reducer now infers zero-value defaults via type(output)() and raises TypeError on actual type mismatches.
|
|
||
| from .base_agent import BaseAgent | ||
| from .context import Context | ||
| from .graph import END |
There was a problem hiding this comment.
Fixed: Context is imported and exported in __init__.py.
| # Warn if agent name shadows the graph itself (find_agent would | ||
| # return the graph, not the sub_agent). Allow it so existing code | ||
| # is not broken, but log a warning. | ||
| if agent.name == self.name: |
There was a problem hiding this comment.
There was a problem hiding this comment.
Fixed: Now raises ValueError at graph build time (add_node()) if agent name collides with GraphAgent name.
| graph.add_node("task_dispatcher", function=dynamic_task_dispatcher) | ||
|
|
||
| # Loop back to dispatcher while tasks remain. | ||
| # Check task_queue directly (mutated in-place by the function node). |
There was a problem hiding this comment.
There was a problem hiding this comment.
Fixed: Comment updated to 'updated via output_mapper return value'.
| [150ms] ✅ Fetched data from products_db (100ms) | ||
| [150ms] ✅ Fetched data from users_db (150ms) | ||
| [200ms] ✅ Fetched data from orders_db (200ms) | ||
|
|
||
| Total: ~200ms (vs 450ms sequential) | ||
| Speedup: ~2.25x | ||
| ``` |
There was a problem hiding this comment.
Acknowledged. These are sample outputs for documentation context.
| Returns: | ||
| New graph state with output merged | ||
| """ | ||
| new_state = GraphState(data=state.data.copy()) |
There was a problem hiding this comment.
The new_state = GraphState(data=state.data.copy()) line performs a shallow copy of the data dictionary. If state.data contains mutable objects (like nested dictionaries or lists), modifying them in new_state will also affect the original state. This can lead to unexpected side effects. Consider using copy.deepcopy(state.data) if nested mutable objects are expected and true isolation is required.
There was a problem hiding this comment.
Fixed: Changed from shallow copy to deepcopy(state.data) in _default_output_mapper for proper nested state isolation.
| - **Core implementation**: ~1,231 lines | ||
| - **Tests**: 82 tests, 2,665 lines |
| # Check if weights are all the same (default behavior: first match) | ||
| weights = [e.weight for _, e in matching_edges] | ||
| all_same_weight = len(set(weights)) == 1 |
There was a problem hiding this comment.
The all_same_weight check is good, but if all weights are 0, total_weight will be 0, leading to a ZeroDivisionError in random.random() * total_weight. The current code handles total_weight == 0 by picking the first matching edge, which is a reasonable fallback. However, it might be clearer to explicitly check for total_weight == 0 earlier or add a comment explaining this edge case.
There was a problem hiding this comment.
Fixed: Added guard for total_weight == 0 with clarifying comment.
| def test_validate_ast_keyword_args(): | ||
| """Line 149: Keyword arguments in safe method calls are validated.""" | ||
| from google.adk.agents.graph.graph_agent import _parse_condition_string | ||
|
|
||
| # .get() with keyword arg — should pass validation | ||
| fn = _parse_condition_string("data.get(key='x')") | ||
| state = GraphState(data={"x": "val"}) | ||
| # Python's dict.get() doesn't accept 'key' kwarg, so it will raise at eval | ||
| # but the AST validation itself should succeed | ||
| assert fn(state) is False # eval error → returns False |
There was a problem hiding this comment.
The test test_validate_ast_keyword_args correctly identifies that keyword arguments to data.get() are not directly supported by Python's dict.get(). While the AST validation passes, the runtime evaluation will fail. This highlights a potential area for improvement in either the validation logic (to disallow unsupported kwargs) or the GraphState's get methods (to handle kwargs gracefully).
There was a problem hiding this comment.
Acknowledged. The test documents that keyword args are allowed in AST validation for dict.get(key, default) patterns.
| """Document shallow copy limitation with nested dicts. | ||
|
|
||
| This test documents that GraphNode._default_output_mapper uses .copy() | ||
| which is a shallow copy. For nested structures, this can cause issues. | ||
| """ | ||
| state1 = GraphState(data={"nested": {"key": "value"}}) | ||
|
|
||
| # Shallow copy (what GraphNode does) | ||
| state2 = GraphState(data=state1.data.copy()) | ||
|
|
||
| # Modify nested structure in state2 | ||
| state2.data["nested"]["key"] = "modified" | ||
|
|
||
| # BUG: state1 is also modified (shared reference) | ||
| # This is a known limitation of shallow copy | ||
| assert state1.data["nested"]["key"] == "modified" # Unintended side effect | ||
|
|
||
| # NOTE: parallel.py uses deepcopy to avoid this issue | ||
| # For regular sequential execution, users should avoid nested mutations | ||
|
|
There was a problem hiding this comment.
The test test_state_nested_dict_shallow_copy_limitation is crucial for documenting the shallow copy limitation with nested mutable objects. It clearly demonstrates that modifying a nested dictionary in a copied state will unintentionally affect the original state. This highlights a potential source of bugs and emphasizes the need for deep copying in such scenarios.
There was a problem hiding this comment.
Fixed: Test updated — now tests deepcopy isolation instead of shallow copy limitation.
8c0afaf to
b64a3dd
Compare
Addressing review feedbackForce-pushed with the following fixes: Critical:
High:
Medium:
|
31de828 to
1414a38
Compare
There was a problem hiding this comment.
Code Review
The GraphAgent implementation provides a robust foundation for directed-graph workflows in ADK, supporting conditional routing, cyclic execution, and state management. The integration with ADK's resumability and telemetry systems is well-designed. However, there are some security concerns regarding the use of eval in condition parsing, and opportunities to improve performance in path tracking and edge evaluation. Additionally, ensuring deep state isolation between nodes would enhance reliability, especially for complex workflows.
| elif isinstance(node, ast.Attribute): | ||
| # Allow attribute access on safe names only | ||
| _validate_condition_ast(node.value) |
There was a problem hiding this comment.
The AST validation for attribute access is incomplete. It currently only validates the object being accessed (node.value) but does not check the attribute name itself (node.attr). This allows access to potentially dangerous dunder attributes like __class__, which could be used to escape the sandbox via eval (e.g., state.__class__.__init__.__globals__).
| elif isinstance(node, ast.Attribute): | |
| # Allow attribute access on safe names only | |
| _validate_condition_ast(node.value) | |
| elif isinstance(node, ast.Attribute): | |
| if node.attr.startswith("_"): | |
| raise ValueError(f"Unsafe attribute access: '{node.attr}'") | |
| _validate_condition_ast(node.value) |
There was a problem hiding this comment.
Fixed: Added dunder attribute blocking in _validate_condition_ast.
| Returns: | ||
| New graph state with output merged | ||
| """ | ||
| new_state = GraphState(data=state.data.copy()) |
There was a problem hiding this comment.
The default output mapper uses a shallow copy of the state data. If the state contains nested dictionaries or lists, multiple nodes might inadvertently share and mutate the same objects, leading to unpredictable side effects and race conditions. Using a deep copy ensures proper state isolation between nodes. Since GraphState is a Pydantic model, model_copy(deep=True) is the recommended approach.
| new_state = GraphState(data=state.data.copy()) | |
| new_state = state.model_copy(deep=True) |
There was a problem hiding this comment.
Fixed: Changed to deepcopy(state.data) in _default_output_mapper.
| if ( | ||
| agent_state.path.count(current_node_name) > 1 | ||
| and current_node.agent | ||
| ): |
There was a problem hiding this comment.
Using list.count() inside the execution loop results in agent_state.node_invocations already tracks visits per node, you can check its length for an
| if ( | |
| agent_state.path.count(current_node_name) > 1 | |
| and current_node.agent | |
| ): | |
| if ( | |
| len(agent_state.node_invocations.get(current_node_name, [])) > 1 | |
| and current_node.agent | |
| ): |
There was a problem hiding this comment.
Fixed: Replaced path.count() with O(1) node_invocations lookup.
| raise ValueError( | ||
| f"Node {current_node_name} has no outgoing edges and is not" | ||
| " an end node" |
There was a problem hiding this comment.
effective_config is calculated at the start of _run_async_impl but not passed to _get_next_node_with_telemetry. This means edge evaluation telemetry might not respect inherited configuration from parent agents (e.g., sampling rates or additional attributes).
next_node_name = self._get_next_node_with_telemetry(
current_node, state, effective_config=effective_config
)There was a problem hiding this comment.
Fixed: All 3 call sites now pass effective_config.
| sorted_edges = sorted(indexed_edges, key=lambda x: (-x[1].priority, x[0])) | ||
|
|
There was a problem hiding this comment.
There was a problem hiding this comment.
Fixed: Added _sorted_edges_cache with invalidation on add_edge().
965a398 to
b39cbab
Compare
|
@gemini-code-assist please re-review |
There was a problem hiding this comment.
Code Review
The PR introduces GraphAgent, a powerful addition to the ADK workflow agents, enabling directed-graph orchestration with conditional routing, cycles, and state management. The implementation is comprehensive, including telemetry, callbacks, and visualization support. However, there are some performance concerns regarding state management (use of deepcopy and full-state iteration for deltas) and some robustness issues in the state reduction logic and configuration parsing that should be addressed to ensure scalability and reliability.
| if ( | ||
| not _k.startswith("_") | ||
| and _k not in _GRAPH_INTERNAL_KEYS | ||
| and ctx.session.state.get(_k) != _v | ||
| ): | ||
| delta[_k] = _v | ||
| if delta: |
There was a problem hiding this comment.
Calculating the state delta by iterating over all keys in state.data and performing a comparison with ctx.session.state after every node execution can be inefficient, especially as the state grows or contains complex objects. This O(N) operation per iteration could become a performance bottleneck. Consider tracking modified keys within GraphState to optimize this sync.
| name=node_config.name, | ||
| agent=sub_agents[0] if sub_agents else None, | ||
| function=function, | ||
| ) | ||
| graph.add_node(node) |
There was a problem hiding this comment.
The from_config method only uses the first agent from the sub_agents list in the node configuration. If a user provides multiple sub-agents for a single node in the YAML, the others will be silently ignored. It would be better to validate that only one sub-agent is provided or update the logic to handle multiple agents if that was the intent.
References
- Ensure that configuration parsing handles all provided data or explicitly validates against unsupported inputs to prevent silent failures.
| if edge_config.source_node in graph.nodes: | ||
| graph.nodes[edge_config.source_node].edges.append(edge) | ||
| graph.nodes[edge_config.source_node]._sorted_edges_cache = None |
There was a problem hiding this comment.
Appending the edge directly to the node's edges list bypasses the validation logic and duplicate edge checks implemented in the add_edge method. It is safer to use the public API for graph construction.
References
- Prefer using public APIs that encapsulate validation logic over direct member manipulation.
| Returns: | ||
| New graph state with output merged | ||
| """ | ||
| new_state = GraphState(data=deepcopy(state.data)) |
There was a problem hiding this comment.
Using deepcopy on the entire state data for every node execution can be very expensive in terms of memory and CPU, especially for large states or graphs with many iterations. Consider using a shallow copy (state.data.copy()) and ensuring that reducers (like APPEND) create new collection instances instead of mutating them in place to maintain isolation.
References
- Avoid expensive operations like deepcopy in tight loops or frequent execution paths if shallow copies or immutable patterns can achieve the same goal.
| existing = new_state.data.get(self.name) | ||
| if existing is None: | ||
| # Infer zero-value from output type: "" for str, 0 for int/float, [] for list | ||
| existing = type(output)() |
There was a problem hiding this comment.
Using type(output)() to infer a zero-value for the SUM reducer is risky. It will raise a TypeError if output is None (which can happen with function nodes) or if it's a type without a default constructor (like some Pydantic models with required fields). Consider adding explicit handling for common types or a null check.
References
- Ensure that dynamic type instantiation handles edge cases like NoneType or types without default constructors.
Add GraphAgent for building directed-graph workflows with conditional routing, cyclic execution, state management with reducers, typed events, streaming, callbacks, rewind, resumability, telemetry with OpenTelemetry tracing, evaluation metrics, and CLI graph visualization for GraphAgent topologies. Includes samples and design documentation.
b39cbab to
ef2722d
Compare
|
@gemini-code-assist please re-review |
There was a problem hiding this comment.
Code Review
This pull request introduces the GraphAgent, a powerful new feature for orchestrating complex, non-linear workflows. The implementation is comprehensive, covering core graph logic, state management, conditional routing, resumability, and telemetry. The addition of numerous well-documented examples and tests is commendable and greatly aids in understanding the new capabilities.
My review focuses on a few minor improvements in the example code for robustness and clarity. The core implementation of GraphAgent and its related components appears solid and well-designed.
- In
graph_agent_dynamic_queue/agent.pyandgraph_agent_react_pattern/agent.py, the logic for retrieving the final session state could be made more robust to avoid using stale data if the session fetch fails. - In
graph_examples/03_cyclic_execution/agent.py, the code to determine the final counter value is a bit complex and could be simplified. - In
graph_examples/15_enhanced_routing/agent.py, the output mapper for storing the score could be improved to parse the value from the agent's output instead of being hardcoded, making the LLM mode example more realistic.
Overall, this is an excellent contribution that significantly enhances the ADK's workflow orchestration capabilities.
| final_session = fresh_session or session | ||
| final_data = final_session.state.get("graph_data", {}) | ||
| final_state = GraphState(data=final_data) if final_data else GraphState() |
There was a problem hiding this comment.
The fallback to the original session object in final_session = fresh_session or session could lead to incorrect final statistics being printed. The session object holds a stale copy of the state from before the runner.run_async call, as noted in the comment on line 245. If session_service.get_session were to fail and return None, this would fall back to the stale data, silently printing wrong results.
A more robust approach would be to handle the None case explicitly.
| final_session = fresh_session or session | |
| final_data = final_session.state.get("graph_data", {}) | |
| final_state = GraphState(data=final_data) if final_data else GraphState() | |
| if not fresh_session: | |
| print("Error: Could not fetch final session state.") | |
| return | |
| final_data = fresh_session.state.get("graph_data", {}) | |
| final_state = GraphState(data=final_data) if final_data else GraphState() |
| final_data = (fresh_session or session).state.get("graph_data", {}) | ||
| final_state = GraphState(data=final_data) |
There was a problem hiding this comment.
The fallback to the original session object in (fresh_session or session) could lead to incorrect final state being displayed if get_session fails and returns None. The original session object contains stale data from before the graph execution.
It would be more robust to handle the possibility of fresh_session being None explicitly.
| final_data = (fresh_session or session).state.get("graph_data", {}) | |
| final_state = GraphState(data=final_data) | |
| if not fresh_session: | |
| print("\nError: Could not fetch final session state.") | |
| return | |
| final_data = fresh_session.state.get("graph_data", {}) | |
| final_state = GraphState(data=final_data) |
| final_counter = session.state.get("counter") | ||
| if final_counter is None: | ||
| graph_data_raw = session.state.get("graph_data") | ||
| if graph_data_raw: | ||
| try: | ||
| data = ( | ||
| json.loads(graph_data_raw) | ||
| if isinstance(graph_data_raw, str) | ||
| else graph_data_raw | ||
| ) | ||
| final_counter = data.get("counter", 0) | ||
| except (json.JSONDecodeError, TypeError): | ||
| final_counter = 0 | ||
|
|
||
| if final_counter is None: | ||
| final_counter = 0 |
There was a problem hiding this comment.
The logic to retrieve the final_counter is a bit complex and contains some redundancy, making it hard to follow. It handles multiple cases (direct state, graph_data as string, graph_data as dict) but can be simplified for better readability and maintainability.
Consider refactoring this block to be more direct. For example, you could default final_counter to 0 and then update it if found in either session.state or the parsed graph_data.
final_counter = session.state.get("counter")
if final_counter is None:
final_counter = 0 # Default value
graph_data_raw = session.state.get("graph_data")
if isinstance(graph_data_raw, str):
try:
graph_data_raw = json.loads(graph_data_raw)
except json.JSONDecodeError:
graph_data_raw = {}
if isinstance(graph_data_raw, dict):
final_counter = graph_data_raw.get("counter", 0)| def store_score(output, state): | ||
| new_state = GraphState(data=state.data.copy()) | ||
| new_state.data["risk_score"] = 0.85 # Score from analyze agent | ||
| return new_state |
There was a problem hiding this comment.
The store_score output mapper currently hardcodes the risk_score to 0.85. While this is consistent with the value passed to create_agents_priority, it makes the example less robust. In LLM mode, it ignores the actual output from the model. In deterministic mode, it's redundant because the ScoreAgent already sets risk_score in the session state, which is then automatically synced to the graph state.
A better implementation would parse the score from the agent's output, making the example more realistic, especially for the LLM use case.
def store_score(output, state):
import re
new_state = GraphState(data=state.data.copy())
# For deterministic mode, ScoreAgent sets this in session state, which is synced.
# For LLM mode, we parse it from the output string.
match = re.search(r"score:\s*([\d.]+)", str(output))
if match:
new_state.data["risk_score"] = float(match.group(1))
return new_state
Please ensure you have read the contribution guide before creating a pull request.
Link to Issue or Description of Change
1. Link to an existing issue (if applicable):
2. Or, if no issue exists, describe the change:
Problem:
ADK lacks a general-purpose directed-graph orchestrator. Users cannot express conditional branching, cycles, or arbitrary DAG topologies with the existing SequentialAgent/ParallelAgent/LoopAgent.
Solution:
Add
GraphAgentengine for directed-graph workflows with conditional routing, cyclic execution, state management with reducers, typed events, streaming, callbacks, rewind, resumability, OpenTelemetry tracing, evaluation metrics, and CLI graph visualization for GraphAgent topologies. Includes 9 sample agents and design documentation.What's included:
src/google/adk/agents/graph/(core files: graph_agent.py, graph_node.py, graph_edge.py, graph_state.py, graph_events.py, graph_export.py, graph_rewind.py, graph_agent_config.py, graph_agent_state.py, graph_telemetry.py, callbacks.py, evaluation_metrics.py, state_utils.py)src/google/adk/telemetry/graph_tracing.pysrc/google/adk/cli/agent_graph.py— CLI viz extension (core GraphAgent rendering)graph_agent_basic,graph_agent_react_pattern,graph_agent_dynamic_queue, examples 01-03/07-08/15)graph_agent_design.mdPart 1 of 5 — see tracking issue #4581.
Testing Plan
Unit Tests:
Manual End-to-End (E2E) Tests:
All 9 sample agents import and instantiate successfully.
Checklist
Additional context
This is part 1 of a 5-PR stack introducing GraphAgent. Subsequent PRs add: graph patterns (PR 2), parallel execution (PR 3), interrupt service (PR 4), checkpoint service (PR 5).