From ef2722db6a75b063ecd8070677bf6ff974b7e082 Mon Sep 17 00:00:00 2001 From: drahnreb <25883607+drahnreb@users.noreply.github.com> Date: Sat, 21 Feb 2026 19:26:19 +0100 Subject: [PATCH 1/2] feat: Add GraphAgent for directed-graph workflow orchestration Add GraphAgent for building directed-graph workflows with conditional routing, cyclic execution, state management with reducers, typed events, streaming, callbacks, rewind, resumability, telemetry with OpenTelemetry tracing, evaluation metrics, and CLI graph visualization for GraphAgent topologies. Includes samples and design documentation. --- .../samples/graph_agent_basic/README.md | 36 + .../samples/graph_agent_basic/__init__.py | 0 .../samples/graph_agent_basic/agent.py | 149 + .../samples/graph_agent_basic/root_agent.yaml | 42 + .../graph_agent_dynamic_queue/README.md | 53 + .../graph_agent_dynamic_queue/agent.py | 303 ++ .../graph_agent_react_pattern/README.md | 67 + .../graph_agent_react_pattern/__init__.py | 0 .../graph_agent_react_pattern/agent.py | 210 ++ .../graph_examples/01_basic/__init__.py | 1 + .../samples/graph_examples/01_basic/agent.py | 138 + .../02_conditional_routing/__init__.py | 1 + .../02_conditional_routing/agent.py | 201 ++ .../03_cyclic_execution/__init__.py | 1 + .../03_cyclic_execution/agent.py | 306 ++ .../graph_examples/07_callbacks/__init__.py | 1 + .../graph_examples/07_callbacks/agent.py | 198 ++ .../graph_examples/08_rewind/__init__.py | 1 + .../samples/graph_examples/08_rewind/agent.py | 181 ++ .../15_enhanced_routing/__init__.py | 1 + .../15_enhanced_routing/agent.py | 411 +++ contributing/samples/graph_examples/README.md | 594 ++++ .../samples/graph_examples/__init__.py | 20 + .../samples/graph_examples/example_utils.py | 62 + .../samples/graph_examples/run_example.py | 159 + src/google/adk/agents/__init__.py | 10 + src/google/adk/agents/graph/__init__.py | 79 + src/google/adk/agents/graph/callbacks.py | 218 ++ .../adk/agents/graph/evaluation_metrics.py | 382 +++ src/google/adk/agents/graph/graph_agent.py | 1651 +++++++++++ .../adk/agents/graph/graph_agent_config.py | 307 ++ .../adk/agents/graph/graph_agent_state.py | 46 + src/google/adk/agents/graph/graph_edge.py | 97 + src/google/adk/agents/graph/graph_events.py | 90 + src/google/adk/agents/graph/graph_export.py | 210 ++ src/google/adk/agents/graph/graph_node.py | 231 ++ src/google/adk/agents/graph/graph_rewind.py | 92 + src/google/adk/agents/graph/graph_state.py | 118 + .../adk/agents/graph/graph_telemetry.py | 183 ++ src/google/adk/agents/graph/state_utils.py | 127 + src/google/adk/cli/agent_graph.py | 46 + src/google/adk/telemetry/graph_tracing.py | 371 +++ tests/unittests/agents/test_graph_agent.py | 2628 +++++++++++++++++ .../agents/test_graph_agent_config.py | 438 +++ .../agents/test_graph_agent_validation.py | 231 ++ .../unittests/agents/test_graph_callbacks.py | 597 ++++ .../agents/test_graph_convenience_api.py | 425 +++ .../unittests/agents/test_graph_evaluation.py | 263 ++ .../test_graph_evaluation_integration.py | 371 +++ .../agents/test_graph_resumability.py | 681 +++++ tests/unittests/agents/test_graph_rewind.py | 607 ++++ tests/unittests/agents/test_graph_routing.py | 483 +++ tests/unittests/agents/test_graph_state.py | 353 +++ .../agents/test_graph_state_management.py | 563 ++++ .../agents/test_graph_telemetry_config.py | 610 ++++ tests/unittests/cli/test_agent_graph.py | 138 + .../unittests/telemetry/test_graph_tracing.py | 522 ++++ 57 files changed, 16304 insertions(+) create mode 100644 contributing/samples/graph_agent_basic/README.md create mode 100644 contributing/samples/graph_agent_basic/__init__.py create mode 100644 contributing/samples/graph_agent_basic/agent.py create mode 100644 contributing/samples/graph_agent_basic/root_agent.yaml create mode 100644 contributing/samples/graph_agent_dynamic_queue/README.md create mode 100644 contributing/samples/graph_agent_dynamic_queue/agent.py create mode 100644 contributing/samples/graph_agent_react_pattern/README.md create mode 100644 contributing/samples/graph_agent_react_pattern/__init__.py create mode 100644 contributing/samples/graph_agent_react_pattern/agent.py create mode 100644 contributing/samples/graph_examples/01_basic/__init__.py create mode 100644 contributing/samples/graph_examples/01_basic/agent.py create mode 100644 contributing/samples/graph_examples/02_conditional_routing/__init__.py create mode 100644 contributing/samples/graph_examples/02_conditional_routing/agent.py create mode 100644 contributing/samples/graph_examples/03_cyclic_execution/__init__.py create mode 100644 contributing/samples/graph_examples/03_cyclic_execution/agent.py create mode 100644 contributing/samples/graph_examples/07_callbacks/__init__.py create mode 100644 contributing/samples/graph_examples/07_callbacks/agent.py create mode 100644 contributing/samples/graph_examples/08_rewind/__init__.py create mode 100644 contributing/samples/graph_examples/08_rewind/agent.py create mode 100644 contributing/samples/graph_examples/15_enhanced_routing/__init__.py create mode 100644 contributing/samples/graph_examples/15_enhanced_routing/agent.py create mode 100644 contributing/samples/graph_examples/README.md create mode 100644 contributing/samples/graph_examples/__init__.py create mode 100644 contributing/samples/graph_examples/example_utils.py create mode 100755 contributing/samples/graph_examples/run_example.py create mode 100644 src/google/adk/agents/graph/__init__.py create mode 100644 src/google/adk/agents/graph/callbacks.py create mode 100644 src/google/adk/agents/graph/evaluation_metrics.py create mode 100644 src/google/adk/agents/graph/graph_agent.py create mode 100644 src/google/adk/agents/graph/graph_agent_config.py create mode 100644 src/google/adk/agents/graph/graph_agent_state.py create mode 100644 src/google/adk/agents/graph/graph_edge.py create mode 100644 src/google/adk/agents/graph/graph_events.py create mode 100644 src/google/adk/agents/graph/graph_export.py create mode 100644 src/google/adk/agents/graph/graph_node.py create mode 100644 src/google/adk/agents/graph/graph_rewind.py create mode 100644 src/google/adk/agents/graph/graph_state.py create mode 100644 src/google/adk/agents/graph/graph_telemetry.py create mode 100644 src/google/adk/agents/graph/state_utils.py create mode 100644 src/google/adk/telemetry/graph_tracing.py create mode 100644 tests/unittests/agents/test_graph_agent.py create mode 100644 tests/unittests/agents/test_graph_agent_config.py create mode 100644 tests/unittests/agents/test_graph_agent_validation.py create mode 100644 tests/unittests/agents/test_graph_callbacks.py create mode 100644 tests/unittests/agents/test_graph_convenience_api.py create mode 100644 tests/unittests/agents/test_graph_evaluation.py create mode 100644 tests/unittests/agents/test_graph_evaluation_integration.py create mode 100644 tests/unittests/agents/test_graph_resumability.py create mode 100644 tests/unittests/agents/test_graph_rewind.py create mode 100644 tests/unittests/agents/test_graph_routing.py create mode 100644 tests/unittests/agents/test_graph_state.py create mode 100644 tests/unittests/agents/test_graph_state_management.py create mode 100644 tests/unittests/agents/test_graph_telemetry_config.py create mode 100644 tests/unittests/cli/test_agent_graph.py create mode 100644 tests/unittests/telemetry/test_graph_tracing.py diff --git a/contributing/samples/graph_agent_basic/README.md b/contributing/samples/graph_agent_basic/README.md new file mode 100644 index 0000000000..7b60001fc6 --- /dev/null +++ b/contributing/samples/graph_agent_basic/README.md @@ -0,0 +1,36 @@ +# GraphAgent Basic Example — Conditional Routing + +This example demonstrates a data validation pipeline using **conditional routing** based on runtime +state. The validator checks input quality and branches to either a processor (success path) or an +error handler (failure path), showing how GraphAgent enables state-dependent decision logic that +sequential or parallel agent composition alone cannot achieve. + +## When to Use This Pattern + +- Any workflow requiring "if X then A, else B" branching on agent output +- Input validation before expensive downstream processing +- Quality-gate patterns where the next step depends on a score or classification + +## How to Run + +```bash +adk run contributing/samples/graph_agent_basic +``` + +## Graph Structure + +``` +validate ──(valid=True)──▶ process + ──(valid=False)─▶ error +``` + +## Key Code Walkthrough + +- **`GraphNode(name="validate", agent=validator_agent)`** — wraps an `LlmAgent` as a graph node +- **`add_edge("validate", "process", condition=lambda s: s.data["valid"] == True)`** — conditional + edge that only fires when the validation flag is set +- **Two end nodes** (`process` and `error`) — GraphAgent can have multiple terminal nodes +- **State propagation** — each node's output is written to `state.data[node_name]` and read by + downstream condition functions +- **No cycles** — this is a simple directed acyclic graph; for loops see `graph_agent_dynamic_queue` + diff --git a/contributing/samples/graph_agent_basic/__init__.py b/contributing/samples/graph_agent_basic/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contributing/samples/graph_agent_basic/agent.py b/contributing/samples/graph_agent_basic/agent.py new file mode 100644 index 0000000000..f7b72bc83c --- /dev/null +++ b/contributing/samples/graph_agent_basic/agent.py @@ -0,0 +1,149 @@ +"""Basic GraphAgent example demonstrating conditional routing. + +This example shows how GraphAgent enables conditional workflow routing +based on runtime state, which cannot be achieved with SequentialAgent +or ParallelAgent composition. + +Use case: Data validation pipeline with retry logic. +- If validation passes -> process data +- If validation fails -> retry validation +- After max retries -> route to error handler +""" + +import asyncio +import os + +from google.adk.agents import GraphAgent +from google.adk.agents import LlmAgent +from google.adk.agents.graph import GraphState +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types +from pydantic import BaseModel + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + + +# --- Validation Result Schema --- +class ValidationResult(BaseModel): + """Validation result structure.""" + + valid: bool + error: str | None = None + + +# --- Validator Agent --- +validator = LlmAgent( + name="validator", + model=_MODEL, + instruction=""" + You validate input data quality. + Check if the input contains valid JSON. + Return {"valid": true} if valid, {"valid": false, "error": "reason"} if invalid. + """, + output_schema=ValidationResult, # Ensures structured JSON output + # output_key auto-defaults to "validator" (agent name) +) + +# --- Processor Agent --- +processor = LlmAgent( + name="processor", + model=_MODEL, + instruction=""" + You process validated data. + Transform the input JSON and return processed results. + """, +) + +# --- Error Handler Agent --- +error_handler = LlmAgent( + name="error_handler", + model=_MODEL, + instruction=""" + You handle validation errors. + Provide helpful error messages and suggestions for fixing invalid data. + """, +) + + +# --- Edge Condition Functions --- +def is_valid_json(state: GraphState) -> bool: + """Check if JSON is valid from structured output.""" + result = state.get_parsed("validator", ValidationResult) + return result.valid if result else False + + +# --- Create GraphAgent with Conditional Routing --- +def build_validation_graph() -> GraphAgent: + """Build the validation pipeline graph.""" + g = GraphAgent(name="validation_pipeline") + + # Add nodes + g.add_node("validate", agent=validator) + g.add_node("process", agent=processor) + g.add_node("error", agent=error_handler) + + # Add conditional edges + # If validation passes (state.data["validator"]["valid"] == True) -> process + g.add_edge( + "validate", + "process", + condition=is_valid_json, + ) + + # If validation fails (state.data["validator"]["valid"] == False) -> error handler + g.add_edge( + "validate", + "error", + condition=lambda state: not is_valid_json(state), + ) + + # Define workflow + g.set_start("validate") + g.set_end("process") # Success path ends at process + g.set_end("error") # Error path ends at error handler + + return g + + +# --- Run the workflow --- + + +async def main(): + graph = build_validation_graph() + runner = Runner( + app_name="validation_pipeline", + agent=graph, + session_service=InMemorySessionService(), + auto_create_session=True, + ) + + # Example: Valid input + print("=== Testing with valid JSON ===") + async for event in runner.run_async( + user_id="user_1", + session_id="session_1", + new_message=types.Content( + role="user", + parts=[types.Part(text='{"name": "John", "age": 30}')], + ), + ): + if event.content and event.content.parts: + print(f"{event.author}: {event.content.parts[0].text}") + + # Example: Invalid input + print("\n=== Testing with invalid JSON ===") + async for event in runner.run_async( + user_id="user_1", + session_id="session_2", + new_message=types.Content( + role="user", + parts=[types.Part(text='{"name": "Invalid data')], + ), + ): + if event.content and event.content.parts: + print(f"{event.author}: {event.content.parts[0].text}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_basic/root_agent.yaml b/contributing/samples/graph_agent_basic/root_agent.yaml new file mode 100644 index 0000000000..d27a79ab1a --- /dev/null +++ b/contributing/samples/graph_agent_basic/root_agent.yaml @@ -0,0 +1,42 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/google/adk-python/refs/heads/main/src/google/adk/agents/config_schemas/AgentConfig.json + +agent_class: GraphAgent +name: validation_pipeline +description: Data validation pipeline with conditional routing + +# Define start and end nodes +start_node: validate +end_nodes: + - process + - error + +# Maximum iterations for cyclic graphs (default: 20) +max_iterations: 10 + +# Node definitions +nodes: + - name: validate + sub_agents: + - code: agents.validator + + - name: process + sub_agents: + - code: agents.processor + + - name: error + sub_agents: + - code: agents.error_handler + +# Edge definitions with conditional routing +edges: + # If validation passes -> process + - source_node: validate + target_node: process + condition: "data.get('valid', False) is True" + priority: 1 + + # If validation fails -> error handler + - source_node: validate + target_node: error + condition: "data.get('valid', False) is False" + priority: 1 diff --git a/contributing/samples/graph_agent_dynamic_queue/README.md b/contributing/samples/graph_agent_dynamic_queue/README.md new file mode 100644 index 0000000000..c5269443ec --- /dev/null +++ b/contributing/samples/graph_agent_dynamic_queue/README.md @@ -0,0 +1,53 @@ +# GraphAgent Dynamic Task Queue Example + +This example demonstrates the **Dynamic Task Queue** pattern for GraphAgent, enabling AI Co-Scientist and similar workflows where tasks are generated and processed dynamically at runtime. + +## Pattern Overview + +The dynamic task queue pattern uses a function node with runtime agent dispatch: +- **Task Queue**: Maintained in GraphState, grows/shrinks dynamically +- **Agent Dispatch**: Different agents selected based on task type +- **Dynamic Task Generation**: Agents generate new tasks from their outputs +- **State-Based Loop**: Continues until queue is empty + +## What This Example Shows + +1. **Mock Agents**: Three agents (generation, review, experiment) for demonstration +2. **Task Parsing**: Extract TODO items from agent outputs to create new tasks +3. **Dynamic Dispatch**: Select agent based on task type at runtime +4. **Queue Management**: Process tasks until queue is empty + +## Architecture Support + +This pattern enables **95%+ architecture support** for: +- AI Co-Scientist (dynamic hypothesis generation and testing) +- Research paper writing (dynamic outline → research → writing loops) +- Multi-agent task orchestration + +## Running the Example + +```bash +cd /path/to/adk-python +source venv/bin/activate +python contributing/samples/graph_agent_dynamic_queue/agent.py +``` + +The example will: +1. Start with 2 initial tasks (generate hypothesis 1 and 2) +2. Process each task with appropriate agent +3. Parse agent outputs for new tasks (TODO: review X, TODO: experiment Y) +4. Add new tasks to queue dynamically +5. Continue until queue is empty + +## Adapting This Pattern + +Replace the mock agents with real agents: +```python +from your_agents import GenerationAgent, ReviewAgent, ExperimentAgent + +generation_agent = GenerationAgent(name="generation") +review_agent = ReviewAgent(name="review") +experiment_agent = ExperimentAgent(name="experiment") +``` + +Customize task parsing logic in `parse_new_tasks_from_result()` to match your agent outputs. diff --git a/contributing/samples/graph_agent_dynamic_queue/agent.py b/contributing/samples/graph_agent_dynamic_queue/agent.py new file mode 100644 index 0000000000..b86abe3ea3 --- /dev/null +++ b/contributing/samples/graph_agent_dynamic_queue/agent.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python3 +"""Dynamic Task Queue Pattern Example + +Demonstrates how to implement AI Co-Scientist pattern using GraphAgent +with a function node that dynamically dispatches to different agents. + +This example shows: +1. Dynamic task queue management +2. Runtime agent dispatch based on task type +3. Dynamic task generation from agent outputs +4. State-based loop control +""" + +import asyncio +import re +from typing import Any +from typing import Dict +from typing import List + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types + + +# Mock agents for demonstration (replace with real agents) +class MockGenerationAgent(BaseAgent): + """Mock agent that generates hypotheses.""" + + async def _run_async_impl(self, ctx: InvocationContext): + input_text = "" + if ctx.user_content and ctx.user_content.parts: + input_text = ctx.user_content.parts[0].text or "" + + # Simulate hypothesis generation + hypothesis = f"Hypothesis: {input_text} leads to interesting results" + + # Generate follow-up tasks + output = f"""{hypothesis} + +TODO: review this hypothesis +TODO: experiment with variation A""" + + yield Event( + author=self.name, + content=types.Content( + role="assistant", parts=[types.Part(text=output)] + ), + ) + + +class MockReviewAgent(BaseAgent): + """Mock agent that reviews hypotheses.""" + + async def _run_async_impl(self, ctx: InvocationContext): + input_text = "" + if ctx.user_content and ctx.user_content.parts: + input_text = ctx.user_content.parts[0].text or "" + + # Simulate review + review = f"Review: {input_text} - APPROVED with score 8/10" + + yield Event( + author=self.name, + content=types.Content( + role="assistant", parts=[types.Part(text=review)] + ), + ) + + +class MockExperimentAgent(BaseAgent): + """Mock agent that runs experiments.""" + + async def _run_async_impl(self, ctx: InvocationContext): + input_text = "" + if ctx.user_content and ctx.user_content.parts: + input_text = ctx.user_content.parts[0].text or "" + + # Simulate experiment + result = f"Experiment: {input_text} - SUCCESS (confidence: 0.92)" + + yield Event( + author=self.name, + content=types.Content( + role="assistant", parts=[types.Part(text=result)] + ), + ) + + +# Initialize worker agents +generation_agent = MockGenerationAgent(name="generation_agent") +review_agent = MockReviewAgent(name="review_agent") +experiment_agent = MockExperimentAgent(name="experiment_agent") + + +def parse_new_tasks_from_result(result: str) -> List[Dict[str, str]]: + """Extract TODO tasks from agent output. + + Looks for lines like: + - TODO: review X + - TODO: experiment with Y + + Returns: + List of task dicts: [{"type": "review", "data": "X"}, ...] + """ + tasks = [] + todo_pattern = r"TODO:\s*(review|experiment)\s+(.+)" + + for match in re.finditer(todo_pattern, result, re.IGNORECASE): + task_type = match.group(1).lower() + task_data = match.group(2).strip() + tasks.append({"type": task_type, "data": task_data}) + + return tasks + + +async def dynamic_task_dispatcher( + state: GraphState, ctx: InvocationContext +) -> Dict[str, Any]: + """Dispatch to agents based on dynamic task queue. + + This function: + 1. Reads task queue from state + 2. Pops next task + 3. Dispatches to appropriate agent + 4. Updates queue with any new tasks generated + 5. Returns updated state + """ + task_queue = state.data.get("task_queue", []) + + if not task_queue: + print("✅ Task queue empty - all tasks complete!") + return {"all_complete": True, "tasks_remaining": 0} + + # Pop next task + next_task = task_queue.pop(0) + task_type = next_task["type"] + task_data = next_task["data"] + + print(f"\n🔄 Processing task: [{task_type}] {task_data}") + + # Dynamic agent dispatch based on task type + if task_type == "generate": + agent = generation_agent + elif task_type == "review": + agent = review_agent + elif task_type == "experiment": + agent = experiment_agent + else: + raise ValueError(f"Unknown task type: {task_type}") + + # Create context for agent with task data + agent_ctx = ctx.model_copy( + update={ + "user_content": types.Content( + role="user", parts=[types.Part(text=task_data)] + ) + } + ) + + # Execute agent and collect result + result = "" + async for event in agent.run_async(agent_ctx): + if event.content and event.content.parts: + result = event.content.parts[0].text or "" + + print(f" Result: {result[:100]}...") + + # Parse result for new tasks (dynamic task generation!) + new_tasks = parse_new_tasks_from_result(result) + if new_tasks: + print(f" Generated {len(new_tasks)} new tasks: {new_tasks}") + task_queue.extend(new_tasks) + + # Update state + state.data["task_queue"] = task_queue + state.data["last_result"] = result + completed = state.data.setdefault("completed_tasks", []) + completed.append({"type": task_type, "data": task_data, "result": result}) + + print(f" Tasks remaining: {len(task_queue)}") + + return {"tasks_remaining": len(task_queue), "completed_count": len(completed)} + + +def build_dynamic_task_queue_graph() -> GraphAgent: + """Build GraphAgent with dynamic task queue pattern. + + The graph has a single node that loops, processing tasks from a queue + that can grow dynamically based on agent outputs. + """ + graph = GraphAgent( + name="ai_co_scientist", + max_iterations=20, # Prevent infinite loops + description="Dynamic task queue with agent dispatch", + ) + + # Single dispatcher node that processes queue + graph.add_node("task_dispatcher", function=dynamic_task_dispatcher) + + # Loop back to dispatcher while tasks remain. + # Check task_queue directly (mutated in-place by the function node). + # The return dict {"tasks_remaining": N} is stored under state.data["task_dispatcher"] + # by the output mapper, so state.data.get("tasks_remaining") would always be 0. + graph.add_edge( + "task_dispatcher", + "task_dispatcher", + condition=lambda state: len(state.data.get("task_queue", [])) > 0, + ) + + graph.set_start("task_dispatcher") + graph.set_end("task_dispatcher") # Terminal when no edge condition matches + + return graph + + +async def main(): + """Run dynamic task queue example.""" + print("=" * 70) + print("Dynamic Task Queue Pattern - AI Co-Scientist Example") + print("=" * 70) + + # Build graph + graph = build_dynamic_task_queue_graph() + + # Create session service + session_service = InMemorySessionService() + + # Initialize with starting tasks + initial_state = GraphState( + data={ + "task_queue": [ + {"type": "generate", "data": "quantum computing approach"}, + {"type": "generate", "data": "machine learning approach"}, + ], + "completed_tasks": [], + } + ) + + # Seed domain data via create_session so it survives the deepcopy: + # InMemorySessionService always deepcopies on get/create, so setting + # session.state after create_session() would only mutate the returned copy. + session = await session_service.create_session( + app_name="dynamic_queue_demo", + user_id="demo_user", + state=initial_state.data, + ) + + # Create runner + runner = Runner( + app_name="dynamic_queue_demo", + agent=graph, + session_service=session_service, + auto_create_session=False, # Session already created + ) + + # Run graph - dispatcher will process queue dynamically + print("\n📋 Initial task queue:") + for task in initial_state.data["task_queue"]: + print(f" - [{task['type']}] {task['data']}") + + print("\n" + "=" * 70) + print("Starting execution...") + print("=" * 70) + + async for event in runner.run_async( + user_id="demo_user", + session_id=session.id, + new_message=types.Content( + role="user", parts=[types.Part(text="Start task queue processing")] + ), + ): + if event.content and event.content.parts: + text = event.content.parts[0].text or "" + if text and "final_output" in text.lower(): + print(f"\n📊 {text}") + + # Print final statistics (re-fetch — create_session returned a deepcopy) + fresh_session = await session_service.get_session( + app_name="dynamic_queue_demo", user_id="demo_user", session_id=session.id + ) + final_session = fresh_session or session + final_data = final_session.state.get("graph_data", {}) + final_state = GraphState(data=final_data) if final_data else GraphState() + + print("\n" + "=" * 70) + print("Execution Complete!") + print("=" * 70) + print( + "Total tasks completed:" + f" {len(final_state.data.get('completed_tasks', []))}" + ) + print(f"\nCompleted tasks:") + for i, task in enumerate(final_state.data.get("completed_tasks", []), 1): + print(f"{i}. [{task['type']}] {task['data']}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_react_pattern/README.md b/contributing/samples/graph_agent_react_pattern/README.md new file mode 100644 index 0000000000..2f8e3be75f --- /dev/null +++ b/contributing/samples/graph_agent_react_pattern/README.md @@ -0,0 +1,67 @@ +# GraphAgent ReAct Pattern + +Demonstrates the **ReAct (Reasoning + Acting)** loop using GraphAgent. + +## Pattern + +``` +reason → act → observe + ↑ | CONTINUE + └─────────┘ + | COMPLETE + ↓ + END +``` + +Each iteration: +1. **reason** — analyse task + previous observation, decide next action +2. **act** — execute the chosen action, produce a result +3. **observe** — evaluate result; output `COMPLETE:` or `CONTINUE:` + +GraphAgent routes `observe → reason` (continue) or exits (complete) based on +the `observation` state key — a conditional edge only GraphAgent can express. + +## When to Use + +- Multi-step problem solving where the number of iterations is unknown +- Tool-augmented agents (search → reason → act → observe loop) +- Any workflow where routing depends on the *content* of an intermediate output + +## Comparison with Other Workflow Agents + +| Capability | SequentialAgent | LoopAgent | **GraphAgent** | +|------------|----------------|-----------|----------------| +| Execute nodes in order | ✅ | ✅ | ✅ | +| Loop (repeat execution) | ✗ | ✅ | ✅ | +| Route based on state content | ✗ | ✗ (escalate only) | ✅ | +| Conditional exit mid-loop | ✗ | via escalate | ✅ | +| Route to *different* next node | ✗ | ✗ | ✅ | + +**LoopAgent** can repeat, but its only conditional exit is via `escalate` — it +cannot inspect the `observation` field and route back to `reason` vs. exit. +**SequentialAgent** cannot loop at all. + +## Key Code + +```python +# Conditional back-edge: loop if not complete +graph.add_edge("observe", "reason", condition=_should_continue) + +# No forward edge needed: set_end() exits when observe has no matching edge +graph.set_end("observe") +``` + +## How to Run + +```bash +cd /path/to/adk-python +source venv/bin/activate +export GOOGLE_API_KEY= +python -m contributing.samples.graph_agent_react_pattern.agent +``` + +## Related Examples + +- `contributing/samples/graph_examples/02_conditional_routing` — basic conditional edges +- `contributing/samples/graph_examples/03_cyclic_execution` — cyclic loop without LLM +- `contributing/samples/graph_agent_advanced` — full research workflow diff --git a/contributing/samples/graph_agent_react_pattern/__init__.py b/contributing/samples/graph_agent_react_pattern/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contributing/samples/graph_agent_react_pattern/agent.py b/contributing/samples/graph_agent_react_pattern/agent.py new file mode 100644 index 0000000000..0ca366288b --- /dev/null +++ b/contributing/samples/graph_agent_react_pattern/agent.py @@ -0,0 +1,210 @@ +"""GraphAgent ReAct Pattern example. + +Demonstrates the Reasoning + Acting (ReAct) loop using GraphAgent: + reason → act → observe + observe loops back to reason if "CONTINUE" + observe ends if "COMPLETE" + +Why GraphAgent (not LoopAgent/SequentialAgent)? +- SequentialAgent: cannot loop; fixed linear path +- LoopAgent: loops unconditionally or escalates; cannot inspect observation + content to decide direction (reason vs. exit) +- GraphAgent: conditional edges read state → route to any node or exit + +Run (requires GOOGLE_API_KEY env var): + python -m contributing.samples.graph_agent_react_pattern.agent +""" + +import asyncio +import os + +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.agents.graph import StateReducer +from google.adk.agents.llm_agent import LlmAgent +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +from pydantic import BaseModel +from pydantic import ValidationError + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +# --------------------------------------------------------------------------- +# Output Schemas +# --------------------------------------------------------------------------- + + +class ObservationResult(BaseModel): + """Structured observation output from observer agent.""" + + status: str # "continue" or "complete" + reasoning: str # Why continue or why complete + + +# --------------------------------------------------------------------------- +# Agents +# --------------------------------------------------------------------------- + +reasoner = LlmAgent( + name="reasoner", + model=_MODEL, + instruction=( + "You are a reasoning agent. Analyse the task and any previous " + "observations, then decide what action to take next. " + "Write your reasoning in 1-3 sentences." + ), + output_key="reasoning", +) + +actor = LlmAgent( + name="actor", + model=_MODEL, + instruction=( + "You are an action agent. Based on the reasoning provided, " + "answer the question or perform the requested analysis using your " + "knowledge. Do NOT write code or tool calls — just provide the " + "factual answer or calculation result directly." + ), + output_key="action_result", +) + +observer = LlmAgent( + name="observer", + model=_MODEL, + instruction=( + "You are an observation agent. Evaluate whether the action result " + "fully answers the original task. " + 'Return {"status": "complete", "reasoning": "..."} if task is done, ' + 'or {"status": "continue", "reasoning": "what is missing..."} if not.' + ), + output_schema=ObservationResult, # Structured output + # output_key auto-defaults to "observer" (agent name) +) + + +# --------------------------------------------------------------------------- +# Routing predicates +# --------------------------------------------------------------------------- + + +def _should_continue(state: GraphState) -> bool: + """Check if ReAct loop should continue using structured output.""" + obs = state.get_parsed("observer", ObservationResult) + return obs.status.lower() == "continue" if obs else False + + + +# --------------------------------------------------------------------------- +# Graph +# --------------------------------------------------------------------------- + + +def build_react_graph() -> GraphAgent: + graph = GraphAgent( + name="react_agent", + description="ReAct pattern: Reasoning + Acting loop", + max_iterations=10, + ) + + graph.add_node( + "reason", + agent=reasoner, + input_mapper=lambda s: ( + f"Task: {s.data.get('task', '')}\n" + f"Previous observation: {s.data.get('observation', 'none')}" + ), + reducer=StateReducer.OVERWRITE, + ) + graph.add_node( + "act", + agent=actor, + input_mapper=lambda s: s.data.get("reasoning", ""), + reducer=StateReducer.OVERWRITE, + ) + graph.add_node( + "observe", + agent=observer, + input_mapper=lambda s: ( + f"Task: {s.data.get('task', '')}\n" + f"Action result: {s.data.get('action_result', '')}" + ), + reducer=StateReducer.OVERWRITE, + ) + + graph.set_start("reason") + graph.add_edge("reason", "act") + graph.add_edge("act", "observe") + + # Loop back if not yet complete + graph.add_edge("observe", "reason", condition=_should_continue) + + # Exit when complete + graph.set_end("observe") + + return graph + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def main() -> None: + session_service = InMemorySessionService() + graph = build_react_graph() + + session = await session_service.create_session( + app_name="react_agent", user_id="user1" + ) + + task = "What are the key features of the Google Agent Development Kit (ADK)?" + print(f"Task: {task}\n") + + # Seed the task into initial state (BEFORE calling runner) + session.state["task"] = task + + # Use Runner instead of manual invocation context + runner = Runner( + app_name="react_agent", + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + iteration = 0 + async for event in runner.run_async( + user_id="user1", + session_id=session.id, + new_message=types.Content(parts=[types.Part(text=task)]), + ): + if event.content and event.content.parts: + author = event.author + text = event.content.parts[0].text or "" + if author == "observer": + iteration += 1 + # Parse from event text (JSON string from output_schema) + try: + obs = ObservationResult.model_validate_json(text.strip()) + status = obs.status.upper() + except ValidationError: + status = "UNKNOWN (parse error)" + print(f"[iteration {iteration}] Observer: {status}") + elif author in ("reasoner", "actor"): + print(f" [{author}]: {text[:120]}...") + + # Re-fetch fresh session state (create_session returns a deepcopy) + fresh_session = await session_service.get_session( + app_name="react_agent", user_id="user1", session_id=session.id + ) + final_data = (fresh_session or session).state.get("graph_data", {}) + final_state = GraphState(data=final_data) + + print("\nFinal observation:") + obs = final_state.get_parsed("observer", ObservationResult) + print(f"Status: {obs.status if obs else 'none'}") + print(f"Reasoning: {obs.reasoning if obs else 'none'}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/01_basic/__init__.py b/contributing/samples/graph_examples/01_basic/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/01_basic/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/01_basic/agent.py b/contributing/samples/graph_examples/01_basic/agent.py new file mode 100644 index 0000000000..ebda8a0379 --- /dev/null +++ b/contributing/samples/graph_examples/01_basic/agent.py @@ -0,0 +1,138 @@ +"""Example 1: Basic GraphAgent Workflow + +Demonstrates: +- Creating a simple directed graph +- Adding nodes (agents) +- Adding edges (transitions) +- Setting start and end nodes +- Executing the workflow + +Run modes: +- Default: python -m contributing.samples.graph_examples.01_basic.agent +- LLM: python -m contributing.samples.graph_examples.01_basic.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.01_basic.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class SimpleAgent(BaseAgent): + """A simple agent that outputs a message.""" + + def __init__(self, name: str, message: str, **kwargs): + super().__init__(name=name, **kwargs) + self._message = message + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=self._message)]), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(): + """Create agents based on USE_LLM mode. + + Returns: + tuple: (validate, process, complete) agents + """ + if use_llm_mode(): + print("🤖 Using LLM-powered agents (gemini-2.5-flash)\n") + + validate = create_llm_agent( + name="validate", + instruction=( + "You are a validation agent. Respond with '✅ Validation passed' to" + " confirm the workflow started successfully." + ), + ) + process = create_llm_agent( + name="process", + instruction=( + "You are a processing agent. Respond with '⚙️ Processing data' to" + " indicate you're processing the workflow." + ), + ) + complete = create_llm_agent( + name="complete", + instruction=( + "You are a completion agent. Respond with '✅ Workflow complete' to" + " signal successful workflow completion." + ), + ) + + return validate, process, complete + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + validate = SimpleAgent(name="validate", message="✅ Validation passed") + process = SimpleAgent(name="process", message="⚙️ Processing data") + complete = SimpleAgent(name="complete", message="✅ Workflow complete") + + return validate, process, complete + + +async def main(): + print("\n" + "=" * 60) + print("Example 1: Basic GraphAgent Workflow") + print("=" * 60 + "\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + validate, process, complete = create_agents() + + # Build graph using convenience API (fluent pattern) + graph = ( + GraphAgent(name="basic_workflow") + .add_node("validate", agent=validate) + .add_node("process", agent=process) + .add_node("complete", agent=complete) + .add_edge("validate", "process") + .add_edge("process", "complete") + .set_start("validate") + .set_end("complete") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="basic_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("🚀 Executing workflow: validate → process → complete\n") + + new_message = types.Content(parts=[types.Part(text="Start workflow")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + print("\n✅ Example complete!\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/02_conditional_routing/__init__.py b/contributing/samples/graph_examples/02_conditional_routing/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/02_conditional_routing/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/02_conditional_routing/agent.py b/contributing/samples/graph_examples/02_conditional_routing/agent.py new file mode 100644 index 0000000000..1c0fd5394d --- /dev/null +++ b/contributing/samples/graph_examples/02_conditional_routing/agent.py @@ -0,0 +1,201 @@ +"""Example 2: Conditional Routing + +Demonstrates: +- Conditional edges based on state +- Multiple routing paths +- State-based decision making + +Run modes: +- Default: python -m contributing.samples.graph_examples.02_conditional_routing.agent +- LLM: python -m contributing.samples.graph_examples.02_conditional_routing.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.02_conditional_routing.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class ValidatorAgent(BaseAgent): + """Validates input and sets quality score.""" + + def __init__(self, name: str, score: int, **kwargs): + super().__init__(name=name, **kwargs) + self._score = score + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=f"✅ Validation complete (score: {self._score})" + ) + ] + ), + ) + + +class ProcessAgent(BaseAgent): + """Process based on quality.""" + + def __init__(self, name: str, quality: str, **kwargs): + super().__init__(name=name, **kwargs) + self._quality = quality + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"⚙️ {self._quality} quality processing")] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(test_score: int): + """Create agents based on USE_LLM mode. + + Args: + test_score: The score to use for validation + + Returns: + tuple: (validate, high_quality, medium_quality, low_quality) agents + """ + if use_llm_mode(): + print("🤖 Using LLM-powered agents (gemini-2.5-flash)\n") + + validate = create_llm_agent( + name="validate", + instruction=( + "You are a validation agent. Respond with 'Validation complete" + f" (score: {test_score})' exactly." + ), + ) + high_quality = create_llm_agent( + name="high_quality", + instruction=( + "You are a high quality processor. Respond with 'HIGH quality" + " processing' exactly." + ), + ) + medium_quality = create_llm_agent( + name="medium_quality", + instruction=( + "You are a medium quality processor. Respond with 'MEDIUM quality" + " processing' exactly." + ), + ) + low_quality = create_llm_agent( + name="low_quality", + instruction=( + "You are a low quality processor. Respond with 'LOW quality" + " processing' exactly." + ), + ) + + return validate, high_quality, medium_quality, low_quality + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + validate = ValidatorAgent(name="validate", score=test_score) + high_quality = ProcessAgent(name="high_quality", quality="HIGH") + medium_quality = ProcessAgent(name="medium_quality", quality="MEDIUM") + low_quality = ProcessAgent(name="low_quality", quality="LOW") + + return validate, high_quality, medium_quality, low_quality + + +async def main(): + print("\n" + "=" * 60) + print("Example 2: Conditional Routing") + print("=" * 60 + "\n") + + # Test with different scores + for test_score in [95, 75, 45]: + print(f"🎯 Testing with score: {test_score}") + + # Create agents (deterministic or LLM based on USE_LLM flag) + validate, high_quality, medium_quality, low_quality = create_agents( + test_score + ) + + # Build graph with conditional routing + graph = ( + GraphAgent(name="conditional_workflow") + .add_node( + "validate", + agent=validate, + output_mapper=lambda output, state: GraphState( + data={**state.data, "score": test_score}, + ), + ) + .add_node("high_quality", agent=high_quality) + .add_node("medium_quality", agent=medium_quality) + .add_node("low_quality", agent=low_quality) + # Conditional edges based on score + .add_edge( + "validate", + "high_quality", + condition=lambda s: s.data.get("score", 0) >= 80, + ) + .add_edge( + "validate", + "medium_quality", + condition=lambda s: 50 <= s.data.get("score", 0) < 80, + ) + .add_edge( + "validate", + "low_quality", + condition=lambda s: s.data.get("score", 0) < 50, + ) + .set_start("validate") + .set_end("high_quality") + .set_end("medium_quality") + .set_end("low_quality") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="routing_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", + session_id=f"session_{test_score}", + new_message=new_message, + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + print() + + print("✅ Example complete!\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/03_cyclic_execution/__init__.py b/contributing/samples/graph_examples/03_cyclic_execution/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/03_cyclic_execution/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/03_cyclic_execution/agent.py b/contributing/samples/graph_examples/03_cyclic_execution/agent.py new file mode 100644 index 0000000000..f5e5c1e608 --- /dev/null +++ b/contributing/samples/graph_examples/03_cyclic_execution/agent.py @@ -0,0 +1,306 @@ +"""Example 3: Cyclic Graph Execution + +Demonstrates: +- Cyclic graphs with conditional back-edges (A -> B -> C -> A loop) +- Two routing patterns depending on execution mode: + - Default mode: state_delta writes (ADK-standard for deterministic routing) + - LLM mode: LlmAgent with include_contents='none' + dynamic instructions + for clean context per iteration (Ralph Loop pattern), with + output_mappers writing structured state for edge conditions +- Edge conditions reading from GraphState.data +- max_iterations guard to prevent infinite loops + +Key design choice (LLM mode): + Cyclic nodes (step, check) use include_contents='none' to prevent + session history accumulation across iterations. Without this, the LLM + sees all previous loop outputs and gets biased toward repeating earlier + responses (context rot). This is the Ralph Loop pattern applied within + ADK: each iteration gets clean context, state lives in session.state + (synced from GraphState.data), not in conversation history. + + Dynamic instructions (callables) read the current counter from + session.state, providing fresh context per iteration without any + conversation history leakage. + +Run modes: +- Default: python -m contributing.samples.graph_examples.03_cyclic_execution.agent +- LLM: python -m contributing.samples.graph_examples.03_cyclic_execution.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.03_cyclic_execution.agent +""" + +import asyncio +import json +import re + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph.graph_state import GraphState +from google.adk.events.event import Event +from google.adk.events.event import EventActions +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +MAX_CYCLES = 3 + + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class StartAgent(BaseAgent): + """Initializes the counter.""" + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text="Workflow started")]), + actions=EventActions(state_delta={"counter": 0}), + ) + + +class StepAgent(BaseAgent): + """Increments the counter and persists it via state_delta.""" + + async def _run_async_impl(self, ctx): + counter = ctx.session.state.get("counter", 0) + 1 + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"Step executed (counter={counter})")] + ), + actions=EventActions(state_delta={"counter": counter}), + ) + + +class CheckAgent(BaseAgent): + """Reads counter and writes routing signal via state_delta.""" + + async def _run_async_impl(self, ctx): + counter = ctx.session.state.get("counter", 0) + status = "CONTINUE" if counter < MAX_CYCLES else "DONE" + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part(text=f"Check: counter={counter}, status={status}") + ] + ), + actions=EventActions(state_delta={"status": status}), + ) + + +class EndAgent(BaseAgent): + """Signals workflow completion.""" + + async def _run_async_impl(self, ctx): + counter = ctx.session.state.get("counter", 0) + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part(text=f"Workflow complete after {counter} cycle(s)") + ] + ), + ) + + +# =========================== +# LLM Dynamic Instructions +# =========================== +# Callables that read current state per iteration — clean context each time. +# Used with include_contents='none' (Ralph Loop pattern). + + +def step_instruction(ctx): + """Dynamic instruction: reads counter from session.state each iteration.""" + counter = ctx.state.get("counter", 0) + return ( + f"The current counter value is {counter}. " + f"Increment it by 1 and respond with ONLY the new number, " + f"nothing else. Just the number." + ) + + +def check_instruction(ctx): + """Dynamic instruction: reads counter from session.state each iteration.""" + counter = ctx.state.get("counter", 0) + return ( + f"The current counter value is {counter} and the threshold is" + f" {MAX_CYCLES}. If {counter} < {MAX_CYCLES}, respond with exactly:" + f" CONTINUE. If {counter} >= {MAX_CYCLES}, respond with exactly:" + f" DONE. One word only." + ) + + +# =========================== +# LLM Output Mappers +# =========================== +# Parse LLM text output into structured state keys for edge conditions. + + +def start_output_mapper(output: str, state: GraphState) -> GraphState: + """Initialize counter=0 in state (LLM agent can't write state_delta).""" + new_state = GraphState(data=state.data.copy()) + new_state.data["counter"] = 0 + new_state.data["start"] = output.strip() + return new_state + + +def step_output_mapper(output: str, state: GraphState) -> GraphState: + """Parse LLM number output -> counter in state.""" + new_state = GraphState(data=state.data.copy()) + text = str(output).strip() + match = re.search(r"\d+", text) + if match: + counter = int(match.group()) + else: + counter = state.data.get("counter", 0) + 1 + new_state.data["counter"] = counter + new_state.data["step"] = f"Step executed (counter={counter})" + return new_state + + +def check_output_mapper(output: str, state: GraphState) -> GraphState: + """Parse LLM CONTINUE/DONE -> status routing signal in state.""" + new_state = GraphState(data=state.data.copy()) + text = str(output).strip().upper() + status = "DONE" if "DONE" in text else "CONTINUE" + counter = state.data.get("counter", 0) + new_state.data["status"] = status + new_state.data["check"] = f"Check: counter={counter}, status={status}" + return new_state + + +# =========================== +# Graph Construction +# =========================== + + +async def main(): + print("\n" + "=" * 60) + print("Example 3: Cyclic Graph Execution") + print("=" * 60 + "\n") + + llm_mode = use_llm_mode() + + # Build graph + graph = GraphAgent(name="cyclic_workflow", max_iterations=10) + + if llm_mode: + print("🤖 Using LLM agents (gemini-2.5-flash)\n") + + start = create_llm_agent( + name="start", + instruction="Respond with exactly: 'Workflow started'", + ) + # Cyclic nodes: include_contents='none' prevents session history + # accumulation. Dynamic instructions read current counter from + # session.state each iteration (Ralph Loop pattern). + step = create_llm_agent( + name="step", + instruction=step_instruction, + include_contents="none", + ) + check = create_llm_agent( + name="check", + instruction=check_instruction, + include_contents="none", + ) + end = create_llm_agent( + name="end_node", + instruction=( + "A cyclic workflow just completed. Summarize: the workflow ran" + f" for {MAX_CYCLES} cycles. Respond in one sentence." + ), + ) + + graph.add_node("start", agent=start, output_mapper=start_output_mapper) + graph.add_node("step", agent=step, output_mapper=step_output_mapper) + graph.add_node("check", agent=check, output_mapper=check_output_mapper) + graph.add_node("end_node", agent=end) + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + start = StartAgent(name="start") + step = StepAgent(name="step") + check = CheckAgent(name="check") + end = EndAgent(name="end_node") + + graph.add_node("start", agent=start) + graph.add_node("step", agent=step) + graph.add_node("check", agent=check) + graph.add_node("end_node", agent=end) + + # Edges are identical in both modes — they read from state.data + ( + graph.add_edge("start", "step") + .add_edge("step", "check") + .add_edge( + "check", + "step", + condition=lambda s: s.data.get("status") == "CONTINUE", + ) + .add_edge( + "check", + "end_node", + condition=lambda s: s.data.get("status") == "DONE", + ) + .set_start("start") + .set_end("end_node") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="cyclic_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print( + f"Executing cyclic workflow (max_cycles={MAX_CYCLES}, max_iterations=10)" + ) + print("Graph: start -> step -> check -> (loop back or exit)\n") + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text and "#metadata" not in event.author: + print(f" [{event.author}] {part.text}") + + session = await session_service.get_session( + app_name="cyclic_demo", user_id="user1", session_id="session1" + ) + final_counter = session.state.get("counter") + if final_counter is None: + graph_data_raw = session.state.get("graph_data") + if graph_data_raw: + try: + data = ( + json.loads(graph_data_raw) + if isinstance(graph_data_raw, str) + else graph_data_raw + ) + final_counter = data.get("counter", 0) + except (json.JSONDecodeError, TypeError): + final_counter = 0 + + if final_counter is None: + final_counter = 0 + print(f"\n Final counter value: {final_counter}") + print(f" Completed {final_counter} cycle(s) before exiting loop") + + print("\nExample complete!\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/07_callbacks/__init__.py b/contributing/samples/graph_examples/07_callbacks/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/07_callbacks/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/07_callbacks/agent.py b/contributing/samples/graph_examples/07_callbacks/agent.py new file mode 100644 index 0000000000..a54325379b --- /dev/null +++ b/contributing/samples/graph_examples/07_callbacks/agent.py @@ -0,0 +1,198 @@ +"""Example 7: Node Callbacks (before_node_callback / after_node_callback) + +Demonstrates: +- Registering before_node_callback and after_node_callback on a GraphAgent +- Measuring per-node execution time with time.perf_counter() +- Callbacks store start times in ctx.metadata and compute elapsed on exit +- Timing results are printed in async def main() after the run completes + +Run modes: +- Default: python -m contributing.samples.graph_examples.07_callbacks.agent +- LLM: python -m contributing.samples.graph_examples.07_callbacks.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.07_callbacks.agent +""" + +import asyncio +import time + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph.callbacks import NodeCallbackContext +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# Shared dict to accumulate timing results from callbacks +_timings: dict[str, float] = {} + + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class FetchAgent(BaseAgent): + """Simulates a data fetch step.""" + + async def _run_async_impl(self, ctx): + await asyncio.sleep(0.02) + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text="Data fetched from source")] + ), + ) + + +class ProcessAgent(BaseAgent): + """Simulates a data processing step.""" + + async def _run_async_impl(self, ctx): + await asyncio.sleep(0.05) + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text="Data processed and transformed")] + ), + ) + + +class SaveAgent(BaseAgent): + """Simulates a data persistence step.""" + + async def _run_async_impl(self, ctx): + await asyncio.sleep(0.01) + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text="Data saved to storage")]), + ) + + +async def before_cb(ctx: NodeCallbackContext) -> None: + """Record start time in shared timings dict keyed by node name.""" + _timings[f"_start_{ctx.node.name}"] = time.perf_counter() + return None + + +async def after_cb(ctx: NodeCallbackContext) -> None: + """Compute elapsed time and store in shared timings dict.""" + start_key = f"_start_{ctx.node.name}" + start = _timings.get(start_key) + if start is not None: + elapsed_ms = (time.perf_counter() - start) * 1000.0 + _timings[ctx.node.name] = elapsed_ms + return None + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(): + """Create agents based on USE_LLM mode. + + Returns: + tuple: (fetch, process, save) agents + """ + if use_llm_mode(): + print("🤖 Using LLM-powered agents (gemini-2.5-flash)\n") + + fetch = create_llm_agent( + name="fetch", + instruction=( + "Respond with 'Data fetched from source' exactly. Respond quickly" + " without delays." + ), + ) + process = create_llm_agent( + name="process", + instruction=( + "Respond with 'Data processed and transformed' exactly. Respond" + " quickly without delays." + ), + ) + save = create_llm_agent( + name="save", + instruction=( + "Respond with 'Data saved to storage' exactly. Respond quickly" + " without delays." + ), + ) + + return fetch, process, save + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + fetch = FetchAgent(name="fetch") + process = ProcessAgent(name="process") + save = SaveAgent(name="save") + + return fetch, process, save + + +async def main(): + print("\n" + "=" * 60) + print("Example 7: Node Callbacks") + print("=" * 60 + "\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + fetch, process, save = create_agents() + + # Build graph with before/after callbacks + graph = ( + GraphAgent( + name="callback_workflow", + before_node_callback=before_cb, + after_node_callback=after_cb, + ) + .add_node("fetch", agent=fetch) + .add_node("process", agent=process) + .add_node("save", agent=save) + .add_edge("fetch", "process") + .add_edge("process", "save") + .set_start("fetch") + .set_end("save") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="callback_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("Executing workflow: fetch -> process -> save") + print("Callbacks will record timing for each node\n") + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" [{event.author}] {part.text}") + + # Print timing results collected by callbacks + print("\n Node execution times (measured by callbacks):") + for node_name in ["fetch", "process", "save"]: + elapsed_ms = _timings.get(node_name) + if elapsed_ms is not None: + print(f" [{node_name}] {elapsed_ms:.1f}ms") + else: + print(f" [{node_name}] timing not recorded") + + print("\nExample complete!\n") + print(" before_node_callback: stores perf_counter start per node") + print(" after_node_callback: computes elapsed ms and stores result") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/08_rewind/__init__.py b/contributing/samples/graph_examples/08_rewind/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/08_rewind/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/08_rewind/agent.py b/contributing/samples/graph_examples/08_rewind/agent.py new file mode 100644 index 0000000000..e2f0206bcc --- /dev/null +++ b/contributing/samples/graph_examples/08_rewind/agent.py @@ -0,0 +1,181 @@ +"""Example 8: Rewind Integration + +Demonstrates: +- Invocation tracking per node +- Rewinding to specific node execution +- Re-execution after rewind +- State restoration + +Run modes: +- Default: python -m contributing.samples.graph_examples.08_rewind.agent +- LLM: python -m contributing.samples.graph_examples.08_rewind.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.08_rewind.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import rewind_to_node +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class CounterAgent(BaseAgent): + """Agent that tracks execution count.""" + + def __init__(self, name: str, **kwargs): + super().__init__(name=name, **kwargs) + self._count = 0 + + async def _run_async_impl(self, ctx): + self._count += 1 + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=f"✅ {self.name} executed (count: {self._count})" + ) + ] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(): + """Create agents based on USE_LLM mode. + + Returns: + tuple: (step1, step2, step3) agents + """ + if use_llm_mode(): + print("🤖 Using LLM-powered agents (gemini-2.5-flash)\n") + + step1 = create_llm_agent( + name="step1", + instruction=( + "Respond with 'step1 executed (count: X)' where X is the execution" + " count. Track this in your context." + ), + ) + step2 = create_llm_agent( + name="step2", + instruction=( + "Respond with 'step2 executed (count: X)' where X is the execution" + " count. Track this in your context." + ), + ) + step3 = create_llm_agent( + name="step3", + instruction=( + "Respond with 'step3 executed (count: X)' where X is the execution" + " count. Track this in your context." + ), + ) + + return step1, step2, step3 + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + step1 = CounterAgent(name="step1") + step2 = CounterAgent(name="step2") + step3 = CounterAgent(name="step3") + + return step1, step2, step3 + + +async def main(): + print("\n" + "=" * 60) + print("Example 8: Rewind Integration") + print("=" * 60 + "\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + step1, step2, step3 = create_agents() + + # Build graph + graph = ( + GraphAgent(name="rewind_workflow") + .add_node("step1", agent=step1) + .add_node("step2", agent=step2) + .add_node("step3", agent=step3) + .add_edge("step1", "step2") + .add_edge("step2", "step3") + .set_start("step1") + .set_end("step3") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="rewind_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("🚀 First execution...\n") + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + # Check invocations + session = await session_service.get_session( + app_name="rewind_demo", user_id="user1", session_id="session1" + ) + node_invocations = session.state.get("node_invocations", {}) + + print(f"\n📊 Invocation Tracking:") + for node_name, invocations in node_invocations.items(): + print(f" {node_name}: {len(invocations)} invocation(s)") + + # Rewind to step2 + print(f"\n⏪ Rewinding to 'step2'...") + await rewind_to_node( + graph, + session_service, + app_name="rewind_demo", + user_id="user1", + session_id="session1", + node_name="step2", + invocation_index=-1, # Last invocation + ) + + print(" ✅ Rewind successful! State restored to before step2") + + # Re-execute from rewind point + print("\n🚀 Re-execution after rewind...\n") + + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + print("\n✅ Example complete!") + print(" Note: step1 count stays at 1, step2 & step3 executed again\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/15_enhanced_routing/__init__.py b/contributing/samples/graph_examples/15_enhanced_routing/__init__.py new file mode 100644 index 0000000000..491d6cc8cf --- /dev/null +++ b/contributing/samples/graph_examples/15_enhanced_routing/__init__.py @@ -0,0 +1 @@ +"""Example 3: Enhanced Routing (Priority, Weight, Fallback).""" diff --git a/contributing/samples/graph_examples/15_enhanced_routing/agent.py b/contributing/samples/graph_examples/15_enhanced_routing/agent.py new file mode 100644 index 0000000000..6dc04718ca --- /dev/null +++ b/contributing/samples/graph_examples/15_enhanced_routing/agent.py @@ -0,0 +1,411 @@ +"""Example 15: Enhanced Routing (Priority, Weight, Fallback) + +Demonstrates: +- Priority-based routing (higher priority evaluated first) +- Weighted random selection (probabilistic routing) +- Fallback edges (priority=0 always matches) + +Run modes: +- Default: python -m contributing.samples.graph_examples.15_enhanced_routing.agent +- LLM: python -m contributing.samples.graph_examples.15_enhanced_routing.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.15_enhanced_routing.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import EdgeCondition +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class SimpleAgent(BaseAgent): + """Agent that outputs a message.""" + + def __init__(self, name: str, message: str, **kwargs): + super().__init__(name=name, **kwargs) + self._message = message + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=self._message)]), + ) + + +class ScoreAgent(BaseAgent): + """Agent that sets a risk score.""" + + def __init__(self, name: str, score: float, **kwargs): + super().__init__(name=name, **kwargs) + self._score = score + + async def _run_async_impl(self, ctx): + ctx.session.state["risk_score"] = self._score + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"Risk score: {self._score}")] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents_priority(score: float): + """Create agents for priority routing example. + + Returns: + tuple: (analyze, critical, warning, normal) agents + """ + if use_llm_mode(): + print("🤖 Using LLM-powered agents (gemini-2.5-flash)\n") + + analyze = create_llm_agent( + name="analyze", + instruction=f"Respond with 'Risk score: {score}' exactly.", + ) + critical = create_llm_agent( + name="critical", + instruction=( + "Respond with 'CRITICAL: Immediate action required' exactly." + ), + ) + warning = create_llm_agent( + name="warning", + instruction="Respond with 'WARNING: Review recommended' exactly.", + ) + normal = create_llm_agent( + name="normal", + instruction="Respond with 'NORMAL: No action needed' exactly.", + ) + + return analyze, critical, warning, normal + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + analyze = ScoreAgent(name="analyze", score=score) + critical = SimpleAgent( + name="critical", message="🚨 CRITICAL: Immediate action required" + ) + warning = SimpleAgent( + name="warning", message="⚠️ WARNING: Review recommended" + ) + normal = SimpleAgent(name="normal", message="✅ NORMAL: No action needed") + + return analyze, critical, warning, normal + + +def create_agents_weighted(): + """Create agents for weighted routing example. + + Returns: + tuple: (start, server_a, server_b, server_c) agents + """ + if use_llm_mode(): + print("🤖 Using LLM-powered agents (gemini-2.5-flash)\n") + + start = create_llm_agent( + name="start", + instruction="Respond with 'Starting load balancer...' exactly.", + ) + server_a = create_llm_agent( + name="server_a", + instruction="Respond with ' → Routed to Server A' exactly.", + ) + server_b = create_llm_agent( + name="server_b", + instruction="Respond with ' → Routed to Server B' exactly.", + ) + server_c = create_llm_agent( + name="server_c", + instruction="Respond with ' → Routed to Server C' exactly.", + ) + + return start, server_a, server_b, server_c + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + start = SimpleAgent(name="start", message="Starting load balancer...") + server_a = SimpleAgent(name="server_a", message=" → Routed to Server A") + server_b = SimpleAgent(name="server_b", message=" → Routed to Server B") + server_c = SimpleAgent(name="server_c", message=" → Routed to Server C") + + return start, server_a, server_b, server_c + + +def create_agents_fallback(score: float): + """Create agents for fallback routing example. + + Returns: + tuple: (validate, premium, standard, fallback) agents + """ + if use_llm_mode(): + print("🤖 Using LLM-powered agents (gemini-2.5-flash)\n") + + validate = create_llm_agent( + name="validate", + instruction=f"Respond with 'Risk score: {score}' exactly.", + ) + premium = create_llm_agent( + name="premium", + instruction="Respond with 'Premium path (VIP users)' exactly.", + ) + standard = create_llm_agent( + name="standard", + instruction="Respond with 'Standard path (regular users)' exactly.", + ) + fallback = create_llm_agent( + name="fallback", + instruction="Respond with 'Fallback path (default handler)' exactly.", + ) + + return validate, premium, standard, fallback + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + validate = ScoreAgent(name="validate", score=score) + premium = SimpleAgent(name="premium", message="🌟 Premium path (VIP users)") + standard = SimpleAgent( + name="standard", message="📦 Standard path (regular users)" + ) + fallback = SimpleAgent( + name="fallback", message="🔒 Fallback path (default handler)" + ) + + return validate, premium, standard, fallback + + +async def main(): + print("\n" + "=" * 60) + print("Example 15: Enhanced Routing") + print("=" * 60 + "\n") + + # ===== Example 1: Priority-based Routing ===== + print("📊 Example 1: Priority-based Routing\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + analyze, critical, warning, normal = create_agents_priority(0.85) + + graph1 = ( + GraphAgent(name="priority_routing") + .add_node("analyze", agent=analyze) + .add_node("critical", agent=critical) + .add_node("warning", agent=warning) + .add_node("normal", agent=normal) + ) + + # Set output mapper to persist risk_score in state + def store_score(output, state): + new_state = GraphState(data=state.data.copy()) + new_state.data["risk_score"] = 0.85 # Score from analyze agent + return new_state + + graph1.nodes["analyze"].output_mapper = store_score + + # Priority-based routing: higher priority evaluated first + graph1.nodes["analyze"].edges = [ + EdgeCondition( + target_node="critical", + condition=lambda s: s.data.get("risk_score", 0) > 0.9, + priority=10, # Highest priority + ), + EdgeCondition( + target_node="warning", + condition=lambda s: s.data.get("risk_score", 0) > 0.7, + priority=5, # Medium priority - THIS WILL MATCH + ), + EdgeCondition( + target_node="normal", + priority=0, # Fallback (priority=0 always matches if no other matched) + ), + ] + + graph1.set_start("analyze") + graph1.set_end("critical") + graph1.set_end("warning") + graph1.set_end("normal") + + session_service = InMemorySessionService() + runner = Runner( + app_name="routing_demo", + agent=graph1, + session_service=session_service, + auto_create_session=True, + ) + + async for event in runner.run_async( + user_id="user1", + session_id="session1", + new_message=types.Content(parts=[types.Part(text="Analyze")]), + ): + if event.content and event.content.parts and event.content.parts[0].text: + print(f" {event.content.parts[0].text}") + + print("\n 💡 Score was 0.85 → matched 'warning' (priority=5)") + print(" 💡 'critical' didn't match (0.85 < 0.9)") + print(" 💡 'normal' fallback not needed (higher priority matched)\n") + + # ===== Example 2: Weighted Random Routing ===== + print("🎲 Example 2: Weighted Random Routing\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + start, server_a, server_b, server_c = create_agents_weighted() + + graph2 = ( + GraphAgent(name="weighted_routing") + .add_node("start", agent=start) + .add_node("server_a", agent=server_a) + .add_node("server_b", agent=server_b) + .add_node("server_c", agent=server_c) + ) + + # Weighted routing: all at same priority, different weights + graph2.nodes["start"].edges = [ + EdgeCondition( + target_node="server_a", + condition=lambda s: True, # All match + priority=1, # Same priority + weight=0.5, # 50% probability + ), + EdgeCondition( + target_node="server_b", + condition=lambda s: True, + priority=1, # Same priority + weight=0.3, # 30% probability + ), + EdgeCondition( + target_node="server_c", + condition=lambda s: True, + priority=1, # Same priority + weight=0.2, # 20% probability + ), + ] + + graph2.set_start("start") + graph2.set_end("server_a") + graph2.set_end("server_b") + graph2.set_end("server_c") + + runner2 = Runner( + app_name="weighted_demo", + agent=graph2, + session_service=session_service, + auto_create_session=True, + ) + + # Run multiple times to show distribution + counts = {"server_a": 0, "server_b": 0, "server_c": 0} + trials = 20 + + print(f" Running {trials} trials with weights (A:50%, B:30%, C:20%):\n") + + for i in range(trials): + async for event in runner2.run_async( + user_id="user1", + session_id=f"session_weighted_{i}", + new_message=types.Content(parts=[types.Part(text="Route")]), + ): + if event.content and event.content.parts and event.author in counts: + text = event.content.parts[0].text + counts[event.author] += 1 + print(f" Trial {i+1:2d}: {text}") + + print(f"\n 📊 Distribution after {trials} trials:") + print( + f" Server A: {counts['server_a']:2d}/{trials}" + f" ({counts['server_a']/trials*100:.0f}%)" + ) + print( + f" Server B: {counts['server_b']:2d}/{trials}" + f" ({counts['server_b']/trials*100:.0f}%)" + ) + print( + f" Server C: {counts['server_c']:2d}/{trials}" + f" ({counts['server_c']/trials*100:.0f}%)\n" + ) + + # ===== Example 3: Fallback Edge ===== + print("🛡️ Example 3: Fallback Edge (priority=0)\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + validate, premium, standard, fallback = create_agents_fallback(0.5) + + graph3 = ( + GraphAgent(name="fallback_routing") + .add_node("validate", agent=validate) + .add_node("premium", agent=premium) + .add_node("standard", agent=standard) + .add_node("fallback", agent=fallback) + ) + + def store_score_fallback(output, state): + new_state = GraphState(data=state.data.copy()) + new_state.data["risk_score"] = 0.5 + # Don't set is_vip or is_standard - will fall through to fallback + return new_state + + graph3.nodes["validate"].output_mapper = store_score_fallback + + graph3.nodes["validate"].edges = [ + EdgeCondition( + target_node="premium", + condition=lambda s: s.data.get("is_vip", False), + priority=10, # High priority - won't match + ), + EdgeCondition( + target_node="standard", + condition=lambda s: s.data.get("is_standard", False), + priority=5, # Medium priority - won't match + ), + EdgeCondition( + target_node="fallback", + priority=0, # FALLBACK - always matches if reached + ), + ] + + graph3.set_start("validate") + graph3.set_end("premium") + graph3.set_end("standard") + graph3.set_end("fallback") + + runner3 = Runner( + app_name="fallback_demo", + agent=graph3, + session_service=session_service, + auto_create_session=True, + ) + + async for event in runner3.run_async( + user_id="user1", + session_id="session_fallback", + new_message=types.Content(parts=[types.Part(text="Validate")]), + ): + if event.content and event.content.parts and event.content.parts[0].text: + print(f" {event.content.parts[0].text}") + + print("\n 💡 No is_vip or is_standard flag set") + print(" 💡 All higher priority edges failed to match") + print(" 💡 Fallback (priority=0) caught it!\n") + + print("=" * 60) + print("✅ Enhanced Routing Complete!") + print("=" * 60 + "\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/README.md b/contributing/samples/graph_examples/README.md new file mode 100644 index 0000000000..7eb7e7ee89 --- /dev/null +++ b/contributing/samples/graph_examples/README.md @@ -0,0 +1,594 @@ +# GraphAgent Examples - All Features + +Comprehensive collection of small, focused examples demonstrating every GraphAgent feature. + +--- + +## API Overview + +GraphAgent supports both **explicit** and **convenience** APIs for building workflows: + +### Convenience API (Recommended) + +```python +# Fluent chaining pattern +graph = ( + GraphAgent(name="workflow") + .add_node("validate", agent=validator) + .add_node("process", agent=processor) + .add_edge("validate", "process") # Positional args + .add_edge("process", "output", priority=10) # With priority + .set_start("validate") + .set_end("output") +) + +# Or step-by-step +graph = GraphAgent(name="workflow") +graph.add_node("step1", agent=agent1) +graph.add_node("step2", function=custom_function) +graph.add_edge("step1", "step2") +``` + +### Explicit API (Also Supported) + +```python +# Using GraphNode and EdgeCondition explicitly +graph.add_node(GraphNode(name="step1", agent=agent1)) +graph.add_edge("source", EdgeCondition( + target_node="target", + priority=10, + condition=lambda s: s.data.get("valid") +)) +``` + +### Key API Features + +- **add_node()**: Convenience syntax `add_node("name", agent=...)` or explicit `add_node(GraphNode(...))` +- **add_edge()**: Positional `add_edge("source", "target")` or keyword `add_edge(source_node="a", target_node="b")` +- **EdgeCondition**: Support for `add_edge("src", EdgeCondition(target_node="tgt", condition=...))` +- **checkpoint_service**: Optional parameter, not required +- **Fluent chaining**: All builder methods return self for method chaining + +--- + +## Quick Start + +Run any example with: +```bash +cd /path/to/adk-python +source venv/bin/activate +python -m contributing.samples.graph_examples..agent +``` + +### Deterministic vs LLM Modes + +Examples support two execution modes: + +**Deterministic Mode (default)** - Uses BaseAgent subclasses with deterministic outputs. No API keys required. +```bash +python -m contributing.samples.graph_examples.01_basic.agent +``` + +**LLM Mode (optional)** - Uses real Gemini LLM endpoints. Requires API credentials. +```bash +# Via command-line flag +python -m contributing.samples.graph_examples.01_basic.agent --use-llm + +# Via environment variable +USE_LLM=1 python -m contributing.samples.graph_examples.01_basic.agent +``` + +**Note:** LLM mode is only available for simple examples (01_basic, 02_conditional_routing, etc.). Examples that require precise state management (05_interrupts, parallel execution) use deterministic agents to demonstrate graph mechanics reliably. + +--- + +## Examples Overview + +### 🟢 Core Features + +#### **01_basic** - Basic GraphAgent Workflow +Simple directed graph with nodes and edges. +```bash +python -m contributing.samples.graph_examples.01_basic.agent +``` +**Demonstrates:** +- Creating a graph with fluent API +- Adding nodes with convenience syntax: `add_node("name", agent=...)` +- Adding edges with positional syntax: `add_edge("source", "target")` +- Executing workflow + +--- + +#### **02_conditional_routing** - Conditional Routing +State-based routing decisions. +```bash +python -m contributing.samples.graph_examples.02_conditional_routing.agent +``` +**Demonstrates:** +- Conditional edges with `condition` parameter +- EdgeCondition support: `add_edge("src", EdgeCondition(target_node="tgt", condition=...))` +- State-based decisions +- Multiple routing paths +- Priority-based routing with `priority` parameter + +--- + +#### **03_cyclic_execution** - Cyclic Graph Execution +Loops and iteration control. +```bash +python -m contributing.samples.graph_examples.03_cyclic_execution.agent +``` +**Demonstrates:** +- Cyclic graphs with back-edges +- Writing routing signals via state_delta (ADK pattern) +- GraphState.data auto-sync from session.state +- max_iterations guard + +**Key Pattern:** +Agents write routing signals via `Event.actions.state_delta`. GraphAgent automatically syncs `session.state` into `GraphState.data` before edge evaluation — no `output_mapper` needed. + +--- + +#### **04_checkpointing** - Checkpointing & Resume +Automatic state persistence. +```bash +python -m contributing.samples.graph_examples.04_checkpointing.agent +``` +**Demonstrates:** +- Automatic checkpointing with optional `checkpoint_service` parameter +- State persistence +- Checkpoint metadata +- Execution path tracking +- Resume from checkpoint capability + +--- + +#### **05_interrupts_basic** - Basic Interrupts +Human-in-the-loop interrupts. +```bash +python -m contributing.samples.graph_examples.05_interrupts_basic.agent +``` +**Demonstrates:** +- All 8 interrupt actions: continue, rerun, pause, go_back, skip, defer, update_state, change_condition +- Concurrent injection via asyncio.create_task +- SlowNode (2 sub-steps × 1s) timing pattern +- AFTER interrupt check behavior (queued during execution, consumed after node completes) + +--- + +#### **06_interrupts_reasoning** - Interrupt with Reasoning +Condition-based action selection. +```bash +python -m contributing.samples.graph_examples.06_interrupts_reasoning.agent +``` +**Demonstrates:** +- Interrupt with condition evaluation +- Automated action selection based on state +- Draft-review workflow with interrupt points + +--- + +#### **07_callbacks** - Node Callbacks +Lifecycle hooks for nodes. +```bash +python -m contributing.samples.graph_examples.07_callbacks.agent +``` +**Demonstrates:** +- `before_node_callback` - executed before node runs +- `after_node_callback` - executed after node completes +- Timing and performance tracking +- Telemetry integration patterns + +--- + +#### **08_rewind** - Rewind Integration +Time-travel debugging. +```bash +python -m contributing.samples.graph_examples.08_rewind.agent +``` +**Demonstrates:** +- Invocation tracking +- Rewinding to specific node +- State restoration +- Re-execution after rewind + +--- + +### ⚡ Parallel Execution + +#### **09_parallel_wait_all** - Parallel Execution (WAIT_ALL) +Concurrent node execution, wait for all. +```bash +python -m contributing.samples.graph_examples.09_parallel_wait_all.agent +``` +**Demonstrates:** +- Parallel node execution +- WAIT_ALL join strategy +- Speedup vs sequential (2.25x) +- Event streaming from parallel nodes + +**Output (example):** +``` +✅ Fetched data from products_db +✅ Fetched data from users_db +✅ Fetched data from orders_db + +All three fetched in parallel. +``` + +--- + +#### **10_parallel_wait_any** - Parallel Execution (WAIT_ANY) +Race condition, first-to-complete wins. +```bash +python -m contributing.samples.graph_examples.10_parallel_wait_any.agent +``` +**Demonstrates:** +- Racing multiple data sources +- WAIT_ANY join strategy +- Automatic cancellation of slower nodes +- Cache-DB-API fallback pattern + +**Output (example):** +``` +✅ Data from CACHE + +Winner: Cache +Cancelled: Database, API +``` + +--- + +#### **11_parallel_wait_n** - Parallel Execution (WAIT_N) +Continue after N of M complete. +```bash +python -m contributing.samples.graph_examples.11_parallel_wait_n.agent +``` +**Demonstrates:** +- WAIT_N join strategy (e.g., 2 out of 3) +- ML model ensemble pattern +- Partial completion workflows +- Automatic cancellation of remaining nodes + +--- + +#### **12_parallel_checkpointing** - Parallel + Checkpointing +State persistence across parallel execution. +```bash +python -m contributing.samples.graph_examples.12_parallel_checkpointing.agent +``` +**Demonstrates:** +- Parallel execution with automatic checkpointing +- State recovery after interruption +- Checkpoint metadata tracking +- Resume from mid-parallel execution + +--- + +#### **13_parallel_interrupts** - Parallel + Interrupts +Interrupt handling inside parallel branches. +```bash +python -m contributing.samples.graph_examples.13_parallel_interrupts.agent +``` +**Demonstrates:** +- Interrupts within parallel node execution +- Branch-specific interrupt handling +- Pause/resume in parallel context +- Interrupt isolation across branches + +--- + +### 🔗 Combined Features + +#### **14_parallel_rewind** - Parallel Execution + Rewind +Rewind works with parallel workflows! +```bash +python -m contributing.samples.graph_examples.14_parallel_rewind.agent +``` +**Demonstrates:** +- Parallel + Rewind integration +- Invocation tracking in parallel groups +- Re-execution of entire parallel group +- State consistency across rewind + +**Key Insight:** +- Rewind to parallel node → entire parallel group re-executes +- All branches get new invocations +- Deterministic re-execution + +--- + +#### **15_enhanced_routing** - Enhanced Routing Patterns +Priority, weighted, and fallback routing. +```bash +python -m contributing.samples.graph_examples.15_enhanced_routing.agent +``` +**Demonstrates:** +- Priority-based routing (higher priority evaluated first) +- Weighted random selection (probabilistic routing) +- Fallback edges (priority=0 always matches) +- Three routing patterns in one example + +--- + +## Feature Matrix + +| Example | Parallel | Rewind | Checkpoints | Interrupts | Callbacks | Cyclic | Routing | +|---------|----------|--------|-------------|------------|-----------|--------|---------| +| 01_basic | - | - | - | - | - | - | Simple | +| 02_conditional_routing | - | - | - | - | - | - | Conditional | +| 03_cyclic_execution | - | - | - | - | - | ✅ | Conditional | +| 04_checkpointing | - | - | ✅ | - | - | - | Simple | +| 05_interrupts_basic | - | - | - | ✅ | - | - | Simple | +| 06_interrupts_reasoning | - | - | - | ✅ | - | - | Conditional | +| 07_callbacks | - | - | - | - | ✅ | - | Simple | +| 08_rewind | - | ✅ | - | - | - | - | Simple | +| 09_parallel_wait_all | ✅ | - | - | - | - | - | Parallel | +| 10_parallel_wait_any | ✅ | - | - | - | - | - | Parallel | +| 11_parallel_wait_n | ✅ | - | - | - | - | - | Parallel | +| 12_parallel_checkpointing | ✅ | - | ✅ | - | - | - | Parallel | +| 13_parallel_interrupts | ✅ | - | - | ✅ | - | - | Parallel | +| 14_parallel_rewind | ✅ | ✅ | - | - | - | - | Parallel | +| 15_enhanced_routing | - | - | - | - | - | - | Advanced | + +--- + +## Architectural Insights + +### Parallel Execution Architecture + +``` +┌─────────────┐ +│ validate │ +└──────┬──────┘ + │ + ├──────────────┬──────────────┐ + │ │ │ + ▼ ▼ ▼ +┌──────────────┐ ┌──────────────┐ ┌──────────────┐ +│ fetch_A │ │ fetch_B │ │ fetch_C │ +│ (isolated) │ │ (isolated) │ │ (isolated) │ +└──────┬───────┘ └──────┬───────┘ └──────┬───────┘ + │ │ │ + └──────────────┴──────────────┘ + │ + ▼ + ┌──────────────┐ + │ aggregate │ + │(merged state)│ + └──────────────┘ +``` + +**Key Points:** +- Each branch has **isolated state** during execution +- No race conditions possible +- State **merged** after all branches complete +- Events **streamed** as branches complete (FIRST_COMPLETED) + +--- + +### Rewind with Parallel Execution + +``` +1. Initial Execution: + validate → (fetch_A || fetch_B || fetch_C) → aggregate + + Invocations created: + - validate: ["inv_1"] + - fetch_A: ["inv_2"] + - fetch_B: ["inv_3"] + - fetch_C: ["inv_4"] + - aggregate: ["inv_5"] + +2. Rewind to fetch_A (inv_2): + Session state restored to BEFORE inv_2 + +3. Re-execution: + (fetch_A || fetch_B || fetch_C) → aggregate + + New invocations: + - fetch_A: ["inv_2", "inv_6"] + - fetch_B: ["inv_3", "inv_7"] + - fetch_C: ["inv_4", "inv_8"] + - aggregate: ["inv_5", "inv_9"] +``` + +**Key Points:** +- Rewind works seamlessly with parallel groups +- Entire parallel group re-executes +- New invocations created on re-execution +- Deterministic behavior guaranteed + +--- + +### State Isolation + +**Problem:** Multiple nodes modifying same state → race conditions + +**Solution:** Isolated state copies per branch + +```python +# During parallel execution +for node in parallel_group.nodes: + # Each branch gets ISOLATED copy + branch_state = state.copy() + + # Modify branch state + execute_node(node, branch_state) + +# After all complete +merged_state = merge(all_branch_states) +``` + +**Benefits:** +- No race conditions +- Deterministic results +- Safe concurrent execution + +--- + +## Performance Comparison + +### Sequential vs Parallel (WAIT_ALL) + +**Scenario:** Fetch from 3 sources (100ms, 150ms, 200ms each) + +**Sequential:** +``` +Total time = 100 + 150 + 200 = 450ms +``` + +**Parallel (WAIT_ALL):** +``` +Total time = max(100, 150, 200) = 200ms +Speedup: 450ms / 200ms = 2.25x +``` + +**Parallel (WAIT_ANY):** +``` +Total time = min(100, 150, 200) = 100ms +Speedup: 450ms / 100ms = 4.5x +``` + +--- + +## Common Patterns + +### 1. Data Pipeline (WAIT_ALL) +Fetch data from multiple sources concurrently. +```python +ParallelNodeGroup( + nodes=["fetch_users", "fetch_products", "fetch_orders"], + join_strategy=JoinStrategy.WAIT_ALL +) +``` + +### 2. Cache-DB-API Fallback (WAIT_ANY) +Race multiple data sources, use fastest. +```python +ParallelNodeGroup( + nodes=["from_cache", "from_db", "from_api"], + join_strategy=JoinStrategy.WAIT_ANY +) +``` + +### 3. ML Model Ensemble (WAIT_N) +Run multiple models, proceed when N complete. +```python +ParallelNodeGroup( + nodes=["model1", "model2", "model3"], + join_strategy=JoinStrategy.WAIT_N, + wait_n=2 # 2 out of 3 +) +``` + +### 4. Interrupt-Driven Review +Human review after key nodes. +```python +InterruptConfig( + mode=InterruptMode.AFTER, + nodes=["draft", "review"] +) +``` + +### 5. Checkpoint-Resume Workflow +Long-running workflows with state persistence. +```python +GraphAgent( + name="workflow", + checkpoint_service=checkpoint_service # Optional parameter +) +``` + +--- + +## Error Handling + +### Parallel Error Policies + +#### FAIL_FAST (default) +```python +ParallelNodeGroup( + nodes=["task1", "task2", "task3"], + error_policy=ErrorPolicy.FAIL_FAST +) +# One error → cancel all → raise exception +``` + +#### CONTINUE +```python +ParallelNodeGroup( + nodes=["task1", "task2", "task3"], + error_policy=ErrorPolicy.CONTINUE +) +# One error → continue others → log error +``` + +#### COLLECT +```python +ParallelNodeGroup( + nodes=["task1", "task2", "task3"], + error_policy=ErrorPolicy.COLLECT +) +# All errors → collect all → raise at end +``` + +--- + +## Testing + +Run tests: +```bash +pytest tests/unittests/agents/test_graph_*.py -v +``` + +--- + +## Next Steps + +1. **Try the examples** - Run each one to see features in action +2. **Modify examples** - Change parameters, add nodes, experiment +3. **Combine features** - Mix parallel + rewind + checkpoints +4. **Build your workflow** - Use patterns for your use case + +--- + +## Related Samples: graph_agent_* (Complex Real-World Examples) + +In addition to the numbered graph_examples, there are advanced samples at `contributing/samples/graph_agent_*`: + +| Sample | Description | Pattern | +|--------|-------------|---------| +| **graph_agent_basic** | Basic research workflow | LLM-powered, `agents.py` + `agent.py` | +| **graph_agent_advanced** | Complex research paper workflow | Multi-phase with review loop | +| **graph_agent_react_pattern** | ReAct pattern (Reason + Act) | Thought-action-observation cycle | +| **graph_agent_multi_agent** | Multiple specialized agents | Delegation and collaboration | +| **graph_agent_dynamic_queue** | Dynamic node queueing | Runtime graph modification | +| **graph_agent_parallel_features** | Parallel feature demonstrations | Showcases parallel capabilities | +| **graph_agent_pattern_dynamic_node** | Dynamic node creation | Runtime node injection | +| **graph_agent_pattern_nested_graph** | Nested GraphAgents | Graph as node pattern | +| **graph_agent_pattern_parallel_group** | Parallel group pattern | Advanced parallel workflows | + +**Key Differences from graph_examples**: +- **LLM-Required**: Use LlmAgent, require API credentials (no deterministic mode) +- **Structure**: Single `agent.py` with inline agent definitions +- **Complexity**: Multi-phase workflows, real-world use cases +- **Purpose**: Demonstrate production patterns, not individual features +- **Note**: Some samples have `__init__.py` for package structure + +**Run Example**: +```bash +cd /path/to/adk-python +source venv/bin/activate +python -m contributing.samples.graph_agent_advanced.agent +``` + +--- + +## Support + +Questions? Check: +- Examples: `contributing/samples/graph_examples/` +- Advanced: `contributing/samples/graph_agent_*/` +- Tests: `tests/unittests/agents/test_graph_*.py` +- Source: `src/google/adk/agents/graph/` diff --git a/contributing/samples/graph_examples/__init__.py b/contributing/samples/graph_examples/__init__.py new file mode 100644 index 0000000000..425a6cca45 --- /dev/null +++ b/contributing/samples/graph_examples/__init__.py @@ -0,0 +1,20 @@ +"""GraphAgent Examples - All Features + +Small, focused examples demonstrating every GraphAgent feature. + +Examples: +- 01_basic: Basic workflow +- 02_conditional_routing: State-based routing +- 15_enhanced_routing: Priority, weight, and fallback routing +- 04_checkpointing: Automatic state persistence +- 05_interrupts_basic: Human-in-the-loop +- 08_rewind: Time-travel debugging +- 09_parallel_wait_all: Parallel execution (WAIT_ALL) +- 10_parallel_wait_any: Parallel execution (WAIT_ANY / race) +- 14_parallel_rewind: Combined parallel + rewind + +Run any example: + python -m contributing.samples.graph_examples..agent + +See README.md for complete documentation. +""" diff --git a/contributing/samples/graph_examples/example_utils.py b/contributing/samples/graph_examples/example_utils.py new file mode 100644 index 0000000000..3b48cd7359 --- /dev/null +++ b/contributing/samples/graph_examples/example_utils.py @@ -0,0 +1,62 @@ +"""Shared utilities for GraphAgent examples. + +Provides: +- USE_LLM flag to toggle between deterministic agents and real LLM endpoints +- Helper to create LLM-powered agents with consistent configuration +""" + +import os +import sys + + +def use_llm_mode() -> bool: + """Check if examples should use real LLM endpoints instead of deterministic agents. + + Returns True if: + - Environment variable USE_LLM=1 or USE_LLM=true + - Command-line flag --use-llm is present + + Default: False (use deterministic BaseAgent implementations) + """ + # Check environment variable + env_use_llm = os.getenv("USE_LLM", "").lower() in ("1", "true", "yes") + + # Check command-line args + arg_use_llm = "--use-llm" in sys.argv + + return env_use_llm or arg_use_llm + + +def create_llm_agent( + name: str, + instruction=None, + model: str = "gemini-2.5-flash", + tools: list = None, + **kwargs, +): + """Create an LLM-powered agent. + + Args: + name: Agent name + instruction: System instruction (str or callable for dynamic instructions) + model: Model to use (default: gemini-2.5-flash) + tools: Optional list of tools + **kwargs: Additional Agent configuration (e.g., include_contents='none') + + Returns: + Agent instance configured with LLM + + Note: + Requires valid API credentials configured via: + - GOOGLE_GENAI_API_KEY environment variable, or + - gcloud auth application-default login + """ + from google.adk import Agent + + return Agent( + name=name, + model=model, + instruction=instruction, + tools=tools or [], + **kwargs, + ) diff --git a/contributing/samples/graph_examples/run_example.py b/contributing/samples/graph_examples/run_example.py new file mode 100755 index 0000000000..c07b2579d4 --- /dev/null +++ b/contributing/samples/graph_examples/run_example.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +"""Utility to run graph_examples with optional trace logging and LLM mode. + +Usage: + # Default mode (deterministic agents) + python run_example.py 01_basic + + # LLM mode + python run_example.py 01_basic --use-llm + + # With trace logs + python run_example.py 01_basic --trace + + # Both + python run_example.py 01_basic --use-llm --trace + + # List all examples + python run_example.py --list +""" + +import argparse +import importlib +import logging +from pathlib import Path +import sys + + +def setup_logging(trace: bool = False): + """Configure logging based on trace flag.""" + level = logging.DEBUG if trace else logging.INFO + format_str = ( + "%(asctime)s [%(levelname)s] %(name)s:%(lineno)d - %(message)s" + if trace + else "%(message)s" + ) + + logging.basicConfig( + level=level, + format=format_str, + datefmt="%H:%M:%S", + ) + + # Enable ADK trace logging + if trace: + logging.getLogger("google_adk").setLevel(logging.DEBUG) + logging.getLogger("google.adk").setLevel(logging.DEBUG) + + +def list_examples(): + """List all available examples.""" + examples_dir = Path(__file__).parent + examples = sorted([ + d.name + for d in examples_dir.iterdir() + if d.is_dir() and (d / "agent.py").exists() and not d.name.startswith("_") + ]) + + print("\n📚 Available graph_examples:\n") + for ex in examples: + agent_file = examples_dir / ex / "agent.py" + # Read first docstring line + with open(agent_file) as f: + lines = f.readlines() + desc = "" + for line in lines: + if line.strip().startswith('"""'): + desc = line.strip('"""').strip() + break + print(f" {ex:30s} - {desc}") + + print("\n") + + +def run_example(example_name: str, use_llm: bool = False, trace: bool = False): + """Run a specific example.""" + setup_logging(trace=trace) + + # Set USE_LLM env var if needed + if use_llm: + import os + + os.environ["USE_LLM"] = "1" + + # Verify example exists + example_dir = Path(__file__).parent / example_name + if not example_dir.exists() or not (example_dir / "agent.py").exists(): + print(f"❌ Example '{example_name}' not found") + print("\nRun with --list to see available examples") + sys.exit(1) + + # Run via subprocess to handle module names starting with numbers + import os + import subprocess + + env = os.environ.copy() + if use_llm: + env["USE_LLM"] = "1" + + # Run from adk-python root + adk_root = Path(__file__).parent.parent.parent.parent + module_path = f"contributing.samples.graph_examples.{example_name}.agent" + + print(f"\n{'='*70}") + print(f"Running: {example_name}") + print(f"Mode: {'🤖 LLM' if use_llm else '🎭 Deterministic'}") + print(f"Trace: {'✓ Enabled' if trace else '✗ Disabled'}") + print(f"{'='*70}\n") + + try: + result = subprocess.run( + [sys.executable, "-m", module_path], + cwd=str(adk_root), + env=env, + capture_output=False, + text=True, + ) + sys.exit(result.returncode) + except Exception as e: + print(f"\n❌ Error running example: {e}") + if trace: + import traceback + + traceback.print_exc() + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser( + description="Run graph_examples with optional trace logging and LLM mode" + ) + parser.add_argument( + "example", nargs="?", help="Example name (e.g., 01_basic)" + ) + parser.add_argument( + "--use-llm", + action="store_true", + help="Use real LLM endpoints instead of deterministic agents", + ) + parser.add_argument( + "--trace", action="store_true", help="Enable detailed trace logging" + ) + parser.add_argument("--list", action="store_true", help="List all examples") + + args = parser.parse_args() + + if args.list: + list_examples() + return + + if not args.example: + parser.print_help() + print("\nRun with --list to see available examples") + sys.exit(1) + + run_example(args.example, use_llm=args.use_llm, trace=args.trace) + + +if __name__ == "__main__": + main() diff --git a/src/google/adk/agents/__init__.py b/src/google/adk/agents/__init__.py index fbd1808f3f..4735379bd0 100644 --- a/src/google/adk/agents/__init__.py +++ b/src/google/adk/agents/__init__.py @@ -14,6 +14,11 @@ from .base_agent import BaseAgent from .context import Context +from .graph import END +from .graph import GraphAgent +from .graph import GraphNode +from .graph import GraphState +from .graph import START from .invocation_context import InvocationContext from .live_request_queue import LiveRequest from .live_request_queue import LiveRequestQueue @@ -29,6 +34,11 @@ 'Agent', 'BaseAgent', 'Context', + 'GraphAgent', + 'GraphNode', + 'GraphState', + 'START', + 'END', 'LlmAgent', 'LoopAgent', 'McpInstructionProvider', diff --git a/src/google/adk/agents/graph/__init__.py b/src/google/adk/agents/graph/__init__.py new file mode 100644 index 0000000000..5ccc0f0c4d --- /dev/null +++ b/src/google/adk/agents/graph/__init__.py @@ -0,0 +1,79 @@ +"""Graph-based agent components. + +This module contains components for graph-based workflow orchestration: +- GraphAgent: Main graph workflow agent +- GraphState: Domain data container +- GraphAgentState: Execution tracking state (BaseAgentState) +- GraphNode: Node wrapper for agents and functions +- EdgeCondition: Conditional routing between nodes +- StateReducer: State merge strategies +- GraphEvent: Typed events for streaming +- GraphEventType: Event type enumeration +- GraphStreamMode: Stream mode enumeration +- NodeCallbackContext: Context for node lifecycle callbacks +- EdgeCallbackContext: Context for edge condition callbacks +- NodeCallback: Type for node lifecycle callbacks +- EdgeCallback: Type for edge condition callbacks +""" + +from __future__ import annotations + +from .callbacks import create_nested_observability_callback +from .callbacks import EdgeCallback +from .callbacks import EdgeCallbackContext +from .callbacks import NodeCallback +from .callbacks import NodeCallbackContext +from .evaluation_metrics import graph_path_match +from .evaluation_metrics import node_execution_count +from .evaluation_metrics import state_contains_keys +from .graph_agent import GraphAgent +from .graph_agent_config import GraphAgentConfig +from .graph_agent_config import GraphEdgeConfig +from .graph_agent_config import GraphNodeConfig +from .graph_agent_state import GraphAgentState +from .graph_edge import EdgeCondition +from .graph_events import GraphEvent +from .graph_events import GraphEventType +from .graph_events import GraphStreamMode +from .graph_export import export_execution_timeline +from .graph_export import export_graph_structure +from .graph_export import export_graph_with_execution +from .graph_node import GraphNode +from .graph_rewind import rewind_to_node +from .graph_state import GraphState +from .graph_state import PydanticJSONEncoder +from .graph_state import StateReducer + +# Sentinel constants for graph boundaries +START = "__start__" +END = "__end__" + +__all__ = [ + "GraphAgent", + "GraphAgentConfig", + "GraphAgentState", + "GraphNodeConfig", + "GraphEdgeConfig", + "GraphState", + "GraphNode", + "EdgeCondition", + "StateReducer", + "PydanticJSONEncoder", + "GraphEvent", + "GraphEventType", + "GraphStreamMode", + "NodeCallbackContext", + "EdgeCallbackContext", + "NodeCallback", + "EdgeCallback", + "create_nested_observability_callback", + "graph_path_match", + "state_contains_keys", + "node_execution_count", + "export_graph_structure", + "export_graph_with_execution", + "export_execution_timeline", + "rewind_to_node", + "START", + "END", +] diff --git a/src/google/adk/agents/graph/callbacks.py b/src/google/adk/agents/graph/callbacks.py new file mode 100644 index 0000000000..6660932b36 --- /dev/null +++ b/src/google/adk/agents/graph/callbacks.py @@ -0,0 +1,218 @@ +"""Callback infrastructure for graph observability and extensibility. + +This module provides callback primitives for customizing graph behavior: +- NodeCallbackContext: Context passed to node lifecycle callbacks +- EdgeCallbackContext: Context passed to edge condition callbacks +- NodeCallback: Type for before/after node callbacks +- EdgeCallback: Type for edge condition callbacks + +Callbacks enable custom observability, logging, debugging, and control flow +without modifying the core GraphAgent implementation. + +Example: + ```python + import json + from google.adk.agents.graph import GraphAgent + from google.adk.agents.graph.callbacks import NodeCallbackContext + from google import genai + + async def my_observability(ctx: NodeCallbackContext): + '''Custom observability callback.''' + # Automatic Pydantic serialization + return genai.types.Event( + author="observability", + content=genai.types.Content(parts=[ + genai.types.Part(text=f"→ Executing: {ctx.node.name}"), + genai.types.Part(text=f"State:\n{ctx.state.data_to_json()}"), + ]), + actions=genai.types.EventActions( + escalate=False, + state_delta={ + "observability_node": ctx.node.name, + "observability_iteration": ctx.iteration, + } + ) + ) + + graph = GraphAgent( + name="my_graph", + before_node_callback=my_observability, + ) + ``` +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import Dict +from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...events.event import Event + +from google import genai + +from .graph_state import GraphState + + +@dataclass +class NodeCallbackContext: + """Context passed to node lifecycle callbacks. + + Contains all information about the current node execution, + including the node itself, current state, iteration number, + and full invocation context. + + Attributes: + node: The GraphNode being executed + state: Current graph state (before or after execution) + iteration: Current iteration number in graph execution + invocation_context: Full ADK invocation context + metadata: Extensible metadata dictionary for custom use + """ + + node: Any # GraphNode (avoiding circular import) + state: GraphState + iteration: int + invocation_context: Any # InvocationContext (from genai.types) + metadata: Dict[str, Any] + + +@dataclass +class EdgeCallbackContext: + """Context passed to edge condition callbacks. + + Contains all information about an edge condition evaluation, + including source/target nodes, the condition function, + evaluation result, and current state. + + Attributes: + source_node: Name of the source node + target_node: Name of the target node + condition: The condition function being evaluated (or None for unconditional) + condition_result: Boolean result of condition evaluation + state: Current graph state + invocation_context: Full ADK invocation context + metadata: Extensible metadata dictionary for custom use + """ + + source_node: str + target_node: str + condition: Optional[Callable[[GraphState], bool]] + condition_result: bool + state: GraphState + invocation_context: Any # InvocationContext (from genai.types) + metadata: Dict[str, Any] + + +# Type aliases for callbacks +NodeCallback = Callable[[NodeCallbackContext], Awaitable[Optional["Event"]]] +"""Callback function type for node lifecycle events. + +Receives NodeCallbackContext and optionally returns an Event to emit. +Returning None skips event emission. + +Example: + ```python + from google.adk.events.event import Event + + async def my_node_callback(ctx: NodeCallbackContext) -> Optional[Event]: + # Custom logic here + if should_emit: + return Event(...) + return None # Skip event + ``` +""" + +EdgeCallback = Callable[[EdgeCallbackContext], Awaitable[Optional["Event"]]] +"""Callback function type for edge condition events. + +Receives EdgeCallbackContext and optionally returns an Event to emit. +Returning None skips event emission. + +Example: + ```python + from google.adk.events.event import Event + from google.genai.types import Content, Part + + async def my_edge_callback(ctx: EdgeCallbackContext) -> Optional[Event]: + # Log conditional routing decisions + if ctx.condition_result: + return Event( + content=Content(parts=[ + Part(text=f"Routing: {ctx.source_node} → {ctx.target_node}") + ]) + ) + return None + ``` +""" + + +def create_nested_observability_callback() -> NodeCallback: + """Create a callback that shows nested graph hierarchy. + + Returns a NodeCallback that includes agent path information + in observability events, making nested graph execution visible. + + Example: + ```python + graph = GraphAgent( + name="my_graph", + before_node_callback=create_nested_observability_callback(), + ) + ``` + + Returns: + NodeCallback that emits events with nesting hierarchy + """ + + async def nested_callback(ctx: NodeCallbackContext) -> Optional["Event"]: + """Emit observability event with nested graph hierarchy.""" + from ...events.event import Event + from ...events.event_actions import EventActions + + # Get agent path from callback metadata + agent_path = ctx.metadata.get("agent_path", []) + hierarchy = " → ".join(agent_path) if agent_path else ctx.node.name + + # Derive node type from class name + node_type_map = { + "GraphNode": "agent", + "DynamicNode": "dynamic", + "DynamicParallelGroup": "parallel", + "NestedGraphNode": "nested", + } + node_class_name = type(ctx.node).__name__ + node_type = node_type_map.get(node_class_name, "function") + + # Collect _debug_ keys relevant to current node + debug_prefix = f"_debug_{ctx.node.name}_" + debug_info = { + k: v for k, v in ctx.state.data.items() if k.startswith(debug_prefix) + } + + return Event( + author="observability", + content=genai.types.Content( + parts=[ + genai.types.Part(text=f"[{hierarchy}] → {ctx.node.name}"), + genai.types.Part(text=f"Iteration: {ctx.iteration}"), + ] + ), + actions=EventActions( + escalate=False, + state_delta={ + "observability_hierarchy": hierarchy, + "observability_level": len(agent_path), + "observability_node": ctx.node.name, + "observability_node_type": node_type, + "observability_debug": debug_info, + }, + ), + ) + + return nested_callback diff --git a/src/google/adk/agents/graph/evaluation_metrics.py b/src/google/adk/agents/graph/evaluation_metrics.py new file mode 100644 index 0000000000..ffa9b8c6d5 --- /dev/null +++ b/src/google/adk/agents/graph/evaluation_metrics.py @@ -0,0 +1,382 @@ +"""Custom evaluation metrics for GraphAgent workflows. + +These metrics enable evaluating graph execution paths, state transitions, +and workflow behavior in ADK's evaluation framework. + +Example usage: + ```python + from google.adk.evaluation import EvalMetric + from google.adk.agents.graph.evaluation_metrics import ( + graph_path_match, + state_contains_keys, + node_execution_count, + ) + + # In eval config: + metrics = [ + EvalMetric( + name="graph_path", + custom_function_path="google.adk.agents.graph.evaluation_metrics.graph_path_match", + ), + ] + ``` +""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from ...evaluation.eval_case import ConversationScenario +from ...evaluation.eval_case import Invocation +from ...evaluation.eval_metrics import EvalMetric +from ...evaluation.eval_metrics import EvalStatus +from ...evaluation.evaluator import EvaluationResult +from ...evaluation.evaluator import PerInvocationResult + + +def graph_path_match( + eval_metric: EvalMetric, + actual_invocations: List[Invocation], + expected_invocations: Optional[List[Invocation]], + conversation_scenario: Optional[ConversationScenario] = None, +) -> EvaluationResult: + """Evaluate if graph execution path matches expected path. + + Checks if the sequence of nodes executed matches the expected path. + Looks for 'graph_path' in session state. + + Args: + eval_metric: Metric configuration + actual_invocations: Actual agent invocations + expected_invocations: Expected invocations (unused, path comes from scenario) + conversation_scenario: Test scenario with expected path in metadata + + Returns: + EvaluationResult with scores based on path matching + + Expected format in conversation_scenario.metadata: + { + "expected_graph_path": ["node1", "node2", "node3"] + } + """ + results = [] + overall_score = 0.0 + overall_status = EvalStatus.PASSED + + # Get expected path from eval_metric custom fields (for testing) + # In production, would come from scenario or expected_invocations + expected_path = getattr(eval_metric, "expected_graph_path", None) + + for actual_inv in actual_invocations: + # Extract actual path from intermediate_data (production) + # or from eval_metric custom fields (for testing) + actual_path = getattr(eval_metric, "actual_graph_path", None) + + if actual_path is None and actual_inv.intermediate_data: + # Extract from InvocationEvents + from ...evaluation.eval_case import InvocationEvents + + if isinstance(actual_inv.intermediate_data, InvocationEvents): + # Parse graph metadata from intermediate events + # Get the LAST/latest metadata event (final graph state) + for event in reversed(actual_inv.intermediate_data.invocation_events): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text and "[GraphMetadata]" in part.text: + # Extract metadata from text + import ast + + try: + # Parse the dict from text like "[GraphMetadata] {'graph_path': [...]}" + metadata_str = part.text.split("[GraphMetadata]", 1)[ + 1 + ].strip() + metadata = ast.literal_eval(metadata_str) + actual_path = metadata.get("graph_path") + if actual_path: + break + except Exception: + continue + if actual_path: + break + + # Compute score + score = 0.0 + status = EvalStatus.NOT_EVALUATED + + if expected_path is None: + # No expected path specified + status = EvalStatus.NOT_EVALUATED + elif actual_path is None: + # No actual path found + status = EvalStatus.FAILED + score = 0.0 + elif actual_path == expected_path: + # Exact match + status = EvalStatus.PASSED + score = 1.0 + else: + # Partial match - score based on how many nodes match + matched = sum(1 for a, e in zip(actual_path, expected_path) if a == e) + max_len = max(len(actual_path), len(expected_path)) + score = matched / max_len if max_len > 0 else 0.0 + status = EvalStatus.FAILED if score < 0.5 else EvalStatus.PASSED + + results.append( + PerInvocationResult( + actual_invocation=actual_inv, + score=score, + eval_status=status, + ) + ) + + if status == EvalStatus.FAILED: + overall_status = EvalStatus.FAILED + + overall_score += score + + # Average score across invocations + if results: + overall_score /= len(results) + + # If all results are NOT_EVALUATED, set overall status to NOT_EVALUATED + if results and all( + r.eval_status == EvalStatus.NOT_EVALUATED for r in results + ): + overall_status = EvalStatus.NOT_EVALUATED + + return EvaluationResult( + overall_score=overall_score, + overall_eval_status=overall_status, + per_invocation_results=results, + ) + + +def state_contains_keys( + eval_metric: EvalMetric, + actual_invocations: List[Invocation], + expected_invocations: Optional[List[Invocation]], + conversation_scenario: Optional[ConversationScenario] = None, +) -> EvaluationResult: + """Evaluate if final state contains expected keys. + + Checks if session state contains all expected keys with correct values. + + Args: + eval_metric: Metric configuration + actual_invocations: Actual agent invocations + expected_invocations: Expected invocations (unused) + conversation_scenario: Test scenario with expected state in metadata + + Returns: + EvaluationResult with scores based on state matching + + Expected format in conversation_scenario.metadata: + { + "expected_state": {"key1": "value1", "key2": 42}, + "actual_state": {"key1": "value1", "key2": 42} # For testing + } + """ + results = [] + overall_score = 0.0 + overall_status = EvalStatus.PASSED + + # Get expected state from eval_metric custom fields (for testing) + expected_state = getattr(eval_metric, "expected_state", None) + + for actual_inv in actual_invocations: + # Extract actual state from eval_metric custom fields (for testing) + # or from intermediate_data (production) + actual_state = getattr(eval_metric, "actual_state", None) + + if actual_state is None and actual_inv.intermediate_data: + # Extract from InvocationEvents + from ...evaluation.eval_case import InvocationEvents + + if isinstance(actual_inv.intermediate_data, InvocationEvents): + # Parse graph_state from metadata events + # Get the LAST/latest metadata event (final state) + for event in reversed(actual_inv.intermediate_data.invocation_events): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text and "[GraphMetadata]" in part.text: + import ast + + try: + metadata_str = part.text.split("[GraphMetadata]", 1)[ + 1 + ].strip() + metadata = ast.literal_eval(metadata_str) + actual_state = metadata.get("graph_state") + if actual_state: + break + except Exception: + continue + if actual_state: + break + + # Compute score + score = 0.0 + status = EvalStatus.NOT_EVALUATED + + if expected_state is None: + status = EvalStatus.NOT_EVALUATED + elif actual_state is None: + status = EvalStatus.FAILED + score = 0.0 + else: + # Check each expected key + total_keys = len(expected_state) + matched_keys = 0 + + for key, expected_value in expected_state.items(): + if key in actual_state and actual_state[key] == expected_value: + matched_keys += 1 + + score = matched_keys / total_keys if total_keys > 0 else 0.0 + status = EvalStatus.PASSED if score >= 1.0 else EvalStatus.FAILED + + results.append( + PerInvocationResult( + actual_invocation=actual_inv, + score=score, + eval_status=status, + ) + ) + + if status == EvalStatus.FAILED: + overall_status = EvalStatus.FAILED + + overall_score += score + + # Average score across invocations + if results: + overall_score /= len(results) + + # If all results are NOT_EVALUATED, set overall status to NOT_EVALUATED + if results and all( + r.eval_status == EvalStatus.NOT_EVALUATED for r in results + ): + overall_status = EvalStatus.NOT_EVALUATED + + return EvaluationResult( + overall_score=overall_score, + overall_eval_status=overall_status, + per_invocation_results=results, + ) + + +def node_execution_count( + eval_metric: EvalMetric, + actual_invocations: List[Invocation], + expected_invocations: Optional[List[Invocation]], + conversation_scenario: Optional[ConversationScenario] = None, +) -> EvaluationResult: + """Evaluate if nodes executed expected number of times. + + Checks node_invocations tracking in session state. + + Args: + eval_metric: Metric configuration + actual_invocations: Actual agent invocations + expected_invocations: Expected invocations (unused) + conversation_scenario: Test scenario with expected counts in metadata + + Returns: + EvaluationResult with scores based on execution counts + + Expected format in conversation_scenario.metadata: + { + "expected_node_counts": {"node1": 1, "node2": 3}, + "actual_node_counts": {"node1": 1, "node2": 3} # For testing + } + """ + results = [] + overall_score = 0.0 + overall_status = EvalStatus.PASSED + + # Get expected counts from eval_metric custom fields (for testing) + expected_counts = getattr(eval_metric, "expected_node_counts", None) + + for actual_inv in actual_invocations: + # Extract actual counts from eval_metric custom fields (for testing) + # In production, would come from intermediate_data + actual_counts = getattr(eval_metric, "actual_node_counts", {}) + + if not actual_counts and actual_inv.intermediate_data: + # Extract from InvocationEvents + from ...evaluation.eval_case import InvocationEvents + + if isinstance(actual_inv.intermediate_data, InvocationEvents): + # Parse node_invocations from graph metadata events + for event in actual_inv.intermediate_data.invocation_events: + if event.content and event.content.parts: + for part in event.content.parts: + if part.text and "[GraphMetadata]" in part.text: + import ast + + try: + metadata_str = part.text.split("[GraphMetadata]", 1)[ + 1 + ].strip() + metadata = ast.literal_eval(metadata_str) + node_invocs = metadata.get("node_invocations", {}) + if node_invocs: + actual_counts = node_invocs + # Continue to get the latest counts + except Exception: + continue + + # Compute score + score = 0.0 + status = EvalStatus.NOT_EVALUATED + + if expected_counts is None: + status = EvalStatus.NOT_EVALUATED + elif not actual_counts: + status = EvalStatus.FAILED + score = 0.0 + else: + # Check each expected node count + total_nodes = len(expected_counts) + matched_nodes = 0 + + for node_name, expected_count in expected_counts.items(): + actual_count = actual_counts.get(node_name, 0) + if actual_count == expected_count: + matched_nodes += 1 + + score = matched_nodes / total_nodes if total_nodes > 0 else 0.0 + status = EvalStatus.PASSED if score >= 1.0 else EvalStatus.FAILED + + results.append( + PerInvocationResult( + actual_invocation=actual_inv, + score=score, + eval_status=status, + ) + ) + + if status == EvalStatus.FAILED: + overall_status = EvalStatus.FAILED + + overall_score += score + + # Average score across invocations + if results: + overall_score /= len(results) + + # If all results are NOT_EVALUATED, set overall status to NOT_EVALUATED + if results and all( + r.eval_status == EvalStatus.NOT_EVALUATED for r in results + ): + overall_status = EvalStatus.NOT_EVALUATED + + return EvaluationResult( + overall_score=overall_score, + overall_eval_status=overall_status, + per_invocation_results=results, + ) diff --git a/src/google/adk/agents/graph/graph_agent.py b/src/google/adk/agents/graph/graph_agent.py new file mode 100644 index 0000000000..3ebde3012f --- /dev/null +++ b/src/google/adk/agents/graph/graph_agent.py @@ -0,0 +1,1651 @@ +"""Graph-based workflow orchestration for ADK. + +GraphAgent is ADK's fourth workflow agent type (alongside Sequential, Loop, Parallel), +enabling directed graph-based orchestration with conditional routing and complex branching. + +GraphAgent enables workflow creation using directed graphs where: +- Nodes are agents or functions +- Edges define allowed transitions with optional conditions +- State flows through the graph with configurable reducers + +Key features: +- Directed graph workflows with conditional routing +- State management with custom reducers (OVERWRITE, APPEND, SUM, CUSTOM) +- Always-on observability: node lifecycle events for every execution +- DatabaseSessionService support for persistence +- Cyclic execution with max_iterations +- Event-based state persistence (ADK-native) +- ADK resumability integration (pause/resume long-running workflows) + +Inspired by adk-graph (Rust) and LangGraph patterns. +""" + +from __future__ import annotations + +import ast +import asyncio +import json +import logging +import time +from typing import Any +from typing import AsyncGenerator +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING + +from google.genai import types +from pydantic import ConfigDict +from pydantic import Field +from typing_extensions import override + +from ...events.event import Event +from ...events.event_actions import EventActions +from ...telemetry import graph_tracing +from ...telemetry.tracing import tracer +from ...utils.feature_decorator import experimental +from ..base_agent import BaseAgent +from ..invocation_context import InvocationContext +from ..llm_agent import LlmAgent +from .callbacks import EdgeCallback +from .callbacks import NodeCallback +from .graph_agent_state import GraphAgentState +from .graph_edge import EdgeCondition +from .graph_node import GraphNode +from .graph_state import GraphState +from .graph_state import StateReducer +from .graph_telemetry import GraphTelemetryMixin + +if TYPE_CHECKING: + from .graph_agent_config import TelemetryConfig + +logger = logging.getLogger("google_adk." + __name__) + +# Keys stored in session.state by GraphAgent's own state_delta events. +# These must be excluded when syncing session.state → state.data to avoid +# circular references (state.data["graph_data"] → state.data) and to keep +# domain data clean of graph-internal bookkeeping. +_GRAPH_INTERNAL_KEYS = frozenset({ + "graph_data", + "graph_checkpoint", + "graph_cancelled", + "graph_cancelled_at_node", + "graph_task_cancelled", + "graph_can_resume", + "graph_iterations", + "graph_path", + "graph_partial_output", + "graph_state", +}) + + +_SAFE_NAMES = frozenset({ + "state", + "data", + "True", + "False", + "None", +}) +_SAFE_METHODS = frozenset({ + "get", + "get_parsed", + "get_str", + "get_dict", +}) +_SAFE_BUILTINS = frozenset({ + "len", + "min", + "max", + "abs", + "bool", + "int", + "float", + "str", + "isinstance", + "type", +}) + + +def _validate_condition_ast(node: ast.AST) -> None: + """Walk AST and reject any unsafe node types. + + Only allows: comparisons, boolean ops, unary not, attribute access, + safe method calls (.get, .get_parsed, .get_str, .get_dict), + constants, and whitelisted names. + + Raises: + ValueError: If an unsafe AST node is encountered. + """ + if isinstance(node, ast.Expression): + _validate_condition_ast(node.body) + elif isinstance(node, ast.BoolOp): + for value in node.values: + _validate_condition_ast(value) + elif isinstance(node, ast.UnaryOp): + if not isinstance(node.op, ast.Not): + raise ValueError(f"Unsafe unary operator: {type(node.op).__name__}") + _validate_condition_ast(node.operand) + elif isinstance(node, ast.Compare): + _validate_condition_ast(node.left) + for comparator in node.comparators: + _validate_condition_ast(comparator) + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Attribute): + if node.func.attr not in _SAFE_METHODS: + raise ValueError(f"Unsafe method call: .{node.func.attr}()") + _validate_condition_ast(node.func.value) + elif isinstance(node.func, ast.Name) and node.func.id in _SAFE_BUILTINS: + pass # Safe builtin call + else: + raise ValueError(f"Unsafe call: {ast.dump(node.func)}") + for arg in node.args: + _validate_condition_ast(arg) + for kw in node.keywords: + _validate_condition_ast(kw.value) + elif isinstance(node, ast.Attribute): + # Block dunder attribute access to prevent sandbox escape + # (e.g., state.__class__.__init__.__globals__) + if node.attr.startswith("_"): + raise ValueError(f"Unsafe attribute access: '{node.attr}'") + _validate_condition_ast(node.value) + elif isinstance(node, ast.Subscript): + _validate_condition_ast(node.value) + _validate_condition_ast(node.slice) + elif isinstance(node, ast.Name): + if node.id not in _SAFE_NAMES and node.id not in _SAFE_BUILTINS: + raise ValueError(f"Unsafe name: '{node.id}'") + elif isinstance(node, ast.Constant): + pass # string, int, float, bool, None literals are safe + elif isinstance(node, (ast.List, ast.Tuple)): + for elt in node.elts: + _validate_condition_ast(elt) + else: + raise ValueError(f"Unsafe expression node: {type(node).__name__}") + + +def _parse_condition_string(condition_str: str) -> Callable[[GraphState], bool]: + """Parse YAML condition string to a safe callable function. + + Conditions are parsed via AST and validated against a whitelist of + safe operations before compilation. This prevents arbitrary code + execution while supporting common condition expressions. + + Allowed in conditions: + - Names: state, data, metadata, True, False, None + - Methods: .get(), .get_parsed(), .get_str(), .get_dict() + - Builtins: len, min, max, abs, bool, int, float, str, isinstance, type + - Operators: ==, !=, <, >, <=, >=, is, is not, in, not in + - Boolean: and, or, not + - Literals: strings, numbers, booleans, None + + Examples: + "data.get('approved') is True" + "data.get('count', 0) < 10" + "'CONTINUE' in data.get('status', '')" + + Args: + condition_str: Python expression string to evaluate safely + + Returns: + Callable that takes GraphState and returns bool + + Raises: + ValueError: If condition contains unsafe expressions + """ + # Parse and validate at definition time (fail fast) + tree = ast.parse(condition_str, mode="eval") + _validate_condition_ast(tree.body) + code = compile(tree, "", "eval") + + def condition_func(state: GraphState) -> bool: + import builtins as _builtins_mod + + safe_builtins = { + name: getattr(_builtins_mod, name) for name in _SAFE_BUILTINS + } + namespace = { + "state": state, + "data": state.data, + } + try: + result = eval(code, {"__builtins__": safe_builtins}, namespace) # noqa: S307 + return bool(result) + except Exception as e: + logger.error( + f"Condition evaluation failed: '{condition_str}' - {e}", + exc_info=True, + ) + return False + + return condition_func + + +# Sentinel constants for graph boundaries +START = "__start__" +END = "__end__" + + +@experimental +class GraphAgent(GraphTelemetryMixin, BaseAgent): # type: ignore[misc] + """Graph-based workflow agent for ADK. + + GraphAgent is the fourth workflow agent type in ADK (alongside SequentialAgent, + LoopAgent, and ParallelAgent), enabling directed graph-based orchestration with + conditional routing and state management. + + Workflow agents control execution flow through deterministic logic rather than LLM + reasoning, providing predictable, reliable, and structured agent orchestration. + + Features: + - Directed graph workflow with nodes (agents/functions) and edges + - Conditional routing based on state predicates + - Cyclic execution support (loops, iterative refinement, ReAct pattern) + - Always-on observability: node lifecycle events emitted for every execution + - DatabaseSessionService support for persistence + - Full ADK event system integration + - ADK resumability (pause/resume via agent state) + + Example: + >>> from google.adk.agents.graph import GraphAgent, GraphNode + >>> from google.adk.agents import LlmAgent + >>> from google.adk.runners import Runner + >>> + >>> graph = GraphAgent(name="workflow") + >>> graph.add_node(GraphNode(name="analyze", agent=LlmAgent(...))) + >>> graph.add_node(GraphNode(name="process", agent=LlmAgent(...))) + >>> graph.add_edge("analyze", "process") + >>> graph.set_start("analyze") + >>> graph.set_end("process") + >>> + >>> runner = Runner(app_name="app", agent=graph) + >>> async for event in runner.run_async(...): + ... print(event) + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + nodes: Dict[str, GraphNode] = Field(default_factory=dict) + start_node: Optional[str] = None + end_nodes: List[str] = Field(default_factory=list) + max_iterations: int = 50 # Prevent infinite loops + checkpointing: bool = False + telemetry_config: Optional[Any] = Field( + default=None, + description="Configuration for OpenTelemetry instrumentation", + ) + before_node_callback: Optional[NodeCallback] = Field( + default=None, + description="Callback invoked before each node execution", + ) + after_node_callback: Optional[NodeCallback] = Field( + default=None, + description="Callback invoked after each node execution", + ) + on_edge_condition_callback: Optional[EdgeCallback] = Field( + default=None, + description="Callback invoked when evaluating edge conditions", + ) + + def __init__( + self, + name: str, + description: str = "", + max_iterations: int = 50, + checkpointing: bool = False, + telemetry_config: Optional[Any] = None, + before_node_callback: Optional[NodeCallback] = None, + after_node_callback: Optional[NodeCallback] = None, + on_edge_condition_callback: Optional[EdgeCallback] = None, + **kwargs: Any, + ) -> None: + """Initialize GraphAgent. + + Args: + name: Agent name + description: Agent description + max_iterations: Max iterations to prevent infinite loops + checkpointing: Enable state checkpointing after each node + telemetry_config: Configuration for OpenTelemetry instrumentation + before_node_callback: Callback invoked before each node execution + after_node_callback: Callback invoked after each node execution + on_edge_condition_callback: Callback invoked when evaluating edge conditions + """ + super().__init__(name=name, description=description, **kwargs) + self.nodes = {} + self.start_node = None + self.end_nodes = [] + self.max_iterations = max_iterations + self.telemetry_config = telemetry_config + self.before_node_callback = before_node_callback + self.after_node_callback = after_node_callback + self.on_edge_condition_callback = on_edge_condition_callback + self.checkpointing = checkpointing + + def add_node( + self, + node: GraphNode | str, + agent: Optional[BaseAgent] = None, + function: Optional[Callable[..., Any]] = None, + **kwargs: Any, + ) -> "GraphAgent": + """Add a node to the graph. + + Supports two usage patterns: + 1. Pass a GraphNode directly + 2. Pass node name and agent/function (convenience method) + + Args: + node: GraphNode to add, or string name (convenience) + agent: Optional agent for convenience pattern + function: Optional function for convenience pattern + **kwargs: Additional GraphNode parameters (output_mapper, state_reducer, etc.) + + Returns: + Self for chaining + + Examples: + >>> # Pattern 1: GraphNode (explicit) + >>> graph.add_node(GraphNode(name="validate", agent=validator)) + + >>> # Pattern 2: Convenience (name + agent) + >>> graph.add_node("validate", agent=validator) + + >>> # Pattern 3: Convenience with kwargs + >>> graph.add_node("validate", agent=validator, state_reducer=StateReducer.OVERWRITE) + """ + if isinstance(node, GraphNode): + # Pattern 1: Direct GraphNode + if agent is not None or function is not None or kwargs: + raise ValueError( + "When passing a GraphNode, do not specify agent, function, or" + " kwargs" + ) + if node.name in self.nodes: + raise ValueError(f"Node '{node.name}' already exists in graph") + self._validate_node_configuration(node) + self.nodes[node.name] = node + elif isinstance(node, str): + # Pattern 2: Convenience (name + agent/function) + if agent is None and function is None: + raise ValueError( + "When passing node name as string, must specify agent or function" + ) + if agent is not None and function is not None: + raise ValueError("Cannot specify both agent and function") + + if node in self.nodes: + raise ValueError(f"Node '{node}' already exists in graph") + + graph_node = GraphNode( + name=node, agent=agent, function=function, **kwargs + ) + self._validate_node_configuration(graph_node) + self.nodes[node] = graph_node + else: + raise TypeError( + f"node must be GraphNode or str, got {type(node).__name__}" + ) + + # Register statically-known agents from the node in sub_agents + graph_node = self.nodes[node.name if isinstance(node, GraphNode) else node] + self._register_node_agents(graph_node) + + return self + + def _get_node_agent(self, node: "GraphNode") -> Optional[BaseAgent]: + """Extract the primary agent from a graph node, including pattern nodes. + + Args: + node: GraphNode (or pattern subclass) to inspect + + Returns: + The agent associated with this node, or None + """ + if node.agent is not None: + return node.agent + return None + + def _register_node_agents(self, node: "GraphNode") -> None: + """Register statically-known agents from a node into sub_agents. + + Handles regular agent nodes and pattern nodes (NestedGraphNode, + DynamicNode). Skips DynamicParallelGroup (runtime-only agents). + + Args: + node: GraphNode to extract agents from + + Raises: + ValueError: If agent name collides with graph name, duplicates + an existing sub_agent name, or agent already has a parent. + """ + agent = self._get_node_agent(node) + if agent is None: + return + + # Identity check: same instance already registered (e.g., shared across nodes) + if any(agent is sa for sa in self.sub_agents): + return + + # Reject agent name that shadows the graph itself — find_agent() would + # return the graph instead of the sub-agent, causing silent bugs. + if agent.name == self.name: + raise ValueError( + f"Node agent name '{agent.name}' collides with GraphAgent name." + " Rename the agent to avoid find_agent() ambiguity." + ) + + # Validate unique name among existing sub_agents + for sa in self.sub_agents: + if sa.name == agent.name: + raise ValueError( + f"Duplicate sub_agent name '{agent.name}'. Another node already" + " registered an agent with this name." + ) + + # Single-parent constraint (matches BaseAgent behavior) + existing_parent = getattr(agent, "parent_agent", None) + if existing_parent is not None: + raise ValueError( + f"Agent '{agent.name}' already has a parent agent" + f" '{existing_parent.name}', cannot add to '{self.name}'" + ) + + agent.parent_agent = self + self.sub_agents.append(agent) + + @override + def find_sub_agent(self, name: str) -> Optional[BaseAgent]: + """Find agent by name, searching sub_agents then graph nodes as fallback. + + Overrides BaseAgent.find_sub_agent to also search graph node agents + that may not be in sub_agents (e.g., agents added to nodes before + registration). NestedGraphNode agents are recursively searched. + + Note: DynamicNode runtime-selected agents (chosen by agent_selector + at execution time) are NOT discoverable via find_sub_agent because + they don't exist until the graph runs. Only DynamicNode.fallback_agent + (if set) is registered and searchable. + + Args: + name: The agent name to find + + Returns: + The matching agent, or None + """ + # Standard sub_agents search first + for sub_agent in self.sub_agents: + if result := sub_agent.find_agent(name): + return result + # Fallback: search graph nodes for agents not in sub_agents + for node in self.nodes.values(): + agent = self._get_node_agent(node) + if agent is not None: + if agent.name == name: + return agent + if result := agent.find_agent(name): + return result + return None + + def _validate_node_configuration(self, node: "GraphNode") -> None: + """Validate node configuration before adding to graph. + + Emits warnings for potential misconfiguration issues. + + Args: + node: GraphNode to validate + """ + # Warn if output_schema present but was auto-defaulted + if isinstance(node.agent, LlmAgent): + if node.agent.output_schema and node.agent.output_key: + # Check if it looks like it was auto-defaulted (matches agent name) + if node.agent.output_key == node.agent.name: + logger.warning( + f"Node '{node.name}': Using auto-defaulted" + f" output_key='{node.agent.output_key}'. To silence this warning," + " explicitly set output_key on the LlmAgent." + ) + + def set_start(self, node_name: str) -> "GraphAgent": + """Set the starting node. + + Args: + node_name: Name of the start node + + Returns: + Self for chaining + + Raises: + ValueError: If node not found in graph + """ + if node_name not in self.nodes: + raise ValueError(f"Node {node_name} not found in graph") + self.start_node = node_name + return self + + def set_end(self, node_name: str) -> "GraphAgent": + """Mark a node as an end node. + + Args: + node_name: Name of the end node + + Returns: + Self for chaining + + Raises: + ValueError: If node not found in graph + """ + if node_name not in self.nodes: + raise ValueError(f"Node {node_name} not found in graph") + if node_name not in self.end_nodes: + self.end_nodes.append(node_name) + return self + + def add_edge( + self, + source_node: str, + target_node: str | EdgeCondition, + condition: Optional[Callable[[GraphState], bool]] = None, + priority: Optional[int] = None, + weight: Optional[float] = None, + ) -> "GraphAgent": + """Add an edge from source node to target node. + + Supports two usage patterns: + 1. Pass EdgeCondition as target_node (explicit) + 2. Pass target node name with optional params (convenience) + + Advanced routing features: + - Priority-based routing (higher priority evaluated first) + - Weighted random selection (probabilistic routing) + - Fallback edges (priority=0 always matches) + + Args: + source_node: Source node name + target_node: Target node name OR EdgeCondition object + condition: Optional condition (ignored if target_node is EdgeCondition) + priority: Optional priority (ignored if target_node is EdgeCondition) + weight: Optional weight (ignored if target_node is EdgeCondition) + + Returns: + Self for chaining + + Raises: + ValueError: If nodes not found in graph + TypeError: If target_node is not str or EdgeCondition + + Examples: + >>> # Pattern 1: EdgeCondition (explicit) + >>> graph.add_edge("validate", EdgeCondition( + ... target_node="process", + ... condition=lambda s: s.data.get("valid"), + ... priority=10 + ... )) + + >>> # Pattern 2: Convenience - simple edge + >>> graph.add_edge("validate", "process") + + >>> # Pattern 2: Convenience - conditional edge + >>> graph.add_edge("validate", "process", condition=lambda s: s.data.get("valid")) + + >>> # Pattern 2: Convenience - priority-based routing + >>> graph.add_edge("check", "critical", condition=lambda s: s.data["score"] > 0.9, priority=10) + >>> graph.add_edge("check", "normal", priority=0) # Fallback + + >>> # Pattern 2: Convenience - weighted random routing + >>> graph.add_edge("start", "server_a", condition=lambda s: True, priority=1, weight=0.5) + >>> graph.add_edge("start", "server_b", condition=lambda s: True, priority=1, weight=0.3) + """ + if source_node not in self.nodes: + raise ValueError(f"Source node {source_node} not found") + + if isinstance(target_node, EdgeCondition): + # Pattern 1: EdgeCondition + if condition is not None or priority is not None or weight is not None: + raise ValueError( + "When passing EdgeCondition, do not specify condition, priority, or" + " weight" + ) + if target_node.target_node not in self.nodes: + raise ValueError(f"Target node {target_node.target_node} not found") + + # Check for duplicate edge + if ( + hasattr(self.nodes[source_node], "edges") + and self.nodes[source_node].edges is not None + ): + for existing_edge in self.nodes[source_node].edges: + if existing_edge.target_node == target_node.target_node: + raise ValueError( + f"Edge from '{source_node}' to '{target_node.target_node}'" + " already exists. Cannot add duplicate edge." + ) + + self.nodes[source_node].edges.append(target_node) + self.nodes[source_node]._sorted_edges_cache = None + + elif isinstance(target_node, str): + # Pattern 2: Convenience + if target_node not in self.nodes: + raise ValueError(f"Target node {target_node} not found") + + # Check for duplicate edge + if ( + hasattr(self.nodes[source_node], "edges") + and self.nodes[source_node].edges is not None + ): + for existing_edge in self.nodes[source_node].edges: + if existing_edge.target_node == target_node: + raise ValueError( + f"Edge from '{source_node}' to '{target_node}' already exists." + " Cannot add duplicate edge." + ) + + # If priority or weight specified, create EdgeCondition + if priority is not None or weight is not None: + edge_condition = EdgeCondition( + target_node=target_node, + condition=condition, + priority=priority if priority is not None else 1, + weight=weight if weight is not None else 1.0, + ) + self.nodes[source_node].edges.append(edge_condition) + self.nodes[source_node]._sorted_edges_cache = None + else: + # Simple edge (no priority/weight) + self.nodes[source_node].add_edge(target_node, condition) + else: + raise TypeError( + "target_node must be str or EdgeCondition, got" + f" {type(target_node).__name__}" + ) + + return self + + # Export methods moved to graph_export.py + # rewind_to_node moved to graph_rewind.py + + # Telemetry methods inherited from GraphTelemetryMixin + + async def _execute_node( + self, + node: GraphNode, + state: GraphState, + ctx: InvocationContext, + effective_config: Optional[TelemetryConfig] = None, + output_holder: Optional[Dict[str, Any]] = None, + iteration: int = 0, + ) -> AsyncGenerator[Event, None]: + """Execute a single node and yield events. + + Output is stored in output_holder["output"] for the caller. + + Args: + node: GraphNode to execute + state: Current graph state + ctx: Invocation context + effective_config: Effective telemetry config (merged parent + own) + output_holder: Mutable dict to store node output + iteration: Current iteration number + + Yields: + Events from node execution + """ + # Determine node type + node_type = "function" if node.function else "agent" + start_time = time.time() + + # Create telemetry span for node execution + with graph_tracing.tracer.start_as_current_span( + f"graph_node {node.name}" + ) as span: + # Add attributes with additional_attributes support + attrs = self._get_telemetry_attributes( + { + graph_tracing.GRAPH_NODE_NAME: node.name, + graph_tracing.GRAPH_NODE_TYPE: node_type, + graph_tracing.GRAPH_NODE_ITERATION: iteration, + graph_tracing.GRAPH_AGENT_NAME: self.name, + }, + effective_config=effective_config, + ) + for key, value in attrs.items(): + span.set_attribute(key, value) + + try: + # Map state to node input with telemetry + mapper_start = time.time() + node_input = node.input_mapper(state) + mapper_latency_ms = (time.time() - mapper_start) * 1000 + + # Record input mapper telemetry (check sampling) + if self._should_sample(effective_config=effective_config): + is_default_mapper = ( + node.input_mapper.__name__ == "_default_input_mapper" + ) + graph_tracing.record_mapper( + node_name=node.name, + mapper_type="input", + agent_name=self.name, + latency_ms=mapper_latency_ms, + is_default=is_default_mapper, + ) + + # Execute node (agent or function) + output = "" + if node.agent: + # Create new context with updated user_content for this node + node_content = types.Content( + role="user", parts=[types.Part(text=node_input)] + ) + node_ctx = ctx.model_copy(update={"user_content": node_content}) + + # Execute ADK agent with updated context + # (BaseAgent will create invoke_agent span automatically) + async for event in node.agent.run_async(node_ctx): + # Extract output from final response + if event.content and event.content.parts: + output = event.content.parts[0].text or "" + yield event + # ADK resumability: pause when long-running tool detected. + # Function nodes don't set "pause" (they run synchronously, + # no tool calls) so output_holder.get("pause") is always + # falsy for them — safe to check unconditionally in caller. + if ctx.should_pause_invocation(event): + if output_holder is not None: + output_holder["output"] = output + output_holder["pause"] = True + return + elif node.function: + # Execute custom function (CRITICAL: no automatic span) + if asyncio.iscoroutinefunction(node.function): + output = await node.function(state, ctx) + else: + output = node.function(state, ctx) + else: # pragma: no cover + # Defensive: This should never happen due to GraphNode validation + raise ValueError(f"Node {node.name} has no agent or function") + + # Store output for caller retrieval + if output_holder is not None: + output_holder["output"] = output + + # Record success metrics (check sampling) + latency_ms = (time.time() - start_time) * 1000 + span.set_attribute("graph.node.success", True) + if self._should_sample(effective_config=effective_config): + graph_tracing.record_node_execution( + node_name=node.name, + node_type=node_type, + agent_name=self.name, + latency_ms=latency_ms, + success=True, + ) + + except Exception as e: + # Record failure metrics (check sampling) + latency_ms = (time.time() - start_time) * 1000 + span.set_attribute("graph.node.success", False) + span.set_attribute("graph.node.error", str(e)) + if self._should_sample(effective_config=effective_config): + graph_tracing.record_node_execution( + node_name=node.name, + node_type=node_type, + agent_name=self.name, + latency_ms=latency_ms, + success=False, + ) + raise + + def _get_next_node_with_telemetry( + self, + current_node: GraphNode, + state: GraphState, + effective_config: Optional[TelemetryConfig] = None, + ) -> Optional[str]: + """Get next node with edge evaluation telemetry. + + Args: + current_node: Current graph node + state: Current graph state + effective_config: Effective telemetry config (merged parent + own) + + Returns: + Name of next node, or None if no edge matches + """ + # Track all condition results for detailed telemetry + condition_results = [] + + # Evaluate each edge with telemetry + for edge in current_node.edges: + start_time = time.time() + + # Create span for edge evaluation + with graph_tracing.tracer.start_as_current_span( + f"edge_condition {edge.target_node}" + ) as span: + # Add attributes with additional_attributes support + attrs = self._get_telemetry_attributes( + { + graph_tracing.GRAPH_EDGE_SOURCE: current_node.name, + graph_tracing.GRAPH_EDGE_TARGET: edge.target_node, + graph_tracing.GRAPH_EDGE_PRIORITY: edge.priority, + graph_tracing.GRAPH_AGENT_NAME: self.name, + }, + effective_config=effective_config, + ) + for key, value in attrs.items(): + span.set_attribute(key, value) + + try: + # Evaluate condition + result = edge.should_route(state) + span.set_attribute( + graph_tracing.GRAPH_EDGE_CONDITION_RESULT, str(result) + ) + + # Track condition result details for debugging + condition_results.append({ + "target_node": edge.target_node, + "condition_matched": result, + "condition_name": getattr(edge.condition, "__name__", ""), + "priority": edge.priority, + }) + + # Record metrics (check sampling) + if self._should_sample(effective_config=effective_config): + latency_ms = (time.time() - start_time) * 1000 + graph_tracing.record_edge_evaluation( + source_node=current_node.name, + target_node=edge.target_node, + agent_name=self.name, + condition_result=result, + latency_ms=latency_ms, + priority=edge.priority, + ) + + except Exception as e: + span.set_attribute("graph.edge.error", str(e)) + span.set_attribute(graph_tracing.GRAPH_EDGE_CONDITION_RESULT, "false") + raise + + # Add detailed condition results to GraphState for debugging + # This helps identify routing issues by showing ALL edge evaluations + if condition_results: + state.data["_debug_edge_evaluations"] = { + "source_node": current_node.name, + "evaluations": condition_results, + "timestamp": time.time(), + } + + # Use original get_next_node for routing logic + selected_node = current_node.get_next_node(state) + + # Log final node selection decision with all context + if selected_node: + # Find which edge was selected (if any) + selected_edge_info = next( + ( + r + for r in condition_results + if r["target_node"] == selected_node and r["condition_matched"] + ), + None, + ) + + # Add node selection to debug info + state.data.setdefault("_debug_node_selections", []).append({ + "from_node": current_node.name, + "to_node": selected_node, + "selected_edge": selected_edge_info, + "num_edges_evaluated": len(condition_results), + "timestamp": time.time(), + }) + + # Log structured selection event + graph_tracing.logger.debug( + f"Node selected: {current_node.name} -> {selected_node}", + extra={ + "source_node": current_node.name, + "selected_node": selected_node, + "condition_name": ( + selected_edge_info["condition_name"] + if selected_edge_info + else None + ), + "priority": ( + selected_edge_info["priority"] if selected_edge_info else None + ), + "edges_evaluated": len(condition_results), + "agent_name": self.name, + }, + ) + + return selected_node + + def _get_resume_state( + self, agent_state: GraphAgentState + ) -> Tuple[Optional[str], int, bool]: + """Get resume point from loaded agent state. + + Mirrors SequentialAgent._get_start_index() pattern. + + Args: + agent_state: Loaded agent state (may have current_node from prior run) + + Returns: + Tuple of (start_node_name, start_iteration, is_resuming) + """ + if agent_state.current_node and agent_state.current_node in self.nodes: + return agent_state.current_node, agent_state.iteration, True + if agent_state.current_node and agent_state.current_node not in self.nodes: + logger.warning( + "Saved node '%s' no longer exists in graph. Restarting from '%s'.", + agent_state.current_node, + self.start_node, + ) + return self.start_node, 0, False + + async def _execute_callback( + self, + callback: Callable[..., Any], + callback_type: str, + current_node: GraphNode, + current_node_name: str, + state: GraphState, + iteration: int, + ctx: InvocationContext, + agent_state: GraphAgentState, + effective_config: Optional["TelemetryConfig"] = None, + output: str = "", + ) -> Optional[Event]: + """Execute a node callback (before_node or after_node) with telemetry. + + Args: + callback: The callback function to execute + callback_type: "before_node" or "after_node" + current_node: The current GraphNode + current_node_name: Name of the current node + state: Current graph state + iteration: Current iteration number + ctx: Invocation context + agent_state: Execution tracking state + effective_config: Effective telemetry config + output: Node output (only for after_node callbacks) + + Returns: + Event from callback, or None + """ + from .callbacks import NodeCallbackContext + + metadata: Dict[str, Any] = { + "agent_path": list(agent_state.agent_path), + "path": list(agent_state.path), + } + if callback_type == "after_node": + metadata["output"] = output + + callback_ctx = NodeCallbackContext( + node=current_node, + state=state, + iteration=iteration, + invocation_context=ctx, + metadata=metadata, + ) + + callback_start_time = time.time() + with graph_tracing.tracer.start_as_current_span( + f"graph_callback {callback_type}" + ) as cb_span: + attrs = self._get_telemetry_attributes( + { + graph_tracing.GRAPH_CALLBACK_TYPE: callback_type, + graph_tracing.GRAPH_AGENT_NAME: self.name, + graph_tracing.GRAPH_NODE_NAME: current_node_name, + }, + effective_config=effective_config, + ) + for key, value in attrs.items(): + cb_span.set_attribute(key, value) + + try: + event = await callback(callback_ctx) + cb_span.set_attribute("graph.callback.success", True) + if self._should_sample(effective_config=effective_config): + callback_latency_ms = (time.time() - callback_start_time) * 1000 + graph_tracing.record_callback_execution( + callback_type=callback_type, + agent_name=self.name, + latency_ms=callback_latency_ms, + success=True, + ) + return event + + except Exception as e: + cb_span.set_attribute("graph.callback.success", False) + cb_span.set_attribute("graph.callback.error", str(e)) + if self._should_sample(effective_config=effective_config): + callback_latency_ms = (time.time() - callback_start_time) * 1000 + graph_tracing.record_callback_execution( + callback_type=callback_type, + agent_name=self.name, + latency_ms=callback_latency_ms, + success=False, + ) + logger.error( + "%s_callback failed for node '%s': %s", + callback_type, + current_node_name, + e, + exc_info=True, + ) + return None + + def _sync_state_and_reduce( + self, + current_node: GraphNode, + current_node_name: str, + state: GraphState, + ctx: InvocationContext, + output: str, + effective_config: Optional["TelemetryConfig"] = None, + ) -> GraphState: + """Sync session state into GraphState and apply output_mapper + reducer. + + Args: + current_node: The current GraphNode + current_node_name: Name of the current node + state: Current graph state (mutated in-place for session sync) + ctx: Invocation context + output: Node output string + effective_config: Effective telemetry config + + Returns: + Updated GraphState after sync and reduction + """ + # Sync session state into GraphState.data + for _sk, _sv in ctx.session.state.items(): + if not _sk.startswith("_") and _sk not in _GRAPH_INTERNAL_KEYS: + state.data[_sk] = _sv + + # Apply output mapper with reducer + if output: + had_previous_value = current_node.name in state.data + reducer_start = time.time() + + prev_state = state + state = current_node.output_mapper(output, state) + if state is None: + state = prev_state + + reducer_latency_ms = (time.time() - reducer_start) * 1000 + if self._should_sample(effective_config=effective_config): + graph_tracing.record_state_reducer( + node_name=current_node.name, + reducer_type=current_node.reducer.name, + state_key=current_node.name, + agent_name=self.name, + latency_ms=reducer_latency_ms, + had_previous_value=had_previous_value, + ) + is_default_mapper = ( + current_node.output_mapper.__name__ == "_default_output_mapper" + ) + graph_tracing.record_mapper( + node_name=current_node.name, + mapper_type="output", + agent_name=self.name, + latency_ms=reducer_latency_ms, + is_default=is_default_mapper, + ) + + return state + + def _build_cancellation_events( + self, + ctx: InvocationContext, + agent_state: GraphAgentState, + current_node_name: str, + state: GraphState, + *, + state_key: str = "graph_cancelled", + message: str, + iteration: Optional[int] = None, + partial_output: Optional[str] = None, + ) -> List[Event]: + """Build agent-state + cancellation events for graph abort scenarios. + + Consolidates the repeated pattern of saving agent state then yielding + a cancellation event with appropriate state_delta keys. + + Args: + ctx: Invocation context + agent_state: Execution tracking state (saved before cancel) + current_node_name: Node where cancellation occurred + state: Current graph state + state_key: Key for the cancellation flag (e.g. "graph_cancelled", + "graph_task_cancelled") + message: Human-readable cancellation message + iteration: Current iteration (included in state_delta when set) + partial_output: Partial node output (included when set) + + Returns: + List of two events: [agent_state_event, cancellation_event] + """ + ctx.set_agent_state(self.name, agent_state=agent_state) + state_event = self._create_agent_state_event(ctx) + + state_delta: Dict[str, Any] = { + state_key: True, + "graph_cancelled_at_node": current_node_name, + "graph_data": state.data, + "graph_can_resume": True, + } + if iteration is not None: + state_delta["graph_iteration"] = iteration + if partial_output is not None: + state_delta["graph_partial_output"] = partial_output + + cancel_event = Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"\u26a0\ufe0f {message}")] + ), + actions=EventActions( + escalate=False, + state_delta=state_delta, + ), + ) + return [state_event, cancel_event] + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Core graph execution logic. + + Executes nodes in graph order, following conditional edges, + supporting loops and cyclic execution. + + Args: + ctx: Invocation context + + Yields: + Events from graph execution + + Raises: + ValueError: If start node not set or graph structure invalid + """ + if not self.start_node: + raise ValueError("Start node not set. Call set_start() first.") + + # Get effective telemetry config for nested graph inheritance + effective_config = self._get_effective_telemetry_config(ctx) + + with tracer.start_as_current_span( + f"graph_agent_execution {self.name}" + ) as span: + span.set_attribute("graph_agent.name", self.name) + span.set_attribute("graph_agent.start_node", self.start_node) + span.set_attribute("graph_agent.max_iterations", self.max_iterations) + try: + # Load execution tracking state (BaseAgentState pattern) + agent_state = ( + self._load_agent_state(ctx, GraphAgentState) or GraphAgentState() + ) + + # Store telemetry config for nested graph inheritance + if effective_config: + agent_state.telemetry_config_dict = effective_config.model_dump() + + # Initialize domain data from session state or user input. + # Exclude graph-internal keys to prevent circular references + # (state.data["graph_data"] → state.data) and keep domain data clean. + domain_data = { + k: v + for k, v in ctx.session.state.items() + if not k.startswith("_") and k not in _GRAPH_INTERNAL_KEYS + } + if domain_data: + state = GraphState(data=domain_data) + else: + # Extract text from Content object + user_text = "" + if ( + hasattr(ctx, "user_content") + and ctx.user_content + and ctx.user_content.parts + ): + user_text = ( + ctx.user_content.parts[0].text + if ctx.user_content.parts[0].text + else "" + ) + state = GraphState(data={"input": user_text}) + + # ADK resumability: resume from saved node or start fresh. + # + # Design note: SequentialAgent ONLY emits state events when + # ctx.is_resumable is True, because its state events serve only + # resumability. GraphAgent's state events serve multiple consumers + # (rewind, interrupts, telemetry) that are orthogonal to + # resumability. Therefore: + # - Per-iteration state events: always emitted (multi-consumer) + # - Resume skip: first iteration skipped when resuming (already + # persisted before pause, avoids duplicate) + # - end_of_agent: guarded by is_resumable (purely a resumability + # lifecycle signal, has no other consumers) + # - Interrupt/cancellation state saves: always emitted (they + # serve interrupt functionality, not just resumability) + current_node_name, iteration, resuming = self._get_resume_state( + agent_state + ) + pause_invocation = False + + while current_node_name and iteration < self.max_iterations: + iteration += 1 + current_node = self.nodes[current_node_name] + + # Track execution path in agent_state + agent_state.path.append(current_node_name) + agent_state.iteration = iteration + agent_state.current_node = current_node_name + agent_state.node_invocations.setdefault(current_node_name, []).append( + ctx.invocation_id + ) + + # ADK resumability: reset sub-agent states on cycle revisit + # (mirrors LoopAgent pattern at loop_agent.py:114) + # O(1) lookup via node_invocations instead of O(N) path.count() + if ( + len(agent_state.node_invocations.get(current_node_name, [])) > 1 + and current_node.agent + ): + ctx.reset_sub_agent_states(current_node.agent.name) + + # Track agent path for nested graph support + if self.name not in agent_state.agent_path: + agent_state.agent_path.append(self.name) + + # Persist execution tracking via agent_state event. + # These events are consumed by rewind, interrupts, and telemetry + # (not just resumability), so they're always emitted. + # Skip only on first iteration when resuming (already persisted). + if not resuming: + ctx.set_agent_state(self.name, agent_state=agent_state) + yield self._create_agent_state_event(ctx) + else: + resuming = False # Only skip first iteration after resume + + # Invoke before_node_callback (custom observability) + if self.before_node_callback: + event = await self._execute_callback( + self.before_node_callback, + "before_node", + current_node, + current_node_name, + state, + iteration, + ctx, + agent_state, + effective_config, + ) + if event: + yield event + + # Execute node with output tracking + output_holder: Dict[str, Any] = {"output": ""} + try: + async for event in self._execute_node( + current_node, + state, + ctx, + effective_config, + output_holder=output_holder, + iteration=iteration, + ): + yield event + except asyncio.CancelledError: + # Task cancelled externally (e.g., timeout, user abort) + logger.info( + f"GraphAgent task cancelled during node '{current_node_name}'" + f" for session {ctx.session.id}" + ) + for _ce in self._build_cancellation_events( + ctx, + agent_state, + current_node_name, + state, + state_key="graph_task_cancelled", + message=f"Task cancelled during node '{current_node_name}'", + partial_output=output_holder["output"], + ): + yield _ce + raise + + # ADK resumability: check if node execution was paused + if output_holder.get("pause"): + pause_invocation = True + return + + # Sync session state + apply output_mapper/reducer + output = output_holder["output"] + state = self._sync_state_and_reduce( + current_node, + current_node_name, + state, + ctx, + output, + effective_config, + ) + + # Emit output_mapper changes as state_delta so domain data + # flows through ADK's event pipeline to session.state. + # This enables downstream LlmAgent nodes to read + # output_mapper results via dynamic instructions. + if output: + delta = {} + for _k, _v in state.data.items(): + if ( + not _k.startswith("_") + and _k not in _GRAPH_INTERNAL_KEYS + and ctx.session.state.get(_k) != _v + ): + delta[_k] = _v + if delta: + yield Event( + author=self.name, + actions=EventActions(state_delta=delta), + ) + + # Invoke after_node_callback (custom observability) + if self.after_node_callback: + event = await self._execute_callback( + self.after_node_callback, + "after_node", + current_node, + current_node_name, + state, + iteration, + ctx, + agent_state, + effective_config, + output=output, + ) + if event: + yield event + + # Emit graph metadata event for evaluation framework + # This will be captured in Invocation.intermediate_data by EvaluationGenerator + # Set partial=True so is_final_response() returns False (making it an intermediate event) + graph_metadata = { + "graph_node": current_node_name, + "graph_iteration": iteration, + "graph_path": list(agent_state.path), + "node_invocations": { + name: len(invocs) + for name, invocs in agent_state.node_invocations.items() + }, + "graph_state": dict(state.data), + } + yield Event( + author=f"{self.name}#metadata", + content=types.Content( + parts=[types.Part(text=f"[GraphMetadata] {graph_metadata}")] + ), + partial=True, # Mark as intermediate event + ) + + # Checkpointing - yield event with state_delta to persist checkpoint + if self.checkpointing: + ctx.set_agent_state(self.name, agent_state=agent_state) + yield self._create_agent_state_event(ctx) + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"Checkpoint: {current_node_name}")] + ), + actions=EventActions( + state_delta={ + "graph_data": state.data, + "graph_checkpoint": { + "node": current_node_name, + "iteration": iteration, + }, + } + ), + ) + + # Inject transient execution data for edge conditions + state.data["_graph_iteration"] = agent_state.iteration + state.data["_graph_path"] = list(agent_state.path) + state.data["_conditions"] = dict(agent_state.conditions) + + # Get next node via conditional routing + next_node_name = self._get_next_node_with_telemetry( + current_node, state, effective_config=effective_config + ) + + # Clean up transient keys + for _tk in ("_graph_iteration", "_graph_path", "_conditions"): + state.data.pop(_tk, None) + if next_node_name is None: + # No more edges - check if we're at an end node + if current_node_name in self.end_nodes: + break + else: + # Not at an end node and no edges - error + raise ValueError( + f"Node {current_node_name} has no outgoing edges and is not" + " an end node" + ) + + current_node_name = next_node_name + + # Record iteration metrics (check sampling) + if self._should_sample(effective_config=effective_config): + graph_tracing.record_graph_iteration( + agent_name=self.name, + iteration=iteration, + path_length=len(agent_state.path), + ) + + # ADK resumability: skip final response + end_of_agent when paused + if not pause_invocation: + # Final response - yield event with graph metadata + # Include last node's output ONLY if: + # 1. explicit final_output is set, OR + # 2. last node was a function (doesn't yield events, so we need to show output) + # Don't include output for agent nodes (they already yielded their output) + final_output = state.data.get("final_output", "") + + # If no explicit final_output, check if last node was a function + if not final_output and current_node_name: + last_node = self.nodes.get(current_node_name) + if last_node and last_node.function: + # Function node - include its output + final_output = state.data.get(current_node_name, "") + + response_text = f"{final_output}" + + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=response_text)]), + actions=EventActions( + state_delta={ + "graph_data": state.data, + "graph_iterations": iteration, + "graph_path": list(agent_state.path), + } + ), + ) + # end_of_agent is guarded by is_resumable because it is purely a + # resumability lifecycle signal (tells the runner "this agent is + # done, don't re-run it on resume"). Unlike per-iteration state + # events which serve rewind/interrupts/telemetry, end_of_agent + # has no other consumers. + if ctx.is_resumable: + ctx.set_agent_state(self.name, end_of_agent=True) + yield self._create_agent_state_event(ctx) + + finally: + span.set_attribute("graph_agent.completed", True) + + @override + @classmethod + def _parse_config( + cls, + config: Any, # GraphAgentConfig + config_abs_path: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """Parse GraphAgentConfig and return kwargs for GraphAgent constructor. + + Args: + config: GraphAgentConfig instance + config_abs_path: Absolute path to config file + kwargs: Base kwargs from BaseAgent + + Returns: + Updated kwargs with graph-specific configuration + """ + from .graph_agent_config import GraphAgentConfig + + if not isinstance(config, GraphAgentConfig): + return kwargs + + # Basic graph config + if hasattr(config, "start_node") and config.start_node: + kwargs["start_node"] = config.start_node + + if hasattr(config, "max_iterations") and config.max_iterations: + kwargs["max_iterations"] = config.max_iterations + + if hasattr(config, "checkpointing"): + kwargs["checkpointing"] = config.checkpointing + + # Callbacks + from ..config_agent_utils import resolve_code_reference + + if ( + hasattr(config, "before_node_callback_ref") + and config.before_node_callback_ref + ): + kwargs["before_node_callback"] = resolve_code_reference( + config.before_node_callback_ref + ) + + if ( + hasattr(config, "after_node_callback_ref") + and config.after_node_callback_ref + ): + kwargs["after_node_callback"] = resolve_code_reference( + config.after_node_callback_ref + ) + + if ( + hasattr(config, "on_edge_condition_callback_ref") + and config.on_edge_condition_callback_ref + ): + kwargs["on_edge_condition_callback"] = resolve_code_reference( + config.on_edge_condition_callback_ref + ) + + return kwargs + + @override + @classmethod + def from_config( + cls, + config: Any, # GraphAgentConfig + config_abs_path: str, + ) -> "GraphAgent": + """Creates a GraphAgent from a YAML config. + + This method performs post-construction setup to add nodes, edges, + end nodes, and parallel groups from the config. + + Args: + config: GraphAgentConfig instance + config_abs_path: Absolute path to config file + + Returns: + Configured GraphAgent instance + """ + from ..config_agent_utils import resolve_agent_reference + from ..config_agent_utils import resolve_code_reference + from .graph_agent_config import GraphAgentConfig + + # Create base agent instance + graph_instance = super().from_config(config, config_abs_path) + + # Type assertion: we know this is a GraphAgent because cls is GraphAgent + assert isinstance( + graph_instance, cls + ), "Expected GraphAgent instance from super().from_config()" + graph: GraphAgent = graph_instance # type: ignore[assignment] + + if not isinstance(config, GraphAgentConfig): + return graph + + # Add nodes + if hasattr(config, "nodes") and config.nodes: + for node_config in config.nodes: + # Resolve sub-agents for this node + sub_agents = [] + if node_config.sub_agents: + for agent_ref in node_config.sub_agents: + agent = resolve_agent_reference(agent_ref, config_abs_path) + sub_agents.append(agent) + + # Resolve function ref + function = None + if node_config.function_ref: + function = resolve_code_reference(node_config.function_ref) + + # Create GraphNode + node = GraphNode( + name=node_config.name, + agent=sub_agents[0] if sub_agents else None, + function=function, + ) + graph.add_node(node) + + # Add edges + if hasattr(config, "edges") and config.edges: + from .graph_edge import EdgeCondition + + for edge_config in config.edges: + condition = None + if edge_config.condition: + # Parse string condition to callable + condition = _parse_condition_string(edge_config.condition) + + # Create EdgeCondition with priority and weight support + edge = EdgeCondition( + target_node=edge_config.target_node, + condition=condition, + priority=edge_config.priority, + weight=edge_config.weight, + ) + + # Add edge directly to the node's edges list + if edge_config.source_node in graph.nodes: + graph.nodes[edge_config.source_node].edges.append(edge) + graph.nodes[edge_config.source_node]._sorted_edges_cache = None + else: + raise ValueError( + f"Source node {edge_config.source_node} not found in graph" + ) + + # Set start node + if hasattr(config, "start_node") and config.start_node: + graph.set_start(config.start_node) + + # Set end nodes + if hasattr(config, "end_nodes") and config.end_nodes: + for end_node in config.end_nodes: + graph.set_end(end_node) + + return graph diff --git a/src/google/adk/agents/graph/graph_agent_config.py b/src/google/adk/agents/graph/graph_agent_config.py new file mode 100644 index 0000000000..3d44d51cf6 --- /dev/null +++ b/src/google/adk/agents/graph/graph_agent_config.py @@ -0,0 +1,307 @@ +"""Config definition for GraphAgent.""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional + +from pydantic import BaseModel # type: ignore[attr-defined] +from pydantic import ConfigDict +from pydantic import Field + +from ...utils.feature_decorator import experimental +from ..base_agent_config import BaseAgentConfig # type: ignore[attr-defined] +from ..common_configs import AgentRefConfig +from ..common_configs import CodeConfig + + +@experimental +class GraphNodeConfig(BaseModel): # type: ignore[misc] + """Configuration for a single node in the graph. + + A node can contain either an agent reference or a function reference, + plus optional mappers and reducers for state management. + """ + + model_config = ConfigDict(extra="forbid") + + name: str = Field(description="Node name") + + # Node can reference an agent (sub_agents) OR a function + sub_agents: Optional[List[AgentRefConfig]] = Field( + default=None, + description="Sub-agents for this node", + ) + + function_ref: Optional[str] = Field( + default=None, + description="Reference to a function (e.g., 'module.function_name')", + ) + + input_mapper_ref: Optional[str] = Field( + default=None, + description="Reference to custom input mapper function", + ) + + output_mapper_ref: Optional[str] = Field( + default=None, + description="Reference to custom output mapper function", + ) + + reducer: str = Field( + default="overwrite", + description="State reducer strategy: overwrite|append|sum|custom", + ) + + custom_reducer_ref: Optional[str] = Field( + default=None, + description="Reference to custom reducer function (if reducer=custom)", + ) + + +@experimental +class GraphEdgeConfig(BaseModel): # type: ignore[misc] + """Configuration for an edge between nodes. + + Edges can have optional conditions for conditional routing. + """ + + model_config = ConfigDict(extra="forbid") + + source_node: str = Field(description="Source node name") + + target_node: str = Field(description="Target node name") + + condition: Optional[str] = Field( + default=None, + description=( + "AST-safe condition expression evaluated against graph state." + " Allowed names: state, data, metadata, True, False, None." + " Allowed methods: .get(), .get_parsed(), .get_str()," + " .get_dict(). Supports comparisons, boolean ops, 'in'." + " Example: \"data.get('approved') is True\"" + ), + ) + + priority: int = Field( + default=1, + description="Edge priority for routing (higher = evaluated first)", + ) + + weight: float = Field( + default=1.0, + description="Edge weight for weighted random routing", + ) + + +@experimental +class InterruptConfigYaml(BaseModel): # type: ignore[misc] + """Configuration for interrupt handling.""" + + model_config = ConfigDict(extra="forbid") + + mode: Optional[Literal["before", "after", "both"]] = Field( + default=None, + description="Interrupt mode (None = disabled, before|after|both)", + ) + + interrupt_service: Optional[CodeConfig] = Field( + default=None, + description="Interrupt service configuration (CodeConfig)", + ) + + +@experimental +class ParallelGroupConfig(BaseModel): # type: ignore[misc] + """Configuration for parallel node execution.""" + + model_config = ConfigDict(extra="forbid") + + nodes: List[str] = Field( + description="List of node names to execute in parallel" + ) + + join_strategy: str = Field( + default="all", + description="Join strategy: all|any|n", + ) + + error_policy: str = Field( + default="fail_fast", + description="Error policy: fail_fast|continue|collect", + ) + + wait_n: int = Field( + default=1, + description="Number of nodes to wait for (when join_strategy=n)", + ) + + +@experimental +class TelemetryConfig(BaseModel): # type: ignore[misc] + """Configuration for GraphAgent telemetry. + + Controls OpenTelemetry instrumentation for graph workflow execution. + """ + + model_config = ConfigDict(extra="forbid") + + enabled: bool = Field( + default=True, + description="Enable/disable all telemetry collection", + ) + + trace_nodes: bool = Field( + default=True, + description="Enable tracing for node executions", + ) + + trace_edges: bool = Field( + default=True, + description="Enable tracing for edge condition evaluations", + ) + + trace_iterations: bool = Field( + default=True, + description="Enable metrics for graph iterations", + ) + + trace_parallel_groups: bool = Field( + default=True, + description="Enable tracing for parallel group executions", + ) + + trace_callbacks: bool = Field( + default=True, + description="Enable tracing for callback executions", + ) + + trace_interrupts: bool = Field( + default=True, + description="Enable tracing for interrupt checks", + ) + + sampling_rate: float = Field( + default=1.0, + ge=0.0, + le=1.0, + description="Sampling rate for telemetry (0.0-1.0, 1.0 = 100%)", + ) + + additional_attributes: Optional[Dict[str, Any]] = Field( + default=None, + description="Additional custom attributes to add to all telemetry", + ) + + +@experimental +class GraphAgentConfig(BaseAgentConfig): # type: ignore[misc] + """The config for the YAML schema of a GraphAgent. + + This config supports defining graph structure, nodes, edges, and + advanced features like interrupts and parallel execution. + + Example YAML: + ```yaml + agent_class: GraphAgent + name: my_graph + description: My graph workflow + start_node: start + end_nodes: + - end + max_iterations: 10 + checkpointing: true + nodes: + - name: start + sub_agents: + - agent1 + - name: middle + sub_agents: + - agent2 + - name: end + sub_agents: + - agent3 + edges: + - source_node: start + target_node: middle + - source_node: middle + target_node: end + ``` + """ + + model_config = ConfigDict(extra="forbid") + + agent_class: str = Field( + default="GraphAgent", + description=( + "The value is used to uniquely identify the GraphAgent class." + ), + ) + + start_node: str = Field(description="Name of the starting node") + + end_nodes: List[str] = Field( + default_factory=list, + description="List of end node names", + ) + + max_iterations: int = Field( + default=20, + description="Maximum iterations for cyclic graphs", + ) + + checkpointing: bool = Field( + default=False, + description="Enable automatic checkpointing", + ) + + checkpoint_service: Optional[CodeConfig] = Field( + default=None, + description="Checkpoint service configuration (CodeConfig)", + ) + + # Graph structure + nodes: List[GraphNodeConfig] = Field( + default_factory=list, + description="List of node configurations", + ) + + edges: List[GraphEdgeConfig] = Field( + default_factory=list, + description="List of edge configurations", + ) + + # Advanced features + interrupt_config: Optional[InterruptConfigYaml] = Field( + default=None, + description="Interrupt configuration", + ) + + telemetry_config: Optional[TelemetryConfig] = Field( + default=None, + description="Telemetry configuration for OpenTelemetry instrumentation", + ) + + parallel_groups: List[ParallelGroupConfig] = Field( + default_factory=list, + description="List of parallel execution group configurations", + ) + + # Callbacks (following ADK CodeConfig pattern) + before_node_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Callbacks executed before each node", + ) + + after_node_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Callbacks executed after each node", + ) + + on_edge_condition_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Callbacks executed when evaluating edge conditions", + ) diff --git a/src/google/adk/agents/graph/graph_agent_state.py b/src/google/adk/agents/graph/graph_agent_state.py new file mode 100644 index 0000000000..94bbce963f --- /dev/null +++ b/src/google/adk/agents/graph/graph_agent_state.py @@ -0,0 +1,46 @@ +"""Execution tracking state for GraphAgent. + +Follows ADK's BaseAgentState pattern: persisted via +ctx.agent_states / Event.actions.agent_state. + +Domain data (node outputs) remains in GraphState via state_delta. +""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from pydantic import Field + +from ...utils.feature_decorator import experimental +from ..base_agent import BaseAgentState + + +@experimental +class GraphAgentState(BaseAgentState): # type: ignore[misc] + """Execution tracking state for GraphAgent. + + Serialized via model_dump(mode='json'), restored via model_validate(). + """ + + current_node: str = "" + prev_node: str = "" + iteration: int = 0 + execution_start: float = 0.0 + + path: List[str] = Field(default_factory=list) + node_invocations: Dict[str, List[str]] = Field(default_factory=dict) + conditions: Dict[str, Any] = Field(default_factory=dict) + rerun_guidance: str = "" + + interrupt_history: List[Dict[str, Any]] = Field(default_factory=list) + interrupt_todos: List[Dict[str, Any]] = Field(default_factory=list) + last_interrupt_decision: Optional[Dict[str, Any]] = None + + telemetry_config_dict: Optional[Dict[str, Any]] = None + + agent_path: List[str] = Field(default_factory=list) + executed_parallel_groups: List[str] = Field(default_factory=list) diff --git a/src/google/adk/agents/graph/graph_edge.py b/src/google/adk/agents/graph/graph_edge.py new file mode 100644 index 0000000000..79517ba1e9 --- /dev/null +++ b/src/google/adk/agents/graph/graph_edge.py @@ -0,0 +1,97 @@ +"""Conditional edges for graph routing.""" + +from __future__ import annotations + +from typing import Callable +from typing import Optional + +from .graph_state import GraphState + + +class EdgeCondition: + """Conditional edge that routes based on state. + + Edges connect nodes in the graph and can have optional conditions, + priorities, and weights for advanced routing strategies. + + Example: + ```python + # Unconditional edge (always taken) + edge = EdgeCondition(target_node="next_node") + + # Conditional edge (taken if score > 0.8) + edge = EdgeCondition( + target_node="high_score_handler", + condition=lambda state: state.data.get("score", 0) > 0.8 + ) + + # Priority-based routing (higher priority evaluated first) + edge1 = EdgeCondition( + target_node="critical_path", + condition=lambda state: state.data.get("is_critical", False), + priority=10 # High priority + ) + edge2 = EdgeCondition( + target_node="normal_path", + priority=5 # Lower priority + ) + + # Weighted random selection (among matching edges) + edge1 = EdgeCondition(target_node="path_a", weight=0.7) # 70% chance + edge2 = EdgeCondition(target_node="path_b", weight=0.3) # 30% chance + + # Fallback edge (priority=0 always matches if no other edge matched) + edge_fallback = EdgeCondition( + target_node="default_handler", + priority=0 # Fallback priority + ) + ``` + """ + + def __init__( + self, + target_node: str, + condition: Optional[Callable[[GraphState], bool]] = None, + priority: int = 1, + weight: float = 1.0, + ): + """Initialize edge condition. + + Args: + target_node: Name of the target node + condition: Function that returns True if this edge should be taken. + If None, edge is always taken (unconditional). + priority: Priority for edge evaluation (higher = evaluated first). + Priority 0 is special: treated as fallback (always matches if reached). + Default is 1 (normal priority). + weight: Weight for weighted random selection among matching edges. + Only used when multiple edges match. Higher weight = higher probability. + Default is 1.0. + """ + self.target_node = target_node + self.has_condition = condition is not None + self.condition = condition or (lambda _: True) + self.priority = priority + self.weight = weight + + def should_route(self, state: GraphState) -> bool: + """Check if this edge should be taken given the current state. + + Args: + state: Current graph state + + Returns: + True if edge condition is satisfied, False otherwise + """ + # Priority 0 is fallback - always matches + if self.priority == 0: + return True + return self.condition(state) + + def __repr__(self) -> str: + """String representation for debugging.""" + return ( + f"EdgeCondition(target={self.target_node}, " + f"priority={self.priority}, weight={self.weight}, " + f"has_condition={self.has_condition})" + ) diff --git a/src/google/adk/agents/graph/graph_events.py b/src/google/adk/agents/graph/graph_events.py new file mode 100644 index 0000000000..1a632b0427 --- /dev/null +++ b/src/google/adk/agents/graph/graph_events.py @@ -0,0 +1,90 @@ +"""Typed event streams for GraphAgent execution.""" + +from __future__ import annotations + +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from pydantic import BaseModel +from pydantic import Field + +from .graph_state import GraphState + + +class GraphEventType(str, Enum): + """Types of graph execution events.""" + + NODE_START = "node_start" # Node execution starting + NODE_END = "node_end" # Node execution completed + EDGE_TRAVERSAL = "edge_traversal" # Edge being traversed + CHECKPOINT = "checkpoint" # Checkpoint created + INTERRUPT = "interrupt" # Human-in-the-loop interrupt + ERROR = "error" # Execution error + COMPLETE = "complete" # Graph execution complete + + +class GraphEvent(BaseModel): # type: ignore[misc] + """Typed event for graph execution streaming. + + GraphEvent provides structured events for monitoring and debugging + graph execution. These events can be streamed to track execution progress, + state changes, and control flow. + + Example: + ```python + async for event in graph.run_async_with_events(ctx): + if event.event_type == GraphEventType.NODE_START: + print(f"Starting node: {event.node_name}") + elif event.event_type == GraphEventType.CHECKPOINT: + print(f"Checkpoint at iteration {event.iteration}") + ``` + """ + + event_type: GraphEventType + timestamp: str = Field(description="ISO timestamp of event") + + # Node information + node_name: Optional[str] = None + iteration: Optional[int] = None + + # State information + graph_state: Optional[Dict[str, Any]] = None + state_delta: Optional[Dict[str, Any]] = None + + # Edge information + source_node: Optional[str] = None + target_node: Optional[str] = None + edge_condition_result: Optional[bool] = None + + # Interrupt information + interrupt_mode: Optional[str] = None + interrupt_message: Optional[str] = None + + # Error information + error_message: Optional[str] = None + error_type: Optional[str] = None + + # Checkpoint information + checkpoint_id: Optional[str] = None + + # Additional metadata + metadata: Dict[str, Any] = Field(default_factory=dict) + + +class GraphStreamMode(str, Enum): + """Stream modes for graph execution. + + Different stream modes provide different levels of detail: + - VALUES: Only final node outputs + - UPDATES: State updates after each node + - MESSAGES: All agent messages and events + - DEBUG: All events including edges, checkpoints, interrupts + """ + + VALUES = "values" # Stream final values only + UPDATES = "updates" # Stream state updates + MESSAGES = "messages" # Stream all messages + DEBUG = "debug" # Stream all debug events diff --git a/src/google/adk/agents/graph/graph_export.py b/src/google/adk/agents/graph/graph_export.py new file mode 100644 index 0000000000..5eef97ca22 --- /dev/null +++ b/src/google/adk/agents/graph/graph_export.py @@ -0,0 +1,210 @@ +"""Graph export functions for visualization. + +Standalone functions that export graph structure and execution data +in D3-compatible JSON format. Separated from GraphAgent for +single-responsibility. +""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .graph_agent import GraphAgent + + +def export_graph_structure(graph: GraphAgent) -> Dict[str, Any]: + """Export graph structure in D3-compatible JSON format. + + Args: + graph: GraphAgent instance to export + + Returns: + Dictionary with nodes, links, and metadata suitable for + D3.js or other graph visualization tools. + """ + nodes = [] + links = [] + + for node_name, node in graph.nodes.items(): + node_data = { + "id": node_name, + "type": "agent" if node.agent else "function", + "name": node.name, + } + nodes.append(node_data) + + for node_name, node in graph.nodes.items(): + for edge in node.edges: + link_data = { + "source": node_name, + "target": edge.target_node, + "conditional": edge.has_condition, + } + links.append(link_data) + + metadata = { + "start_node": graph.start_node, + "end_nodes": graph.end_nodes, + "checkpointing": graph.checkpointing, + "max_iterations": graph.max_iterations, + } + + return { + "nodes": nodes, + "links": links, + "metadata": metadata, + "directed": True, + } + + +def export_graph_with_execution( + graph: GraphAgent, + execution_history: Optional[List[Dict[str, Any]]] = None, + state_history: Optional[List[Dict[str, Any]]] = None, + interrupt_markers: Optional[List[Dict[str, Any]]] = None, +) -> Dict[str, Any]: + """Export graph with execution history, state evolution, and interrupts. + + Enhanced D3-compatible format including runtime information for + interactive visualization and replay. + + Args: + graph: GraphAgent instance to export + execution_history: List of executed nodes with timestamps + state_history: List of state snapshots after each node + interrupt_markers: List of interrupt events + + Returns: + Enhanced dictionary with execution data overlaid on graph structure. + """ + base_structure = export_graph_structure(graph) + nodes = base_structure["nodes"] + links = base_structure["links"] + + if execution_history: + node_executions: Dict[str, List[Dict[str, Any]]] = {} + for exec_record in execution_history: + node_name = exec_record["node"] + if node_name not in node_executions: + node_executions[node_name] = [] + node_executions[node_name].append(exec_record) + + for node in nodes: + node_id = node["id"] + if node_id in node_executions: + node["executions"] = node_executions[node_id] + node["execution_count"] = len(node_executions[node_id]) + statuses = [ + e.get("status", "unknown") for e in node_executions[node_id] + ] + node["status_summary"] = { + "success": statuses.count("success"), + "error": statuses.count("error"), + "interrupted": statuses.count("interrupted"), + } + else: + node["executions"] = [] + node["execution_count"] = 0 + + if execution_history: + link_traversals: Dict[Tuple[str, str], int] = {} + for i in range(len(execution_history) - 1): + source = execution_history[i]["node"] + target = execution_history[i + 1]["node"] + link_key = (source, target) + link_traversals[link_key] = link_traversals.get(link_key, 0) + 1 + + for link in links: + link_key = (link["source"], link["target"]) + link["traversals"] = link_traversals.get(link_key, 0) + + if interrupt_markers: + node_interrupts: Dict[str, List[Dict[str, Any]]] = {} + for interrupt in interrupt_markers: + node_name = interrupt.get("node") + if node_name: + if node_name not in node_interrupts: + node_interrupts[node_name] = [] + node_interrupts[node_name].append(interrupt) + + for node in nodes: + node_id = node["id"] + if node_id in node_interrupts: + node["interrupt_markers"] = node_interrupts[node_id] + node["interrupt_count"] = len(node_interrupts[node_id]) + + return { + "nodes": nodes, + "links": links, + "metadata": base_structure["metadata"], + "execution_history": execution_history or [], + "state_history": state_history or [], + "interrupt_markers": interrupt_markers or [], + "directed": True, + "enhanced": True, + } + + +def export_execution_timeline( + execution_history: List[Dict[str, Any]], + state_history: Optional[List[Dict[str, Any]]] = None, +) -> Dict[str, Any]: + """Export execution timeline for temporal visualization. + + Creates a timeline view of graph execution suitable for + Gantt charts, timeline visualizations, or replay UIs. + + Args: + execution_history: List of executed nodes with timestamps + state_history: Optional state snapshots at each step + + Returns: + Dictionary with timeline, total_duration, total_steps, iterations. + """ + if not execution_history: + return { + "timeline": [], + "total_duration": 0, + "total_steps": 0, + "iterations": 0, + } + + timeline = [] + for i, exec_record in enumerate(execution_history): + step_data = { + "step": i, + "node": exec_record["node"], + "timestamp": exec_record.get("timestamp", 0), + "iteration": exec_record.get("iteration", 1), + "status": exec_record.get("status", "unknown"), + } + + if i < len(execution_history) - 1: + next_timestamp = execution_history[i + 1].get("timestamp", 0) + step_data["duration"] = next_timestamp - step_data["timestamp"] + else: + step_data["duration"] = 0 + + if state_history and i < len(state_history): + step_data["state"] = state_history[i].get("state", {}) + + timeline.append(step_data) + + total_duration = 0 + if len(timeline) > 1: + total_duration = timeline[-1]["timestamp"] - timeline[0]["timestamp"] + + max_iteration = max((step["iteration"] for step in timeline), default=0) + + return { + "timeline": timeline, + "total_duration": total_duration, + "total_steps": len(timeline), + "iterations": max_iteration, + } diff --git a/src/google/adk/agents/graph/graph_node.py b/src/google/adk/agents/graph/graph_node.py new file mode 100644 index 0000000000..996b5b1f5e --- /dev/null +++ b/src/google/adk/agents/graph/graph_node.py @@ -0,0 +1,231 @@ +"""Graph node wrapper for agents and functions.""" + +from __future__ import annotations + +from copy import deepcopy +import logging +import random +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import TYPE_CHECKING + +from ..base_agent import BaseAgent +from .graph_edge import EdgeCondition +from .graph_state import GraphState +from .graph_state import StateReducer + +if TYPE_CHECKING: + from ..llm_agent import LlmAgent + +logger = logging.getLogger(__name__) + + +class GraphNode: + """A node in the graph that can wrap ANY ADK agent or custom function. + + Supports all BaseAgent types: + - LLMAgent: Single agent with tools + - SequentialAgent: Chain of agents executed in sequence + - ParallelAgent: Multiple agents executed concurrently + - LoopAgent: Iterative agent execution + - GraphAgent: Recursive graph workflows (graphs within graphs!) + - Custom agents: Any subclass of BaseAgent + + This enables powerful composition patterns like: + - Nested workflows (GraphAgent containing other GraphAgents) + - Validation pipelines (SequentialAgent as a node) + - Parallel analysis (ParallelAgent as a node) + """ + + def __init__( + self, + name: str, + agent: Optional[BaseAgent] = None, + function: Optional[Callable[..., Any]] = None, + input_mapper: Optional[Callable[[GraphState], str]] = None, + output_mapper: Optional[Callable[[str, GraphState], GraphState]] = None, + reducer: StateReducer = StateReducer.OVERWRITE, + custom_reducer: Optional[Callable[..., Any]] = None, + ): + """Initialize graph node. + + Args: + name: Node name + agent: ANY ADK BaseAgent subclass (LLMAgent, SequentialAgent, ParallelAgent, + LoopAgent, GraphAgent, or custom agents) + function: Custom async function to execute (alternative to agent) + input_mapper: Maps GraphState to agent input + output_mapper: Maps agent output back to GraphState + reducer: Strategy for merging output into state + custom_reducer: Custom reduction function + """ + if agent is None and function is None: + raise ValueError("Either agent or function must be provided") + + self.name = name + self.agent = agent + self.function = function + self.input_mapper = input_mapper or self._default_input_mapper + self.output_mapper = output_mapper or self._default_output_mapper + self.reducer = reducer + self.custom_reducer = custom_reducer + self.edges: List[EdgeCondition] = [] + self._sorted_edges_cache: Optional[List[tuple[int, EdgeCondition]]] = None + + # Auto-default output_key for LlmAgent with output_schema + from ..llm_agent import LlmAgent + + if isinstance(agent, LlmAgent): + if agent.output_schema and not agent.output_key: + logger.info( + f"Node '{name}': Auto-defaulting output_key to agent name" + f" '{agent.name}' (LlmAgent has output_schema but no explicit" + " output_key)" + ) + # Use model_copy to avoid mutating the caller's agent instance + self.agent = agent.model_copy(update={"output_key": agent.name}) + + def add_edge( + self, target_node: str, condition: Optional[Callable[..., Any]] = None + ) -> None: + """Add an edge to another node. + + Args: + target_node: Name of the target node + condition: Optional condition function for conditional routing + """ + self.edges.append(EdgeCondition(target_node, condition)) + self._sorted_edges_cache = None # Invalidate cache + + def _default_input_mapper(self, state: GraphState) -> str: + """Default input mapper: extract 'input' or 'messages' from state. + + Args: + state: Current graph state + + Returns: + Input string for the node + """ + return str(state.data.get("input", state.data.get("messages", ""))) + + def _default_output_mapper( + self, output: str, state: GraphState + ) -> GraphState: + """Default output mapper: store output in state with node name as key. + + Applies the configured reducer strategy to merge the output into state. + + Args: + output: Node execution output + state: Current graph state + + Returns: + New graph state with output merged + """ + new_state = GraphState(data=deepcopy(state.data)) + + if self.reducer == StateReducer.OVERWRITE: + new_state.data[self.name] = output + elif self.reducer == StateReducer.APPEND: + if self.name not in new_state.data: + new_state.data[self.name] = [] + new_state.data[self.name].append(output) + elif self.reducer == StateReducer.SUM: + existing = new_state.data.get(self.name) + if existing is None: + # Infer zero-value from output type: "" for str, 0 for int/float, [] for list + existing = type(output)() + try: + new_state.data[self.name] = existing + output + except TypeError: + raise TypeError( + f"StateReducer.SUM: cannot add {type(existing).__name__} + " + f"{type(output).__name__} for node '{self.name}'. " + "Ensure consistent types or use a custom output_mapper." + ) + elif self.reducer == StateReducer.CUSTOM and self.custom_reducer: + new_state.data[self.name] = self.custom_reducer( + new_state.data.get(self.name), output + ) + + return new_state + + def get_next_node(self, state: GraphState) -> Optional[str]: + """Determine next node based on conditional edges with priority and weight. + + Enhanced routing logic: + 1. Sort edges by priority (highest first), preserving insertion order + 2. Evaluate conditions in priority order + 3. Within same priority, return first match (insertion order) + 4. If edges have different weights, use weighted random selection + 5. Priority 0 edges are fallbacks (always match if no higher priority matched) + + Args: + state: Current graph state + + Returns: + Name of next node, or None if no edge matches + """ + if not self.edges: + return None + + # Use cached sorted edges (sort once, invalidated by add_edge) + if self._sorted_edges_cache is None: + indexed_edges = [(i, e) for i, e in enumerate(self.edges)] + self._sorted_edges_cache = sorted( + indexed_edges, key=lambda x: (-x[1].priority, x[0]) + ) + sorted_edges = self._sorted_edges_cache + + # Group edges by priority + current_priority = sorted_edges[0][1].priority + matching_edges: list[tuple[int, EdgeCondition]] = [] + + for idx, edge in sorted_edges: + # If we've moved to a lower priority and already have matches, stop + if edge.priority < current_priority and matching_edges: + break + + # Update current priority + current_priority = edge.priority + + # Check if edge matches + if edge.should_route(state): + matching_edges.append((idx, edge)) + + # No matching edges + if not matching_edges: + return None + + # Single matching edge - return it + if len(matching_edges) == 1: + return matching_edges[0][1].target_node + + # Multiple matching edges at same priority + # Check if weights are all the same (default behavior: first match) + weights = [e.weight for _, e in matching_edges] + all_same_weight = len(set(weights)) == 1 + + if all_same_weight: + # All weights equal - return first match in insertion order + return matching_edges[0][1].target_node + + # Different weights - use weighted random selection + total_weight = sum(weights) + if total_weight == 0: + # All weights are 0 — guard against ZeroDivisionError in weighted + # random selection below. Fall back to first match in insertion order. + return matching_edges[0][1].target_node + + # Weighted random choice + rand_value = random.random() * total_weight + cumulative = 0.0 + for idx, edge in matching_edges: + cumulative += edge.weight + if rand_value <= cumulative: + return edge.target_node + + # Fallback (shouldn't reach here, but safety) + return matching_edges[-1][1].target_node diff --git a/src/google/adk/agents/graph/graph_rewind.py b/src/google/adk/agents/graph/graph_rewind.py new file mode 100644 index 0000000000..ba74baa786 --- /dev/null +++ b/src/google/adk/agents/graph/graph_rewind.py @@ -0,0 +1,92 @@ +"""Graph rewind functionality. + +Standalone function for rewinding graph execution to a specific node. +Integrates with ADK's Runner.rewind_async for temporal navigation. +""" + +from __future__ import annotations + +from typing import Any +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .graph_agent import GraphAgent + + +async def rewind_to_node( + graph: GraphAgent, + session_service: Any, + app_name: str, + user_id: str, + session_id: str, + node_name: str, + invocation_index: int = -1, +) -> None: + """Rewind graph execution to before a specific node execution. + + Enables temporal navigation within graph workflows: + - Retry failed nodes with different inputs + - Explore alternative execution paths + - Debug workflow issues + - Select specific iteration in loops + + Args: + graph: GraphAgent instance + session_service: Session service instance + app_name: Application name + user_id: User ID + session_id: Session ID + node_name: Node to rewind to + invocation_index: Which invocation (-1 for most recent) + + Raises: + ValueError: If node has not been executed yet + ValueError: If invocation_index is out of range + """ + session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + if not session: + raise ValueError(f"Session not found: {session_id}") + + # Extract node_invocations from latest agent_state event + all_node_invocations: dict[str, list[str]] = {} + for event in reversed(session.events): + if ( + event.actions + and event.actions.agent_state + and "node_invocations" in (event.actions.agent_state or {}) + ): + all_node_invocations = event.actions.agent_state["node_invocations"] + break + + node_invocations = all_node_invocations.get(node_name, []) + if not node_invocations: + raise ValueError( + f"Node '{node_name}' has not been executed yet." + f" Available nodes: {list(all_node_invocations.keys())}" + ) + + if invocation_index < -len(node_invocations) or ( + invocation_index >= len(node_invocations) + ): + raise ValueError( + f"Invocation index {invocation_index} out of range. " + f"Node '{node_name}' has" + f" {len(node_invocations)} invocations." + ) + + invocation_id = node_invocations[invocation_index] + + from ...runners import Runner + + runner = Runner( + app_name=app_name, + agent=graph, + session_service=session_service, + ) + await runner.rewind_async( + user_id=user_id, + session_id=session_id, + rewind_before_invocation_id=invocation_id, + ) diff --git a/src/google/adk/agents/graph/graph_state.py b/src/google/adk/agents/graph/graph_state.py new file mode 100644 index 0000000000..62bb873644 --- /dev/null +++ b/src/google/adk/agents/graph/graph_state.py @@ -0,0 +1,118 @@ +"""Graph state management with typed state and reducers.""" + +from __future__ import annotations + +from enum import Enum +from typing import Any +from typing import Dict +from typing import Optional +from typing import Type +from typing import TypeVar + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field + +from .state_utils import parse_state_value +from .state_utils import PydanticJSONEncoder +from .state_utils import state_value_as_dict +from .state_utils import state_value_as_str + +# Re-export for backward compat +__all__ = ["GraphState", "StateReducer", "PydanticJSONEncoder"] + +T = TypeVar("T", bound=BaseModel) + + +class StateReducer(str, Enum): + """State reduction strategies for merging node outputs. + + Defines how node outputs are merged into the graph state: + - OVERWRITE: Replace existing value with new value + - APPEND: Append new value to list (creates list if needed) + - SUM: Accumulate values using + operator (works for strings, numbers, lists) + - CUSTOM: Use custom reducer function + """ + + OVERWRITE = "overwrite" + APPEND = "append" + SUM = "sum" + CUSTOM = "custom" + + +class GraphState(BaseModel): # type: ignore[misc] + """Domain data container for graph execution. + + GraphState holds node outputs and intermediate results as the graph + executes. Execution tracking (iteration, path, etc.) is handled + separately by GraphAgentState. + + Example: + ```python + state = GraphState( + data={"input": "user query", "result": "agent response"}, + ) + ``` + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + data: Dict[str, Any] = Field( + default_factory=dict, description="Node outputs and intermediate results" + ) + + def data_to_json(self, indent: int = 2) -> str: + """Convert state.data to JSON string (automatically handles Pydantic models). + + Args: + indent: JSON indentation level (default: 2) + + Returns: + JSON string of state.data + """ + import json + + return json.dumps(self.data, cls=PydanticJSONEncoder, indent=indent) + + def get_parsed( + self, key: str, schema: Type[T], default: Optional[T] = None + ) -> Optional[T]: + """Get state value with automatic JSON-string parsing. + + Handles both dict and JSON-string representations transparently. + + Args: + key: State data key (usually agent output_key) + schema: Pydantic model to parse into + default: Value to return if key missing or parse fails + + Returns: + Parsed Pydantic model instance or default + """ + return parse_state_value(self.data.get(key), schema, default) + + def get_str(self, key: str, default: str = "") -> str: + """Get state value as string (for non-schema agent outputs). + + Args: + key: State data key + default: Value to return if key missing + + Returns: + String value or default + """ + return state_value_as_str(self.data.get(key), default) + + def get_dict( + self, key: str, default: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Get state value as dict with JSON-string fallback. + + Args: + key: State data key + default: Value to return if key missing or parse fails + + Returns: + Dict value or default + """ + return state_value_as_dict(self.data.get(key), default) diff --git a/src/google/adk/agents/graph/graph_telemetry.py b/src/google/adk/agents/graph/graph_telemetry.py new file mode 100644 index 0000000000..29810b96b0 --- /dev/null +++ b/src/google/adk/agents/graph/graph_telemetry.py @@ -0,0 +1,183 @@ +"""Telemetry mixins for agent observability. + +Two-layer design for reusability: +- AgentTelemetryMixin: Generic telemetry (any agent with telemetry_config) +- GraphTelemetryMixin: Graph-specific trace toggles (nodes, edges, etc.) +""" + +from __future__ import annotations + +import random +from typing import Any +from typing import Dict +from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..invocation_context import InvocationContext + from .graph_agent_config import TelemetryConfig + + +class AgentTelemetryMixin: + """Generic telemetry mixin for any agent with a telemetry_config field. + + Expects the host class to have: + - self.telemetry_config: Optional[TelemetryConfig] + """ + + telemetry_config: Any # Declared for type checking + name: str # Provided by host agent class + + def _is_telemetry_enabled(self) -> bool: + """Check if telemetry is enabled globally.""" + if not self.telemetry_config: + return True # Default to enabled if no config + return bool(self.telemetry_config.enabled) + + def _should_sample( + self, effective_config: Optional[TelemetryConfig] = None + ) -> bool: + """Check if current operation should be sampled. + + Uses random sampling to control telemetry volume. + + Args: + effective_config: Effective telemetry config (merged parent + own). + If None, uses self.telemetry_config. + """ + config = effective_config or self.telemetry_config + if not config: + return True # Default to 100% sampling if no config + return bool(random.random() < config.sampling_rate) + + def _get_telemetry_attributes( + self, + base_attributes: Dict[str, Any], + effective_config: Optional[TelemetryConfig] = None, + ) -> Dict[str, Any]: + """Get telemetry attributes including custom attributes. + + Args: + base_attributes: Base attributes for the telemetry event + effective_config: Effective config. If None, uses self.telemetry_config. + + Returns: + Combined attributes with additional custom attributes + """ + config = effective_config or self.telemetry_config + if not config or not config.additional_attributes: + return base_attributes + + combined = dict(config.additional_attributes) + combined.update(base_attributes) + return combined + + def _get_parent_telemetry_config( + self, ctx: InvocationContext + ) -> Optional[Dict[str, Any]]: + """Get parent telemetry config from agent_states. + + Used for nested agents to inherit telemetry settings from parent. + + Args: + ctx: Invocation context with agent_states + """ + if not ctx.agent_states: + return None + for agent_name, state_dict in ctx.agent_states.items(): + if agent_name != self.name and isinstance(state_dict, dict): + config = state_dict.get("telemetry_config_dict") + if config is not None: + return dict(config) if isinstance(config, dict) else None + return None + + def _get_effective_telemetry_config( + self, ctx: InvocationContext + ) -> Optional[TelemetryConfig]: + """Get effective telemetry config by merging parent and own config. + + Own config takes precedence over parent config. + + Args: + ctx: Invocation context with session state + """ + parent_config_dict = self._get_parent_telemetry_config(ctx) + + if not parent_config_dict: + return self.telemetry_config # type: ignore[no-any-return] + + if not self.telemetry_config: + from .graph_agent_config import TelemetryConfig + + return TelemetryConfig(**parent_config_dict) + + # Merge: own config takes precedence + merged_dict = parent_config_dict.copy() + own_dict = self.telemetry_config.model_dump() + for key, value in own_dict.items(): + if key == "additional_attributes" and value is not None: + parent_attrs = merged_dict.get("additional_attributes") or {} + own_attrs = value or {} + merged_dict["additional_attributes"] = {**parent_attrs, **own_attrs} + elif value is not None: + merged_dict[key] = value + + from .graph_agent_config import TelemetryConfig + + return TelemetryConfig(**merged_dict) + + +class GraphTelemetryMixin(AgentTelemetryMixin): + """Graph-specific telemetry toggles. + + Extends AgentTelemetryMixin with granular trace controls for + graph execution components (nodes, edges, iterations, etc.). + """ + + def _should_trace_nodes(self) -> bool: + """Check if node execution tracing is enabled.""" + if not self._is_telemetry_enabled(): + return False + if not self.telemetry_config: + return True + return bool(self.telemetry_config.trace_nodes) + + def _should_trace_edges(self) -> bool: + """Check if edge evaluation tracing is enabled.""" + if not self._is_telemetry_enabled(): + return False + if not self.telemetry_config: + return True + return bool(self.telemetry_config.trace_edges) + + def _should_trace_iterations(self) -> bool: + """Check if graph iteration metrics are enabled.""" + if not self._is_telemetry_enabled(): + return False + if not self.telemetry_config: + return True + return bool(self.telemetry_config.trace_iterations) + + def _should_trace_parallel_groups(self) -> bool: + """Check if parallel group execution tracing is enabled.""" + if not self._is_telemetry_enabled(): + return False + if not self.telemetry_config: + return True + return bool(self.telemetry_config.trace_parallel_groups) + + def _should_trace_callbacks(self) -> bool: + """Check if callback execution tracing is enabled.""" + if not self._is_telemetry_enabled(): + return False + if not self.telemetry_config: + return True + return bool(self.telemetry_config.trace_callbacks) + + def _should_trace_interrupts(self) -> bool: + """Check if interrupt check tracing is enabled.""" + if not self._is_telemetry_enabled(): + return False + if not self.telemetry_config: + return True + return bool(self.telemetry_config.trace_interrupts) diff --git a/src/google/adk/agents/graph/state_utils.py b/src/google/adk/agents/graph/state_utils.py new file mode 100644 index 0000000000..99b2d4bf62 --- /dev/null +++ b/src/google/adk/agents/graph/state_utils.py @@ -0,0 +1,127 @@ +"""Reusable state parsing utilities. + +Generic functions for parsing state values — usable by any agent type, +not just GraphAgent. GraphState.get_parsed/get_str/get_dict delegate here. + +Example: + ```python + from google.adk.agents.graph.state_utils import parse_state_value + + # Parse a raw value (dict or JSON string) into a Pydantic model + result = parse_state_value(raw_value, MyModel) + ``` +""" + +from __future__ import annotations + +import json +from typing import Any +from typing import cast +from typing import Dict +from typing import Optional +from typing import Type +from typing import TypeVar + +from pydantic import BaseModel + +T = TypeVar("T", bound=BaseModel) + + +class PydanticJSONEncoder(json.JSONEncoder): + """JSON encoder that automatically handles Pydantic models. + + This encoder allows json.dumps() to work transparently with Pydantic models + without requiring special serialization methods. + + Example: + ```python + import json + state_json = json.dumps(data, cls=PydanticJSONEncoder, indent=2) + ``` + """ + + def default(self, obj: Any) -> Any: + """Convert Pydantic models to dicts automatically.""" + if isinstance(obj, BaseModel): + return obj.model_dump() + return super().default(obj) + + +def parse_state_value( + raw: Any, schema: Type[T], default: Optional[T] = None +) -> Optional[T]: + """Parse a raw state value into a Pydantic model. + + Handles both dict and JSON-string representations transparently. + + Args: + raw: Raw value from state (dict, JSON string, or None) + schema: Pydantic model class to parse into + default: Value to return if raw is None or parse fails + + Returns: + Parsed Pydantic model instance or default + """ + if raw is None: + return default + + if isinstance(raw, dict): + try: + return cast(T, schema.model_validate(raw)) + except Exception: + return default + + if isinstance(raw, str): + try: + return cast(T, schema.model_validate_json(raw)) + except Exception: + return default + + return default + + +def state_value_as_str(raw: Any, default: str = "") -> str: + """Convert a raw state value to string. + + Args: + raw: Raw value from state + default: Value to return if raw is None + + Returns: + String representation or default + """ + if raw is None: + return default + return str(raw) + + +def state_value_as_dict( + raw: Any, default: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: + """Convert a raw state value to dict, with JSON-string fallback. + + Args: + raw: Raw value from state (dict, JSON string, or other) + default: Value to return if conversion fails + + Returns: + Dict value or default + """ + _default = default or {} + + if raw is None: + return _default + + if isinstance(raw, dict): + return cast(Dict[str, Any], raw) + + if isinstance(raw, str): + try: + result = json.loads(raw) + if isinstance(result, dict): + return cast(Dict[str, Any], result) + return _default + except Exception: + return _default + + return _default diff --git a/src/google/adk/cli/agent_graph.py b/src/google/adk/cli/agent_graph.py index ffb7114522..fb313f51a4 100644 --- a/src/google/adk/cli/agent_graph.py +++ b/src/google/adk/cli/agent_graph.py @@ -20,6 +20,7 @@ import graphviz from ..agents.base_agent import BaseAgent +from ..agents.graph.graph_agent import GraphAgent from ..agents.llm_agent import LlmAgent from ..agents.loop_agent import LoopAgent from ..agents.parallel_agent import ParallelAgent @@ -38,6 +39,17 @@ retrieval_tool_module_loaded = True +def _graph_node_id(node) -> str: + """Get the graphviz node ID for a GraphNode. + + For agent nodes, uses agent.name (matches build_graph's node naming). + For function nodes, uses node.name. + """ + if node.agent is not None: + return node.agent.name + return node.name + + async def build_graph( graph: graphviz.Digraph, agent: BaseAgent, @@ -69,6 +81,8 @@ def get_node_name(tool_or_agent: Union[BaseAgent, BaseTool]): return tool_or_agent.name + ' (Loop Agent)' elif isinstance(tool_or_agent, ParallelAgent): return tool_or_agent.name + ' (Parallel Agent)' + elif isinstance(tool_or_agent, GraphAgent): + return tool_or_agent.name + ' (Graph Agent)' else: return tool_or_agent.name elif isinstance(tool_or_agent, BaseTool): @@ -126,6 +140,8 @@ def should_build_agent_cluster(tool_or_agent: Union[BaseAgent, BaseTool]): return True elif isinstance(tool_or_agent, ParallelAgent): return True + elif isinstance(tool_or_agent, GraphAgent): + return True else: return False elif retrieval_tool_module_loaded and isinstance( @@ -188,6 +204,36 @@ async def build_cluster(child: graphviz.Digraph, agent: BaseAgent, name: str): await build_graph(child, sub_agent, highlight_pairs) if parent_agent: draw_edge(parent_agent.name, sub_agent.name) + elif isinstance(agent, GraphAgent): + # Render all graph nodes inside the cluster + for node_name, node in agent.nodes.items(): + if node.agent is not None: + # Agent node — render inside cluster + await build_graph(child, node.agent, highlight_pairs) + else: + # Function-only node + child.node( + node_name, + node_name, + shape='box', + style='rounded', + color=light_gray, + fontcolor=light_gray, + ) + # Draw edges from graph topology + for node_name, node in agent.nodes.items(): + src = _graph_node_id(node) + for edge in node.edges: + tgt_node = agent.nodes.get(edge.target_node) + if tgt_node: + dst = _graph_node_id(tgt_node) + draw_edge(src, dst) + # Connect parent to start node + if parent_agent and agent.start_node: + sn = agent.nodes.get(agent.start_node) + if sn: + start_id = _graph_node_id(sn) + draw_edge(parent_agent.name, start_id) else: for sub_agent in agent.sub_agents: await build_graph(child, sub_agent, highlight_pairs) diff --git a/src/google/adk/telemetry/graph_tracing.py b/src/google/adk/telemetry/graph_tracing.py new file mode 100644 index 0000000000..ee5a1695f0 --- /dev/null +++ b/src/google/adk/telemetry/graph_tracing.py @@ -0,0 +1,371 @@ +"""OpenTelemetry instrumentation for GraphAgent workflow execution. + +This module provides tracing, logging, and metrics for graph orchestration +following OpenTelemetry semantic conventions. +""" + +from __future__ import annotations + +import logging +import time +from typing import TYPE_CHECKING + +from opentelemetry import _logs +from opentelemetry import metrics +from opentelemetry import trace +from opentelemetry.semconv.schemas import Schemas + +from .. import version + +if TYPE_CHECKING: + from ..agents.graph.graph_state import GraphState + from ..agents.invocation_context import InvocationContext + +# OpenTelemetry tracer for graph execution +tracer = trace.get_tracer( + instrumenting_module_name="gcp.vertex.agent.graph", + instrumenting_library_version=version.__version__, + schema_url=Schemas.V1_36_0.value, +) + +# OpenTelemetry logger for structured logs +otel_logger = _logs.get_logger( + instrumenting_module_name="gcp.vertex.agent.graph", + instrumenting_library_version=version.__version__, + schema_url=Schemas.V1_36_0.value, +) + +# Python logger for standard logging +logger = logging.getLogger("google_adk." + __name__) + +# OpenTelemetry meter for metrics +meter = metrics.get_meter( + name="gcp.vertex.agent.graph", + version=version.__version__, + schema_url=Schemas.V1_36_0.value, +) + +# Metrics - Node Execution +node_execution_counter = meter.create_counter( + name="graph.node.executions", + description="Number of node executions", + unit="1", +) + +node_execution_latency = meter.create_histogram( + name="graph.node.latency", + description="Node execution latency in milliseconds", + unit="ms", +) + +# Metrics - Edge Evaluation +edge_evaluation_counter = meter.create_counter( + name="graph.edge.evaluations", + description="Number of edge condition evaluations", + unit="1", +) + +edge_evaluation_latency = meter.create_histogram( + name="graph.edge.latency", + description="Edge condition evaluation latency in milliseconds", + unit="ms", +) + +# Metrics - Graph Iterations +graph_iteration_counter = meter.create_counter( + name="graph.iterations", + description="Graph execution iterations", + unit="1", +) + +# Metrics - Parallel Groups +parallel_group_counter = meter.create_counter( + name="graph.parallel.executions", + description="Parallel group executions", + unit="1", +) + +parallel_group_latency = meter.create_histogram( + name="graph.parallel.latency", + description="Parallel group execution latency in milliseconds", + unit="ms", +) + +# Metrics - Callbacks +callback_execution_counter = meter.create_counter( + name="graph.callback.executions", + description="Callback executions", + unit="1", +) + +callback_execution_latency = meter.create_histogram( + name="graph.callback.latency", + description="Callback execution latency in milliseconds", + unit="ms", +) + +# Metrics - Interrupts +interrupt_check_counter = meter.create_counter( + name="graph.interrupt.checks", + description="Interrupt check operations", + unit="1", +) + +# Metrics - State Reducers +state_reducer_counter = meter.create_counter( + name="graph.state.reducer.applications", + description="State reducer applications", + unit="1", +) + +state_reducer_latency = meter.create_histogram( + name="graph.state.reducer.latency", + description="State reducer application latency in milliseconds", + unit="ms", +) + +# Metrics - Mappers +mapper_counter = meter.create_counter( + name="graph.mapper.applications", + description="Mapper function applications", + unit="1", +) + +mapper_latency = meter.create_histogram( + name="graph.mapper.latency", + description="Mapper function execution latency in milliseconds", + unit="ms", +) + +# Semantic Conventions - Graph Attributes +GRAPH_AGENT_NAME = "graph.agent.name" +GRAPH_NODE_NAME = "graph.node.name" +GRAPH_NODE_TYPE = "graph.node.type" # agent|function +GRAPH_NODE_ITERATION = "graph.node.iteration" +GRAPH_EDGE_SOURCE = "graph.edge.source" +GRAPH_EDGE_TARGET = "graph.edge.target" +GRAPH_EDGE_CONDITION_RESULT = "graph.edge.condition.result" +GRAPH_EDGE_PRIORITY = "graph.edge.priority" +GRAPH_ITERATION = "graph.iteration" +GRAPH_PATH = "graph.path" +GRAPH_PARALLEL_NODE_COUNT = "graph.parallel.node_count" +GRAPH_PARALLEL_STRATEGY = "graph.parallel.strategy" +GRAPH_PARALLEL_WAIT_N = "graph.parallel.wait_n" +GRAPH_CALLBACK_TYPE = "graph.callback.type" # before_node|after_node|on_edge +GRAPH_INTERRUPT_MODE = "graph.interrupt.mode" # before|after|both +GRAPH_SESSION_ID = "graph.session.id" +GRAPH_STATE_REDUCER_TYPE = ( # OVERWRITE|APPEND|SUM|CUSTOM + "graph.state.reducer.type" +) +GRAPH_STATE_KEY = "graph.state.key" # Key being modified in state +GRAPH_MAPPER_TYPE = "graph.mapper.type" # input|output +GRAPH_MAPPER_IS_DEFAULT = ( # Whether using default mapper + "graph.mapper.is_default" +) + + +def record_node_execution( + node_name: str, + node_type: str, + agent_name: str, + latency_ms: float, + success: bool = True, +) -> None: + """Record node execution metrics. + + Args: + node_name: Name of the executed node + node_type: Type of node (agent or function) + agent_name: Name of the GraphAgent + latency_ms: Execution latency in milliseconds + success: Whether execution succeeded + """ + attributes = { + GRAPH_NODE_NAME: node_name, + GRAPH_NODE_TYPE: node_type, + GRAPH_AGENT_NAME: agent_name, + "success": success, + } + + node_execution_counter.add(1, attributes=attributes) + node_execution_latency.record(latency_ms, attributes=attributes) + + +def record_edge_evaluation( + source_node: str, + target_node: str, + agent_name: str, + condition_result: bool, + latency_ms: float, + priority: int = 0, +) -> None: + """Record edge condition evaluation metrics. + + Args: + source_node: Source node name + target_node: Target node name + agent_name: Name of the GraphAgent + condition_result: Result of condition evaluation + latency_ms: Evaluation latency in milliseconds + priority: Edge priority + """ + attributes = { + GRAPH_EDGE_SOURCE: source_node, + GRAPH_EDGE_TARGET: target_node, + GRAPH_AGENT_NAME: agent_name, + GRAPH_EDGE_CONDITION_RESULT: str(condition_result), + GRAPH_EDGE_PRIORITY: priority, + } + + edge_evaluation_counter.add(1, attributes=attributes) + edge_evaluation_latency.record(latency_ms, attributes=attributes) + + +def record_graph_iteration( + agent_name: str, + iteration: int, + path_length: int, +) -> None: + """Record graph iteration metrics. + + Args: + agent_name: Name of the GraphAgent + iteration: Current iteration number + path_length: Length of execution path so far + """ + attributes = { + GRAPH_AGENT_NAME: agent_name, + GRAPH_ITERATION: iteration, + "path_length": path_length, + } + + graph_iteration_counter.add(1, attributes=attributes) + + +def record_parallel_group_execution( + agent_name: str, + node_count: int, + strategy: str, + latency_ms: float, + completed_count: int, +) -> None: + """Record parallel group execution metrics. + + Args: + agent_name: Name of the GraphAgent + node_count: Number of nodes in parallel group + strategy: Join strategy (all, any, n) + latency_ms: Total execution latency in milliseconds + completed_count: Number of nodes that completed successfully + """ + attributes = { + GRAPH_AGENT_NAME: agent_name, + GRAPH_PARALLEL_NODE_COUNT: node_count, + GRAPH_PARALLEL_STRATEGY: strategy, + "completed_count": completed_count, + } + + parallel_group_counter.add(1, attributes=attributes) + parallel_group_latency.record(latency_ms, attributes=attributes) + + +def record_callback_execution( + callback_type: str, + agent_name: str, + latency_ms: float, + success: bool = True, +) -> None: + """Record callback execution metrics. + + Args: + callback_type: Type of callback (before_node, after_node, on_edge) + agent_name: Name of the GraphAgent + latency_ms: Execution latency in milliseconds + success: Whether callback succeeded + """ + attributes = { + GRAPH_CALLBACK_TYPE: callback_type, + GRAPH_AGENT_NAME: agent_name, + "success": success, + } + + callback_execution_counter.add(1, attributes=attributes) + callback_execution_latency.record(latency_ms, attributes=attributes) + + +def record_interrupt_check( + mode: str, + agent_name: str, + session_id: str, +) -> None: + """Record interrupt check metrics. + + Args: + mode: Interrupt mode (before, after, both) + agent_name: Name of the GraphAgent + session_id: Session identifier + """ + attributes = { + GRAPH_INTERRUPT_MODE: mode, + GRAPH_AGENT_NAME: agent_name, + GRAPH_SESSION_ID: session_id, + } + + interrupt_check_counter.add(1, attributes=attributes) + + +def record_state_reducer( + node_name: str, + reducer_type: str, + state_key: str, + agent_name: str, + latency_ms: float, + had_previous_value: bool, +) -> None: + """Record state reducer application metrics. + + Args: + node_name: Name of the node applying the reducer + reducer_type: Type of reducer (OVERWRITE, APPEND, SUM, CUSTOM) + state_key: Key being modified in state.data + agent_name: Name of the GraphAgent + latency_ms: Reducer application latency in milliseconds + had_previous_value: Whether the key existed in state before reduction + """ + attributes = { + GRAPH_NODE_NAME: node_name, + GRAPH_STATE_REDUCER_TYPE: reducer_type, + GRAPH_STATE_KEY: state_key, + GRAPH_AGENT_NAME: agent_name, + "had_previous_value": had_previous_value, + } + + state_reducer_counter.add(1, attributes=attributes) + state_reducer_latency.record(latency_ms, attributes=attributes) + + +def record_mapper( + node_name: str, + mapper_type: str, + agent_name: str, + latency_ms: float, + is_default: bool, +) -> None: + """Record mapper transformation metrics. + + Args: + node_name: Name of the node using the mapper + mapper_type: Type of mapper (input or output) + agent_name: Name of the GraphAgent + latency_ms: Mapper execution latency in milliseconds + is_default: Whether using default mapper implementation + """ + attributes = { + GRAPH_NODE_NAME: node_name, + GRAPH_MAPPER_TYPE: mapper_type, + GRAPH_AGENT_NAME: agent_name, + GRAPH_MAPPER_IS_DEFAULT: is_default, + } + + mapper_counter.add(1, attributes=attributes) + mapper_latency.record(latency_ms, attributes=attributes) diff --git a/tests/unittests/agents/test_graph_agent.py b/tests/unittests/agents/test_graph_agent.py new file mode 100644 index 0000000000..1ccf1070c0 --- /dev/null +++ b/tests/unittests/agents/test_graph_agent.py @@ -0,0 +1,2628 @@ +"""Comprehensive test suite for GraphAgent implementation. + +Tests all features with 100% coverage: +- Graph-based workflows with nodes and edges +- AgentNode for wrapping LLM agents +- Cyclic support for loops and iterative reasoning (ReAct pattern) +- Conditional routing based on state +- State management with reducers (overwrite, append, sum, custom) +- Checkpointing with persistent state (memory, SQLite) +- Human-in-the-loop with interrupt capabilities +""" + +import asyncio +from datetime import datetime +from datetime import timezone +from typing import Any +from typing import Dict +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.agents import LlmAgent +from google.adk.agents import ParallelAgent +from google.adk.agents import SequentialAgent +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import EdgeCondition +from google.adk.agents.graph import export_execution_timeline +from google.adk.agents.graph import export_graph_structure +from google.adk.agents.graph import export_graph_with_execution +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphNode +from google.adk.agents.graph import GraphState +from google.adk.agents.graph import rewind_to_node +from google.adk.agents.graph import StateReducer +from google.adk.agents.graph.graph_agent_state import GraphAgentState +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.run_config import RunConfig +from google.adk.apps import ResumabilityConfig +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session +from google.genai import types +import pytest + +# ============================================================================ +# Mock Agents for Testing +# ============================================================================ + + +class SimpleTestAgent(BaseAgent): + """Real test agent that extends BaseAgent per ADK guidelines. + + This replaces MockAgent to comply with ADK testing guidelines: + - Extends BaseAgent (not a mock) + - Implements _run_async_impl (proper agent pattern) + - Uses real agent infrastructure + + Uses private attributes to store test data to avoid Pydantic validation. + """ + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + def __init__(self, name: str, responses: list[str], delay: float = 0.0): + super().__init__(name=name) + # Use object.__setattr__ to bypass Pydantic for extra attributes + object.__setattr__(self, "_responses", responses) + object.__setattr__(self, "_call_count", 0) + object.__setattr__(self, "_delay", delay) + + async def _run_async_impl(self, ctx): + """Real agent implementation that yields predetermined responses.""" + delay = object.__getattribute__(self, "_delay") + await asyncio.sleep(delay) # Simulate processing time + + call_count = object.__getattribute__(self, "_call_count") + responses = object.__getattribute__(self, "_responses") + + response = responses[min(call_count, len(responses) - 1)] + object.__setattr__(self, "_call_count", call_count + 1) + + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=response)]), + ) + + @property + def call_count(self): + """Get number of times agent was called.""" + return object.__getattribute__(self, "_call_count") + + +class MockLlmAgent(LlmAgent): + """Mock LLM agent that doesn't call real LLM.""" + + # Use model_config to allow extra attributes + model_config = {"arbitrary_types_allowed": True, "extra": "allow"} + + def __init__(self, name: str, response: str = "mock response", **kwargs): + super().__init__( + name=name, model="gemini-2.0-flash-exp", instruction="mock", **kwargs + ) + # Store as model extra fields + object.__setattr__(self, "_mock_response", response) + object.__setattr__(self, "_mock_call_count", 0) + + async def _run_async_impl(self, ctx): + """Mock implementation.""" + count = object.__getattribute__(self, "_mock_call_count") + object.__setattr__(self, "_mock_call_count", count + 1) + + response = object.__getattribute__(self, "_mock_response") + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=response)]), + ) + + @property + def call_count(self): + """Get call count.""" + return object.__getattribute__(self, "_mock_call_count") + + +# ============================================================================ +# Test: Basic Graph Structure +# ============================================================================ + + +class TestGraphStructure: + """Test basic graph construction and structure.""" + + def test_create_empty_graph(self): + """Test creating empty graph.""" + graph = GraphAgent(name="test_graph", description="Test graph") + assert graph.name == "test_graph" + assert graph.description == "Test graph" + assert len(graph.nodes) == 0 + assert graph.start_node is None + assert len(graph.end_nodes) == 0 + + def test_add_nodes(self): + """Test adding nodes to graph.""" + graph = GraphAgent(name="test") + + node1 = GraphNode(name="node1", agent=MockLlmAgent("agent1")) + node2 = GraphNode(name="node2", agent=MockLlmAgent("agent2")) + + graph.add_node(node1) + graph.add_node(node2) + + assert len(graph.nodes) == 2 + assert "node1" in graph.nodes + assert "node2" in graph.nodes + + def test_add_edges(self): + """Test adding edges between nodes.""" + graph = GraphAgent(name="test") + + graph.add_node(GraphNode(name="node1", agent=MockLlmAgent("agent1"))) + graph.add_node(GraphNode(name="node2", agent=MockLlmAgent("agent2"))) + + graph.add_edge("node1", "node2") + + assert len(graph.nodes["node1"].edges) == 1 + assert graph.nodes["node1"].edges[0].target_node == "node2" + + def test_set_start_end(self): + """Test setting start and end nodes.""" + graph = GraphAgent(name="test") + graph.add_node(GraphNode(name="start", agent=MockLlmAgent("agent1"))) + graph.add_node(GraphNode(name="end", agent=MockLlmAgent("agent2"))) + + graph.set_start("start") + graph.set_end("end") + + assert graph.start_node == "start" + assert "end" in graph.end_nodes + + def test_invalid_edge_raises_error(self): + """Test that invalid edges raise errors.""" + graph = GraphAgent(name="test") + graph.add_node(GraphNode(name="node1", agent=MockLlmAgent("agent1"))) + + with pytest.raises(ValueError, match="Target node node2 not found"): + graph.add_edge("node1", "node2") + + def test_invalid_start_raises_error(self): + """Test that invalid start node raises error.""" + graph = GraphAgent(name="test") + + with pytest.raises(ValueError, match="Node invalid not found"): + graph.set_start("invalid") + + +# ============================================================================ +# Test: Cyclic Support and ReAct Pattern +# ============================================================================ + + +@pytest.mark.asyncio +class TestCyclicExecution: + """Test cyclic graph execution (loops, ReAct pattern).""" + + async def test_simple_loop(self): + """Test graph with loop executes multiple iterations.""" + graph = GraphAgent(name="loop_graph", max_iterations=5) + + # Counter agent that increments + counter_responses = [str(i) for i in range(1, 10)] + counter_agent = SimpleTestAgent("counter", counter_responses) + + graph.add_node( + GraphNode( + name="counter", + agent=counter_agent, + output_mapper=lambda output, state: GraphState( + data={**state.data, "count": int(output)}, + ), + ) + ) + + # Loop back if count < 3 + graph.set_start("counter") + graph.add_edge( + "counter", "counter", condition=lambda s: s.data.get("count", 0) < 3 + ) + graph.set_end("counter") + + # Execute with Runner + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + + # Create session first + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + iterations = 0 + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content( + role="user", parts=[types.Part(text="start")] + ), + ): + if event.content and event.content.parts: + iterations = ( + event.actions.state_delta.get("graph_iterations", 0) + if event.actions and event.actions.state_delta + else 0 + ) + + # Should run 3 iterations (count 1, 2, 3) + assert iterations == 3 + assert counter_agent.call_count == 3 + + async def test_max_iterations_prevents_infinite_loop(self): + """Test max_iterations prevents infinite loops.""" + graph = GraphAgent(name="infinite", max_iterations=3) + + # Agent that never ends + loop_agent = SimpleTestAgent("loop", ["continue"] * 100) + + graph.add_node(GraphNode(name="loop", agent=loop_agent)) + graph.set_start("loop") + graph.add_edge("loop", "loop") # Always loop back + + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + + # Create session first + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + iterations = 0 + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content( + role="user", parts=[types.Part(text="start")] + ), + ): + if event.content and event.content.parts: + iterations = ( + event.actions.state_delta.get("graph_iterations", 0) + if event.actions and event.actions.state_delta + else 0 + ) + + # Should stop at max_iterations + assert iterations == 3 + + async def test_react_pattern(self): + """Test ReAct pattern (Reason -> Act -> Observe -> loop).""" + graph = GraphAgent(name="react", max_iterations=10) + + # Simulate ReAct: Complete after 2 iterations + reason_agent = SimpleTestAgent("reason", ["plan action 1", "plan action 2"]) + act_agent = SimpleTestAgent("act", ["result 1", "result 2"]) + observe_agent = SimpleTestAgent("observe", ["CONTINUE", "COMPLETE"]) + + graph.add_node(GraphNode(name="reason", agent=reason_agent)) + graph.add_node(GraphNode(name="act", agent=act_agent)) + graph.add_node(GraphNode(name="observe", agent=observe_agent)) + + graph.set_start("reason") + graph.add_edge("reason", "act") + graph.add_edge("act", "observe") + + # Loop back if CONTINUE, otherwise end (observe is end node) + graph.add_edge( + "observe", + "reason", + condition=lambda s: "CONTINUE" in s.data.get("observe", "").upper(), + ) + # When COMPLETE (or any other value), no edge matches, so execution stops at end node + graph.set_end("observe") + + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + + # Create session first + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + path = [] + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content( + role="user", parts=[types.Part(text="test task")] + ), + ): + if event.content and event.content.parts: + path = ( + event.actions.state_delta.get("graph_path", []) + if event.actions and event.actions.state_delta + else [] + ) + + # Should execute: reason -> act -> observe -> reason -> act -> observe + expected = ["reason", "act", "observe", "reason", "act", "observe"] + assert path == expected + + +# ============================================================================ +# Test: Human-in-the-Loop Interrupts +# ============================================================================ + + +# ============================================================================ +# Test: Observability +# ============================================================================ +# NOTE: Tests for callback-based observability moved to test_graph_callbacks.py +# The old hardcoded "→ Node:" and "✓ Completed:" events were intentionally +# removed and replaced with callback-based observability per refactor requirements. + + +# ============================================================================ +# Test: Checkpointing +# ============================================================================ + + +@pytest.mark.asyncio +class TestCheckpointing: + """Test state checkpointing for resumability.""" + + async def test_checkpointing_enabled(self): + """Test that checkpointing saves state after each node.""" + graph = GraphAgent(name="test", checkpointing=True) + + agent1 = SimpleTestAgent("agent1", ["step1"]) + agent2 = SimpleTestAgent("agent2", ["step2"]) + + graph.add_node(GraphNode(name="node1", agent=agent1)) + graph.add_node(GraphNode(name="node2", agent=agent2)) + graph.set_start("node1") + graph.add_edge("node1", "node2") + graph.set_end("node2") + + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + + # Create session first + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + checkpoints = [] + last_checkpoint = None + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content(role="user", parts=[types.Part(text="test")]), + ): + session = await runner.session_service.get_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + if "graph_checkpoint" in session.state: + current_checkpoint = session.state["graph_checkpoint"] + # Only append if it's a new checkpoint (different node or iteration) + if current_checkpoint != last_checkpoint: + checkpoints.append(current_checkpoint.copy()) + last_checkpoint = current_checkpoint + + # Should have checkpoints for both nodes + assert len(checkpoints) >= 2 + assert checkpoints[0]["node"] == "node1" + assert checkpoints[1]["node"] == "node2" + + async def test_checkpoint_contains_state(self): + """Test checkpoint contains graph state.""" + graph = GraphAgent(name="test", checkpointing=True) + + agent = SimpleTestAgent("agent", ["response"]) + graph.add_node(GraphNode(name="worker", agent=agent)) + graph.set_start("worker") + graph.set_end("worker") + + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + + # Create session first + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content(role="user", parts=[types.Part(text="test")]), + ): + pass + + # Check saved state + session = await runner.session_service.get_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + assert "graph_data" in session.state + graph_data = session.state["graph_data"] + assert graph_data["worker"] == "response" + + +# ============================================================================ +# Test: Agent Type Support (LLM, Sequential, Parallel, Graph) +# ============================================================================ + + +@pytest.mark.asyncio +class TestAgentTypeSupport: + """Test support for all BaseAgent types.""" + + async def test_llm_agent_node(self): + """Test node with LLMAgent.""" + graph = GraphAgent(name="test") + + llm_agent = MockLlmAgent("llm", response="llm response") + graph.add_node(GraphNode(name="llm", agent=llm_agent)) + graph.set_start("llm") + graph.set_end("llm") + + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + + # Create session first + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + # Collect all event texts + event_texts = [] + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content(role="user", parts=[types.Part(text="test")]), + ): + if event.content and event.content.parts: + event_texts.append(event.content.parts[0].text) + + # Check that llm agent's response appears in events + assert any( + "llm response" in text for text in event_texts + ), f"Expected 'llm response' in events, got {event_texts}" + assert llm_agent.call_count == 1 + + async def test_custom_function_node(self): + """Test node with custom function instead of agent.""" + graph = GraphAgent(name="test") + + async def custom_fn(state: GraphState, ctx): + """Custom function.""" + return f"processed: {state.data.get('input', '')}" + + graph.add_node(GraphNode(name="custom", function=custom_fn)) + graph.set_start("custom") + graph.set_end("custom") + + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + + # Create session first + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + final_output = None + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content( + role="user", parts=[types.Part(text="test input")] + ), + ): + if event.content and event.content.parts: + final_output = ( + event.content.parts[0].text + if event.content and event.content.parts + else "" + ) + + assert "processed: test input" in str(final_output) + + def test_node_requires_agent_or_function(self): + """Test that node requires either agent or function.""" + with pytest.raises( + ValueError, match="Either agent or function must be provided" + ): + GraphNode(name="invalid", agent=None, function=None) + + +# ============================================================================ +# Test: Input/Output Mappers +# ============================================================================ + + +class TestMappers: + """Test input and output mappers.""" + + def test_custom_input_mapper(self): + """Test custom input mapper transforms state to agent input.""" + + def input_mapper(state: GraphState) -> str: + return f"Custom: {state.data.get('value', '')}" + + node = GraphNode( + name="test", agent=MockLlmAgent("agent"), input_mapper=input_mapper + ) + + state = GraphState(data={"value": "test"}) + mapped_input = node.input_mapper(state) + + assert mapped_input == "Custom: test" + + def test_custom_output_mapper(self): + """Test custom output mapper transforms agent output to state.""" + + def output_mapper(output: str, state: GraphState) -> GraphState: + new_state = GraphState(data={**state.data, "result": output.upper()}) + return new_state + + node = GraphNode( + name="test", agent=MockLlmAgent("agent"), output_mapper=output_mapper + ) + + state = GraphState(data={}) + new_state = node.output_mapper("hello", state) + + assert new_state.data["result"] == "HELLO" + + +# ============================================================================ +# Test: Error Handling +# ============================================================================ + + +class TestErrorHandling: + """Test error handling and validation.""" + + def test_set_end_invalid_node(self): + """Test set_end raises error for non-existent node.""" + graph = GraphAgent(name="test") + with pytest.raises( + ValueError, match="Node invalid_node not found in graph" + ): + graph.set_end("invalid_node") + + def test_add_edge_invalid_source(self): + """Test add_edge raises error for non-existent source node.""" + graph = GraphAgent(name="test") + agent = SimpleTestAgent("agent", ["response"]) + graph.add_node(GraphNode(name="node1", agent=agent)) + + with pytest.raises(ValueError, match="Source node invalid not found"): + graph.add_edge("invalid", "node1") + + @pytest.mark.asyncio + async def test_node_no_edges_not_end_raises_error(self): + """Test execution raises error when node has no edges and is not an end node.""" + graph = GraphAgent(name="test") + + agent = SimpleTestAgent("agent", ["response"]) + graph.add_node(GraphNode(name="node1", agent=agent)) + graph.set_start("node1") + # Don't set as end node and don't add edges + + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + with pytest.raises( + ValueError, match="has no outgoing edges and is not an end node" + ): + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content( + role="user", parts=[types.Part(text="test")] + ), + ): + pass + + @pytest.mark.asyncio + async def test_start_node_not_set_raises_error(self): + """Test execution raises error when start node is not set.""" + graph = GraphAgent(name="test") + + agent = SimpleTestAgent("agent", ["response"]) + graph.add_node(GraphNode(name="node1", agent=agent)) + # Don't set start node + + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + with pytest.raises(ValueError, match="Start node not set"): + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content( + role="user", parts=[types.Part(text="test")] + ), + ): + pass + + +# ============================================================================ +# Test: Function Execution +# ============================================================================ + + +@pytest.mark.asyncio +class TestFunctionExecution: + """Test synchronous and asynchronous function execution.""" + + async def test_sync_function_node(self): + """Test node with synchronous function.""" + graph = GraphAgent(name="test") + + # Synchronous function + def sync_fn(state: GraphState, ctx): + return f"sync: {state.data.get('input', '')}" + + graph.add_node(GraphNode(name="sync_node", function=sync_fn)) + graph.set_start("sync_node") + graph.set_end("sync_node") + + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + final_output = None + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content(role="user", parts=[types.Part(text="test")]), + ): + if event.content and event.content.parts: + final_output = event.content.parts[0].text + + assert "sync: test" in str(final_output) + + +# ============================================================================ +# Test: State Restoration +# ============================================================================ + + +@pytest.mark.asyncio +class TestStateRestoration: + """Test state restoration from session.""" + + async def test_state_restoration_from_session(self): + """Test that graph can restore state from session.""" + graph = GraphAgent(name="test", checkpointing=True) + + agent = SimpleTestAgent("agent", ["response1", "response2"]) + graph.add_node(GraphNode(name="node1", agent=agent)) + graph.set_start("node1") + graph.set_end("node1") + + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + # First run - create state + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content( + role="user", parts=[types.Part(text="first")] + ), + ): + pass + + # Second run - should restore state + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content( + role="user", parts=[types.Part(text="second")] + ), + ): + pass + + # Verify state was persisted + session = await session_service.get_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + assert "graph_data" in session.state + + +# ============================================================================ +# Test: ADK Conformity +# ============================================================================ + + +@pytest.mark.asyncio +class TestADKConformity: + """Test ADK conformance.""" + + async def test_event_structure_conformity(self): + """Test that GraphAgent yields proper Event objects.""" + graph = GraphAgent(name="test") + + agent = SimpleTestAgent("agent", ["response"]) + graph.add_node(GraphNode(name="node1", agent=agent)) + graph.set_start("node1") + graph.set_end("node1") + + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + # Collect all events + events = [] + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content(role="user", parts=[types.Part(text="test")]), + ): + events.append(event) + + # Verify all events are proper Event objects + assert any( + e.author == "agent" for e in events + ), "Expected at least one event from the inner agent" + for event in events: + # Must have author field + assert hasattr(event, "author") + assert event.author is not None + + # Should have content (some events might not) + if event.content: + assert isinstance(event.content, types.Content) + assert hasattr(event.content, "parts") + + # May have actions (EventActions) + if event.actions: + assert isinstance(event.actions, EventActions) + + async def test_invocation_context_conformity(self): + """Test that InvocationContext is properly structured.""" + graph = GraphAgent(name="test") + + # Custom function that verifies InvocationContext structure + def verify_ctx(state: GraphState, ctx): + # Verify required InvocationContext fields + assert hasattr(ctx, "session") + assert hasattr(ctx, "invocation_id") + assert hasattr(ctx, "agent") + assert hasattr(ctx, "session_service") + return "context valid" + + graph.add_node(GraphNode(name="node1", function=verify_ctx)) + graph.set_start("node1") + graph.set_end("node1") + + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + # If this doesn't raise, context is valid + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content(role="user", parts=[types.Part(text="test")]), + ): + pass + + async def test_state_delta_conformity(self): + """Test that state changes use EventActions.state_delta.""" + graph = GraphAgent(name="test", checkpointing=True) + + agent = SimpleTestAgent("agent", ["response"]) + graph.add_node(GraphNode(name="node1", agent=agent)) + graph.set_start("node1") + graph.set_end("node1") + + runner = Runner( + app_name="test_graph", + agent=graph, + session_service=InMemorySessionService(), + ) + session_service = runner.session_service + await session_service.create_session( + app_name="test_graph", user_id="test_user", session_id="test" + ) + + # Collect events with state_delta + state_delta_events = [] + async for event in runner.run_async( + user_id="test_user", + session_id="test", + new_message=types.Content(role="user", parts=[types.Part(text="test")]), + ): + if event.actions and event.actions.state_delta: + state_delta_events.append(event) + + # Checkpointing should produce at least one state_delta event + assert ( + len(state_delta_events) >= 1 + ), "Expected checkpoint to emit state_delta events" + + # Verify state_delta structure + for event in state_delta_events: + assert isinstance(event.actions.state_delta, dict) + + # NOTE: test_escalate_flag_conformity removed - tested hardcoded events that were removed + # Callback-based observability (the replacement) is tested in test_graph_callbacks.py + + +# ============================================================================ +# Graph Export Tests +# ============================================================================ + + +class TestGraphExport: + """Tests for D3-compatible graph structure export.""" + + def test_export_graph_structure(self): + """Test exporting graph structure in D3 format.""" + graph = GraphAgent(name="test_graph", checkpointing=True) + + # Add nodes + graph.add_node(GraphNode(name="start", function=lambda s, c: "start")) + graph.add_node(GraphNode(name="process", function=lambda s, c: "process")) + graph.add_node(GraphNode(name="end", function=lambda s, c: "end")) + + # Add edges + graph.add_edge("start", "process") + graph.add_edge("process", "end") + + # Set start and end + graph.set_start("start") + graph.set_end("end") + + # Export structure + structure = export_graph_structure(graph) + + # Verify structure + assert "nodes" in structure + assert "links" in structure + assert "metadata" in structure + assert structure["directed"] is True + + # Verify nodes + assert len(structure["nodes"]) == 3 + node_ids = [n["id"] for n in structure["nodes"]] + assert "start" in node_ids + assert "process" in node_ids + assert "end" in node_ids + + # Verify all nodes are function type + for node in structure["nodes"]: + assert node["type"] == "function" + + # Verify links + assert len(structure["links"]) == 2 + links = [(l["source"], l["target"]) for l in structure["links"]] + assert ("start", "process") in links + assert ("process", "end") in links + + # Verify metadata + assert structure["metadata"]["start_node"] == "start" + assert structure["metadata"]["end_nodes"] == ["end"] + assert structure["metadata"]["checkpointing"] is True + + def test_export_with_conditional_edges(self): + """Test export includes conditional edge information.""" + graph = GraphAgent(name="test") + + graph.add_node(GraphNode(name="a", function=lambda s, c: "a")) + graph.add_node(GraphNode(name="b", function=lambda s, c: "b")) + graph.add_node(GraphNode(name="c", function=lambda s, c: "c")) + + # Add conditional and unconditional edges + graph.add_edge("a", "b", condition=lambda s: s.data.get("go_b", False)) + graph.add_edge("a", "c") # No condition + + structure = export_graph_structure(graph) + + # Verify conditional flags + links = structure["links"] + assert len(links) == 2 + + # Find the links + b_link = next(l for l in links if l["target"] == "b") + c_link = next(l for l in links if l["target"] == "c") + + assert b_link["conditional"] is True + assert c_link["conditional"] is False + + def test_export_with_agent_nodes(self): + """Test export distinguishes agent vs function nodes.""" + graph = GraphAgent(name="test") + + # Add function node + graph.add_node(GraphNode(name="func", function=lambda s, c: "func")) + + # Add agent node + mock_agent = Mock(spec=BaseAgent) + mock_agent.name = "agent" + graph.add_node(GraphNode(name="agent", agent=mock_agent)) + + structure = export_graph_structure(graph) + + # Verify node types + nodes = {n["id"]: n for n in structure["nodes"]} + assert nodes["func"]["type"] == "function" + assert nodes["agent"]["type"] == "agent" + + def test_export_empty_graph(self): + """Test export of empty graph.""" + graph = GraphAgent(name="empty") + + structure = export_graph_structure(graph) + + assert structure["nodes"] == [] + assert structure["links"] == [] + assert structure["metadata"]["start_node"] is None + assert structure["metadata"]["end_nodes"] == [] + + def test_export_cyclic_graph(self): + """Test export of graph with cycles.""" + graph = GraphAgent(name="cyclic") + + graph.add_node(GraphNode(name="a", function=lambda s, c: "a")) + graph.add_node(GraphNode(name="b", function=lambda s, c: "b")) + graph.add_node(GraphNode(name="c", function=lambda s, c: "c")) + + # Create cycle: a -> b -> c -> a + graph.add_edge("a", "b") + graph.add_edge("b", "c") + graph.add_edge("c", "a") + + structure = export_graph_structure(graph) + + # Verify cycle is preserved + assert len(structure["links"]) == 3 + links = [(l["source"], l["target"]) for l in structure["links"]] + assert ("a", "b") in links + assert ("b", "c") in links + assert ("c", "a") in links + + +# ============================================================================ +# Run Tests +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) + + +# ============================================================================ +# Additional coverage tests (graph_agent.py lines 117-130, 364-382, 500-564, +# 612-650, 704-734, 750-810, 833-860, 868-932, 1215-1239, 1275-1338, +# 1623-1686, 1892, 1976-2065, 2078-2097, 2119-2170, 2191-2281) +# ============================================================================ + + +# --------------------------------------------------------------------------- +# _parse_condition_string — AST-safe evaluation +# --------------------------------------------------------------------------- + + +def test_parse_condition_string_safe_eval_success(): + """Safe condition string evaluates correctly.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + fn = _parse_condition_string("data.get('x') == 'yes'") + state = GraphState() + state.data["x"] = "yes" + assert fn(state) is True + + state.data["x"] = "no" + assert fn(state) is False + + +def test_parse_condition_string_comparison_operators(): + """Comparison operators work in conditions.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + state = GraphState() + state.data["count"] = 5 + + assert _parse_condition_string("data.get('count', 0) < 10")(state) is True + assert _parse_condition_string("data.get('count', 0) > 10")(state) is False + assert _parse_condition_string("data.get('count', 0) == 5")(state) is True + + +def test_parse_condition_string_boolean_ops(): + """Boolean and/or/not work in conditions.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + state = GraphState() + state.data["a"] = True + state.data["b"] = False + + fn_and = _parse_condition_string("data.get('a') and data.get('b')") + assert fn_and(state) is False + + fn_or = _parse_condition_string("data.get('a') or data.get('b')") + assert fn_or(state) is True + + fn_not = _parse_condition_string("not data.get('b')") + assert fn_not(state) is True + + +def test_parse_condition_string_is_none(): + """'is True', 'is None', 'is not None' work in conditions.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + state = GraphState() + state.data["val"] = None + + assert _parse_condition_string("data.get('val') is None")(state) is True + assert _parse_condition_string("data.get('val') is not None")(state) is False + + state.data["val"] = True + assert _parse_condition_string("data.get('val') is True")(state) is True + + +def test_parse_condition_string_in_operator(): + """'in' operator works in conditions.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + state = GraphState() + state.data["status"] = "CONTINUE_PROCESSING" + + fn = _parse_condition_string("'CONTINUE' in data.get('status', '')") + assert fn(state) is True + + state.data["status"] = "STOP" + assert fn(state) is False + + +def test_parse_condition_string_rejects_unsafe_names(): + """Unsafe names like __import__, os, etc. are rejected at parse time.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + with pytest.raises(ValueError, match="Unsafe"): + _parse_condition_string("__import__('os').system('rm -rf /')") + + with pytest.raises(ValueError, match="Unsafe"): + _parse_condition_string("os.system('ls')") + + +def test_parse_condition_string_rejects_dunder_traversal(): + """Attribute traversal attacks via __class__.__bases__ are rejected.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + with pytest.raises(ValueError, match="Unsafe"): + _parse_condition_string("state.__class__.__bases__[0].__subclasses__()") + + +def test_parse_condition_string_rejects_unsafe_calls(): + """Arbitrary function calls (not .get()) are rejected.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + with pytest.raises(ValueError, match="Unsafe"): + _parse_condition_string("print('hello')") + + with pytest.raises(ValueError, match="Unsafe method"): + _parse_condition_string("data.pop('key')") + + +def test_parse_condition_string_rejects_lambda(): + """Lambda expressions are rejected.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + with pytest.raises(ValueError, match="Unsafe"): + _parse_condition_string("(lambda: True)()") + + +# --------------------------------------------------------------------------- +# Safe builtins in condition strings +# --------------------------------------------------------------------------- + + +def test_parse_condition_string_allows_len(): + """len() is allowed in condition strings.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + state = GraphState() + state.data["items"] = [1, 2, 3] + + fn = _parse_condition_string("len(data.get('items', [])) > 0") + assert fn(state) is True + + state.data["items"] = [] + assert fn(state) is False + + +def test_parse_condition_string_allows_min_max(): + """min() and max() are allowed in condition strings.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + state = GraphState() + state.data["scores"] = [50, 85, 92] + + fn = _parse_condition_string("max(data.get('scores', [0])) > 80") + assert fn(state) is True + + fn_min = _parse_condition_string("min(data.get('scores', [0])) > 60") + assert fn_min(state) is False + + +def test_parse_condition_string_allows_int_conversion(): + """int() conversion is allowed in condition strings.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + state = GraphState() + state.data["count"] = "10" + + fn = _parse_condition_string("int(data.get('count', '0')) > 5") + assert fn(state) is True + + +def test_parse_condition_string_allows_isinstance(): + """isinstance() is allowed in condition strings.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + state = GraphState() + state.data["value"] = 42 + + fn = _parse_condition_string("isinstance(data.get('value'), int)") + assert fn(state) is True + + state.data["value"] = "not_int" + assert fn(state) is False + + +def test_parse_condition_string_still_rejects_dangerous(): + """Dangerous builtins like eval, exec, __import__, print are rejected.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + with pytest.raises(ValueError, match="Unsafe call"): + _parse_condition_string("eval('1+1')") + + with pytest.raises(ValueError, match="Unsafe call"): + _parse_condition_string("exec('import os')") + + with pytest.raises(ValueError, match="Unsafe call"): + _parse_condition_string("__import__('os')") + + with pytest.raises(ValueError, match="Unsafe call"): + _parse_condition_string("print('hello')") + + +# --------------------------------------------------------------------------- +# AST validation: exhaustive branch coverage for _validate_condition_ast +# --------------------------------------------------------------------------- + + +def test_validate_ast_expression_wrapper(): + """Line 127: ast.Expression wrapper is handled (defensive branch).""" + import ast + + from google.adk.agents.graph.graph_agent import _validate_condition_ast + + tree = ast.parse("True", mode="eval") + # Call with the Expression wrapper (not tree.body) + _validate_condition_ast(tree) # Should not raise + + +def test_validate_ast_unsafe_unary_op(): + """Line 133: Unary operators other than `not` are rejected.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + with pytest.raises(ValueError, match="Unsafe unary operator"): + _parse_condition_string("-data.get('x', 0)") + + with pytest.raises(ValueError, match="Unsafe unary operator"): + _parse_condition_string("~data.get('x', 0)") + + +def test_validate_ast_keyword_args(): + """Line 149: Keyword arguments in safe method calls are validated.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + # .get() with keyword arg — should pass validation + fn = _parse_condition_string("data.get(key='x')") + state = GraphState(data={"x": "val"}) + # Python's dict.get() doesn't accept 'key' kwarg, so it will raise at eval + # but the AST validation itself should succeed + assert fn(state) is False # eval error → returns False + + +def test_validate_ast_standalone_attribute(): + """Line 152: Standalone attribute access (not inside a call).""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + fn = _parse_condition_string("state.data") + state = GraphState(data={"x": 1}) + # state.data is truthy (non-empty dict) + assert fn(state) is True + + +def test_validate_ast_subscript(): + """Lines 154-155: Subscript access like data['key'].""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + fn = _parse_condition_string("data['x'] == 'yes'") + state = GraphState(data={"x": "yes"}) + assert fn(state) is True + + state2 = GraphState(data={"x": "no"}) + assert fn(state2) is False + + +def test_validate_ast_unsafe_standalone_name(): + """Line 158: Standalone unsafe name is rejected.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + with pytest.raises(ValueError, match="Unsafe name"): + _parse_condition_string("x") + + with pytest.raises(ValueError, match="Unsafe name"): + _parse_condition_string("os") + + +def test_validate_ast_unsafe_expression_type(): + """Line 162: Unsupported AST node types are rejected.""" + from google.adk.agents.graph.graph_agent import _parse_condition_string + + # Ternary expression → ast.IfExp + with pytest.raises(ValueError, match="Unsafe expression node"): + _parse_condition_string("True if True else False") + + # Dict literal → ast.Dict + with pytest.raises(ValueError, match="Unsafe expression node"): + _parse_condition_string("{'key': 'val'}") + + # List/Tuple literals are allowed (needed for builtins like len, min, max) + # Set comprehension → ast.SetComp (still rejected) + with pytest.raises(ValueError, match="Unsafe expression node"): + _parse_condition_string("{x for x in [1, 2]}") + + +def test_export_graph_with_execution_history(): + """Lines 500-564: enriches nodes/links with execution data and interrupt markers.""" + graph = GraphAgent(name="g") + graph.add_node(GraphNode(name="n1", function=lambda s, c: "x")) + graph.add_node(GraphNode(name="n2", function=lambda s, c: "y")) + graph.add_edge("n1", "n2") + graph.set_start("n1") + graph.set_end("n2") + + history = [ + {"node": "n1", "status": "success"}, + {"node": "n1", "status": "error"}, + {"node": "n2", "status": "success"}, + ] + state_hist = [{"state": {"x": 1}}, {"state": {"x": 2}}, {"state": {"x": 3}}] + markers = [{"node": "n1", "message": "manual check"}] + + result = export_graph_with_execution( + graph, + execution_history=history, + state_history=state_hist, + interrupt_markers=markers, + ) + + n1_data = next(n for n in result["nodes"] if n["id"] == "n1") + assert n1_data["execution_count"] == 2 + assert n1_data["status_summary"]["success"] == 1 + assert n1_data["status_summary"]["error"] == 1 + assert n1_data["interrupt_count"] == 1 + + # link traversal: n1→n2 appears once (indices 1→2) + link = next(k for k in result["links"] if k["source"] == "n1") + assert link["traversals"] == 1 + + assert result["execution_history"] == history + assert result["state_history"] == state_hist + + +# --------------------------------------------------------------------------- +# export_execution_timeline (lines 612-650) +# --------------------------------------------------------------------------- + + +def test_export_execution_timeline_with_history(): + """Lines 612-654: builds timeline from history with durations and iteration.""" + graph = GraphAgent(name="g") + graph.add_node(GraphNode(name="n1", function=lambda s, c: "x")) + graph.set_start("n1") + graph.set_end("n1") + + history = [ + {"node": "n1", "timestamp": 0.0, "iteration": 1, "status": "success"}, + {"node": "n1", "timestamp": 1.5, "iteration": 2, "status": "success"}, + ] + state_hist = [{"state": {"a": 1}}, {"state": {"a": 2}}] + + timeline = export_execution_timeline( + execution_history=history, state_history=state_hist + ) + + assert timeline["total_steps"] == 2 + assert timeline["iterations"] == 2 + assert abs(timeline["total_duration"] - 1.5) < 0.001 + assert timeline["timeline"][0]["duration"] == 1.5 + assert timeline["timeline"][1]["duration"] == 0 # last step has 0 duration + assert timeline["timeline"][0]["state"] == {"a": 1} + + +def test_export_execution_timeline_empty(): + """Line 612 early-return branch: empty history returns empty timeline.""" + graph = GraphAgent(name="g") + result = export_execution_timeline(execution_history=[]) + assert result["total_steps"] == 0 + assert result["timeline"] == [] + + +# --------------------------------------------------------------------------- +# Telemetry helpers (lines 768, 790-792, 810, 833-860, 868-932) +# --------------------------------------------------------------------------- + + +def test_should_sample_with_sampling_rate(): + """Line 768: returns random bool when sampling_rate < 1.0.""" + from google.adk.agents.graph.graph_agent_config import TelemetryConfig + + graph = GraphAgent( + name="g", + telemetry_config=TelemetryConfig(sampling_rate=0.5), + ) + results = {graph._should_sample() for _ in range(100)} + # With rate=0.5 and 100 samples, both True and False must appear + assert ( + True in results and False in results + ), "Expected both True and False with 100 samples at 0.5 rate" + assert isinstance(graph._should_sample(), bool) + + +def test_get_telemetry_attributes_merges_additional(): + """Lines 790-792: additional_attributes merged with base, base takes precedence.""" + from google.adk.agents.graph.graph_agent_config import TelemetryConfig + + graph = GraphAgent( + name="g", + telemetry_config=TelemetryConfig( + additional_attributes={"env": "prod", "ver": "1"} + ), + ) + result = graph._get_telemetry_attributes( + {"graph": "my_graph", "env": "override"} + ) + # Base attributes override additional_attributes + assert result["env"] == "override" + assert result["ver"] == "1" + assert result["graph"] == "my_graph" + + +def test_get_parent_telemetry_config_returns_dict_from_agent_states(): + """Returns dict when agent_states contains telemetry_config_dict.""" + graph = GraphAgent(name="g") + graph.add_node(GraphNode(name="n", function=lambda s, c: "x")) + graph.set_start("n") + graph.set_end("n") + + svc = InMemorySessionService() + session = Session(id="s", appName="app", userId="u") + + ctx = InvocationContext( + session=session, + session_service=svc, + invocation_id="inv1", + agent=SimpleTestAgent("dummy", ["x"]), + user_content=None, + ) + ctx.agent_states = { + "parent_graph": { + "telemetry_config_dict": {"enabled": True, "sampling_rate": 0.9} + } + } + + result = graph._get_parent_telemetry_config(ctx) + assert result == {"enabled": True, "sampling_rate": 0.9} + + +def test_get_effective_telemetry_config_uses_parent_when_no_own(): + """Lines 833-837: no own telemetry_config → build from parent dict.""" + from google.adk.agents.graph.graph_agent_config import TelemetryConfig + + graph = GraphAgent(name="g") # no telemetry_config set + svc = InMemorySessionService() + session = Session(id="s", appName="app", userId="u") + + ctx = InvocationContext( + session=session, + session_service=svc, + invocation_id="inv1", + agent=SimpleTestAgent("dummy", ["x"]), + user_content=None, + ) + ctx.agent_states = { + "parent_graph": { + "telemetry_config_dict": {"enabled": True, "sampling_rate": 0.7} + } + } + + effective = graph._get_effective_telemetry_config(ctx) + assert effective is not None + assert isinstance(effective, TelemetryConfig) + assert effective.sampling_rate == 0.7 + + +def test_get_effective_telemetry_config_merges_own_and_parent(): + """Lines 840-860: both own and parent config → merged, own takes precedence.""" + from google.adk.agents.graph.graph_agent_config import TelemetryConfig + + own = TelemetryConfig( + sampling_rate=0.3, + additional_attributes={"own_key": "own_val"}, + ) + graph = GraphAgent(name="g", telemetry_config=own) + + svc = InMemorySessionService() + session = Session(id="s", appName="app", userId="u") + + ctx = InvocationContext( + session=session, + session_service=svc, + invocation_id="inv1", + agent=SimpleTestAgent("dummy", ["x"]), + user_content=None, + ) + ctx.agent_states = { + "parent_graph": { + "telemetry_config_dict": { + "enabled": True, + "sampling_rate": 0.9, + "additional_attributes": { + "parent_key": "parent_val", + "own_key": "parent_override", + }, + } + } + } + + effective = graph._get_effective_telemetry_config(ctx) + assert effective is not None + # own sampling_rate wins + assert effective.sampling_rate == 0.3 + # additional_attributes: own_key comes from own (not overridden by parent) + assert effective.additional_attributes["own_key"] == "own_val" + # parent_key also included (merged) + assert effective.additional_attributes["parent_key"] == "parent_val" + + +# --------------------------------------------------------------------------- +# _should_interrupt_before/after with node filter (lines 2078-2097) +# --------------------------------------------------------------------------- + + +def test_parse_config_non_graph_config_passes_through(): + """Line 2122: non-GraphAgentConfig → kwargs unchanged.""" + + class OtherConfig: + pass + + original_kwargs = {"name": "x"} + result = GraphAgent._parse_config(OtherConfig(), "/tmp", original_kwargs) + assert result is original_kwargs + + +# --------------------------------------------------------------------------- +# from_config (lines 2191-2281) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_before_node_callback_receives_node_context(): + """Lines 1274-1342: before_node_callback is called with node name.""" + from google.adk.agents.graph.callbacks import NodeCallbackContext + + callback_nodes: list = [] + + async def my_callback(ctx: NodeCallbackContext) -> None: + callback_nodes.append(ctx.node.name) + + graph = GraphAgent(name="g", before_node_callback=my_callback) + graph.add_node(GraphNode(name="step", function=lambda s, c: "done")) + graph.set_start("step") + graph.set_end("step") + + svc = InMemorySessionService() + runner = Runner(app_name="app", agent=graph, session_service=svc) + await svc.create_session(app_name="app", user_id="u", session_id="s") + + async for _ in runner.run_async( + user_id="u", + session_id="s", + new_message=types.Content(role="user", parts=[types.Part(text="go")]), + ): + pass + + assert ( + "step" in callback_nodes + ), "before_node_callback should have been called with 'step'" + + +@pytest.mark.asyncio +async def test_after_node_callback_receives_node_context(): + """Lines 1623-1686: after_node_callback is invoked after each node.""" + from google.adk.agents.graph.callbacks import NodeCallbackContext + + after_calls: list = [] + + async def after_cb(ctx: NodeCallbackContext) -> None: + after_calls.append(ctx.node.name) + + graph = GraphAgent(name="g", after_node_callback=after_cb) + graph.add_node(GraphNode(name="node1", function=lambda s, c: "result")) + graph.set_start("node1") + graph.set_end("node1") + + svc = InMemorySessionService() + runner = Runner(app_name="app", agent=graph, session_service=svc) + await svc.create_session(app_name="app", user_id="u", session_id="s") + + async for _ in runner.run_async( + user_id="u", + session_id="s", + new_message=types.Content(role="user", parts=[types.Part(text="go")]), + ): + pass + + assert "node1" in after_calls + + +@pytest.mark.asyncio +def test_export_graph_with_execution_node_not_in_history(): + """Lines 529-530: node present in graph but absent from execution_history. + + When the execution_history contains entries for some nodes but not all, + the else branch assigns executions=[] and execution_count=0. + """ + graph = GraphAgent(name="g") + graph.add_node(GraphNode(name="n1", function=lambda s, c: "x")) + graph.add_node(GraphNode(name="n2", function=lambda s, c: "y")) + graph.add_edge("n1", "n2") + graph.set_start("n1") + graph.set_end("n2") + + # Only n1 appears in history; n2 is absent → hits lines 529-530 + history = [{"node": "n1", "status": "success"}] + result = export_graph_with_execution(graph, execution_history=history) + + n2_data = next(n for n in result["nodes"] if n["id"] == "n2") + assert n2_data["executions"] == [] + assert n2_data["execution_count"] == 0 + + +# --------------------------------------------------------------------------- +# rewind_to_node – session not found (line 708) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_rewind_to_node_session_not_found(): + """Line 708: rewind_to_node raises ValueError when session doesn't exist.""" + graph = GraphAgent(name="g") + svc = InMemorySessionService() + + with pytest.raises(ValueError, match="Session not found: no_such_session"): + await rewind_to_node( + graph, + session_service=svc, + app_name="app", + user_id="u", + session_id="no_such_session", + node_name="n1", + ) + + +# --------------------------------------------------------------------------- +# _get_effective_telemetry_config – no own config, no parent → None (line 838) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.asyncio +async def test_edge_condition_exception_propagates(): + """Lines 1132-1137: exception raised inside edge condition is re-raised. + + The span attributes are set and the exception is re-raised from inside + the telemetry wrapper. + """ + + def bad_condition(state): + raise RuntimeError("edge evaluation failed") + + graph = GraphAgent(name="g") + graph.add_node(GraphNode(name="n1", function=lambda s, c: "x")) + graph.add_node(GraphNode(name="n2", function=lambda s, c: "y")) + graph.add_edge("n1", "n2", condition=bad_condition) + graph.set_start("n1") + graph.set_end("n2") + + runner = Runner( + app_name="app", agent=graph, session_service=InMemorySessionService() + ) + svc = runner.session_service + await svc.create_session(app_name="app", user_id="u", session_id="s") + + with pytest.raises(RuntimeError, match="edge evaluation failed"): + async for _ in runner.run_async( + user_id="u", + session_id="s", + new_message=types.Content(role="user", parts=[types.Part(text="go")]), + ): + pass + + +# --------------------------------------------------------------------------- +# _run_async_impl – effective_config stored in session.state (line 1170) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_run_async_stores_telemetry_config_in_agent_state(): + """When telemetry_config is set, _run_async_impl stores it in agent_state.""" + from google.adk.agents.graph.graph_agent_config import TelemetryConfig + + telemetry = TelemetryConfig(enabled=True, trace_nodes=True) + graph = GraphAgent(name="g", telemetry_config=telemetry) + graph.add_node(GraphNode(name="n", function=lambda s, c: "x")) + graph.set_start("n") + graph.set_end("n") + + svc = InMemorySessionService() + session = Session(id="s", appName="app", userId="u") + ctx = InvocationContext( + session=session, + session_service=svc, + invocation_id="inv", + agent=SimpleTestAgent("a", ["x"]), + user_content=types.Content(role="user", parts=[types.Part(text="go")]), + ) + + # Run the graph and collect agent_state events + # (end_of_agent=True clears ctx.agent_states, so inspect events instead) + agent_state_dict = {} + async for event in graph._run_async_impl(ctx): + if event.actions and event.actions.agent_state: + agent_state_dict = event.actions.agent_state + + assert "telemetry_config_dict" in agent_state_dict + assert agent_state_dict["telemetry_config_dict"]["enabled"] is True + + +# --------------------------------------------------------------------------- +# _execute_interrupt_action – pause action returns "pause" (line 2033) +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# BEFORE interrupt: go_back tuple (lines 1363-1364, 1391-1394) +# BEFORE interrupt: rerun → continue (line 1396) +# BEFORE interrupt: skip + no next node → break (line 1403) +# BEFORE interrupt: pause + wait_if_paused cancelled (lines 1407-1416) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.asyncio +@pytest.mark.asyncio +@pytest.mark.asyncio +@pytest.mark.asyncio +@pytest.mark.asyncio +@pytest.mark.asyncio +@pytest.mark.asyncio +@pytest.mark.asyncio +@pytest.mark.asyncio +async def test_cancelled_error_during_node_execution(): + """Lines 1583-1614: asyncio.CancelledError during node execution yields + a cancel event with state_delta and re-raises.""" + + class CancellingAgent(BaseAgent): + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + async def _run_async_impl(self, ctx): + raise asyncio.CancelledError() + yield # noqa: unreachable – needed to make this an async generator + + graph = GraphAgent(name="g", max_iterations=3) + graph.add_node( + GraphNode(name="cancel_node", agent=CancellingAgent(name="cancel_agent")) + ) + graph.set_start("cancel_node") + graph.set_end("cancel_node") + + svc = InMemorySessionService() + runner = Runner(app_name="app", agent=graph, session_service=svc) + await svc.create_session(app_name="app", user_id="u", session_id="s") + + events = [] + with pytest.raises((asyncio.CancelledError, Exception)): + async for event in runner.run_async( + user_id="u", + session_id="s", + new_message=types.Content(role="user", parts=[types.Part(text="go")]), + ): + events.append(event) + + # The cancel event should have been yielded before re-raising + cancel_events = [ + e + for e in events + if e.content + and e.content.parts + and "cancelled" in (e.content.parts[0].text or "").lower() + ] + assert len(cancel_events) >= 1 + # state_delta should include graph_task_cancelled flag + assert any( + e.actions + and e.actions.state_delta + and e.actions.state_delta.get("graph_task_cancelled") + for e in cancel_events + ) + + +# --------------------------------------------------------------------------- +# AFTER interrupt: go_back tuple (lines 1738-1739, 1786-1789) +# AFTER interrupt: pause + cancelled (lines 1795-1806) +# AFTER interrupt: pause + timeout (lines 1807-1810) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.asyncio +@pytest.mark.asyncio +def test_parse_config_callback_refs_resolved(): + """Lines 2150, 2158, 2166: _parse_config resolves before_node_callback_ref, + after_node_callback_ref and on_edge_condition_callback_ref via resolve_code_reference. + + GraphAgentConfig doesn't define these fields, so we create a subclass that + does — isinstance(config, GraphAgentConfig) still passes. + """ + from typing import Optional + + from google.adk.agents.graph.graph_agent_config import GraphAgentConfig + + class ExtendedConfig(GraphAgentConfig): + model_config = {"extra": "allow"} + + before_node_callback_ref: Optional[str] = None + after_node_callback_ref: Optional[str] = None + on_edge_condition_callback_ref: Optional[str] = None + + async def my_callback(ctx): + pass + + config = ExtendedConfig( + name="ext", + start_node="placeholder", + end_nodes=["placeholder"], + before_node_callback_ref="some.module.before", + after_node_callback_ref="some.module.after", + on_edge_condition_callback_ref="some.module.on_edge", + ) + + with patch( + "google.adk.agents.config_agent_utils.resolve_code_reference", + return_value=my_callback, + ): + kwargs = GraphAgent._parse_config(config, "/tmp", {}) + + assert kwargs["before_node_callback"] is my_callback + assert kwargs["after_node_callback"] is my_callback + assert kwargs["on_edge_condition_callback"] is my_callback + + +# --------------------------------------------------------------------------- +# from_config – non-GraphAgentConfig early return (line 2204) +# --------------------------------------------------------------------------- + + +def test_from_config_non_graph_agent_config_returns_graph_agent(): + """Line 2204: when config is NOT a GraphAgentConfig, from_config returns + the base graph instance without additional graph setup.""" + from google.adk.agents.base_agent_config import BaseAgentConfig + + base_config = BaseAgentConfig(name="base_graph") + result = GraphAgent.from_config(base_config, "/tmp") + + assert isinstance(result, GraphAgent) + # No nodes/edges added + assert len(result.nodes) == 0 + + +# --------------------------------------------------------------------------- +# from_config – sub_agents in node config (lines 2212-2214) +# --------------------------------------------------------------------------- + + +def test_from_config_node_with_sub_agents(): + """Lines 2212-2214: node_config.sub_agents triggers resolve_agent_reference.""" + from google.adk.agents.common_configs import AgentRefConfig + from google.adk.agents.graph.graph_agent_config import GraphAgentConfig + from google.adk.agents.graph.graph_agent_config import GraphNodeConfig + + async def _dummy_fn(state, ctx): + return "ok" + + # Create node config with a sub_agent reference (code-based) + config = GraphAgentConfig( + name="with_sub", + nodes=[ + GraphNodeConfig( + name="n1", + sub_agents=[AgentRefConfig(code="my.module.my_agent")], + ) + ], + start_node="n1", + end_nodes=["n1"], + ) + + mock_agent = SimpleTestAgent("resolved_agent", ["ok"]) + + with patch( + "google.adk.agents.config_agent_utils.resolve_agent_reference", + return_value=mock_agent, + ): + graph = GraphAgent.from_config(config, "/tmp") + + assert "n1" in graph.nodes + assert graph.nodes["n1"].agent is mock_agent + + +# --------------------------------------------------------------------------- +# from_config – edge with unknown source node → ValueError (line 2251) +# --------------------------------------------------------------------------- + + +def test_from_config_edge_unknown_source_node_raises(): + """Line 2251: edge from_node references a node not in the graph → ValueError.""" + from google.adk.agents.graph.graph_agent_config import GraphAgentConfig + from google.adk.agents.graph.graph_agent_config import GraphEdgeConfig + from google.adk.agents.graph.graph_agent_config import GraphNodeConfig + + async def _dummy_fn(state, ctx): + return "ok" + + config = GraphAgentConfig( + name="bad_edge", + nodes=[GraphNodeConfig(name="n1", function_ref="dummy.n1")], + edges=[ + GraphEdgeConfig( + source_node="nonexistent_source", + target_node="n1", + ) + ], + start_node="n1", + end_nodes=["n1"], + ) + + with patch( + "google.adk.agents.config_agent_utils.resolve_code_reference", + return_value=_dummy_fn, + ): + with pytest.raises( + ValueError, match="Source node nonexistent_source not found" + ): + GraphAgent.from_config(config, "/tmp") + + +# ============================================================================ +# Test: Sub-Agent Registration via add_node +# ============================================================================ + + +class TestSubAgentRegistration: + """Test that GraphAgent registers node agents in sub_agents.""" + + def test_add_node_registers_agent_in_sub_agents(self): + """Agent nodes should appear in graph.sub_agents after add_node.""" + graph = GraphAgent(name="g") + agent = SimpleTestAgent(name="a", responses=["ok"]) + graph.add_node(GraphNode(name="n", agent=agent)) + assert agent in graph.sub_agents + assert len(graph.sub_agents) == 1 + + def test_function_node_not_in_sub_agents(self): + """Function-only nodes should NOT add anything to sub_agents.""" + graph = GraphAgent(name="g") + + async def fn(state, ctx): + return "done" + + graph.add_node("fn_node", function=fn) + assert len(graph.sub_agents) == 0 + + def test_parent_agent_set_on_node_agent(self): + """Node agent's parent_agent should be set to the graph.""" + graph = GraphAgent(name="g") + agent = SimpleTestAgent(name="a", responses=["ok"]) + graph.add_node(GraphNode(name="n", agent=agent)) + assert agent.parent_agent is graph + + def test_find_agent_finds_node_agent(self): + """graph.find_agent should find agents registered via add_node.""" + graph = GraphAgent(name="g") + agent = SimpleTestAgent(name="a", responses=["ok"]) + graph.add_node(GraphNode(name="n", agent=agent)) + assert graph.find_agent("a") is agent + + def test_find_agent_finds_graph_itself(self): + """graph.find_agent(graph.name) should return graph itself.""" + graph = GraphAgent(name="g") + assert graph.find_agent("g") is graph + + def test_find_agent_returns_none_for_unknown(self): + """find_agent returns None for non-existent name.""" + graph = GraphAgent(name="g") + agent = SimpleTestAgent(name="a", responses=["ok"]) + graph.add_node(GraphNode(name="n", agent=agent)) + assert graph.find_agent("nonexistent") is None + + def test_find_sub_agent_searches_nested(self): + """find_agent should recursively find agents inside node agent sub_agents.""" + graph = GraphAgent(name="g") + inner = SimpleTestAgent(name="inner", responses=["ok"]) + outer = SimpleTestAgent(name="outer", responses=["ok"]) + outer.sub_agents = [inner] + inner.parent_agent = outer + graph.add_node(GraphNode(name="n", agent=outer)) + assert graph.find_agent("inner") is inner + + def test_duplicate_agent_name_raises(self): + """Adding two nodes with agents of the same name should raise.""" + graph = GraphAgent(name="g") + a1 = SimpleTestAgent(name="a", responses=["ok"]) + a2 = SimpleTestAgent(name="a", responses=["ok"]) + graph.add_node(GraphNode(name="n1", agent=a1)) + with pytest.raises(ValueError, match="Duplicate sub_agent name"): + graph.add_node(GraphNode(name="n2", agent=a2)) + + def test_agent_with_parent_raises(self): + """Agent already parented to another graph should raise.""" + g1 = GraphAgent(name="g1") + g2 = GraphAgent(name="g2") + agent = SimpleTestAgent(name="a", responses=["ok"]) + g1.add_node(GraphNode(name="n1", agent=agent)) + with pytest.raises(ValueError, match="already has a parent"): + g2.add_node(GraphNode(name="n2", agent=agent)) + + def test_sub_agents_count_matches_agent_nodes(self): + """N agent nodes should produce N entries in sub_agents.""" + graph = GraphAgent(name="g") + agents = [] + for i in range(5): + a = SimpleTestAgent(name=f"a{i}", responses=["ok"]) + agents.append(a) + graph.add_node(GraphNode(name=f"n{i}", agent=a)) + assert len(graph.sub_agents) == 5 + for a in agents: + assert a in graph.sub_agents + + def test_agent_name_matches_graph_name_raises(self): + """Agent with same name as graph should raise ValueError.""" + graph = GraphAgent(name="g") + agent = SimpleTestAgent(name="g", responses=["ok"]) + with pytest.raises(ValueError, match="collides with GraphAgent name"): + graph.add_node(GraphNode(name="n", agent=agent)) + + def test_same_agent_instance_two_nodes_skips_second(self): + """Same agent instance in two nodes: registered once, no error.""" + graph = GraphAgent(name="g") + agent = SimpleTestAgent(name="a", responses=["ok"]) + node1 = GraphNode(name="n1", agent=agent) + node2 = GraphNode(name="n2", agent=agent) + graph.add_node(node1) + graph.add_node(node2) + assert len(graph.sub_agents) == 1 + assert graph.sub_agents[0] is agent + + def test_convenience_add_node_registers(self): + """Convenience add_node("name", agent=...) also registers.""" + graph = GraphAgent(name="g") + agent = SimpleTestAgent(name="a", responses=["ok"]) + graph.add_node("n", agent=agent) + assert agent in graph.sub_agents + assert agent.parent_agent is graph + + +# ============================================================================ +# Test: add_node error paths (lines 387, 392, 398, 402, 405, 413) +# ============================================================================ + + +class TestAddNodeErrors: + """Cover every error branch in GraphAgent.add_node().""" + + def test_graphnode_with_extra_agent_raises(self): + """Passing GraphNode + agent= kwarg raises ValueError.""" + graph = GraphAgent(name="g") + node = GraphNode(name="n", function=lambda s, c: "x") + extra = SimpleTestAgent(name="extra", responses=["x"]) + with pytest.raises(ValueError, match="do not specify agent"): + graph.add_node(node, agent=extra) + + def test_graphnode_with_extra_function_raises(self): + """Passing GraphNode + function= kwarg raises ValueError.""" + graph = GraphAgent(name="g") + agent = SimpleTestAgent(name="a", responses=["x"]) + node = GraphNode(name="n", agent=agent) + with pytest.raises(ValueError, match="do not specify agent"): + graph.add_node(node, function=lambda s, c: "x") + + def test_graphnode_with_extra_kwargs_raises(self): + """Passing GraphNode + arbitrary kwargs raises ValueError.""" + graph = GraphAgent(name="g") + node = GraphNode(name="n", function=lambda s, c: "x") + with pytest.raises(ValueError, match="do not specify agent"): + graph.add_node(node, reducer="bogus") + + def test_graphnode_duplicate_name_raises(self): + """Adding GraphNode with name already in graph raises ValueError.""" + graph = GraphAgent(name="g") + graph.add_node(GraphNode(name="dup", function=lambda s, c: "x")) + with pytest.raises(ValueError, match="already exists in graph"): + graph.add_node(GraphNode(name="dup", function=lambda s, c: "y")) + + def test_string_no_agent_no_function_raises(self): + """String name without agent or function raises ValueError.""" + graph = GraphAgent(name="g") + with pytest.raises(ValueError, match="must specify agent or function"): + graph.add_node("n") + + def test_string_both_agent_and_function_raises(self): + """String name with both agent and function raises ValueError.""" + graph = GraphAgent(name="g") + agent = SimpleTestAgent(name="a", responses=["x"]) + with pytest.raises(ValueError, match="Cannot specify both"): + graph.add_node("n", agent=agent, function=lambda s, c: "x") + + def test_string_duplicate_name_raises(self): + """Adding string node with name already in graph raises ValueError.""" + graph = GraphAgent(name="g") + graph.add_node("dup", function=lambda s, c: "x") + with pytest.raises(ValueError, match="already exists in graph"): + graph.add_node("dup", function=lambda s, c: "y") + + def test_invalid_type_raises(self): + """Passing non-GraphNode non-str raises TypeError.""" + graph = GraphAgent(name="g") + with pytest.raises(TypeError, match="node must be GraphNode or str"): + graph.add_node(123) + + +# ============================================================================ +# Test: find_sub_agent fallback (lines 515, 517) +# ============================================================================ + + +class TestFindSubAgentFallback: + """Cover fallback search in overridden find_sub_agent.""" + + def test_fallback_finds_unregistered_agent(self): + """Agent added to node AFTER add_node should be found via fallback.""" + graph = GraphAgent(name="g") + graph.add_node("fn_node", function=lambda s, c: "x") + # Manually assign an agent to the node (bypasses registration) + sneaky = SimpleTestAgent(name="sneaky", responses=["x"]) + graph.nodes["fn_node"].agent = sneaky + # Not in sub_agents + assert sneaky not in graph.sub_agents + # But found via fallback + assert graph.find_sub_agent("sneaky") is sneaky + + def test_fallback_finds_nested_in_unregistered_agent(self): + """Recursive search through unregistered agent's sub_agents.""" + graph = GraphAgent(name="g") + graph.add_node("fn_node", function=lambda s, c: "x") + deep = SimpleTestAgent(name="deep", responses=["x"]) + wrapper = SimpleTestAgent(name="wrapper", responses=["x"]) + wrapper.sub_agents = [deep] + deep.parent_agent = wrapper + graph.nodes["fn_node"].agent = wrapper + # Not in sub_agents + assert wrapper not in graph.sub_agents + # Found recursively via fallback + assert graph.find_sub_agent("deep") is deep + + +# ============================================================================ +# Test: _validate_node_configuration (lines 532-533) +# ============================================================================ + + +class TestValidateNodeConfiguration: + """Cover _validate_node_configuration auto-defaulted output_key warning.""" + + def test_llm_agent_auto_defaulted_output_key_warns(self): + """LlmAgent with output_schema and auto-defaulted output_key triggers warning.""" + from pydantic import BaseModel + + class OutputSchema(BaseModel): + result: str + + # GraphNode auto-defaults output_key to agent.name when + # output_schema is set but output_key is not. + agent = MockLlmAgent(name="llm_a", output_schema=OutputSchema) + # The GraphNode constructor copies the agent with output_key set + node = GraphNode(name="n", agent=agent) + # After auto-default, node.agent.output_key == node.agent.name + assert node.agent.output_key == node.agent.name + + graph = GraphAgent(name="g") + # _validate_node_configuration should log warning + import logging + + with patch( + "google.adk.agents.graph.graph_agent.logger" + ) as mock_logger: + graph.add_node(node) + mock_logger.warning.assert_called_once() + + +# ============================================================================ +# Test: add_edge EdgeCondition pattern (lines 634-654, 668, 675-681, 686) +# ============================================================================ + + +class TestAddEdgeEdgeCondition: + """Cover add_edge with EdgeCondition objects, priority/weight, duplicates.""" + + def _make_graph(self): + graph = GraphAgent(name="g") + graph.add_node("src", function=lambda s, c: "x") + graph.add_node("tgt", function=lambda s, c: "y") + graph.add_node("tgt2", function=lambda s, c: "z") + return graph + + def test_edge_condition_with_extra_params_raises(self): + """EdgeCondition + condition/priority/weight kwargs raises ValueError.""" + graph = self._make_graph() + ec = EdgeCondition(target_node="tgt") + with pytest.raises( + ValueError, match="do not specify condition, priority, or weight" + ): + graph.add_edge("src", ec, condition=lambda s: True) + + def test_edge_condition_with_extra_priority_raises(self): + """EdgeCondition + priority kwarg raises ValueError.""" + graph = self._make_graph() + ec = EdgeCondition(target_node="tgt") + with pytest.raises( + ValueError, match="do not specify condition, priority, or weight" + ): + graph.add_edge("src", ec, priority=5) + + def test_edge_condition_target_not_found_raises(self): + """EdgeCondition with non-existent target raises ValueError.""" + graph = self._make_graph() + ec = EdgeCondition(target_node="nonexistent") + with pytest.raises(ValueError, match="Target node nonexistent not found"): + graph.add_edge("src", ec) + + def test_edge_condition_duplicate_raises(self): + """Adding same EdgeCondition target twice raises ValueError.""" + graph = self._make_graph() + graph.add_edge("src", EdgeCondition(target_node="tgt")) + with pytest.raises(ValueError, match="already exists"): + graph.add_edge("src", EdgeCondition(target_node="tgt")) + + def test_edge_condition_appends(self): + """EdgeCondition appended correctly to node edges.""" + graph = self._make_graph() + ec = EdgeCondition(target_node="tgt", condition=lambda s: True, priority=5) + graph.add_edge("src", ec) + assert len(graph.nodes["src"].edges) == 1 + assert graph.nodes["src"].edges[0].target_node == "tgt" + assert graph.nodes["src"].edges[0].priority == 5 + + def test_string_duplicate_edge_raises(self): + """Adding same string edge twice raises ValueError.""" + graph = self._make_graph() + graph.add_edge("src", "tgt") + with pytest.raises(ValueError, match="already exists"): + graph.add_edge("src", "tgt") + + def test_string_with_priority_creates_edge_condition(self): + """String edge with priority creates EdgeCondition internally.""" + graph = self._make_graph() + graph.add_edge("src", "tgt", priority=10, weight=0.5) + assert len(graph.nodes["src"].edges) == 1 + edge = graph.nodes["src"].edges[0] + assert edge.target_node == "tgt" + assert edge.priority == 10 + assert edge.weight == 0.5 + + def test_string_with_weight_only_creates_edge_condition(self): + """String edge with weight only creates EdgeCondition.""" + graph = self._make_graph() + graph.add_edge("src", "tgt", weight=0.7) + edge = graph.nodes["src"].edges[0] + assert edge.priority == 1 # default + assert edge.weight == 0.7 + + def test_invalid_target_type_raises(self): + """Non-str non-EdgeCondition target raises TypeError.""" + graph = self._make_graph() + with pytest.raises(TypeError, match="target_node must be str or EdgeCondition"): + graph.add_edge("src", 42) + + +# ============================================================================ +# Test: Callback returns Event + sampling (lines 1154-1163, 1488-1497) +# ============================================================================ + + +@pytest.mark.asyncio +class TestCallbackReturnsEvent: + """Cover before/after_node_callback returning an Event (truthy path).""" + + async def test_before_node_callback_returns_event(self): + """Async before_node_callback returning Event yields it.""" + from google.adk.agents.graph.callbacks import NodeCallbackContext + + yielded_events = [] + + async def before_cb(ctx: NodeCallbackContext): + return Event( + author="before_cb", + content=types.Content( + parts=[types.Part(text="before_event")] + ), + ) + + graph = GraphAgent(name="g", before_node_callback=before_cb) + graph.add_node("step", function=lambda s, c: "done") + graph.set_start("step") + graph.set_end("step") + + svc = InMemorySessionService() + runner = Runner(app_name="app", agent=graph, session_service=svc) + await svc.create_session(app_name="app", user_id="u", session_id="s") + + async for event in runner.run_async( + user_id="u", + session_id="s", + new_message=types.Content( + role="user", parts=[types.Part(text="go")] + ), + ): + yielded_events.append(event) + + # The before callback event should appear in the stream + before_texts = [ + e.content.parts[0].text + for e in yielded_events + if e.content and e.content.parts and e.content.parts[0].text == "before_event" + ] + assert len(before_texts) == 1 + + async def test_after_node_callback_returns_event(self): + """Async after_node_callback returning Event yields it.""" + from google.adk.agents.graph.callbacks import NodeCallbackContext + + async def after_cb(ctx: NodeCallbackContext): + return Event( + author="after_cb", + content=types.Content( + parts=[types.Part(text="after_event")] + ), + ) + + graph = GraphAgent(name="g", after_node_callback=after_cb) + graph.add_node("step", function=lambda s, c: "done") + graph.set_start("step") + graph.set_end("step") + + svc = InMemorySessionService() + runner = Runner(app_name="app", agent=graph, session_service=svc) + await svc.create_session(app_name="app", user_id="u", session_id="s") + + yielded_events = [] + async for event in runner.run_async( + user_id="u", + session_id="s", + new_message=types.Content( + role="user", parts=[types.Part(text="go")] + ), + ): + yielded_events.append(event) + + after_texts = [ + e.content.parts[0].text + for e in yielded_events + if e.content and e.content.parts and e.content.parts[0].text == "after_event" + ] + assert len(after_texts) == 1 + + +# ============================================================================ +# Coverage tests: edge-case code paths for 100% coverage +# ============================================================================ + + +class _CovFailingAgent(BaseAgent): + """Agent that raises on execution (for coverage tests).""" + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + def __init__(self, name: str, error_msg: str = "boom"): + super().__init__(name=name) + object.__setattr__(self, "_error_msg", error_msg) + + async def _run_async_impl(self, ctx): + msg = object.__getattribute__(self, "_error_msg") + raise RuntimeError(msg) + yield # noqa: E711 + + +class _CovMultiEventAgent(BaseAgent): + """Agent yielding multiple events (for mid-execution cancellation tests).""" + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + def __init__(self, name: str, event_count: int = 3): + super().__init__(name=name) + object.__setattr__(self, "_event_count", event_count) + + async def _run_async_impl(self, ctx): + n = object.__getattribute__(self, "_event_count") + for i in range(n): + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=f"event_{i}")]), + ) + + +def _cov_make_ctx( + agent, + *, + resumable=False, + session_state=None, +): + svc = InMemorySessionService() + state = session_state or {} + session = Session( + id="test-session", appName="test", userId="test-user", state=state + ) + ctx = InvocationContext( + session=session, + session_service=svc, + invocation_id="inv-1", + agent=agent, + user_content=types.Content( + role="user", parts=[types.Part(text="test")] + ), + ) + if resumable: + ctx.resumability_config = ResumabilityConfig(is_resumable=True) + ctx.run_config = RunConfig() + return ctx + + +async def _cov_collect(graph, ctx): + events = [] + async for event in graph._run_async_impl(ctx): + events.append(event) + return events + + +def _cov_linear_graph(name, agents, names): + graph = GraphAgent(name=name) + for nname, agent in zip(names, agents): + graph.add_node(GraphNode(name=nname, agent=agent)) + for i in range(len(names) - 1): + graph.add_edge(names[i], names[i + 1]) + graph.set_start(names[0]) + graph.set_end(names[-1]) + return graph + + +class TestGetNodeAgent: + def test_regular_node_returns_agent(self): + agent = SimpleTestAgent("a", ["ok"]) + node = GraphNode(name="n", agent=agent) + outer = GraphAgent(name="outer") + assert outer._get_node_agent(node) is agent + + +@pytest.mark.asyncio +class TestDomainDataFromSession: + async def test_session_state_populates_domain_data(self): + agent_a = SimpleTestAgent("a", ["done"]) + graph = _cov_linear_graph("g", [agent_a], ["nA"]) + ctx = _cov_make_ctx(graph, session_state={"my_key": "my_value", "another": 42}) + events = await _cov_collect(graph, ctx) + final = [ + e for e in events + if e.actions and e.actions.state_delta + and "graph_data" in (e.actions.state_delta or {}) + ] + assert len(final) == 1 + graph_data = final[0].actions.state_delta["graph_data"] + assert graph_data["my_key"] == "my_value" + assert graph_data["another"] == 42 + + async def test_internal_keys_excluded_from_domain_data(self): + agent_a = SimpleTestAgent("a", ["done"]) + graph = _cov_linear_graph("g", [agent_a], ["nA"]) + ctx = _cov_make_ctx( + graph, + session_state={ + "my_key": "ok", + "graph_data": {"old": "stale"}, + "graph_cancelled": True, + "_private": "hidden", + }, + ) + events = await _cov_collect(graph, ctx) + final = [ + e for e in events + if e.actions and e.actions.state_delta + and "graph_data" in (e.actions.state_delta or {}) + ] + assert len(final) == 1 + graph_data = final[0].actions.state_delta["graph_data"] + assert graph_data["my_key"] == "ok" + assert "graph_data" not in graph_data + assert "graph_cancelled" not in graph_data + assert "_private" not in graph_data + + +@pytest.mark.asyncio +@pytest.mark.asyncio +class TestBeforeNodeCallbackException: + async def test_before_callback_failure_continues_execution(self): + agent_a = SimpleTestAgent("a", ["a_out"]) + graph = _cov_linear_graph("g", [agent_a], ["nA"]) + + async def failing_callback(ctx): + raise ValueError("callback_error") + + graph.before_node_callback = failing_callback + ctx = _cov_make_ctx(graph) + events = await _cov_collect(graph, ctx) + assert agent_a.call_count == 1 + + +@pytest.mark.asyncio +@pytest.mark.asyncio +@pytest.mark.asyncio +class TestOutputMapperNoneFallback: + async def test_output_mapper_returning_none_uses_prev_state(self): + agent_a = SimpleTestAgent("a", ["a_out"]) + graph = GraphAgent(name="g") + graph.add_node( + GraphNode( + name="nA", + agent=agent_a, + output_mapper=lambda output, state: state.data.update( + {"custom_key": output} + ), + ) + ) + graph.set_start("nA") + graph.set_end("nA") + ctx = _cov_make_ctx(graph) + events = await _cov_collect(graph, ctx) + assert agent_a.call_count == 1 + final = [ + e for e in events + if e.actions and e.actions.state_delta + and "graph_data" in (e.actions.state_delta or {}) + ] + assert len(final) == 1 + assert final[0].actions.state_delta["graph_data"].get("custom_key") == "a_out" + + +@pytest.mark.asyncio +class TestAfterNodeCallbackException: + async def test_after_callback_failure_continues_execution(self): + agent_a = SimpleTestAgent("a", ["a_out"]) + agent_b = SimpleTestAgent("b", ["b_out"]) + graph = _cov_linear_graph("g", [agent_a, agent_b], ["nA", "nB"]) + + async def failing_callback(ctx): + raise ValueError("after_callback_error") + + graph.after_node_callback = failing_callback + ctx = _cov_make_ctx(graph) + events = await _cov_collect(graph, ctx) + assert agent_a.call_count == 1 + assert agent_b.call_count == 1 + + +@pytest.mark.asyncio +class TestNodeExecutionException: + async def test_node_exception_raises_and_records_metrics(self): + failing = _CovFailingAgent("fail", "node_error") + graph = _cov_linear_graph("g", [failing], ["nA"]) + ctx = _cov_make_ctx(graph) + with pytest.raises(RuntimeError, match="node_error"): + await _cov_collect(graph, ctx) + + +@pytest.mark.asyncio +class TestConditionEvalLogging: + """Condition evaluation failures must log with exc_info for debugging.""" + + async def test_condition_eval_failure_logs_with_exc_info(self): + from google.adk.agents.graph.graph_agent import _parse_condition_string + + with patch("google.adk.agents.graph.graph_agent.logger") as mock_logger: + cond_fn = _parse_condition_string("data.get('x')['missing']") + state = GraphState(data={"x": "not_a_dict"}) + result = cond_fn(state) + assert result is False + mock_logger.error.assert_called_once() + assert mock_logger.error.call_args[1].get("exc_info") is True + diff --git a/tests/unittests/agents/test_graph_agent_config.py b/tests/unittests/agents/test_graph_agent_config.py new file mode 100644 index 0000000000..188a3ffe73 --- /dev/null +++ b/tests/unittests/agents/test_graph_agent_config.py @@ -0,0 +1,438 @@ +"""Test suite for GraphAgent configuration classes. + +Tests Pydantic model validation for all graph-related config classes: +- GraphNodeConfig +- GraphEdgeConfig +- InterruptConfigYaml +- ParallelGroupConfig +- GraphAgentConfig +""" + +from google.adk.agents.graph.graph_agent_config import GraphAgentConfig +from google.adk.agents.graph.graph_agent_config import GraphEdgeConfig +from google.adk.agents.graph.graph_agent_config import GraphNodeConfig +from google.adk.agents.graph.graph_agent_config import InterruptConfigYaml +from google.adk.agents.graph.graph_agent_config import ParallelGroupConfig +from pydantic import ValidationError +import pytest + + +class TestGraphNodeConfig: + """Tests for GraphNodeConfig validation.""" + + def test_minimal_node_config(self): + """Test minimal valid node configuration.""" + config = GraphNodeConfig(name="test_node") + assert config.name == "test_node" + assert config.function_ref is None + assert config.reducer == "overwrite" + + def test_node_with_function_ref(self): + """Test node with function reference.""" + config = GraphNodeConfig( + name="test_node", + function_ref="my_module.my_function", + ) + assert config.function_ref == "my_module.my_function" + + def test_node_with_mappers(self): + """Test node with input/output mappers.""" + config = GraphNodeConfig( + name="test_node", + input_mapper_ref="mappers.input_fn", + output_mapper_ref="mappers.output_fn", + ) + assert config.input_mapper_ref == "mappers.input_fn" + assert config.output_mapper_ref == "mappers.output_fn" + + def test_node_with_reducers(self): + """Test node with different reducer strategies.""" + # Overwrite (default) + config1 = GraphNodeConfig(name="node1") + assert config1.reducer == "overwrite" + + # Append + config2 = GraphNodeConfig(name="node2", reducer="append") + assert config2.reducer == "append" + + # Sum + config3 = GraphNodeConfig(name="node3", reducer="sum") + assert config3.reducer == "sum" + + # Custom + config4 = GraphNodeConfig( + name="node4", + reducer="custom", + custom_reducer_ref="reducers.my_reducer", + ) + assert config4.reducer == "custom" + assert config4.custom_reducer_ref == "reducers.my_reducer" + + def test_node_extra_forbid(self): + """Test that extra fields are forbidden.""" + with pytest.raises(ValidationError) as exc_info: + GraphNodeConfig(name="test", invalid_field="value") + assert "Extra inputs are not permitted" in str(exc_info.value) + + +class TestGraphEdgeConfig: + """Tests for GraphEdgeConfig validation.""" + + def test_minimal_edge_config(self): + """Test minimal valid edge configuration.""" + config = GraphEdgeConfig(source_node="start", target_node="end") + assert config.source_node == "start" + assert config.target_node == "end" + assert config.condition is None + assert config.priority == 1 + assert config.weight == 1.0 + + def test_edge_with_condition(self): + """Test edge with condition string expression.""" + config = GraphEdgeConfig( + source_node="start", + target_node="end", + condition="data.get('success') is True", + ) + assert config.condition == "data.get('success') is True" + + def test_edge_with_priority(self): + """Test edge with custom priority.""" + config = GraphEdgeConfig( + source_node="start", + target_node="end", + priority=10, + ) + assert config.priority == 10 + + def test_edge_with_weight(self): + """Test edge with custom weight for weighted routing.""" + config = GraphEdgeConfig( + source_node="start", + target_node="end", + weight=0.75, + ) + assert config.weight == 0.75 + + def test_edge_extra_forbid(self): + """Test that extra fields are forbidden.""" + with pytest.raises(ValidationError) as exc_info: + GraphEdgeConfig( + source_node="start", + target_node="end", + invalid_field="value", + ) + assert "Extra inputs are not permitted" in str(exc_info.value) + + +class TestInterruptConfigYaml: + """Tests for InterruptConfigYaml validation.""" + + def test_minimal_interrupt_config(self): + """Test minimal valid interrupt configuration.""" + config = InterruptConfigYaml() + assert config.mode is None # Optional[Literal] defaults to None + assert config.interrupt_service is None + + def test_interrupt_modes(self): + """Test different interrupt modes.""" + # None (default) + config1 = InterruptConfigYaml() + assert config1.mode is None + + # Before + config2 = InterruptConfigYaml(mode="before") + assert config2.mode == "before" + + # After + config3 = InterruptConfigYaml(mode="after") + assert config3.mode == "after" + + # Both + config4 = InterruptConfigYaml(mode="both") + assert config4.mode == "both" + + def test_interrupt_with_service_ref(self): + """Test interrupt config with service configuration.""" + config = InterruptConfigYaml( + mode="both", + interrupt_service={ + "name": "google.adk.agents.graph.interrupt_service.InterruptService" + }, + ) + assert config.mode == "both" + assert config.interrupt_service is not None + assert ( + config.interrupt_service.name + == "google.adk.agents.graph.interrupt_service.InterruptService" + ) + + def test_interrupt_extra_forbid(self): + """Test that extra fields are forbidden.""" + with pytest.raises(ValidationError) as exc_info: + InterruptConfigYaml(invalid_field="value") + assert "Extra inputs are not permitted" in str(exc_info.value) + + +class TestParallelGroupConfig: + """Tests for ParallelGroupConfig validation.""" + + def test_minimal_parallel_config(self): + """Test minimal valid parallel group configuration.""" + config = ParallelGroupConfig(nodes=["node1", "node2"]) + assert config.nodes == ["node1", "node2"] + assert config.join_strategy == "all" + assert config.error_policy == "fail_fast" + assert config.wait_n == 1 + + def test_parallel_join_strategies(self): + """Test different join strategies.""" + # All (default) + config1 = ParallelGroupConfig(nodes=["n1", "n2"]) + assert config1.join_strategy == "all" + + # Any + config2 = ParallelGroupConfig(nodes=["n1", "n2"], join_strategy="any") + assert config2.join_strategy == "any" + + # N + config3 = ParallelGroupConfig( + nodes=["n1", "n2", "n3"], + join_strategy="n", + wait_n=2, + ) + assert config3.join_strategy == "n" + assert config3.wait_n == 2 + + def test_parallel_error_policies(self): + """Test different error policies.""" + # Fail fast (default) + config1 = ParallelGroupConfig(nodes=["n1", "n2"]) + assert config1.error_policy == "fail_fast" + + # Continue + config2 = ParallelGroupConfig(nodes=["n1", "n2"], error_policy="continue") + assert config2.error_policy == "continue" + + # Collect + config3 = ParallelGroupConfig(nodes=["n1", "n2"], error_policy="collect") + assert config3.error_policy == "collect" + + def test_parallel_extra_forbid(self): + """Test that extra fields are forbidden.""" + with pytest.raises(ValidationError) as exc_info: + ParallelGroupConfig(nodes=["n1"], invalid_field="value") + assert "Extra inputs are not permitted" in str(exc_info.value) + + +class TestGraphAgentConfig: + """Tests for GraphAgentConfig validation.""" + + def test_minimal_graph_config(self): + """Test minimal valid graph configuration.""" + config = GraphAgentConfig( + name="test_graph", + start_node="start", + ) + assert config.name == "test_graph" + assert config.agent_class == "GraphAgent" + assert config.start_node == "start" + assert config.end_nodes == [] + assert config.max_iterations == 20 + assert config.checkpointing is False + + def test_graph_with_end_nodes(self): + """Test graph with multiple end nodes.""" + config = GraphAgentConfig( + name="test_graph", + start_node="start", + end_nodes=["end1", "end2"], + ) + assert config.end_nodes == ["end1", "end2"] + + def test_graph_with_max_iterations(self): + """Test graph with custom max_iterations.""" + config = GraphAgentConfig( + name="test_graph", + start_node="start", + max_iterations=50, + ) + assert config.max_iterations == 50 + + def test_graph_with_checkpointing(self): + """Test graph with checkpointing enabled.""" + config = GraphAgentConfig( + name="test_graph", + start_node="start", + checkpointing=True, + checkpoint_service={"name": "google.adk.checkpoints.CheckpointService"}, + ) + assert config.checkpointing is True + assert config.checkpoint_service is not None + assert ( + config.checkpoint_service.name + == "google.adk.checkpoints.CheckpointService" + ) + + def test_graph_with_nodes(self): + """Test graph with node configurations.""" + config = GraphAgentConfig( + name="test_graph", + start_node="start", + nodes=[ + {"name": "start", "sub_agents": [{"code": "Agent1()"}]}, + {"name": "middle", "sub_agents": [{"code": "Agent2()"}]}, + {"name": "end", "sub_agents": [{"code": "Agent3()"}]}, + ], + ) + assert len(config.nodes) == 3 + assert config.nodes[0].name == "start" + assert len(config.nodes[1].sub_agents) == 1 + + def test_graph_with_edges(self): + """Test graph with edge configurations.""" + config = GraphAgentConfig( + name="test_graph", + start_node="start", + edges=[ + {"source_node": "start", "target_node": "middle"}, + {"source_node": "middle", "target_node": "end"}, + ], + ) + assert len(config.edges) == 2 + assert config.edges[0].source_node == "start" + assert config.edges[1].target_node == "end" + + def test_graph_with_interrupt_config(self): + """Test graph with interrupt configuration.""" + config = GraphAgentConfig( + name="test_graph", + start_node="start", + interrupt_config={ + "mode": "both", + "interrupt_service": { + "name": ( + "google.adk.agents.graph.interrupt_service.InterruptService" + ) + }, + }, + ) + assert config.interrupt_config is not None + assert config.interrupt_config.mode == "both" + assert config.interrupt_config.interrupt_service is not None + + def test_graph_with_parallel_groups(self): + """Test graph with parallel execution groups.""" + config = GraphAgentConfig( + name="test_graph", + start_node="start", + parallel_groups=[{ + "nodes": ["parallel1", "parallel2"], + "join_strategy": "all", + "error_policy": "fail_fast", + }], + ) + assert len(config.parallel_groups) == 1 + assert config.parallel_groups[0].nodes == ["parallel1", "parallel2"] + + def test_graph_with_callbacks(self): + """Test graph with callback references.""" + config = GraphAgentConfig( + name="test_graph", + start_node="start", + before_node_callbacks=[{"name": "callbacks.before"}], + after_node_callbacks=[{"name": "callbacks.after"}], + on_edge_condition_callbacks=[{"name": "callbacks.on_edge"}], + ) + assert len(config.before_node_callbacks) == 1 + assert config.before_node_callbacks[0].name == "callbacks.before" + assert len(config.after_node_callbacks) == 1 + assert len(config.on_edge_condition_callbacks) == 1 + + def test_graph_complete_config(self): + """Test complete graph configuration with all features.""" + config = GraphAgentConfig( + name="complete_graph", + description="A complete graph configuration", + start_node="start", + end_nodes=["end"], + max_iterations=30, + checkpointing=True, + checkpoint_service={"name": "google.adk.checkpoints.CheckpointService"}, + nodes=[ + { + "name": "start", + "sub_agents": [{"code": "Agent1()"}], + "reducer": "overwrite", + }, + { + "name": "middle", + "function_ref": "functions.process", + "input_mapper_ref": "mappers.input", + "output_mapper_ref": "mappers.output", + }, + {"name": "end", "sub_agents": [{"code": "Agent3()"}]}, + ], + edges=[ + {"source_node": "start", "target_node": "middle", "priority": 1}, + { + "source_node": "middle", + "target_node": "end", + "condition": "data.get('success', False) is True", + }, + ], + interrupt_config={ + "mode": "both", + "interrupt_service": { + "name": ( + "google.adk.agents.graph.interrupt_service.InterruptService" + ) + }, + }, + parallel_groups=[{ + "nodes": ["parallel1", "parallel2"], + "join_strategy": "all", + }], + before_node_callbacks=[{"name": "callbacks.before"}], + after_node_callbacks=[{"name": "callbacks.after"}], + ) + + # Verify all fields + assert config.name == "complete_graph" + assert config.start_node == "start" + assert config.end_nodes == ["end"] + assert config.max_iterations == 30 + assert config.checkpointing is True + assert len(config.nodes) == 3 + assert len(config.edges) == 2 + assert config.interrupt_config is not None + assert len(config.parallel_groups) == 1 + assert len(config.before_node_callbacks) == 1 + assert len(config.after_node_callbacks) == 1 + + def test_graph_extra_forbid(self): + """Test that extra fields are forbidden.""" + with pytest.raises(ValidationError) as exc_info: + GraphAgentConfig( + name="test", + start_node="start", + invalid_field="value", + ) + assert "Extra inputs are not permitted" in str(exc_info.value) + + def test_graph_agent_class_default(self): + """Test that agent_class defaults to GraphAgent.""" + config = GraphAgentConfig(name="test", start_node="start") + assert config.agent_class == "GraphAgent" + + def test_graph_missing_required_fields(self): + """Test validation fails when required fields are missing.""" + # Missing name + with pytest.raises(ValidationError) as exc_info: + GraphAgentConfig(start_node="start") + assert "name" in str(exc_info.value).lower() + + # Missing start_node + with pytest.raises(ValidationError) as exc_info: + GraphAgentConfig(name="test") + assert "start_node" in str(exc_info.value).lower() diff --git a/tests/unittests/agents/test_graph_agent_validation.py b/tests/unittests/agents/test_graph_agent_validation.py new file mode 100644 index 0000000000..8590de80aa --- /dev/null +++ b/tests/unittests/agents/test_graph_agent_validation.py @@ -0,0 +1,231 @@ +"""Test suite for GraphAgent validation features. + +Tests: +- Duplicate node name validation +- Duplicate edge validation +- Auto-defaulting output_key for LlmAgent with output_schema +- Warning emissions for auto-defaulted output_key +""" + +import logging + +from google.adk.agents import LlmAgent +from google.adk.agents.graph import EdgeCondition +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphNode +from pydantic import BaseModel +import pytest + + +# Test schema for output_schema tests +class TestOutput(BaseModel): + """Test output schema.""" + + result: str + + +def test_add_node_duplicate_name_raises_error(): + """Adding node with duplicate name raises ValueError.""" + graph = GraphAgent(name="test_graph") + agent1 = LlmAgent(name="agent1", model="gemini-2.0-flash") + agent2 = LlmAgent(name="agent2", model="gemini-2.0-flash") + + # Add first node + graph.add_node(GraphNode(name="duplicate", agent=agent1)) + + # Adding second node with same name should raise + with pytest.raises(ValueError, match="already exists"): + graph.add_node(GraphNode(name="duplicate", agent=agent2)) + + +def test_add_node_duplicate_name_convenience_api(): + """Duplicate node validation works with convenience API.""" + graph = GraphAgent(name="test_graph") + agent1 = LlmAgent(name="agent1", model="gemini-2.0-flash") + agent2 = LlmAgent(name="agent2", model="gemini-2.0-flash") + + # Add first node using convenience API + graph.add_node("duplicate", agent=agent1) + + # Adding second node with same name should raise + with pytest.raises(ValueError, match="already exists"): + graph.add_node("duplicate", agent=agent2) + + +def test_add_edge_duplicate_raises_error(): + """Adding edge with duplicate source→target raises ValueError.""" + graph = GraphAgent(name="test_graph") + + # Add nodes + graph.add_node( + "start", agent=LlmAgent(name="start", model="gemini-2.0-flash") + ) + graph.add_node("end", agent=LlmAgent(name="end", model="gemini-2.0-flash")) + + # Add first edge + graph.add_edge("start", "end") + + # Adding second edge with same source→target should raise + with pytest.raises(ValueError, match="already exists"): + graph.add_edge("start", "end") + + +def test_add_edge_duplicate_with_conditions(): + """Duplicate edge validation works even with different conditions.""" + graph = GraphAgent(name="test_graph") + + # Add nodes + graph.add_node( + "start", agent=LlmAgent(name="start", model="gemini-2.0-flash") + ) + graph.add_node("end", agent=LlmAgent(name="end", model="gemini-2.0-flash")) + + # Add first edge with condition + graph.add_edge("start", "end", condition=lambda s: s.data.get("foo")) + + # Adding second edge to same target should raise even with different condition + with pytest.raises(ValueError, match="already exists"): + graph.add_edge("start", "end", condition=lambda s: s.data.get("bar")) + + +def test_output_key_auto_defaults_to_agent_name(): + """GraphNode auto-defaults output_key to agent.name when output_schema is set. + + model_copy() is used so the original agent is NOT mutated; the copy stored + in node.agent has the defaulted output_key. + """ + agent = LlmAgent( + name="analyzer", + model="gemini-2.0-flash", + output_schema=TestOutput, + # output_key NOT SET + ) + + # Before wrapping + assert agent.output_key is None + + # After wrapping in GraphNode + node = GraphNode(name="test_node", agent=agent) + + # node.agent is a copy with the auto-defaulted output_key + assert node.agent.output_key == "analyzer" + # Original agent is NOT mutated (model_copy creates an isolated copy) + assert agent.output_key is None + + +def test_explicit_output_key_not_overridden(): + """Explicit output_key is not overridden by auto-defaulting.""" + agent = LlmAgent( + name="analyzer", + model="gemini-2.0-flash", + output_schema=TestOutput, + output_key="custom_key", # Explicit + ) + + node = GraphNode(name="test_node", agent=agent) + + # Should keep explicit value + assert agent.output_key == "custom_key" + + +def test_no_auto_default_without_output_schema(): + """output_key is not auto-defaulted if output_schema is not set.""" + agent = LlmAgent( + name="analyzer", + model="gemini-2.0-flash", + # No output_schema + ) + + # Before wrapping + assert agent.output_key is None + + # After wrapping in GraphNode + node = GraphNode(name="test_node", agent=agent) + + # output_key should still be None (no auto-default) + assert agent.output_key is None + + +def test_warning_for_auto_defaulted_output_key(caplog): + """Warning emitted when output_key is auto-defaulted.""" + agent = LlmAgent( + name="analyzer", + model="gemini-2.0-flash", + output_schema=TestOutput, + ) + graph = GraphAgent(name="test_graph") + + with caplog.at_level(logging.WARNING): + graph.add_node(GraphNode(name="test_node", agent=agent)) + + # Verify warning about auto-defaulting + assert any("auto-defaulted" in rec.message.lower() for rec in caplog.records) + + +def test_no_warning_for_explicit_output_key(caplog): + """No warning emitted when output_key is explicitly set.""" + agent = LlmAgent( + name="analyzer", + model="gemini-2.0-flash", + output_schema=TestOutput, + output_key="custom_key", + ) + graph = GraphAgent(name="test_graph") + + with caplog.at_level(logging.WARNING): + graph.add_node(GraphNode(name="test_node", agent=agent)) + + # No warning should be emitted + assert not any( + "auto-defaulted" in rec.message.lower() for rec in caplog.records + ) + + +def test_add_edge_duplicate_edge_condition_raises(): + """Duplicate EdgeCondition (Pattern 1) to same target raises ValueError.""" + graph = GraphAgent(name="test_graph") + graph.add_node( + "start", agent=LlmAgent(name="start", model="gemini-2.0-flash") + ) + graph.add_node("end", agent=LlmAgent(name="end", model="gemini-2.0-flash")) + + edge = EdgeCondition(target_node="end", priority=10) + graph.add_edge("start", edge) + + # Second EdgeCondition to the same target must raise + with pytest.raises(ValueError, match="already exists"): + graph.add_edge("start", EdgeCondition(target_node="end", priority=5)) + + +def test_agent_name_collides_with_graph_name_raises(): + """Agent name matching GraphAgent name raises ValueError.""" + graph = GraphAgent(name="my_graph") + agent = LlmAgent(name="my_graph", model="gemini-2.0-flash") + + with pytest.raises(ValueError, match="collides with GraphAgent name"): + graph.add_node("node1", agent=agent) + + +def test_ast_rejects_dunder_attribute_access(): + """AST validator blocks dunder attribute access (sandbox escape prevention).""" + from google.adk.agents.graph.graph_agent import _validate_condition_ast + import ast + + # Safe attribute access should pass + tree = ast.parse("data.get('x')", mode="eval") + _validate_condition_ast(tree.body) + + # Dunder attribute access should be rejected + tree = ast.parse("data.__class__", mode="eval") + with pytest.raises(ValueError, match="Unsafe attribute access.*__class__"): + _validate_condition_ast(tree.body) + + # Nested dunder chain should be rejected (outermost attr checked first) + tree = ast.parse("data.__class__.__init__", mode="eval") + with pytest.raises(ValueError, match="Unsafe attribute access.*__init__"): + _validate_condition_ast(tree.body) + + # Single underscore prefix should also be rejected + tree = ast.parse("data._private", mode="eval") + with pytest.raises(ValueError, match="Unsafe attribute access.*_private"): + _validate_condition_ast(tree.body) diff --git a/tests/unittests/agents/test_graph_callbacks.py b/tests/unittests/agents/test_graph_callbacks.py new file mode 100644 index 0000000000..2bb33bb72a --- /dev/null +++ b/tests/unittests/agents/test_graph_callbacks.py @@ -0,0 +1,597 @@ +"""Test suite for GraphAgent callback infrastructure. + +Tests callback-based observability and extensibility: +- NodeCallback (before/after node execution) +- EdgeCallback (on edge condition evaluation) +- Custom observability patterns +- Nested graph hierarchy tracking +""" + +from typing import AsyncGenerator +from typing import Optional + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import create_nested_observability_callback +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphNode +from google.adk.agents.graph import NodeCallbackContext +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +import pytest +from typing_extensions import override + + +# Mock agent for testing +class MockAgent(BaseAgent): + _response: str + + def __init__(self, name: str, response: str = "mock", **kwargs): + super().__init__(name=name, **kwargs) + self._response = response + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=self._response)]), + ) + + +@pytest.mark.asyncio +async def test_before_node_callback_invoked(): + """Test that before_node_callback is invoked before node execution.""" + callback_invocations = [] + + async def before_callback(ctx: NodeCallbackContext) -> Optional[Event]: + callback_invocations.append(("before", ctx.node.name)) + return Event( + author="test", + content=types.Content( + parts=[types.Part(text=f"Before: {ctx.node.name}")] + ), + ) + + graph = GraphAgent(name="test_graph", before_node_callback=before_callback) + node_a = GraphNode(name="node_a", agent=MockAgent("agent_a", "output_a")) + node_b = GraphNode(name="node_b", agent=MockAgent("agent_b", "output_b")) + + graph.add_node(node_a).add_node(node_b) + graph.add_edge("node_a", "node_b") + graph.set_start("node_a").set_end("node_b") + + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", agent=graph, session_service=session_service + ) + + # Create session first + await session_service.create_session( + app_name="test_app", user_id="test_user", session_id="test_session" + ) + + events = [] + async for event in runner.run_async( + user_id="test_user", + session_id="test_session", + new_message=types.Content( + role="user", parts=[types.Part(text="test input")] + ), + ): + events.append(event) + + # Verify callback was invoked for both nodes + assert len(callback_invocations) == 2 + assert callback_invocations[0] == ("before", "node_a") + assert callback_invocations[1] == ("before", "node_b") + + # Verify callback events were emitted + before_events = [ + e + for e in events + if e.content and "Before:" in (e.content.parts[0].text or "") + ] + assert len(before_events) == 2 + + +@pytest.mark.asyncio +async def test_after_node_callback_invoked(): + """Test that after_node_callback is invoked after node execution.""" + callback_invocations = [] + + async def after_callback(ctx: NodeCallbackContext) -> Optional[Event]: + callback_invocations.append( + ("after", ctx.node.name, ctx.metadata.get("output")) + ) + return Event( + author="test", + content=types.Content( + parts=[types.Part(text=f"After: {ctx.node.name}")] + ), + ) + + graph = GraphAgent(name="test_graph", after_node_callback=after_callback) + node_a = GraphNode(name="node_a", agent=MockAgent("agent_a", "output_a")) + graph.add_node(node_a) + graph.set_start("node_a").set_end("node_a") + + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", agent=graph, session_service=session_service + ) + + # Create session first + await session_service.create_session( + app_name="test_app", user_id="test_user", session_id="test_session" + ) + + events = [] + async for event in runner.run_async( + user_id="test_user", + session_id="test_session", + new_message=types.Content( + role="user", parts=[types.Part(text="test input")] + ), + ): + events.append(event) + + # Verify callback was invoked with output + assert len(callback_invocations) == 1 + assert callback_invocations[0][0] == "after" + assert callback_invocations[0][1] == "node_a" + assert callback_invocations[0][2] == "output_a" + + +@pytest.mark.asyncio +async def test_callback_returning_none_skips_event(): + """Test that callback returning None skips event emission.""" + callback_invocations = [] + + async def selective_callback(ctx: NodeCallbackContext) -> Optional[Event]: + callback_invocations.append(ctx.node.name) + # Only emit for node_a + if ctx.node.name == "node_a": + return Event( + author="test", + content=types.Content(parts=[types.Part(text="Event")]), + ) + return None # Skip for node_b + + graph = GraphAgent(name="test_graph", before_node_callback=selective_callback) + node_a = GraphNode(name="node_a", agent=MockAgent("agent_a")) + node_b = GraphNode(name="node_b", agent=MockAgent("agent_b")) + + graph.add_node(node_a).add_node(node_b) + graph.add_edge("node_a", "node_b") + graph.set_start("node_a").set_end("node_b") + + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", agent=graph, session_service=session_service + ) + + # Create session first + await session_service.create_session( + app_name="test_app", user_id="test_user", session_id="test_session" + ) + + events = [] + async for event in runner.run_async( + user_id="test_user", + session_id="test_session", + new_message=types.Content( + role="user", parts=[types.Part(text="test input")] + ), + ): + events.append(event) + + # Callback invoked for both + assert len(callback_invocations) == 2 + + # But only one event emitted + test_events = [e for e in events if e.author == "test"] + assert len(test_events) == 1 + + +@pytest.mark.asyncio +async def test_callback_has_full_context(): + """Test that callback receives full context including state and iteration.""" + captured_contexts = [] + + async def capture_callback(ctx: NodeCallbackContext) -> Optional[Event]: + captured_contexts.append({ + "node_name": ctx.node.name, + "iteration": ctx.iteration, + "state_data_keys": list(ctx.state.data.keys()), + "agent_path": list(ctx.metadata.get("agent_path", [])), + "path": list(ctx.metadata.get("path", [])), + }) + return None + + graph = GraphAgent(name="test_graph", before_node_callback=capture_callback) + node_a = GraphNode(name="node_a", agent=MockAgent("agent_a", "output_a")) + node_b = GraphNode(name="node_b", agent=MockAgent("agent_b", "output_b")) + + graph.add_node(node_a).add_node(node_b) + graph.add_edge("node_a", "node_b") + graph.set_start("node_a").set_end("node_b") + + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", agent=graph, session_service=session_service + ) + + # Create session first + await session_service.create_session( + app_name="test_app", user_id="test_user", session_id="test_session" + ) + + async for _ in runner.run_async( + user_id="test_user", + session_id="test_session", + new_message=types.Content( + role="user", parts=[types.Part(text="test input")] + ), + ): + pass + + # Verify contexts + assert len(captured_contexts) == 2 + + # First node + assert captured_contexts[0]["node_name"] == "node_a" + assert captured_contexts[0]["iteration"] == 1 + assert "input" in captured_contexts[0]["state_data_keys"] + assert captured_contexts[0]["agent_path"] == ["test_graph"] + assert captured_contexts[0]["path"] == ["node_a"] + + # Second node + assert captured_contexts[1]["node_name"] == "node_b" + assert captured_contexts[1]["iteration"] == 2 + assert captured_contexts[1]["agent_path"] == ["test_graph"] + assert captured_contexts[1]["path"] == ["node_a", "node_b"] + + +@pytest.mark.asyncio +async def test_nested_observability_callback(): + """Test create_nested_observability_callback shows hierarchy.""" + graph = GraphAgent( + name="outer_graph", + before_node_callback=create_nested_observability_callback(), + ) + node_a = GraphNode(name="node_a", agent=MockAgent("agent_a")) + graph.add_node(node_a) + graph.set_start("node_a").set_end("node_a") + + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", agent=graph, session_service=session_service + ) + + # Create session first + await session_service.create_session( + app_name="test_app", user_id="test_user", session_id="test_session" + ) + + events = [] + async for event in runner.run_async( + user_id="test_user", + session_id="test_session", + new_message=types.Content( + role="user", parts=[types.Part(text="test input")] + ), + ): + events.append(event) + + # Find observability event + obs_events = [e for e in events if e.author == "observability"] + assert len(obs_events) == 1 + + # Check hierarchy is shown + event_text = obs_events[0].content.parts[0].text + assert "outer_graph" in event_text + assert "node_a" in event_text + + +@pytest.mark.asyncio +async def test_both_callbacks_invoked_in_order(): + """Test that before and after callbacks are invoked in correct order.""" + invocation_order = [] + + async def before_callback(ctx: NodeCallbackContext) -> Optional[Event]: + invocation_order.append(f"before_{ctx.node.name}") + return None + + async def after_callback(ctx: NodeCallbackContext) -> Optional[Event]: + invocation_order.append(f"after_{ctx.node.name}") + return None + + graph = GraphAgent( + name="test_graph", + before_node_callback=before_callback, + after_node_callback=after_callback, + ) + node_a = GraphNode(name="node_a", agent=MockAgent("agent_a")) + node_b = GraphNode(name="node_b", agent=MockAgent("agent_b")) + + graph.add_node(node_a).add_node(node_b) + graph.add_edge("node_a", "node_b") + graph.set_start("node_a").set_end("node_b") + + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", agent=graph, session_service=session_service + ) + + # Create session first + await session_service.create_session( + app_name="test_app", user_id="test_user", session_id="test_session" + ) + + async for _ in runner.run_async( + user_id="test_user", + session_id="test_session", + new_message=types.Content( + role="user", parts=[types.Part(text="test input")] + ), + ): + pass + + # Verify order: before_a, after_a, before_b, after_b + assert invocation_order == [ + "before_node_a", + "after_node_a", + "before_node_b", + "after_node_b", + ] + + +@pytest.mark.asyncio +async def test_before_callback_error_is_caught_and_graph_continues(): + """Test that errors in before_node_callback are caught and graph continues.""" + callback_invocations = [] + + async def failing_before_callback( + ctx: NodeCallbackContext, + ) -> Optional[Event]: + callback_invocations.append(ctx.node.name) + raise ValueError(f"Callback error for {ctx.node.name}") + + graph = GraphAgent( + name="test_graph", before_node_callback=failing_before_callback + ) + node_a = GraphNode(name="node_a", agent=MockAgent("agent_a", "output_a")) + node_b = GraphNode(name="node_b", agent=MockAgent("agent_b", "output_b")) + + graph.add_node(node_a).add_node(node_b) + graph.add_edge("node_a", "node_b") + graph.set_start("node_a").set_end("node_b") + + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", agent=graph, session_service=session_service + ) + + await session_service.create_session( + app_name="test_app", user_id="test_user", session_id="test_session" + ) + + events = [] + # Graph should complete despite callback errors + async for event in runner.run_async( + user_id="test_user", + session_id="test_session", + new_message=types.Content( + role="user", parts=[types.Part(text="test input")] + ), + ): + events.append(event) + + # Callbacks were attempted for both nodes + assert len(callback_invocations) == 2 + assert callback_invocations == ["node_a", "node_b"] + + # Graph execution completed (node outputs present) + agent_events = [e for e in events if e.author in ["agent_a", "agent_b"]] + assert len(agent_events) == 2 + + +@pytest.mark.asyncio +async def test_after_callback_error_is_caught_and_graph_continues(): + """Test that errors in after_node_callback are caught and graph continues.""" + callback_invocations = [] + + async def failing_after_callback(ctx: NodeCallbackContext) -> Optional[Event]: + callback_invocations.append(ctx.node.name) + raise RuntimeError(f"After callback error for {ctx.node.name}") + + graph = GraphAgent( + name="test_graph", after_node_callback=failing_after_callback + ) + node_a = GraphNode(name="node_a", agent=MockAgent("agent_a", "output_a")) + node_b = GraphNode(name="node_b", agent=MockAgent("agent_b", "output_b")) + + graph.add_node(node_a).add_node(node_b) + graph.add_edge("node_a", "node_b") + graph.set_start("node_a").set_end("node_b") + + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", agent=graph, session_service=session_service + ) + + await session_service.create_session( + app_name="test_app", user_id="test_user", session_id="test_session" + ) + + events = [] + # Graph should complete despite callback errors + async for event in runner.run_async( + user_id="test_user", + session_id="test_session", + new_message=types.Content( + role="user", parts=[types.Part(text="test input")] + ), + ): + events.append(event) + + # Callbacks were attempted for both nodes + assert len(callback_invocations) == 2 + assert callback_invocations == ["node_a", "node_b"] + + # Graph execution completed + agent_events = [e for e in events if e.author in ["agent_a", "agent_b"]] + assert len(agent_events) == 2 + + +@pytest.mark.asyncio +async def test_both_callbacks_error_graph_still_completes(): + """Test that graph completes even when both callbacks raise errors.""" + before_invocations = [] + after_invocations = [] + + async def failing_before(ctx: NodeCallbackContext) -> Optional[Event]: + before_invocations.append(ctx.node.name) + raise ValueError("Before error") + + async def failing_after(ctx: NodeCallbackContext) -> Optional[Event]: + after_invocations.append(ctx.node.name) + raise ValueError("After error") + + graph = GraphAgent( + name="test_graph", + before_node_callback=failing_before, + after_node_callback=failing_after, + ) + node_a = GraphNode(name="node_a", agent=MockAgent("agent_a", "output_a")) + graph.add_node(node_a) + graph.set_start("node_a").set_end("node_a") + + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", agent=graph, session_service=session_service + ) + + await session_service.create_session( + app_name="test_app", user_id="test_user", session_id="test_session" + ) + + events = [] + async for event in runner.run_async( + user_id="test_user", + session_id="test_session", + new_message=types.Content( + role="user", parts=[types.Part(text="test input")] + ), + ): + events.append(event) + + # Both callbacks were attempted + assert len(before_invocations) == 1 + assert len(after_invocations) == 1 + + # Graph execution still completed + agent_events = [e for e in events if e.author == "agent_a"] + assert len(agent_events) == 1 + + +@pytest.mark.asyncio +async def test_callback_error_is_logged(caplog): + """Test that callback errors are logged with details.""" + import logging + + async def failing_callback(ctx: NodeCallbackContext) -> Optional[Event]: + raise ValueError("Intentional callback error") + + graph = GraphAgent(name="test_graph", before_node_callback=failing_callback) + node_a = GraphNode(name="node_a", agent=MockAgent("agent_a")) + graph.add_node(node_a) + graph.set_start("node_a").set_end("node_a") + + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", agent=graph, session_service=session_service + ) + + await session_service.create_session( + app_name="test_app", user_id="test_user", session_id="test_session" + ) + + with caplog.at_level(logging.ERROR): + async for _ in runner.run_async( + user_id="test_user", + session_id="test_session", + new_message=types.Content( + role="user", parts=[types.Part(text="test input")] + ), + ): + pass + + # Verify error was logged + assert len(caplog.records) > 0 + error_logs = [r for r in caplog.records if r.levelname == "ERROR"] + assert len(error_logs) == 1 + assert "before_node_callback failed" in error_logs[0].message + assert "node_a" in error_logs[0].message + assert "Intentional callback error" in error_logs[0].message + + +@pytest.mark.asyncio +async def test_partial_callback_errors_do_not_affect_successful_callbacks(): + """Test that errors in one node's callback don't affect other nodes.""" + invocations = [] + + async def selective_failing_callback( + ctx: NodeCallbackContext, + ) -> Optional[Event]: + invocations.append(ctx.node.name) + if ctx.node.name == "node_a": + raise ValueError("Error for node_a") + # node_b succeeds + return Event( + author="callback", + content=types.Content( + parts=[types.Part(text=f"Success: {ctx.node.name}")] + ), + ) + + graph = GraphAgent( + name="test_graph", before_node_callback=selective_failing_callback + ) + node_a = GraphNode(name="node_a", agent=MockAgent("agent_a")) + node_b = GraphNode(name="node_b", agent=MockAgent("agent_b")) + + graph.add_node(node_a).add_node(node_b) + graph.add_edge("node_a", "node_b") + graph.set_start("node_a").set_end("node_b") + + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", agent=graph, session_service=session_service + ) + + await session_service.create_session( + app_name="test_app", user_id="test_user", session_id="test_session" + ) + + events = [] + async for event in runner.run_async( + user_id="test_user", + session_id="test_session", + new_message=types.Content( + role="user", parts=[types.Part(text="test input")] + ), + ): + events.append(event) + + # Both callbacks were attempted + assert invocations == ["node_a", "node_b"] + + # Only node_b's callback event was emitted (node_a failed) + callback_events = [e for e in events if e.author == "callback"] + assert len(callback_events) == 1 + assert "Success: node_b" in callback_events[0].content.parts[0].text diff --git a/tests/unittests/agents/test_graph_convenience_api.py b/tests/unittests/agents/test_graph_convenience_api.py new file mode 100644 index 0000000000..e5ebbd2b29 --- /dev/null +++ b/tests/unittests/agents/test_graph_convenience_api.py @@ -0,0 +1,425 @@ +"""Tests for GraphAgent convenience API methods. + +Tests the convenience methods for add_node() and add_edge() that provide +simpler syntax alternatives to the explicit GraphNode/EdgeCondition patterns. +""" + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphNode +from google.adk.agents.graph import GraphState +from google.adk.agents.graph import StateReducer +from google.adk.agents.graph.graph_edge import EdgeCondition +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.genai import types +import pytest + + +class SimpleAgent(BaseAgent): + """Simple test agent.""" + + async def _run_async_impl(self, ctx: InvocationContext): + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=f"{self.name} output")]), + ) + + +def simple_function(state: GraphState, ctx: InvocationContext) -> str: + """Simple test function.""" + return "function output" + + +class TestAddNodeConvenience: + """Test add_node() convenience patterns.""" + + def test_add_node_with_graphnode(self): + """Test traditional GraphNode pattern still works.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node(GraphNode(name="node1", agent=agent)) + + assert "node1" in graph.nodes + assert graph.nodes["node1"].agent == agent + + def test_add_node_convenience_with_agent(self): + """Test convenience pattern: add_node(name, agent=...)""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("node1", agent=agent) + + assert "node1" in graph.nodes + assert graph.nodes["node1"].agent == agent + + def test_add_node_convenience_with_function(self): + """Test convenience pattern: add_node(name, function=...)""" + graph = GraphAgent(name="test") + + graph.add_node("node1", function=simple_function) + + assert "node1" in graph.nodes + assert graph.nodes["node1"].function == simple_function + + def test_add_node_convenience_with_kwargs(self): + """Test convenience pattern with additional kwargs.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node( + "node1", + agent=agent, + reducer=StateReducer.APPEND, + ) + + assert "node1" in graph.nodes + assert graph.nodes["node1"].agent == agent + assert graph.nodes["node1"].reducer == StateReducer.APPEND + + def test_add_node_error_graphnode_with_kwargs(self): + """Test error when passing GraphNode with kwargs.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + node = GraphNode(name="node1", agent=agent) + + with pytest.raises(ValueError, match="When passing a GraphNode"): + graph.add_node(node, agent=agent) + + def test_add_node_error_string_without_agent_or_function(self): + """Test error when passing string name without agent or function.""" + graph = GraphAgent(name="test") + + with pytest.raises(ValueError, match="must specify agent or function"): + graph.add_node("node1") + + def test_add_node_error_both_agent_and_function(self): + """Test error when passing both agent and function.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + with pytest.raises(ValueError, match="Cannot specify both"): + graph.add_node("node1", agent=agent, function=simple_function) + + def test_add_node_error_invalid_type(self): + """Test error when passing invalid node type.""" + graph = GraphAgent(name="test") + + with pytest.raises(TypeError, match="node must be GraphNode or str"): + graph.add_node(123) # Invalid type + + def test_add_node_chaining(self): + """Test that add_node returns self for chaining.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + result = graph.add_node("node1", agent=agent).add_node("node2", agent=agent) + + assert result is graph + assert "node1" in graph.nodes + assert "node2" in graph.nodes + + +class TestAddEdgeConvenience: + """Test add_edge() convenience patterns.""" + + def test_add_edge_simple(self): + """Test simple unconditional edge.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("node1", agent=agent) + graph.add_node("node2", agent=agent) + graph.add_edge("node1", "node2") + + # Edge should be added to node1 + assert len(graph.nodes["node1"].edges) == 1 + assert graph.nodes["node1"].edges[0].target_node == "node2" + + def test_add_edge_with_condition(self): + """Test conditional edge.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("node1", agent=agent) + graph.add_node("node2", agent=agent) + + condition = lambda s: s.data.get("valid", False) + graph.add_edge("node1", "node2", condition=condition) + + assert len(graph.nodes["node1"].edges) == 1 + assert graph.nodes["node1"].edges[0].condition == condition + + def test_add_edge_with_priority(self): + """Test priority-based edge.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("node1", agent=agent) + graph.add_node("node2", agent=agent) + graph.add_node("node3", agent=agent) + + # Add edges with priority + graph.add_edge( + "node1", + "node2", + condition=lambda s: s.data.get("score", 0) > 0.5, + priority=10, + ) + graph.add_edge("node1", "node3", priority=0) # Fallback + + # Should create EdgeCondition objects + assert hasattr(graph.nodes["node1"], "edges") + assert len(graph.nodes["node1"].edges) == 2 + assert graph.nodes["node1"].edges[0].priority == 10 + assert graph.nodes["node1"].edges[1].priority == 0 + + def test_add_edge_with_weight(self): + """Test weighted random edge.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("node1", agent=agent) + graph.add_node("node2", agent=agent) + + graph.add_edge( + "node1", "node2", condition=lambda s: True, priority=1, weight=0.5 + ) + + assert len(graph.nodes["node1"].edges) == 1 + assert graph.nodes["node1"].edges[0].weight == 0.5 + + def test_add_edge_error_source_not_found(self): + """Test error when source node not found.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("node2", agent=agent) + + with pytest.raises(ValueError, match="Source node node1 not found"): + graph.add_edge("node1", "node2") + + def test_add_edge_error_target_not_found(self): + """Test error when target node not found.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("node1", agent=agent) + + with pytest.raises(ValueError, match="Target node node2 not found"): + graph.add_edge("node1", "node2") + + def test_add_edge_chaining(self): + """Test that add_edge returns self for chaining.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("node1", agent=agent) + graph.add_node("node2", agent=agent) + graph.add_node("node3", agent=agent) + + result = graph.add_edge("node1", "node2").add_edge("node2", "node3") + + assert result is graph + + def test_add_edge_mixed_simple_and_priority(self): + """Test mixing simple edges and priority edges on same node.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("node1", agent=agent) + graph.add_node("node2", agent=agent) + graph.add_node("node3", agent=agent) + + # Add simple edge first + graph.add_edge("node1", "node2") + + # Add priority edge + graph.add_edge("node1", "node3", priority=10) + + # Both edges should be in edges list + assert len(graph.nodes["node1"].edges) == 2 + # First is simple edge (added via add_edge with no priority) + assert graph.nodes["node1"].edges[0].target_node == "node2" + # Second is priority edge + assert graph.nodes["node1"].edges[1].target_node == "node3" + assert graph.nodes["node1"].edges[1].priority == 10 + + +class TestConvenienceAPIIntegration: + """Integration tests using both convenience patterns together.""" + + def test_full_graph_with_convenience_api(self): + """Test building complete graph using convenience API.""" + graph = GraphAgent(name="validation_pipeline") + + # Create agents + validator = SimpleAgent(name="validator") + processor = SimpleAgent(name="processor") + error_handler = SimpleAgent(name="error_handler") + + # Build graph using convenience API + ( + graph.add_node("validate", agent=validator) + .add_node("process", agent=processor) + .add_node("error", agent=error_handler) + .add_edge( + "validate", "process", condition=lambda s: s.data.get("valid") + ) + .add_edge( + "validate", "error", condition=lambda s: not s.data.get("valid") + ) + .set_start("validate") + .set_end("process") + .set_end("error") + ) + + assert len(graph.nodes) == 3 + assert graph.start_node == "validate" + assert set(graph.end_nodes) == {"process", "error"} + + def test_priority_routing_with_convenience_api(self): + """Test priority routing using convenience API.""" + graph = GraphAgent(name="router") + agent = SimpleAgent(name="test_agent") + + ( + graph.add_node("check", agent=agent) + .add_node("critical", agent=agent) + .add_node("warning", agent=agent) + .add_node("normal", agent=agent) + .add_edge( + "check", + "critical", + condition=lambda s: s.data.get("score", 0) > 0.9, + priority=10, + ) + .add_edge( + "check", + "warning", + condition=lambda s: s.data.get("score", 0) > 0.5, + priority=5, + ) + .add_edge("check", "normal", priority=0) # Fallback + .set_start("check") + ) + + # Verify priority edges created + assert len(graph.nodes["check"].edges) == 3 + priorities = [e.priority for e in graph.nodes["check"].edges] + assert priorities == [10, 5, 0] + + +class TestAddEdgeWithEdgeCondition: + """Test add_edge() with EdgeCondition objects (Pattern 1: Explicit).""" + + def test_add_edge_with_edge_condition(self): + """Test add_edge with EdgeCondition object.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("source", agent=agent) + graph.add_node("target", agent=agent) + + edge = EdgeCondition( + target_node="target", + condition=lambda s: s.data.get("valid"), + priority=10, + weight=0.5, + ) + graph.add_edge("source", edge) + + assert len(graph.nodes["source"].edges) == 1 + assert graph.nodes["source"].edges[0] is edge + assert graph.nodes["source"].edges[0].priority == 10 + assert graph.nodes["source"].edges[0].weight == 0.5 + + def test_add_edge_error_edge_condition_with_params(self): + """Test error when passing EdgeCondition with extra params.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("source", agent=agent) + graph.add_node("target", agent=agent) + + edge = EdgeCondition(target_node="target", priority=10) + + with pytest.raises(ValueError, match="do not specify condition"): + graph.add_edge("source", edge, condition=lambda s: True) + + with pytest.raises(ValueError, match="do not specify"): + graph.add_edge("source", edge, priority=5) + + with pytest.raises(ValueError, match="do not specify"): + graph.add_edge("source", edge, weight=0.5) + + def test_add_edge_edge_condition_target_not_found(self): + """Test error when EdgeCondition references non-existent target.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("source", agent=agent) + + edge = EdgeCondition(target_node="nonexistent", priority=10) + + with pytest.raises(ValueError, match="Target node nonexistent not found"): + graph.add_edge("source", edge) + + def test_add_edge_invalid_type(self): + """Test error when passing invalid target type.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("source", agent=agent) + + with pytest.raises( + TypeError, match="target_node must be str or EdgeCondition" + ): + graph.add_edge("source", 123) # Invalid type + + def test_add_edge_chaining_with_edge_condition(self): + """Test that add_edge returns self for chaining when using EdgeCondition.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("node1", agent=agent) + graph.add_node("node2", agent=agent) + graph.add_node("node3", agent=agent) + + result = graph.add_edge( + "node1", EdgeCondition(target_node="node2", priority=10) + ).add_edge("node2", "node3") + + assert result is graph + assert len(graph.nodes["node1"].edges) == 1 + assert len(graph.nodes["node2"].edges) == 1 + + def test_add_edge_mixed_edge_condition_and_convenience(self): + """Test mixing EdgeCondition and convenience patterns on same node.""" + graph = GraphAgent(name="test") + agent = SimpleAgent(name="test_agent") + + graph.add_node("source", agent=agent) + graph.add_node("target1", agent=agent) + graph.add_node("target2", agent=agent) + + # Add EdgeCondition first + graph.add_edge( + "source", + EdgeCondition( + target_node="target1", + condition=lambda s: s.data.get("score", 0) > 0.9, + priority=10, + ), + ) + + # Add convenience edge + graph.add_edge("source", "target2", priority=5) + + # Both edges should be in edges list + assert len(graph.nodes["source"].edges) == 2 + assert graph.nodes["source"].edges[0].target_node == "target1" + assert graph.nodes["source"].edges[0].priority == 10 + assert graph.nodes["source"].edges[1].target_node == "target2" + assert graph.nodes["source"].edges[1].priority == 5 diff --git a/tests/unittests/agents/test_graph_evaluation.py b/tests/unittests/agents/test_graph_evaluation.py new file mode 100644 index 0000000000..b5c2180f83 --- /dev/null +++ b/tests/unittests/agents/test_graph_evaluation.py @@ -0,0 +1,263 @@ +"""Tests for GraphAgent evaluation metrics.""" + +from types import SimpleNamespace + +from google.adk.agents.graph.evaluation_metrics import graph_path_match +from google.adk.agents.graph.evaluation_metrics import node_execution_count +from google.adk.agents.graph.evaluation_metrics import state_contains_keys +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_metrics import EvalStatus +from google.genai import types +import pytest + + +@pytest.mark.asyncio +async def test_graph_path_match_exact(): + """Test graph_path_match metric with exact path match.""" + invocation = Invocation( + userContent=types.Content(parts=[types.Part(text="test")]), + finalResponse=types.Content(parts=[types.Part(text="response")]), + ) + + # Create metric with custom attributes (SimpleNamespace for testing) + # NOTE: In production, actual_graph_path would come from intermediate_data + metric = SimpleNamespace( + metric_name="graph_path", + expected_graph_path=["n1", "n2", "n3"], + actual_graph_path=["n1", "n2", "n3"], # Exact match + ) + + # Evaluate + result = graph_path_match(metric, [invocation], None, None) + + # Should pass with perfect score + assert result.overall_score == 1.0 + assert result.overall_eval_status == EvalStatus.PASSED + assert len(result.per_invocation_results) == 1 + assert result.per_invocation_results[0].score == 1.0 + + +@pytest.mark.asyncio +async def test_graph_path_match_partial(): + """Test graph_path_match with partial path match.""" + invocation = Invocation( + userContent=types.Content(parts=[types.Part(text="test")]), + finalResponse=types.Content(parts=[types.Part(text="response")]), + ) + + metric = SimpleNamespace( + metric_name="graph_path", + expected_graph_path=["n1", "n3", "n4"], + actual_graph_path=["n1", "n2"], # Partial match + ) + + result = graph_path_match(metric, [invocation], None, None) + + # Should have partial score (1 match out of 3 expected) + assert result.overall_score < 1.0 + assert result.overall_score > 0.0 # At least n1 matches + + +@pytest.mark.asyncio +async def test_state_contains_keys_exact(): + """Test state_contains_keys metric with exact match.""" + invocation = Invocation( + userContent=types.Content(parts=[types.Part(text="test")]), + finalResponse=types.Content(parts=[types.Part(text="done")]), + ) + + metric = SimpleNamespace( + metric_name="state_check", + expected_state={"key1": "value1", "key2": 42}, + actual_state={"key1": "value1", "key2": 42}, # Exact match + ) + + result = state_contains_keys(metric, [invocation], None, None) + + # Should pass with perfect score + assert result.overall_score == 1.0 + assert result.overall_eval_status == EvalStatus.PASSED + + +@pytest.mark.asyncio +async def test_state_contains_keys_partial(): + """Test state_contains_keys with partial match.""" + invocation = Invocation( + userContent=types.Content(parts=[types.Part(text="test")]), + finalResponse=types.Content(parts=[types.Part(text="done")]), + ) + + metric = SimpleNamespace( + metric_name="state_check", + expected_state={"key1": "value1", "key2": 42}, + actual_state={"key1": "value1", "key2": 999}, # key2 wrong + ) + + result = state_contains_keys(metric, [invocation], None, None) + + # Should have partial score (1 out of 2 keys match) + assert result.overall_score == 0.5 + assert result.overall_eval_status == EvalStatus.FAILED + + +@pytest.mark.asyncio +async def test_node_execution_count_exact(): + """Test node_execution_count with exact counts.""" + invocation = Invocation( + userContent=types.Content(parts=[types.Part(text="test")]), + finalResponse=types.Content(parts=[types.Part(text="done")]), + ) + + metric = SimpleNamespace( + metric_name="execution_count", + expected_node_counts={"loop_node": 3}, + actual_node_counts={"loop_node": 3}, # Exact match + ) + + result = node_execution_count(metric, [invocation], None, None) + + # Should pass if count matches + assert result.overall_score == 1.0 + assert result.overall_eval_status == EvalStatus.PASSED + + +@pytest.mark.asyncio +async def test_metrics_with_no_expected_data(): + """Test metrics skip when no expected data provided.""" + invocation = Invocation( + userContent=types.Content(parts=[types.Part(text="test")]), + finalResponse=types.Content(parts=[types.Part(text="done")]), + ) + + metric = SimpleNamespace(metric_name="test") # No custom fields + + # All metrics should return NOT_EVALUATED when no expected data + result1 = graph_path_match(metric, [invocation], None, None) + assert result1.overall_eval_status == EvalStatus.NOT_EVALUATED + + result2 = state_contains_keys(metric, [invocation], None, None) + assert result2.overall_eval_status == EvalStatus.NOT_EVALUATED + + result3 = node_execution_count(metric, [invocation], None, None) + assert result3.overall_eval_status == EvalStatus.NOT_EVALUATED + + +# --------------------------------------------------------------------------- +# InvocationEvents-based paths: exception handlers and None-result branches +# --------------------------------------------------------------------------- + + +def _make_invocation_with_event(text: str) -> "Invocation": + """Helper: Invocation whose intermediate_data carries a single text event.""" + from google.adk.evaluation.eval_case import InvocationEvent + from google.adk.evaluation.eval_case import InvocationEvents + + evt = InvocationEvent( + author="graph", + content=types.Content(parts=[types.Part(text=text)]), + ) + return Invocation( + userContent=types.Content(parts=[types.Part(text="q")]), + finalResponse=types.Content(parts=[types.Part(text="a")]), + intermediateData=InvocationEvents(invocationEvents=[evt]), + ) + + +def test_graph_path_match_malformed_metadata_exception_handled(): + """Lines 92-93: ast.literal_eval fails on malformed metadata → continue. + + The except block swallows the error and leaves actual_path as None, + ultimately producing FAILED status (expected path set, no actual path found). + """ + metric = SimpleNamespace( + metric_name="path", + expected_graph_path=["n1", "n2"], + # No actual_graph_path shortcut → will try InvocationEvents path + ) + inv = _make_invocation_with_event( + "[GraphMetadata] {this is: not valid python}" + ) + + result = graph_path_match(metric, [inv], None, None) + + # Parsing failed → actual_path stays None → FAILED + assert result.overall_eval_status == EvalStatus.FAILED + assert result.overall_score == 0.0 + + +def test_graph_path_match_actual_path_none_from_events(): + """Lines 106-107: expected_path set but actual_path is None → FAILED. + + Valid [GraphMetadata] event but the dict has no 'graph_path' key. + """ + metric = SimpleNamespace( + metric_name="path", + expected_graph_path=["n1", "n2"], + ) + # Valid Python dict in [GraphMetadata] but no 'graph_path' key + inv = _make_invocation_with_event("[GraphMetadata] {'graph_state': {'x': 1}}") + + result = graph_path_match(metric, [inv], None, None) + + assert result.overall_eval_status == EvalStatus.FAILED + assert result.overall_score == 0.0 + + +def test_state_contains_keys_actual_state_none_from_events(): + """Lines 217-218: expected_state set but actual_state is None → FAILED. + + Valid [GraphMetadata] event whose dict has no 'graph_state' key. + """ + metric = SimpleNamespace( + metric_name="state", + expected_state={"key1": "v1"}, + ) + inv = _make_invocation_with_event("[GraphMetadata] {'graph_path': ['n1']}") + + result = state_contains_keys(metric, [inv], None, None) + + assert result.overall_eval_status == EvalStatus.FAILED + assert result.overall_score == 0.0 + + +def test_state_contains_keys_malformed_metadata_exception_handled(): + """Lines 205-206: malformed [GraphMetadata] in state metric → continue.""" + metric = SimpleNamespace( + metric_name="state", + expected_state={"key1": "v1"}, + ) + inv = _make_invocation_with_event("[GraphMetadata] << invalid >>") + + result = state_contains_keys(metric, [inv], None, None) + + # Parsing error → actual_state stays None → FAILED + assert result.overall_eval_status == EvalStatus.FAILED + + +def test_node_execution_count_empty_actual_counts(): + """Lines 327-328: expected_counts set but no actual counts found → FAILED.""" + metric = SimpleNamespace( + metric_name="count", + expected_node_counts={"loop_node": 3}, + ) + # [GraphMetadata] present but no 'node_invocations' key + inv = _make_invocation_with_event("[GraphMetadata] {'graph_path': ['n1']}") + + result = node_execution_count(metric, [inv], None, None) + + assert result.overall_eval_status == EvalStatus.FAILED + assert result.overall_score == 0.0 + + +def test_node_execution_count_malformed_metadata_exception_handled(): + """Lines 317-318: malformed [GraphMetadata] in count metric → continue.""" + metric = SimpleNamespace( + metric_name="count", + expected_node_counts={"loop_node": 3}, + ) + inv = _make_invocation_with_event("[GraphMetadata] *** bad ***") + + result = node_execution_count(metric, [inv], None, None) + + # Exception swallowed → actual_counts empty → FAILED + assert result.overall_eval_status == EvalStatus.FAILED diff --git a/tests/unittests/agents/test_graph_evaluation_integration.py b/tests/unittests/agents/test_graph_evaluation_integration.py new file mode 100644 index 0000000000..c2cd37281b --- /dev/null +++ b/tests/unittests/agents/test_graph_evaluation_integration.py @@ -0,0 +1,371 @@ +"""Integration tests for GraphAgent evaluation with intermediate_data extraction.""" + +from types import SimpleNamespace + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphNode +from google.adk.agents.graph import GraphState +from google.adk.agents.graph.evaluation_metrics import graph_path_match +from google.adk.agents.graph.evaluation_metrics import node_execution_count +from google.adk.agents.graph.evaluation_metrics import state_contains_keys +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_case import InvocationEvent +from google.adk.evaluation.eval_case import InvocationEvents +from google.adk.evaluation.eval_metrics import EvalStatus +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types +import pytest + + +class SimpleAgent(BaseAgent): + """Simple test agent.""" + + def __init__(self, name: str, output: str): + super().__init__(name=name) + self._output = output + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=self._output)]), + ) + + +class StatefulAgent(BaseAgent): + """Agent that produces output to be mapped to state.""" + + def __init__(self, name: str, state_updates: dict): + super().__init__(name=name) + self._state_updates = state_updates + + async def _run_async_impl(self, ctx): + # Yield state updates as JSON string in event + import json + + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=json.dumps(self._state_updates))] + ), + ) + + +@pytest.mark.asyncio +async def test_graph_path_extraction_from_intermediate_data(): + """Test that graph_path is extracted from intermediate_data during actual GraphAgent execution.""" + # Build simple graph + graph = GraphAgent(name="test_graph") + agent1 = SimpleAgent(name="agent1", output="a1") + agent2 = SimpleAgent(name="agent2", output="a2") + agent3 = SimpleAgent(name="agent3", output="a3") + + graph.add_node(GraphNode(name="n1", agent=agent1)) + graph.add_node(GraphNode(name="n2", agent=agent2)) + graph.add_node(GraphNode(name="n3", agent=agent3)) + + graph.add_edge("n1", "n2") + graph.add_edge("n2", "n3") + graph.set_start("n1") + graph.set_end("n3") + + # Execute graph and collect events + session_service = InMemorySessionService() + runner = Runner( + app_name="test", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + events = [] + async for event in runner.run_async( + user_id="u1", + session_id="s1", + new_message=types.Content(parts=[types.Part(text="test")]), + ): + events.append(event) + + # Extract intermediate events (non-user, non-final events with content) + # This mimics what EvaluationGenerator does + intermediate_events = [] + final_response = None + user_content = None + + for event in events: + if event.author and event.author.lower() == "user": + user_content = event.content + continue + + if event.is_final_response(): + final_response = event.content + elif event.content and event.content.parts: + # Check if event has function_call, function_response, or text + for part in event.content.parts: + if part.function_call or part.function_response or part.text: + intermediate_events.append( + InvocationEvent(author=event.author, content=event.content) + ) + break + + # Create Invocation with intermediate_data + invocation = Invocation( + userContent=user_content + or types.Content(parts=[types.Part(text="test")]), + finalResponse=final_response + or types.Content(parts=[types.Part(text="done")]), + intermediateData=InvocationEvents(invocation_events=intermediate_events), + ) + + # Create metric with expected path + metric = SimpleNamespace( + metric_name="graph_path", + expected_graph_path=["n1", "n2", "n3"], + # No actual_graph_path - should extract from intermediate_data + ) + + # Evaluate + result = graph_path_match(metric, [invocation], None, None) + + # Should pass with perfect score if extraction works + assert ( + result.overall_score == 1.0 + ), f"Score: {result.overall_score}, Expected: 1.0" + assert result.overall_eval_status == EvalStatus.PASSED + assert len(result.per_invocation_results) == 1 + assert result.per_invocation_results[0].score == 1.0 + + +@pytest.mark.asyncio +async def test_node_execution_count_extraction_from_intermediate_data(): + """Test that node execution counts are extracted from intermediate_data.""" + graph = GraphAgent(name="test_graph", max_iterations=5) + + agent = SimpleAgent(name="agent", output="out") + graph.add_node(GraphNode(name="n1", agent=agent)) + graph.add_node(GraphNode(name="n2", agent=agent)) + + graph.set_start("n1") + graph.add_edge("n1", "n2") + graph.add_edge( + "n2", "n1", condition=lambda s: s.data.get("_graph_iteration", 0) < 2 + ) + graph.set_end("n1") + graph.set_end("n2") # n2 can also be an end node when loop exits + + # Execute graph and collect events + session_service = InMemorySessionService() + runner = Runner( + app_name="test", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + events = [] + async for event in runner.run_async( + user_id="u1", + session_id="s1", + new_message=types.Content(parts=[types.Part(text="test")]), + ): + events.append(event) + + # Extract intermediate events + intermediate_events = [] + final_response = None + user_content = None + + for event in events: + if event.author and event.author.lower() == "user": + user_content = event.content + continue + + if event.is_final_response(): + final_response = event.content + elif event.content and event.content.parts: + for part in event.content.parts: + if part.function_call or part.function_response or part.text: + intermediate_events.append( + InvocationEvent(author=event.author, content=event.content) + ) + break + + # Create Invocation + invocation = Invocation( + userContent=user_content + or types.Content(parts=[types.Part(text="test")]), + finalResponse=final_response + or types.Content(parts=[types.Part(text="done")]), + intermediateData=InvocationEvents(invocation_events=intermediate_events), + ) + + # Create metric - n1 and n2 should each execute at least once + # The exact counts depend on the loop logic + metric = SimpleNamespace( + metric_name="execution_count", + expected_node_counts={"n1": 2, "n2": 1}, + # No actual_node_counts - should extract from intermediate_data + ) + + # Evaluate + result = node_execution_count(metric, [invocation], None, None) + + # Should extract counts from intermediate_data + # Check that we got some score (extraction worked) + assert ( + result.overall_score >= 0.0 + ), "Should have extracted counts from intermediate_data" + assert len(result.per_invocation_results) == 1 + + +@pytest.mark.asyncio +async def test_graph_metadata_event_format(): + """Test that GraphAgent emits metadata events in the expected format.""" + graph = GraphAgent(name="test_graph") + agent = SimpleAgent(name="agent", output="test") + + graph.add_node(GraphNode(name="n1", agent=agent)) + graph.set_start("n1") + graph.set_end("n1") + + # Execute graph + session_service = InMemorySessionService() + runner = Runner( + app_name="test", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + metadata_events = [] + async for event in runner.run_async( + user_id="u1", + session_id="s1", + new_message=types.Content(parts=[types.Part(text="test")]), + ): + # Find metadata events + if event.author and "#metadata" in event.author: + metadata_events.append(event) + + # Should have at least one metadata event + assert len(metadata_events) > 0, "GraphAgent should emit metadata events" + + # Check metadata event format + metadata_event = metadata_events[0] + assert metadata_event.content is not None + assert len(metadata_event.content.parts) > 0 + + text = metadata_event.content.parts[0].text + assert ( + "[GraphMetadata]" in text + ), "Metadata event should have [GraphMetadata] marker" + + # Extract and parse metadata + import ast + + metadata_str = text.split("[GraphMetadata]", 1)[1].strip() + metadata = ast.literal_eval(metadata_str) + + # Verify expected fields + assert "graph_node" in metadata + assert "graph_iteration" in metadata + assert "graph_path" in metadata + assert "node_invocations" in metadata + assert "graph_state" in metadata + + assert isinstance(metadata["graph_path"], list) + assert isinstance(metadata["node_invocations"], dict) + assert isinstance(metadata["graph_state"], dict) + + +@pytest.mark.asyncio +async def test_state_extraction_from_intermediate_data(): + """Test that graph_state is extracted from intermediate_data for state_contains_keys metric.""" + # Build graph with stateful agents + graph = GraphAgent(name="test_graph") + agent1 = StatefulAgent( + name="agent1", state_updates={"count": 1, "status": "processing"} + ) + agent2 = StatefulAgent( + name="agent2", state_updates={"count": 2, "status": "done"} + ) + + # Define output mapper to parse JSON and update state + def state_mapper(output: str, state: GraphState) -> GraphState: + import json + + updates = json.loads(output) + return GraphState(data={**state.data, **updates}) + + graph.add_node(GraphNode(name="n1", agent=agent1, output_mapper=state_mapper)) + graph.add_node(GraphNode(name="n2", agent=agent2, output_mapper=state_mapper)) + + graph.add_edge("n1", "n2") + graph.set_start("n1") + graph.set_end("n2") + + # Execute graph and collect events + session_service = InMemorySessionService() + runner = Runner( + app_name="test", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + events = [] + async for event in runner.run_async( + user_id="u1", + session_id="s1", + new_message=types.Content(parts=[types.Part(text="test")]), + ): + events.append(event) + + # Extract intermediate events + intermediate_events = [] + final_response = None + user_content = None + + for event in events: + if event.author and event.author.lower() == "user": + user_content = event.content + continue + + if event.is_final_response(): + final_response = event.content + elif event.content and event.content.parts: + for part in event.content.parts: + if part.function_call or part.function_response or part.text: + intermediate_events.append( + InvocationEvent(author=event.author, content=event.content) + ) + break + + # Create Invocation + invocation = Invocation( + userContent=user_content + or types.Content(parts=[types.Part(text="test")]), + finalResponse=final_response + or types.Content(parts=[types.Part(text="done")]), + intermediateData=InvocationEvents(invocation_events=intermediate_events), + ) + + # Create metric - expect final state after both agents run + metric = SimpleNamespace( + metric_name="state_check", + expected_state={"count": 2, "status": "done"}, + # No actual_state - should extract from intermediate_data + ) + + # Evaluate + result = state_contains_keys(metric, [invocation], None, None) + + # Should pass with perfect score if extraction works + assert ( + result.overall_score == 1.0 + ), f"Score: {result.overall_score}, Expected: 1.0" + assert result.overall_eval_status == EvalStatus.PASSED + assert len(result.per_invocation_results) == 1 + assert result.per_invocation_results[0].score == 1.0 diff --git a/tests/unittests/agents/test_graph_resumability.py b/tests/unittests/agents/test_graph_resumability.py new file mode 100644 index 0000000000..c073a1b660 --- /dev/null +++ b/tests/unittests/agents/test_graph_resumability.py @@ -0,0 +1,681 @@ +"""Tests for GraphAgent ADK resumability integration. + +Verifies that GraphAgent properly integrates with ADK's built-in +resumability pattern: ctx.is_resumable guards, ctx.should_pause_invocation, +resume from saved state, and end_of_agent lifecycle. +""" + +from __future__ import annotations + +from typing import AsyncGenerator +from unittest.mock import patch + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph.graph_agent import GraphAgent +from google.adk.agents.graph.graph_agent_state import GraphAgentState +from google.adk.agents.graph.graph_node import GraphNode +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.run_config import RunConfig +from google.adk.apps import ResumabilityConfig +from google.adk.events.event import Event +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session +from google.genai import types +import pytest + + +# ============================================================================ +# Test Agents +# ============================================================================ + + +class SimpleTestAgent(BaseAgent): + """Test agent that yields predetermined responses.""" + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + def __init__(self, name: str, responses: list[str]): + super().__init__(name=name) + object.__setattr__(self, "_responses", responses) + object.__setattr__(self, "_call_count", 0) + + async def _run_async_impl(self, ctx) -> AsyncGenerator[Event, None]: + call_count = object.__getattribute__(self, "_call_count") + responses = object.__getattribute__(self, "_responses") + response = responses[min(call_count, len(responses) - 1)] + object.__setattr__(self, "_call_count", call_count + 1) + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=response)]), + ) + + @property + def call_count(self): + return object.__getattribute__(self, "_call_count") + + +class PausingAgent(BaseAgent): + """Agent that yields a long-running tool event to trigger pause.""" + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + def __init__(self, name: str): + super().__init__(name=name) + + async def _run_async_impl(self, ctx) -> AsyncGenerator[Event, None]: + # Yield an event with long_running_tool_ids to trigger pause + fc = types.FunctionCall( + id="tool_call_1", + name="long_running_tool", + args={}, + ) + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(function_call=fc)]), + long_running_tool_ids=["tool_call_1"], + ) + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _build_linear_graph( + name: str, + node_agents: list[BaseAgent], + node_names: list[str] | None = None, +) -> GraphAgent: + """Build a linear graph: node0 -> node1 -> ... -> nodeN.""" + if node_names is None: + node_names = [f"n{i}" for i in range(len(node_agents))] + graph = GraphAgent(name=name) + for nname, agent in zip(node_names, node_agents): + graph.add_node(GraphNode(name=nname, agent=agent)) + for i in range(len(node_names) - 1): + graph.add_edge(node_names[i], node_names[i + 1]) + graph.set_start(node_names[0]) + graph.set_end(node_names[-1]) + return graph + + +def _make_ctx( + agent: BaseAgent, + *, + resumable: bool = False, + agent_states: dict | None = None, +) -> InvocationContext: + """Create InvocationContext for testing.""" + svc = InMemorySessionService() + session = Session(id="test-session", appName="test", userId="test-user") + ctx = InvocationContext( + session=session, + session_service=svc, + invocation_id="inv-1", + agent=agent, + user_content=types.Content( + role="user", parts=[types.Part(text="test")] + ), + ) + if resumable: + ctx.resumability_config = ResumabilityConfig(is_resumable=True) + ctx.run_config = RunConfig() + if agent_states: + ctx.agent_states = agent_states + return ctx + + +async def _collect_events(graph: GraphAgent, ctx: InvocationContext) -> list[Event]: + """Collect all events from graph execution.""" + events = [] + async for event in graph._run_async_impl(ctx): + events.append(event) + return events + + +# ============================================================================ +# Tests: Fix 1 — Resume from saved node +# ============================================================================ + + +@pytest.mark.asyncio +class TestResumeFromSavedNode: + """Verify graph resumes from agent_state.current_node, not start_node.""" + + async def test_resume_from_saved_node(self): + """After pause at node B, resume starts from B (not A).""" + agent_a = SimpleTestAgent("a", ["output_a"]) + agent_b = SimpleTestAgent("b", ["output_b"]) + agent_c = SimpleTestAgent("c", ["output_c"]) + graph = _build_linear_graph("g", [agent_a, agent_b, agent_c], ["nA", "nB", "nC"]) + + # Simulate resumed context: agent_state says we're at nB, iteration 1 + saved_state = GraphAgentState(current_node="nB", iteration=1, path=["nA"]) + ctx = _make_ctx( + graph, + resumable=True, + agent_states={"g": saved_state.model_dump(mode="json")}, + ) + + events = await _collect_events(graph, ctx) + + # Agent A should NOT have been called (we resumed past it) + assert agent_a.call_count == 0 + # Agents B and C should have been called + assert agent_b.call_count == 1 + assert agent_c.call_count == 1 + + async def test_resume_with_removed_node_restarts(self): + """If saved node no longer exists, restart from start_node.""" + agent_a = SimpleTestAgent("a", ["output_a"]) + agent_b = SimpleTestAgent("b", ["output_b"]) + graph = _build_linear_graph("g", [agent_a, agent_b], ["nA", "nB"]) + + # Saved state references a node that doesn't exist + saved_state = GraphAgentState(current_node="nX_removed", iteration=3) + ctx = _make_ctx( + graph, + resumable=True, + agent_states={"g": saved_state.model_dump(mode="json")}, + ) + + events = await _collect_events(graph, ctx) + + # Should restart from beginning + assert agent_a.call_count == 1 + assert agent_b.call_count == 1 + + +# ============================================================================ +# Tests: Fix 2 — Guard state events with is_resumable +# ============================================================================ + + +@pytest.mark.asyncio +class TestStateEventGuards: + """Verify state events are only emitted when ctx.is_resumable=True.""" + + async def test_non_resumable_context_no_end_of_agent(self): + """When ctx.is_resumable=False, no end_of_agent event emitted. + + Per-iteration state events are always emitted (they serve rewind, + interrupts, telemetry — not just resumability). Only end_of_agent + is guarded by is_resumable. + """ + agent_a = SimpleTestAgent("a", ["out"]) + graph = _build_linear_graph("g", [agent_a], ["nA"]) + ctx = _make_ctx(graph, resumable=False) + + events = await _collect_events(graph, ctx) + + # end_of_agent should NOT be emitted for non-resumable + end_events = [ + e for e in events + if e.actions and e.actions.end_of_agent + ] + assert len(end_events) == 0 + + # But per-iteration state events ARE emitted (they serve other consumers) + state_events = [ + e for e in events + if e.actions and e.actions.agent_state is not None + ] + assert len(state_events) > 0 + + async def test_resumable_context_emits_state_events(self): + """When ctx.is_resumable=True, agent_state events are emitted.""" + agent_a = SimpleTestAgent("a", ["out"]) + agent_b = SimpleTestAgent("b", ["out"]) + graph = _build_linear_graph("g", [agent_a, agent_b], ["nA", "nB"]) + ctx = _make_ctx(graph, resumable=True) + + events = await _collect_events(graph, ctx) + + # Should have state events (at least end_of_agent) + state_events = [ + e for e in events + if e.actions and ( + e.actions.agent_state is not None or e.actions.end_of_agent + ) + ] + assert len(state_events) > 0 + + async def test_resume_skips_duplicate_state_event(self): + """First iteration after resume doesn't emit duplicate state event.""" + agent_b = SimpleTestAgent("b", ["output_b"]) + agent_c = SimpleTestAgent("c", ["output_c"]) + graph = _build_linear_graph( + "g", + [SimpleTestAgent("a", ["x"]), agent_b, agent_c], + ["nA", "nB", "nC"], + ) + + # Resume at nB + saved_state = GraphAgentState(current_node="nB", iteration=1, path=["nA"]) + ctx = _make_ctx( + graph, + resumable=True, + agent_states={"g": saved_state.model_dump(mode="json")}, + ) + + events = await _collect_events(graph, ctx) + + # Count state events for graph "g" (not end_of_agent, just agent_state) + state_events = [ + e for e in events + if e.author == "g" + and e.actions + and e.actions.agent_state is not None + and not e.actions.end_of_agent + ] + # For a 2-node execution (B, C) with resume skipping first: + # Only nC should emit a state event (nB is skipped as resume iteration) + assert len(state_events) == 1 + + +# ============================================================================ +# Tests: Fix 3 — Pause on long-running tool +# ============================================================================ + + +@pytest.mark.asyncio +class TestPauseOnLongRunningTool: + """Verify should_pause_invocation triggers pause in graph execution.""" + + async def test_pause_on_long_running_tool(self): + """should_pause_invocation triggers, execution stops, state preserved.""" + agent_a = SimpleTestAgent("a", ["output_a"]) + pausing_agent = PausingAgent("pauser") + agent_c = SimpleTestAgent("c", ["output_c"]) + graph = _build_linear_graph( + "g", [agent_a, pausing_agent, agent_c], ["nA", "nB", "nC"] + ) + + ctx = _make_ctx(graph, resumable=True) + + events = await _collect_events(graph, ctx) + + # Agent A should have run + assert agent_a.call_count == 1 + # Agent C should NOT have run (paused at B) + assert agent_c.call_count == 0 + + # The pause event (long_running_tool_ids) should be in events + pause_events = [ + e for e in events if e.long_running_tool_ids + ] + assert len(pause_events) == 1 + + # No end_of_agent should be emitted (we paused) + end_events = [ + e for e in events + if e.actions and e.actions.end_of_agent + ] + assert len(end_events) == 0 + + async def test_resume_after_pause_continues(self): + """Full pause->resume roundtrip: A runs, B pauses, resume->B runs, C runs.""" + agent_a = SimpleTestAgent("a", ["output_a"]) + pausing_agent = PausingAgent("pauser") + agent_b_resumed = SimpleTestAgent("b_resumed", ["output_b"]) + agent_c = SimpleTestAgent("c", ["output_c"]) + + # First run: A -> B(pauses) + graph1 = _build_linear_graph( + "g", [agent_a, pausing_agent, agent_c], ["nA", "nB", "nC"] + ) + ctx1 = _make_ctx(graph1, resumable=True) + events1 = await _collect_events(graph1, ctx1) + + # Verify paused at B + assert agent_a.call_count == 1 + assert agent_c.call_count == 0 + + # Build a new graph for resume where B won't pause anymore + agent_a2 = SimpleTestAgent("a2", ["output_a2"]) + agent_b2 = SimpleTestAgent("b2", ["output_b2"]) + agent_c2 = SimpleTestAgent("c2", ["output_c2"]) + graph2 = _build_linear_graph( + "g", [agent_a2, agent_b2, agent_c2], ["nA", "nB", "nC"] + ) + + # Resume from nB + saved_state = GraphAgentState(current_node="nB", iteration=1, path=["nA"]) + ctx2 = _make_ctx( + graph2, + resumable=True, + agent_states={"g": saved_state.model_dump(mode="json")}, + ) + events2 = await _collect_events(graph2, ctx2) + + # A should NOT run (resumed past it), B and C should run + assert agent_a2.call_count == 0 + assert agent_b2.call_count == 1 + assert agent_c2.call_count == 1 + + # end_of_agent should be emitted on completed run + end_events = [ + e for e in events2 + if e.actions and e.actions.end_of_agent + ] + assert len(end_events) == 1 + + +# ============================================================================ +# Tests: Fix 4 — end_of_agent guarded +# ============================================================================ + + +@pytest.mark.asyncio +class TestEndOfAgentGuard: + """Verify end_of_agent lifecycle events.""" + + async def test_end_of_agent_guarded(self): + """end_of_agent only emitted when ctx.is_resumable=True.""" + agent_a = SimpleTestAgent("a", ["out"]) + graph = _build_linear_graph("g", [agent_a], ["nA"]) + + # Non-resumable: no end_of_agent + ctx_nr = _make_ctx(graph, resumable=False) + events_nr = await _collect_events(graph, ctx_nr) + end_nr = [e for e in events_nr if e.actions and e.actions.end_of_agent] + assert len(end_nr) == 0 + + # Resumable: has end_of_agent + agent_a2 = SimpleTestAgent("a2", ["out"]) + graph2 = _build_linear_graph("g", [agent_a2], ["nA"]) + ctx_r = _make_ctx(graph2, resumable=True) + events_r = await _collect_events(graph2, ctx_r) + end_r = [e for e in events_r if e.actions and e.actions.end_of_agent] + assert len(end_r) == 1 + + async def test_end_of_agent_skipped_on_pause(self): + """No end_of_agent when paused mid-graph.""" + agent_a = SimpleTestAgent("a", ["out"]) + pausing = PausingAgent("pauser") + graph = _build_linear_graph("g", [agent_a, pausing], ["nA", "nB"]) + ctx = _make_ctx(graph, resumable=True) + + events = await _collect_events(graph, ctx) + + end_events = [e for e in events if e.actions and e.actions.end_of_agent] + assert len(end_events) == 0 + + +# ============================================================================ +# Tests: Fix 5 — Cycle resets sub-agent states +# ============================================================================ + + +@pytest.mark.asyncio +class TestCycleReset: + """Verify reset_sub_agent_states on cycle revisit.""" + + async def test_cycle_resets_sub_agent_state(self): + """Back-edge to visited node calls reset_sub_agent_states.""" + # Create a graph with a cycle: nA -> nB -> nA (conditional) + agent_a = SimpleTestAgent("a", ["go", "done"]) + agent_b = SimpleTestAgent("b", ["loop_back", "final"]) + + graph = GraphAgent(name="g", max_iterations=5) + graph.add_node(GraphNode(name="nA", agent=agent_a)) + graph.add_node(GraphNode(name="nB", agent=agent_b)) + + graph.add_edge("nA", "nB") + # B -> A (back-edge, always taken except when max_iterations) + graph.add_edge("nB", "nA") + graph.set_start("nA") + # No end node — relies on max_iterations + + ctx = _make_ctx(graph, resumable=True) + + # Track reset calls via patch + reset_calls = [] + original_reset = ctx.reset_sub_agent_states + + def tracking_reset(agent_name: str): + reset_calls.append(agent_name) + return original_reset(agent_name) + + with patch.object( + type(ctx), "reset_sub_agent_states", side_effect=tracking_reset + ): + events = await _collect_events(graph, ctx) + + # Agent A should be visited at least twice (nA -> nB -> nA) + assert agent_a.call_count >= 2 + # reset_sub_agent_states should have been called for the agent + # when revisiting nA (the second visit) + assert len(reset_calls) > 0 + # The reset should be for agent "a" (the node agent) + assert "a" in reset_calls + + +# ============================================================================ +# Tests: _get_resume_state method +# ============================================================================ + + +class TestGetResumeState: + """Unit tests for _get_resume_state helper.""" + + def test_fresh_start(self): + """No saved state returns start_node.""" + graph = _build_linear_graph("g", [SimpleTestAgent("a", ["x"])], ["nA"]) + state = GraphAgentState() + node, iteration, resuming = graph._get_resume_state(state) + assert node == "nA" + assert iteration == 0 + assert resuming is False + + def test_resume_from_valid_node(self): + """Saved state with valid node returns that node.""" + agent_a = SimpleTestAgent("a", ["x"]) + agent_b = SimpleTestAgent("b", ["x"]) + graph = _build_linear_graph("g", [agent_a, agent_b], ["nA", "nB"]) + state = GraphAgentState(current_node="nB", iteration=3) + node, iteration, resuming = graph._get_resume_state(state) + assert node == "nB" + assert iteration == 3 + assert resuming is True + + def test_resume_from_invalid_node_falls_back(self): + """Saved state with removed node falls back to start_node.""" + graph = _build_linear_graph("g", [SimpleTestAgent("a", ["x"])], ["nA"]) + state = GraphAgentState(current_node="nRemoved", iteration=5) + node, iteration, resuming = graph._get_resume_state(state) + assert node == "nA" + assert iteration == 0 + assert resuming is False + + +# ============================================================================ +# Tests: Integration — pause_invocation flag, state integrity, rewind compat +# ============================================================================ + + +@pytest.mark.asyncio +class TestPauseInvocationFlag: + """Verify pause_invocation=True/False is reflected in emitted events.""" + + async def test_pause_sets_flag_and_skips_final_events(self): + """When pause_invocation=True, no final response or end_of_agent.""" + agent_a = SimpleTestAgent("a", ["out"]) + pausing = PausingAgent("pauser") + agent_c = SimpleTestAgent("c", ["out"]) + graph = _build_linear_graph( + "g", [agent_a, pausing, agent_c], ["nA", "nB", "nC"] + ) + ctx = _make_ctx(graph, resumable=True) + events = await _collect_events(graph, ctx) + + # No final graph response (author=graph, state_delta with graph_data) + final_responses = [ + e for e in events + if e.author == "g" + and e.content + and e.actions + and e.actions.state_delta + and "graph_data" in (e.actions.state_delta or {}) + ] + assert len(final_responses) == 0 + + # No end_of_agent + end_events = [e for e in events if e.actions and e.actions.end_of_agent] + assert len(end_events) == 0 + + async def test_no_pause_emits_final_events(self): + """When pause_invocation=False (normal run), final response emitted.""" + agent_a = SimpleTestAgent("a", ["out"]) + graph = _build_linear_graph("g", [agent_a], ["nA"]) + ctx = _make_ctx(graph, resumable=True) + events = await _collect_events(graph, ctx) + + # Final graph response present + final_responses = [ + e for e in events + if e.author == "g" + and e.content + and e.actions + and e.actions.state_delta + and "graph_data" in (e.actions.state_delta or {}) + ] + assert len(final_responses) == 1 + + # end_of_agent present + end_events = [e for e in events if e.actions and e.actions.end_of_agent] + assert len(end_events) == 1 + + +@pytest.mark.asyncio +class TestStateIntegrity: + """Verify agent_state is consistent through pause/resume cycle.""" + + async def test_agent_state_tracks_current_node_on_pause(self): + """After pause, ctx.agent_states has current_node pointing to paused node.""" + agent_a = SimpleTestAgent("a", ["out"]) + pausing = PausingAgent("pauser") + graph = _build_linear_graph( + "g", [agent_a, pausing], ["nA", "nB"] + ) + ctx = _make_ctx(graph, resumable=True) + events = await _collect_events(graph, ctx) + + # The last agent_state event should have current_node = "nB" + state_events = [ + e for e in events + if e.author == "g" + and e.actions + and e.actions.agent_state is not None + and not e.actions.end_of_agent + ] + assert len(state_events) >= 1 + last_state = state_events[-1].actions.agent_state + assert last_state["current_node"] == "nB" + assert "nA" in last_state["path"] + + async def test_load_agent_state_roundtrip(self): + """State saved during run can be loaded back via _load_agent_state.""" + agent_a = SimpleTestAgent("a", ["out"]) + agent_b = SimpleTestAgent("b", ["out"]) + graph = _build_linear_graph("g", [agent_a, agent_b], ["nA", "nB"]) + ctx = _make_ctx(graph, resumable=True) + events = await _collect_events(graph, ctx) + + # Get the last state event's agent_state dict + state_events = [ + e for e in events + if e.author == "g" + and e.actions + and e.actions.agent_state is not None + and not e.actions.end_of_agent + ] + last_state_dict = state_events[-1].actions.agent_state + + # Simulate loading it back (as _load_agent_state does) + loaded = GraphAgentState.model_validate(last_state_dict) + assert loaded.current_node == "nB" + assert loaded.iteration == 2 + assert loaded.path == ["nA", "nB"] + + async def test_function_node_pause_is_falsy(self): + """Function nodes never set pause in output_holder — safe path.""" + def my_func(state, ctx): + return "func_output" + + graph = GraphAgent(name="g") + graph.add_node(GraphNode(name="nA", function=my_func)) + graph.set_start("nA") + graph.set_end("nA") + + ctx = _make_ctx(graph, resumable=True) + events = await _collect_events(graph, ctx) + + # Should complete normally with end_of_agent + end_events = [e for e in events if e.actions and e.actions.end_of_agent] + assert len(end_events) == 1 + + +@pytest.mark.asyncio +class TestRewindCompatibility: + """Verify rewind works with resumability state events.""" + + async def test_state_events_contain_node_invocations(self): + """State events include node_invocations needed by rewind_to_node.""" + agent_a = SimpleTestAgent("a", ["out"]) + agent_b = SimpleTestAgent("b", ["out"]) + graph = _build_linear_graph("g", [agent_a, agent_b], ["nA", "nB"]) + ctx = _make_ctx(graph, resumable=True) + events = await _collect_events(graph, ctx) + + # Find state events with node_invocations + state_events = [ + e for e in events + if e.author == "g" + and e.actions + and e.actions.agent_state is not None + and not e.actions.end_of_agent + ] + # Last state event should have node_invocations for both nodes + last_state = state_events[-1].actions.agent_state + assert "node_invocations" in last_state + assert "nA" in last_state["node_invocations"] + assert "nB" in last_state["node_invocations"] + + async def test_rewind_to_node_with_runner(self): + """Full integration: run graph via Runner, then rewind_to_node.""" + from google.adk.agents.graph import rewind_to_node + from google.adk.runners import Runner + + agent_a = SimpleTestAgent("step1", ["a_out"]) + agent_b = SimpleTestAgent("step2", ["b_out"]) + graph = _build_linear_graph("g", [agent_a, agent_b], ["step1", "step2"]) + + svc = InMemorySessionService() + runner = Runner(app_name="test", agent=graph, session_service=svc) + await svc.create_session(app_name="test", user_id="u", session_id="s") + + # Execute graph + events = [] + async for event in runner.run_async( + user_id="u", + session_id="s", + new_message=types.Content( + role="user", parts=[types.Part(text="go")] + ), + ): + events.append(event) + + # Both agents should have run + assert agent_a.call_count == 1 + assert agent_b.call_count == 1 + + # Rewind to step1 — should not raise + await rewind_to_node( + graph=graph, + session_service=svc, + app_name="test", + user_id="u", + session_id="s", + node_name="step1", + ) diff --git a/tests/unittests/agents/test_graph_rewind.py b/tests/unittests/agents/test_graph_rewind.py new file mode 100644 index 0000000000..c1b6dcc56d --- /dev/null +++ b/tests/unittests/agents/test_graph_rewind.py @@ -0,0 +1,607 @@ +"""Tests for GraphAgent + Rewind integration. + +Tests the tight coupling between GraphAgent and ADK's rewind feature, +enabling temporal navigation within graph workflows. +""" + +from typing import AsyncGenerator + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphNode +from google.adk.agents.graph import rewind_to_node +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types +import pytest +from typing_extensions import override + +# Test Fixtures (proper BaseAgent implementations per ADK guidelines) + + +class SimpleAgent(BaseAgent): + """Simple test agent that returns predictable output.""" + + def __init__(self, name: str, output: str): + super().__init__(name=name) + self._test_output = output + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Return test output.""" + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=self._test_output)]), + ) + + +class StatefulAgent(BaseAgent): + """Agent that modifies state.""" + + def __init__(self, name: str, state_key: str, state_value: str): + super().__init__(name=name) + self._state_key = state_key + self._state_value = state_value + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Modify state and return output.""" + ctx.session.state[self._state_key] = self._state_value + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part(text=f"Set {self._state_key}={self._state_value}") + ] + ), + ) + + +def _get_node_invocations_from_events(session) -> dict: + """Extract node_invocations from the latest agent_state event.""" + for event in reversed(session.events): + if ( + event.actions + and event.actions.agent_state + and "node_invocations" in (event.actions.agent_state or {}) + ): + return event.actions.agent_state["node_invocations"] + return {} + + +@pytest.fixture +def session_service(): + """Create in-memory session service.""" + return InMemorySessionService() + + +# Test 1: Basic Rewind to Node + + +@pytest.mark.asyncio +async def test_rewind_to_node_basic(session_service): + """Test rewinding to a specific node.""" + # Create graph with 3 nodes + graph = GraphAgent(name="test_graph") + graph.add_node( + GraphNode(name="step1", agent=SimpleAgent("agent1", "Output 1")) + ) + graph.add_node( + GraphNode(name="step2", agent=SimpleAgent("agent2", "Output 2")) + ) + graph.add_node( + GraphNode(name="step3", agent=SimpleAgent("agent3", "Output 3")) + ) + + graph.add_edge("step1", "step2") + graph.add_edge("step2", "step3") + graph.set_start("step1") + graph.set_end("step3") + + # Execute graph + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + events = [] + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="u1", session_id="s1", new_message=new_message + ): + events.append(event) + + # Get session + session = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s1" + ) + + # Verify all nodes executed + node_invocations = _get_node_invocations_from_events(session) + assert "step1" in node_invocations + assert "step2" in node_invocations + assert "step3" in node_invocations + assert len(node_invocations["step1"]) == 1 + assert len(node_invocations["step2"]) == 1 + assert len(node_invocations["step3"]) == 1 + + # Rewind to step2 - just verify it doesn't raise an error + # The actual state reversion is handled by ADK's rewind functionality + await rewind_to_node(graph, session_service, "test_app", "u1", "s1", "step2") + + # Verify we can still access the session after rewind + session_after_rewind = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s1" + ) + assert session_after_rewind is not None + + +# Test 2: Rewind in Loop (Multiple Invocations) + + +@pytest.mark.asyncio +async def test_rewind_to_node_in_loop(session_service): + """Test rewinding when node executed multiple times.""" + # Create graph with loop + graph = GraphAgent(name="test_graph", max_iterations=5) + graph.add_node( + GraphNode( + name="counter", + agent=StatefulAgent("counter_agent", "count", "incremented"), + ) + ) + graph.add_node( + GraphNode( + name="validator", + agent=SimpleAgent("validator_agent", "Validated"), + ) + ) + + # Create loop: counter -> validator -> counter (condition based) + graph.add_edge("counter", "validator") + graph.add_edge( + "validator", + "counter", + condition=lambda s: s.data.get("_graph_iteration", 0) < 3, + ) + graph.set_start("counter") + graph.set_end("validator") + + # Execute graph (should loop 3 times) + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="u1", session_id="s2", new_message=new_message + ): + pass + + # Get session + session = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s2" + ) + + # Verify multiple invocations + node_invocations = _get_node_invocations_from_events(session) + assert len(node_invocations.get("counter", [])) >= 2 + assert len(node_invocations.get("validator", [])) >= 2 + + # Rewind to 2nd invocation of counter + counter_invocations = node_invocations["counter"] + if len(counter_invocations) >= 2: + # Just verify rewind works with specific invocation index + await rewind_to_node( + graph, + session_service, + "test_app", + "u1", + "s2", + "counter", + invocation_index=1, + ) + + # Verify session still accessible + session_after = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s2" + ) + assert session_after is not None + + +# Test 3: Rewind Restores State + + +@pytest.mark.asyncio +async def test_rewind_restores_state(session_service): + """Test that rewind works with sequential nodes.""" + graph = GraphAgent(name="test_graph") + graph.add_node( + GraphNode(name="step_a", agent=SimpleAgent("agent_a", "Output A")) + ) + graph.add_node( + GraphNode(name="step_b", agent=SimpleAgent("agent_b", "Output B")) + ) + graph.add_node( + GraphNode(name="step_c", agent=SimpleAgent("agent_c", "Output C")) + ) + + graph.add_edge("step_a", "step_b") + graph.add_edge("step_b", "step_c") + graph.set_start("step_a") + graph.set_end("step_c") + + # Execute graph + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="u1", session_id="s3", new_message=new_message + ): + pass + + # Get session after full execution + session = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s3" + ) + + # Verify all nodes executed + node_invocations = _get_node_invocations_from_events(session) + assert "step_a" in node_invocations + assert "step_b" in node_invocations + assert "step_c" in node_invocations + + # Rewind to step_b - verify it works without error + await rewind_to_node(graph, session_service, "test_app", "u1", "s3", "step_b") + + # Verify session still accessible after rewind + session_after = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s3" + ) + assert session_after is not None + + +# Test 4: Rewind Invalid Node + + +@pytest.mark.asyncio +async def test_rewind_invalid_node(session_service): + """Test rewind fails for non-executed node.""" + graph = GraphAgent(name="test_graph") + graph.add_node( + GraphNode(name="step1", agent=SimpleAgent("agent1", "Output 1")) + ) + graph.add_node( + GraphNode(name="step2", agent=SimpleAgent("agent2", "Output 2")) + ) + + graph.add_edge("step1", "step2") + graph.set_start("step1") + graph.set_end("step2") + + # Execute graph + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="u1", session_id="s4", new_message=new_message + ): + pass + + session = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s4" + ) + + # Try to rewind to non-existent node + with pytest.raises(ValueError, match="has not been executed yet"): + await rewind_to_node( + graph, session_service, "test_app", "u1", "s4", "nonexistent_node" + ) + + +# Test 5: Rewind Invalid Invocation Index + + +@pytest.mark.asyncio +async def test_rewind_invalid_invocation_index(session_service): + """Test rewind fails for invalid invocation index.""" + graph = GraphAgent(name="test_graph") + graph.add_node( + GraphNode(name="step1", agent=SimpleAgent("agent1", "Output 1")) + ) + graph.set_start("step1") + graph.set_end("step1") + + # Execute graph + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="u1", session_id="s5", new_message=new_message + ): + pass + + session = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s5" + ) + + # Try to rewind with invalid index + with pytest.raises(ValueError, match="out of range"): + await rewind_to_node( + graph, + session_service, + "test_app", + "u1", + "s5", + "step1", + invocation_index=10, + ) + + +# Test 6: Rewind Preserves Path + + +@pytest.mark.asyncio +async def test_rewind_preserves_path(session_service): + """Test execution path preserved correctly after rewind.""" + graph = GraphAgent(name="test_graph") + graph.add_node(GraphNode(name="a", agent=SimpleAgent("agent_a", "A"))) + graph.add_node(GraphNode(name="b", agent=SimpleAgent("agent_b", "B"))) + graph.add_node(GraphNode(name="c", agent=SimpleAgent("agent_c", "C"))) + + graph.add_edge("a", "b") + graph.add_edge("b", "c") + graph.set_start("a") + graph.set_end("c") + + # Execute graph + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="u1", session_id="s6", new_message=new_message + ): + pass + + session = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s6" + ) + + # Path should be [a, b, c] + # Note: Path tracking depends on implementation details + + # Rewind to b - verify it works + await rewind_to_node(graph, session_service, "test_app", "u1", "s6", "b") + + # Verify session still accessible + session_after = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s6" + ) + assert session_after is not None + + +# Test 7: Rewind with Conditional Branching + + +@pytest.mark.asyncio +async def test_rewind_with_branching(session_service): + """Test rewind works with conditional branches.""" + graph = GraphAgent(name="test_graph") + graph.add_node( + GraphNode(name="start", agent=SimpleAgent("start_agent", "Start")) + ) + graph.add_node(GraphNode(name="branch_a", agent=SimpleAgent("a_agent", "A"))) + graph.add_node(GraphNode(name="branch_b", agent=SimpleAgent("b_agent", "B"))) + graph.add_node(GraphNode(name="end", agent=SimpleAgent("end_agent", "End"))) + + # Simple branching - always take branch_a + graph.add_edge("start", "branch_a", condition=lambda s: True) + graph.add_edge("start", "branch_b", condition=lambda s: False) + graph.add_edge("branch_a", "end") + graph.add_edge("branch_b", "end") + graph.set_start("start") + graph.set_end("end") + + # Execute graph + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="u1", session_id="s7", new_message=new_message + ): + pass + + session = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s7" + ) + + # Verify execution completed + node_invocations = _get_node_invocations_from_events(session) + assert "start" in node_invocations + + # Rewind to start - verify rewind works with conditional branching + await rewind_to_node(graph, session_service, "test_app", "u1", "s7", "start") + + # Verify session still accessible + session_after = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s7" + ) + assert session_after is not None + + +# Test 8: Rewind Negative Index + + +@pytest.mark.asyncio +async def test_rewind_negative_index(session_service): + """Test rewind with negative invocation index (most recent).""" + graph = GraphAgent(name="test_graph", max_iterations=5) + graph.add_node( + GraphNode( + name="repeater", + agent=SimpleAgent("repeater_agent", "Repeated"), + ) + ) + graph.add_edge( + "repeater", + "repeater", + condition=lambda s: s.data.get("_graph_iteration", 0) < 2, + ) + graph.set_start("repeater") + graph.set_end("repeater") + + # Execute graph (loops 2 times) + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="u1", session_id="s8", new_message=new_message + ): + pass + + session = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s8" + ) + + node_invocations = _get_node_invocations_from_events(session) + repeater_invocations = node_invocations.get("repeater", []) + + if len(repeater_invocations) >= 2: + # Rewind to last invocation (index -1) - test negative indexing + await rewind_to_node( + graph, + session_service, + "test_app", + "u1", + "s8", + "repeater", + invocation_index=-1, + ) + + # Verify session still accessible + session_after = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s8" + ) + assert session_after is not None + + +# Test 9: Rewind Integration with Session State + + +@pytest.mark.asyncio +async def test_rewind_session_state_integration(session_service): + """Test rewind properly integrates with session state tracking.""" + graph = GraphAgent(name="test_graph") + graph.add_node( + GraphNode(name="step1", agent=SimpleAgent("agent1", "Output 1")) + ) + graph.add_node( + GraphNode(name="step2", agent=SimpleAgent("agent2", "Output 2")) + ) + + graph.add_edge("step1", "step2") + graph.set_start("step1") + graph.set_end("step2") + + # Execute graph + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="u1", session_id="s9", new_message=new_message + ): + pass + + session = await session_service.get_session( + app_name="test_app", user_id="u1", session_id="s9" + ) + + # Verify node_invocations in agent_state events + node_invocations = _get_node_invocations_from_events(session) + assert isinstance(node_invocations, dict) + assert len(node_invocations) > 0 + + # Verify invocation IDs are tracked + for node_name, invocations in node_invocations.items(): + assert isinstance(invocations, list) + assert len(invocations) > 0 + # Each invocation should be a string (invocation ID) + for inv_id in invocations: + assert isinstance(inv_id, str) + + +# Test 10: Rewind with Empty Graph State + + +@pytest.mark.asyncio +async def test_rewind_empty_node_invocations(session_service): + """Test rewind handles case with no invocations gracefully.""" + graph = GraphAgent(name="test_graph") + graph.add_node( + GraphNode(name="step1", agent=SimpleAgent("agent1", "Output 1")) + ) + graph.set_start("step1") + graph.set_end("step1") + + # Create session without executing + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + session = await session_service.create_session( + app_name="test_app", user_id="u1", session_id="s10" + ) + + # Try to rewind without any execution + with pytest.raises(ValueError, match="has not been executed yet"): + await rewind_to_node( + graph, session_service, "test_app", "u1", "s10", "step1" + ) diff --git a/tests/unittests/agents/test_graph_routing.py b/tests/unittests/agents/test_graph_routing.py new file mode 100644 index 0000000000..c3c7dad3e9 --- /dev/null +++ b/tests/unittests/agents/test_graph_routing.py @@ -0,0 +1,483 @@ +"""Tests for enhanced graph routing (priority, weight, fallback).""" + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import EdgeCondition +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphNode +from google.adk.agents.graph import GraphState +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types +import pytest + + +class SimpleAgent(BaseAgent): + """Simple test agent that returns predictable output.""" + + def __init__(self, name: str, output: str): + super().__init__(name=name) + self._test_output = output + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=self._test_output)]), + ) + + +@pytest.fixture +async def session_service(): + """Create InMemorySessionService for tests.""" + return InMemorySessionService() + + +@pytest.mark.asyncio +async def test_priority_routing_basic(session_service): + """Test that higher priority edges are evaluated first.""" + graph = GraphAgent(name="test_graph") + + start = SimpleAgent(name="start", output="starting") + high_priority = SimpleAgent(name="high_priority", output="high_priority_path") + low_priority = SimpleAgent(name="low_priority", output="low_priority_path") + + graph.add_node(GraphNode(name="start", agent=start)) + graph.add_node(GraphNode(name="high_priority", agent=high_priority)) + graph.add_node(GraphNode(name="low_priority", agent=low_priority)) + + # Both edges have conditions that would match, but high priority should win + start_node = graph.nodes["start"] + start_node.edges = [ + EdgeCondition( + target_node="low_priority", + condition=lambda s: True, # Always matches + priority=1, # Lower priority + ), + EdgeCondition( + target_node="high_priority", + condition=lambda s: True, # Always matches + priority=10, # Higher priority - should be chosen + ), + ] + + graph.set_start("start") + graph.set_end("high_priority") + graph.set_end("low_priority") + + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + events = [] + async for event in runner.run_async( + user_id="u1", + session_id="s1", + new_message=types.Content(parts=[types.Part(text="Start")]), + ): + events.append(event) + + event_texts = [ + e.content.parts[0].text for e in events if e.content and e.content.parts + ] + + # Should route to high_priority path (priority=10 > priority=1) + assert "high_priority_path" in event_texts + assert "low_priority_path" not in event_texts + + +@pytest.mark.asyncio +async def test_fallback_edge_priority_zero(session_service): + """Test that priority=0 edges act as fallbacks.""" + graph = GraphAgent(name="test_graph") + + start = SimpleAgent(name="start", output="starting") + conditional = SimpleAgent(name="conditional", output="conditional_path") + fallback = SimpleAgent(name="fallback", output="fallback_path") + + graph.add_node(GraphNode(name="start", agent=start)) + graph.add_node(GraphNode(name="conditional", agent=conditional)) + graph.add_node(GraphNode(name="fallback", agent=fallback)) + + # Add conditional edge (won't match) and fallback edge + start_node = graph.nodes["start"] + start_node.edges = [ + EdgeCondition( + target_node="conditional", + condition=lambda s: s.data.get("trigger_condition", False), + priority=5, + ), + EdgeCondition( + target_node="fallback", + priority=0, # Fallback - always matches if no higher priority matched + ), + ] + + graph.set_start("start") + graph.set_end("conditional") + graph.set_end("fallback") + + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + events = [] + async for event in runner.run_async( + user_id="u1", + session_id="s1", + new_message=types.Content(parts=[types.Part(text="Start")]), + ): + events.append(event) + + event_texts = [ + e.content.parts[0].text for e in events if e.content and e.content.parts + ] + + # Should route to fallback (conditional doesn't match) + assert "fallback_path" in event_texts + assert "conditional_path" not in event_texts + + +@pytest.mark.asyncio +async def test_weighted_routing(session_service): + """Test weighted random selection among matching edges.""" + graph = GraphAgent(name="test_graph") + + start = SimpleAgent(name="start", output="starting") + path_a = SimpleAgent(name="path_a", output="path_a") + path_b = SimpleAgent(name="path_b", output="path_b") + + graph.add_node(GraphNode(name="start", agent=start)) + graph.add_node(GraphNode(name="path_a", agent=path_a)) + graph.add_node(GraphNode(name="path_b", agent=path_b)) + + # Both edges match, but path_a has much higher weight + start_node = graph.nodes["start"] + start_node.edges = [ + EdgeCondition( + target_node="path_a", + condition=lambda s: True, + priority=1, + weight=0.9, # 90% probability + ), + EdgeCondition( + target_node="path_b", + condition=lambda s: True, + priority=1, # Same priority + weight=0.1, # 10% probability + ), + ] + + graph.set_start("start") + graph.set_end("path_a") + graph.set_end("path_b") + + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + # Run multiple times to verify weighted distribution + path_a_count = 0 + path_b_count = 0 + trials = 50 + + for i in range(trials): + events = [] + async for event in runner.run_async( + user_id="u1", + session_id=f"s_{i}", + new_message=types.Content(parts=[types.Part(text="Start")]), + ): + events.append(event) + + event_texts = [ + e.content.parts[0].text for e in events if e.content and e.content.parts + ] + + if "path_a" in event_texts: + path_a_count += 1 + if "path_b" in event_texts: + path_b_count += 1 + + # With 0.9/0.1 weights, expect roughly 45/5 split (90%/10%) + # Allow some variance (at least 35/50 = 70% for path_a) + assert ( + path_a_count > trials * 0.7 + ), f"Expected path_a > 70%, got {path_a_count}/{trials}" + assert ( + path_b_count < trials * 0.3 + ), f"Expected path_b < 30%, got {path_b_count}/{trials}" + + +@pytest.mark.asyncio +async def test_priority_and_condition(session_service): + """Test that priority works correctly with conditions.""" + graph = GraphAgent(name="test_graph") + + start = SimpleAgent(name="start", output="starting") + high_match = SimpleAgent(name="high_match", output="high_match_path") + low_no_match = SimpleAgent(name="low_no_match", output="low_no_match_path") + fallback = SimpleAgent(name="fallback", output="fallback_path") + + graph.add_node(GraphNode(name="start", agent=start)) + graph.add_node(GraphNode(name="high_match", agent=high_match)) + graph.add_node(GraphNode(name="low_no_match", agent=low_no_match)) + graph.add_node(GraphNode(name="fallback", agent=fallback)) + + # Set state with score=0.5 + start_node = graph.nodes["start"] + + # Define output mapper to set score in state + def set_score(output, state): + new_state = GraphState(data=state.data.copy().copy()) + new_state.data["score"] = 0.5 + return new_state + + start_node.output_mapper = set_score + + start_node.edges = [ + EdgeCondition( + target_node="low_no_match", + condition=lambda s: s.data.get("score", 0) + > 0.8, # Won't match (0.5 < 0.8) + priority=20, # Highest priority but won't match + ), + EdgeCondition( + target_node="high_match", + condition=lambda s: s.data.get("score", 0) + > 0.3, # Will match (0.5 > 0.3) + priority=10, # Medium priority and matches + ), + EdgeCondition( + target_node="fallback", + priority=0, # Fallback + ), + ] + + graph.set_start("start") + graph.set_end("high_match") + graph.set_end("low_no_match") + graph.set_end("fallback") + + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + events = [] + async for event in runner.run_async( + user_id="u1", + session_id="s1", + new_message=types.Content(parts=[types.Part(text="Start")]), + ): + events.append(event) + + event_texts = [ + e.content.parts[0].text for e in events if e.content and e.content.parts + ] + + # Should route to high_match (highest priority that matches) + assert "high_match_path" in event_texts + assert "low_no_match_path" not in event_texts + assert "fallback_path" not in event_texts + + +@pytest.mark.asyncio +async def test_backward_compatibility(session_service): + """Test that existing code without priority/weight still works.""" + graph = GraphAgent(name="test_graph") + + start = SimpleAgent(name="start", output="starting") + next_node = SimpleAgent(name="next", output="next_output") + + graph.add_node(GraphNode(name="start", agent=start)) + graph.add_node(GraphNode(name="next", agent=next_node)) + + # Old style: just target_node and condition (no priority/weight) + start_node = graph.nodes["start"] + start_node.edges = [ + EdgeCondition(target_node="next", condition=lambda s: True), + ] + + graph.set_start("start") + graph.set_end("next") + + runner = Runner( + app_name="test_app", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + events = [] + async for event in runner.run_async( + user_id="u1", + session_id="s1", + new_message=types.Content(parts=[types.Part(text="Start")]), + ): + events.append(event) + + event_texts = [ + e.content.parts[0].text for e in events if e.content and e.content.parts + ] + + # Should work exactly as before + assert "next_output" in event_texts + + +@pytest.mark.asyncio +async def test_edge_repr(): + """Test EdgeCondition string representation.""" + edge = EdgeCondition( + target_node="test_target", + condition=lambda s: True, + priority=5, + weight=0.7, + ) + + repr_str = repr(edge) + assert "test_target" in repr_str + assert "priority=5" in repr_str + assert "weight=0.7" in repr_str + assert "has_condition=True" in repr_str + + +# --------------------------------------------------------------------------- +# Edge-weight selection edge cases (graph_node.py lines 176, 187) +# --------------------------------------------------------------------------- + + +def test_weighted_edges_all_zero_weight(): + """Line 176: when all matching edges have weight=0.0, pick the first one. + + all_same_weight is False (weights differ: -0.5 vs 0.5) but total_weight == 0, + so the code falls into the total_weight == 0 branch and picks matching_edges[0]. + """ + from google.adk.agents.graph.graph_state import GraphState + + node = GraphNode(name="n", function=lambda s, c: "x") + # Two edges at same priority that both match, but weights are asymmetric around 0 + node.edges = [ + EdgeCondition( + target_node="first", condition=lambda s: True, priority=1, weight=-0.5 + ), + EdgeCondition( + target_node="second", condition=lambda s: True, priority=1, weight=0.5 + ), + ] + + state = GraphState() + # With different weights but total == 0, should always pick "first" + for _ in range(20): + result = node.get_next_node(state) + assert ( + result == "first" + ), f"Expected 'first' (total_weight=0 branch), got '{result}'" + + +def test_weighted_edges_float_fallback(): + """Line 187: safety fallback when rand_value > cumulative after loop. + + Achieved by mocking random.random() to return 2.0 (> 1.0), making + rand_value = 2.0 * total_weight > total_weight so no edge matches in + the weighted-choice loop and we reach the fallback return. + """ + from unittest.mock import patch + + from google.adk.agents.graph.graph_state import GraphState + + node = GraphNode(name="n", function=lambda s, c: "x") + # Two edges at same priority, different weights (so weighted path is used) + node.edges = [ + EdgeCondition( + target_node="a", condition=lambda s: True, priority=1, weight=0.4 + ), + EdgeCondition( + target_node="b", condition=lambda s: True, priority=1, weight=0.6 + ), + ] + + state = GraphState() + # random.random() returns 2.0 → rand_value = 2.0 * 1.0 = 2.0 > cumulative(1.0) + # The for-loop completes without returning → fallback at line 187 returns last edge + with patch("random.random", return_value=2.0): + result = node.get_next_node(state) + assert result == "b", f"Expected fallback to last edge 'b', got '{result}'" + + +class TestConditionalRouting: + """Unit tests for EdgeCondition.should_route and GraphNode.get_next_node.""" + + def test_unconditional_edge(self): + """Edge without condition always routes.""" + edge = EdgeCondition(target_node="next") + state = GraphState(data={}) + + assert edge.should_route(state) is True + + def test_conditional_edge_true(self): + """Conditional edge routes when condition is true.""" + edge = EdgeCondition( + target_node="next", condition=lambda s: s.data.get("value") > 10 + ) + state = GraphState(data={"value": 15}) + + assert edge.should_route(state) is True + + def test_conditional_edge_false(self): + """Conditional edge does not route when condition is false.""" + edge = EdgeCondition( + target_node="next", condition=lambda s: s.data.get("value") > 10 + ) + state = GraphState(data={"value": 5}) + + assert edge.should_route(state) is False + + def test_node_routing_multiple_edges_picks_first_match(self): + """Node with multiple conditional edges selects the first matching edge.""" + node = GraphNode(name="router", agent=SimpleAgent("agent", "ok")) + + node.add_edge("path1", condition=lambda s: s.data.get("type") == "A") + node.add_edge("path2", condition=lambda s: s.data.get("type") == "B") + node.add_edge("default", condition=lambda s: True) + + assert node.get_next_node(GraphState(data={"type": "A"})) == "path1" + assert node.get_next_node(GraphState(data={"type": "B"})) == "path2" + assert node.get_next_node(GraphState(data={"type": "C"})) == "default" + + +def test_edge_sort_cache_invalidation(): + """Edge sort cache is built once and invalidated on add_edge.""" + node = GraphNode( + name="test", + agent=SimpleAgent(name="agent", output="out"), + ) + + assert node._sorted_edges_cache is None + + node.add_edge("path1", condition=lambda s: True) + assert node._sorted_edges_cache is None + + node.get_next_node(GraphState(data={})) + assert node._sorted_edges_cache is not None + cached = node._sorted_edges_cache + + node.get_next_node(GraphState(data={})) + assert node._sorted_edges_cache is cached + + node.add_edge("path2", condition=lambda s: False) + assert node._sorted_edges_cache is None + + node.get_next_node(GraphState(data={})) + assert node._sorted_edges_cache is not None + assert node._sorted_edges_cache is not cached diff --git a/tests/unittests/agents/test_graph_state.py b/tests/unittests/agents/test_graph_state.py new file mode 100644 index 0000000000..8d29306ee4 --- /dev/null +++ b/tests/unittests/agents/test_graph_state.py @@ -0,0 +1,353 @@ +"""Tests for GraphState accessors and state_utils parsing functions.""" + +from __future__ import annotations + +import json +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from google.adk.agents.graph.graph_state import GraphState +from google.adk.agents.graph.graph_state import PydanticJSONEncoder +from google.adk.agents.graph.state_utils import parse_state_value +from google.adk.agents.graph.state_utils import state_value_as_dict +from google.adk.agents.graph.state_utils import state_value_as_str +from pydantic import BaseModel +import pytest + +# ── Test models ────────────────────────────────────────────────────── + + +class ReviewResult(BaseModel): + decision: str + reasoning: str + + +class NestedModel(BaseModel): + name: str + tags: List[str] = [] + metadata: Optional[Dict[str, Any]] = None + + +class StrictModel(BaseModel): + value: int + label: str + + +# ── parse_state_value ──────────────────────────────────────────────── + + +class TestParseStateValue: + + def test_dict_to_model(self): + raw = {"decision": "approve", "reasoning": "looks good"} + result = parse_state_value(raw, ReviewResult) + assert result is not None + assert result.decision == "approve" + assert result.reasoning == "looks good" + + def test_json_string_to_model(self): + raw = json.dumps({"decision": "reject", "reasoning": "needs work"}) + result = parse_state_value(raw, ReviewResult) + assert result is not None + assert result.decision == "reject" + + def test_none_returns_default(self): + default = ReviewResult(decision="skip", reasoning="default") + result = parse_state_value(None, ReviewResult, default=default) + assert result is default + + def test_none_returns_none_when_no_default(self): + result = parse_state_value(None, ReviewResult) + assert result is None + + def test_invalid_dict_returns_default(self): + raw = {"wrong_key": 123} + default = ReviewResult(decision="fallback", reasoning="bad data") + result = parse_state_value(raw, ReviewResult, default=default) + assert result is default + + def test_invalid_json_string_returns_default(self): + raw = "not valid json at all" + result = parse_state_value(raw, ReviewResult) + assert result is None + + def test_unexpected_type_returns_default(self): + result = parse_state_value(42, ReviewResult) + assert result is None + + def test_unexpected_type_with_default(self): + default = ReviewResult(decision="x", reasoning="y") + result = parse_state_value(42, ReviewResult, default=default) + assert result is default + + def test_nested_model(self): + raw = {"name": "test", "tags": ["a", "b"], "metadata": {"key": "val"}} + result = parse_state_value(raw, NestedModel) + assert result is not None + assert result.name == "test" + assert result.tags == ["a", "b"] + assert result.metadata == {"key": "val"} + + def test_json_string_nested_model(self): + raw = json.dumps({"name": "test", "tags": ["x"]}) + result = parse_state_value(raw, NestedModel) + assert result is not None + assert result.name == "test" + + def test_empty_dict_with_required_fields(self): + result = parse_state_value({}, StrictModel) + assert result is None + + def test_list_type_returns_default(self): + result = parse_state_value([1, 2, 3], ReviewResult) + assert result is None + + def test_bool_type_returns_default(self): + result = parse_state_value(True, ReviewResult) + assert result is None + + +# ── state_value_as_str ─────────────────────────────────────────────── + + +class TestStateValueAsStr: + + def test_string_value(self): + assert state_value_as_str("hello") == "hello" + + def test_int_value(self): + assert state_value_as_str(42) == "42" + + def test_float_value(self): + assert state_value_as_str(3.14) == "3.14" + + def test_none_returns_default(self): + assert state_value_as_str(None) == "" + + def test_none_custom_default(self): + assert state_value_as_str(None, "N/A") == "N/A" + + def test_dict_value(self): + result = state_value_as_str({"key": "val"}) + assert "key" in result + + def test_list_value(self): + result = state_value_as_str([1, 2, 3]) + assert result == "[1, 2, 3]" + + def test_bool_value(self): + assert state_value_as_str(True) == "True" + assert state_value_as_str(False) == "False" + + def test_empty_string(self): + assert state_value_as_str("") == "" + + +# ── state_value_as_dict ────────────────────────────────────────────── + + +class TestStateValueAsDict: + + def test_dict_value(self): + raw = {"key": "val", "num": 1} + assert state_value_as_dict(raw) == {"key": "val", "num": 1} + + def test_json_string(self): + raw = json.dumps({"a": 1, "b": 2}) + assert state_value_as_dict(raw) == {"a": 1, "b": 2} + + def test_invalid_json_returns_default(self): + assert state_value_as_dict("not json") == {} + + def test_invalid_json_custom_default(self): + default = {"mode": "auto"} + assert state_value_as_dict("not json", default=default) == default + + def test_none_returns_default(self): + assert state_value_as_dict(None) == {} + + def test_none_custom_default(self): + assert state_value_as_dict(None, default={"x": 1}) == {"x": 1} + + def test_non_dict_non_string_returns_default(self): + assert state_value_as_dict(42) == {} + + def test_json_list_string_returns_default(self): + """JSON list is valid JSON but not a dict.""" + assert state_value_as_dict("[1, 2, 3]") == {} + + def test_empty_dict(self): + assert state_value_as_dict({}) == {} + + def test_empty_json_object(self): + assert state_value_as_dict("{}") == {} + + def test_nested_dict(self): + raw = {"outer": {"inner": "val"}} + result = state_value_as_dict(raw) + assert result["outer"]["inner"] == "val" + + +# ── GraphState.get_parsed ─────────────────────────────────────────── + + +class TestGraphStateGetParsed: + + def test_dict_value(self): + state = GraphState( + data={"review": {"decision": "approve", "reasoning": "OK"}} + ) + result = state.get_parsed("review", ReviewResult) + assert result is not None + assert result.decision == "approve" + + def test_json_string_value(self): + state = GraphState( + data={"review": json.dumps({"decision": "reject", "reasoning": "bad"})} + ) + result = state.get_parsed("review", ReviewResult) + assert result is not None + assert result.decision == "reject" + + def test_missing_key(self): + state = GraphState(data={}) + result = state.get_parsed("missing", ReviewResult) + assert result is None + + def test_missing_key_with_default(self): + default = ReviewResult(decision="default", reasoning="n/a") + state = GraphState(data={}) + result = state.get_parsed("missing", ReviewResult, default=default) + assert result is default + + def test_invalid_dict(self): + state = GraphState(data={"bad": {"wrong": 123}}) + result = state.get_parsed("bad", ReviewResult) + assert result is None + + def test_unexpected_type(self): + state = GraphState(data={"val": 42}) + result = state.get_parsed("val", ReviewResult) + assert result is None + + +# ── GraphState.get_str ────────────────────────────────────────────── + + +class TestGraphStateGetStr: + + def test_string_value(self): + state = GraphState(data={"text": "hello"}) + assert state.get_str("text") == "hello" + + def test_non_string_value(self): + state = GraphState(data={"num": 42}) + assert state.get_str("num") == "42" + + def test_missing_key(self): + state = GraphState(data={}) + assert state.get_str("missing") == "" + + def test_missing_key_custom_default(self): + state = GraphState(data={}) + assert state.get_str("missing", default="N/A") == "N/A" + + def test_none_value(self): + state = GraphState(data={"val": None}) + assert state.get_str("val") == "" + + +# ── GraphState.get_dict ───────────────────────────────────────────── + + +class TestGraphStateGetDict: + + def test_dict_value(self): + state = GraphState(data={"config": {"mode": "fast"}}) + assert state.get_dict("config") == {"mode": "fast"} + + def test_json_string_value(self): + state = GraphState(data={"config": '{"mode": "fast"}'}) + assert state.get_dict("config") == {"mode": "fast"} + + def test_invalid_json_string(self): + state = GraphState(data={"config": "not json"}) + assert state.get_dict("config") == {} + + def test_missing_key(self): + state = GraphState(data={}) + assert state.get_dict("missing") == {} + + def test_missing_key_custom_default(self): + state = GraphState(data={}) + assert state.get_dict("missing", default={"x": 1}) == {"x": 1} + + def test_non_dict_non_string(self): + state = GraphState(data={"val": 42}) + assert state.get_dict("val") == {} + + +# ── GraphState.data_to_json ───────────────────────────────────────── + + +class ChildModel(BaseModel): + score: float + label: str + + +class TestDataToJson: + + def test_plain_dict(self): + state = GraphState(data={"key": "val", "num": 1}) + result = state.data_to_json() + parsed = json.loads(result) + assert parsed == {"key": "val", "num": 1} + + def test_nested_pydantic(self): + state = GraphState(data={"result": ChildModel(score=0.95, label="good")}) + result = state.data_to_json() + parsed = json.loads(result) + assert parsed["result"]["score"] == 0.95 + assert parsed["result"]["label"] == "good" + + def test_empty_data(self): + state = GraphState() + result = state.data_to_json() + assert json.loads(result) == {} + + def test_indent(self): + state = GraphState(data={"a": 1}) + compact = state.data_to_json(indent=0) + assert "\n" in compact # indent=0 still adds newlines + no_indent = json.dumps({"a": 1}) + assert len(compact) >= len(no_indent) + + +# ── PydanticJSONEncoder ───────────────────────────────────────────── + + +class TestPydanticJSONEncoder: + + def test_encodes_model(self): + model = ReviewResult(decision="yes", reasoning="ok") + result = json.dumps(model, cls=PydanticJSONEncoder) + parsed = json.loads(result) + assert parsed["decision"] == "yes" + + def test_encodes_nested_model(self): + data = {"child": ChildModel(score=0.5, label="mid")} + result = json.dumps(data, cls=PydanticJSONEncoder) + parsed = json.loads(result) + assert parsed["child"]["score"] == 0.5 + + def test_plain_types(self): + data = {"str": "a", "int": 1, "float": 1.5, "bool": True, "null": None} + result = json.dumps(data, cls=PydanticJSONEncoder) + parsed = json.loads(result) + assert parsed == data + + def test_non_serializable_raises(self): + with pytest.raises(TypeError): + json.dumps({"obj": object()}, cls=PydanticJSONEncoder) diff --git a/tests/unittests/agents/test_graph_state_management.py b/tests/unittests/agents/test_graph_state_management.py new file mode 100644 index 0000000000..8c138a9066 --- /dev/null +++ b/tests/unittests/agents/test_graph_state_management.py @@ -0,0 +1,563 @@ +"""Comprehensive state management tests for GraphAgent. + +Tests all state reducers and state propagation patterns: +- StateReducer.OVERWRITE +- StateReducer.APPEND +- StateReducer.SUM +- StateReducer.CUSTOM +- State propagation through graph +- State isolation in parallel execution + +These are unit tests focusing on state reducer logic, not full integration tests. +Full integration tests are in test_graph_agent.py and test_parallel_execution.py. +""" + +from typing import Any +from typing import Dict + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphNode +from google.adk.agents.graph import GraphState +from google.adk.agents.graph import StateReducer +from google.adk.events.event import Event +from google.genai import types +import pytest + +# ============================================================================ +# Test Agents (Real BaseAgent implementations per ADK guidelines) +# ============================================================================ + + +class TextAgent(BaseAgent): + """Agent that outputs text.""" + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + def __init__(self, name: str, text: str): + super().__init__(name=name) + object.__setattr__(self, "_text", text) + + async def _run_async_impl(self, ctx): + """Output text.""" + text = object.__getattribute__(self, "_text") + yield Event( + author=self.name, content=types.Content(parts=[types.Part(text=text)]) + ) + + +# ============================================================================ +# Test: StateReducer.OVERWRITE +# ============================================================================ + + +@pytest.mark.asyncio +class TestStateReducerOverwrite: + """Test OVERWRITE reducer - replaces existing value.""" + + async def test_overwrite_reducer_basic(self): + """Test basic OVERWRITE behavior - new value replaces old.""" + node = GraphNode( + name="test_node", + agent=TextAgent("agent", "new"), + reducer=StateReducer.OVERWRITE, + ) + + # Initial state with existing value + state = GraphState(data={"test_node": "old"}) + + # Apply output with OVERWRITE reducer + new_state = node._default_output_mapper("new", state) + + # Verify value was overwritten + assert new_state.data["test_node"] == "new" + assert "old" not in str(new_state.data["test_node"]) + + async def test_overwrite_reducer_new_key(self): + """Test OVERWRITE creates key if it doesn't exist.""" + node = GraphNode( + name="new_key", + agent=TextAgent("agent", "value"), + reducer=StateReducer.OVERWRITE, + ) + + state = GraphState(data={}) + new_state = node._default_output_mapper("value", state) + + assert new_state.data["new_key"] == "value" + + async def test_overwrite_preserves_other_keys(self): + """Test OVERWRITE doesn't affect other state keys.""" + node = GraphNode( + name="key1", + agent=TextAgent("agent", "new"), + reducer=StateReducer.OVERWRITE, + ) + + state = GraphState(data={"key1": "old", "key2": "preserved"}) + new_state = node._default_output_mapper("new", state) + + assert new_state.data["key1"] == "new" + assert new_state.data["key2"] == "preserved" + + +# ============================================================================ +# Test: StateReducer.APPEND +# ============================================================================ + + +@pytest.mark.asyncio +class TestStateReducerAppend: + """Test APPEND reducer - appends to list.""" + + async def test_append_reducer_creates_list(self): + """Test APPEND reducer creates list when key doesn't exist.""" + node = GraphNode( + name="collector", + agent=TextAgent("agent", "item"), + reducer=StateReducer.APPEND, + ) + + state = GraphState(data={}) + new_state = node._default_output_mapper("first_item", state) + + # Verify list was created with first item + assert "collector" in new_state.data + assert isinstance(new_state.data["collector"], list) + assert new_state.data["collector"] == ["first_item"] + + async def test_append_reducer_appends_to_existing_list(self): + """Test APPEND adds to existing list.""" + node = GraphNode( + name="collector", + agent=TextAgent("agent", "item"), + reducer=StateReducer.APPEND, + ) + + state = GraphState(data={"collector": ["item1", "item2"]}) + new_state = node._default_output_mapper("item3", state) + + assert new_state.data["collector"] == ["item1", "item2", "item3"] + + async def test_append_multiple_values(self): + """Test APPEND accumulates multiple values.""" + node = GraphNode( + name="results", + agent=TextAgent("agent", "item"), + reducer=StateReducer.APPEND, + ) + + # First append + state1 = GraphState(data={}) + state2 = node._default_output_mapper("first", state1) + + # Second append + state3 = node._default_output_mapper("second", state2) + + # Third append + state4 = node._default_output_mapper("third", state3) + + assert state4.data["results"] == ["first", "second", "third"] + + +# ============================================================================ +# Test: StateReducer.SUM +# ============================================================================ + + +@pytest.mark.asyncio +class TestStateReducerSum: + """Test SUM reducer - accumulates values via + operator.""" + + async def test_sum_reducer_string_concatenation(self): + """Test SUM reducer concatenates string outputs.""" + node = GraphNode( + name="log", agent=TextAgent("agent", "x"), reducer=StateReducer.SUM + ) + + state = GraphState(data={}) + state1 = node._default_output_mapper("hello", state) + assert state1.data["log"] == "hello" + + state2 = node._default_output_mapper(" world", state1) + assert state2.data["log"] == "hello world" + + async def test_sum_reducer_with_existing_string(self): + """Test SUM concatenates onto existing string value.""" + node = GraphNode( + name="counter", agent=TextAgent("agent", "x"), reducer=StateReducer.SUM + ) + + state = GraphState(data={"counter": "prefix"}) + state1 = node._default_output_mapper("_suffix", state) + assert state1.data["counter"] == "prefix_suffix" + + async def test_sum_reducer_numeric_types(self): + """Test SUM works correctly with int and float values.""" + node = GraphNode( + name="total", agent=TextAgent("agent", "x"), reducer=StateReducer.SUM + ) + + state = GraphState(data={}) + # int + int + state1 = node._default_output_mapper(10, state) + assert state1.data["total"] == 10 + + # int + float + state2 = node._default_output_mapper(2.5, state1) + assert state2.data["total"] == 12.5 + + # float + int + state3 = node._default_output_mapper(3, state2) + assert state3.data["total"] == 15.5 + + async def test_sum_reducer_type_mismatch(self): + """Test SUM raises TypeError for incompatible types.""" + node = GraphNode( + name="counter", agent=TextAgent("agent", "x"), reducer=StateReducer.SUM + ) + + # string existing + int output → TypeError + state = GraphState(data={"counter": "not_a_number"}) + with pytest.raises(TypeError, match="cannot add"): + node._default_output_mapper(5, state) + + async def test_sum_reducer_list_concatenation(self): + """Test SUM concatenates list outputs.""" + node = GraphNode( + name="items", agent=TextAgent("agent", "x"), reducer=StateReducer.SUM + ) + + state = GraphState(data={}) + state1 = node._default_output_mapper([1, 2], state) + assert state1.data["items"] == [1, 2] + + state2 = node._default_output_mapper([3, 4], state1) + assert state2.data["items"] == [1, 2, 3, 4] + + async def test_sum_reducer_agent_string_output(self): + """Test SUM works natively with agent string outputs (no custom mapper needed).""" + node = GraphNode( + name="transcript", agent=TextAgent("agent", "x"), reducer=StateReducer.SUM + ) + + state = GraphState(data={}) + state1 = node._default_output_mapper("Agent says: hello. ", state) + state2 = node._default_output_mapper("Agent says: goodbye. ", state1) + assert state2.data["transcript"] == "Agent says: hello. Agent says: goodbye. " + + +# ============================================================================ +# Test: StateReducer.CUSTOM +# ============================================================================ + + +@pytest.mark.asyncio +class TestStateReducerCustom: + """Test CUSTOM reducer - uses custom reduction function.""" + + async def test_custom_reducer_basic(self): + """Test CUSTOM reducer with simple concatenation.""" + + def concat_reducer(existing, new_value): + if existing is None: + return new_value + return f"{existing}|{new_value}" + + node = GraphNode( + name="custom", + agent=TextAgent("agent", "test"), + reducer=StateReducer.CUSTOM, + custom_reducer=concat_reducer, + ) + + # First call - no existing value + state1 = GraphState(data={}) + new_state1 = node._default_output_mapper("A", state1) + assert new_state1.data["custom"] == "A" + + # Second call - merge with existing + state2 = GraphState(data={"custom": "A"}) + new_state2 = node._default_output_mapper("B", state2) + assert new_state2.data["custom"] == "A|B" + + async def test_custom_reducer_dict_merge(self): + """Test CUSTOM reducer for merging dictionaries.""" + + def dict_merge_reducer(existing, new_value): + """Merge dict-like string representations.""" + if existing is None: + return {"data": [new_value]} + if isinstance(existing, dict): + existing["data"].append(new_value) + return existing + return {"data": [existing, new_value]} + + node = GraphNode( + name="merger", + agent=TextAgent("agent", "test"), + reducer=StateReducer.CUSTOM, + custom_reducer=dict_merge_reducer, + ) + + state1 = GraphState(data={}) + new_state1 = node._default_output_mapper("item1", state1) + assert new_state1.data["merger"] == {"data": ["item1"]} + + new_state2 = node._default_output_mapper("item2", new_state1) + assert new_state2.data["merger"] == {"data": ["item1", "item2"]} + + async def test_custom_reducer_counter(self): + """Test CUSTOM reducer for counting.""" + + def count_reducer(existing, new_value): + """Count occurrences.""" + if existing is None: + return 1 + return existing + 1 + + node = GraphNode( + name="counter", + agent=TextAgent("agent", "test"), + reducer=StateReducer.CUSTOM, + custom_reducer=count_reducer, + ) + + state1 = GraphState(data={}) + new_state1 = node._default_output_mapper("ignored", state1) + assert new_state1.data["counter"] == 1 + + new_state2 = node._default_output_mapper("ignored", new_state1) + assert new_state2.data["counter"] == 2 + + new_state3 = node._default_output_mapper("ignored", new_state2) + assert new_state3.data["counter"] == 3 + + +# ============================================================================ +# Test: State Propagation (Unit Tests) +# ============================================================================ + + +@pytest.mark.asyncio +class TestStatePropagation: + """Test how state flows through graph nodes (unit tests).""" + + async def test_output_mapper_preserves_existing_state(self): + """Test that output mapper preserves existing state data.""" + node = GraphNode(name="node1", agent=TextAgent("agent", "new")) + + state = GraphState(data={"existing_key": "existing_value", "meta": "data"}) + + new_state = node._default_output_mapper("new_output", state) + + # New output added + assert new_state.data["node1"] == "new_output" + + # Existing state preserved + assert new_state.data["existing_key"] == "existing_value" + assert new_state.data["meta"] == "data" + + async def test_domain_data_preserved_across_state_updates(self): + """Test that domain data is preserved during state updates.""" + node = GraphNode(name="node", agent=TextAgent("agent", "output")) + + state = GraphState( + data={"key": "value", "iteration": 1, "path": ["start", "middle"]}, + ) + + new_state = node._default_output_mapper("output", state) + + assert new_state.data["iteration"] == 1 + assert new_state.data["path"] == ["start", "middle"] + + async def test_state_isolation_between_nodes(self): + """Test that each node gets its own state copy.""" + node1 = GraphNode(name="node1", agent=TextAgent("agent1", "output1")) + node2 = GraphNode(name="node2", agent=TextAgent("agent2", "output2")) + + state = GraphState(data={}) + + # Node 1 processes state + state1 = node1._default_output_mapper("output1", state) + + # Node 2 processes original state (not state1) + state2 = node2._default_output_mapper("output2", state) + + # Verify isolation - state2 doesn't have node1's output + assert "node1" in state1.data + assert "node1" not in state2.data + assert "node2" in state2.data + + +# ============================================================================ +# Test: Custom Output Mappers +# ============================================================================ + + +@pytest.mark.asyncio +class TestCustomOutputMappers: + """Test custom output mapper functionality.""" + + async def test_custom_output_mapper_override(self): + """Test custom output mapper completely overrides default.""" + + def custom_mapper(output: str, state: GraphState) -> GraphState: + # Completely custom logic + new_state = GraphState( + data={"custom_key": f"CUSTOM_{output}", "custom": True} + ) + return new_state + + node = GraphNode( + name="custom", + agent=TextAgent("agent", "test"), + output_mapper=custom_mapper, + ) + + state = GraphState(data={"existing": "data"}) + new_state = node.output_mapper("output", state) + + # Custom mapper replaced everything + assert "custom_key" in new_state.data + assert new_state.data["custom_key"] == "CUSTOM_output" + assert new_state.data.get("custom") == True + # Original state data gone (custom mapper replaced it) + assert "existing" not in new_state.data + + async def test_custom_output_mapper_with_state_merge(self): + """Test custom output mapper that merges with existing state.""" + + def merging_mapper(output: str, state: GraphState) -> GraphState: + # Preserve existing state and add new data + new_state = GraphState(data=state.data.copy()) + new_state.data["processed"] = output.upper() + new_state.data["processed_count"] = ( + new_state.data.get("processed_count", 0) + 1 + ) + return new_state + + node = GraphNode( + name="merger", + agent=TextAgent("agent", "test"), + output_mapper=merging_mapper, + ) + + state = GraphState(data={"existing": "value", "processed_count": 5}) + new_state = node.output_mapper("hello", state) + + # Existing state preserved + assert new_state.data["existing"] == "value" + # New data added + assert new_state.data["processed"] == "HELLO" + # Processed count updated in data + assert new_state.data["processed_count"] == 6 + + +# ============================================================================ +# Test: Edge Cases +# ============================================================================ + + +@pytest.mark.asyncio +class TestStateEdgeCases: + """Test edge cases in state management.""" + + async def test_empty_state_initialization(self): + """Test graph node with empty initial state.""" + node = GraphNode(name="solo", agent=TextAgent("agent", "output")) + + state = GraphState(data={}) + new_state = node._default_output_mapper("output", state) + + assert new_state.data["solo"] == "output" + + async def test_state_copy_safety(self): + """Test that state copies don't share references for simple types.""" + state1 = GraphState(data={"key": "value", "meta": "data"}) + + # GraphNode does .copy() for data + state2 = GraphState(data=state1.data.copy()) + + # Modify state2 + state2.data["key"] = "modified" + state2.data["meta"] = "modified" + + # State1 unchanged (shallow copy works for simple types) + assert state1.data["key"] == "value" + assert state1.data["meta"] == "data" + + async def test_state_nested_dict_deep_copy_isolation(self): + """Verify _default_output_mapper uses deepcopy for nested state isolation.""" + state = GraphState(data={"nested": {"key": "value"}, "list": [1, 2, 3]}) + + node = GraphNode(name="test_node", agent=TextAgent("agent", "output")) + new_state = node._default_output_mapper("output", state) + + # Modify nested structure in new_state + new_state.data["nested"]["key"] = "modified" + new_state.data["list"].append(4) + + # Original state is NOT affected (deepcopy ensures isolation) + assert state.data["nested"]["key"] == "value" + assert state.data["list"] == [1, 2, 3] + + async def test_reducer_with_none_output(self): + """Test reducer behavior with None or empty output.""" + node = GraphNode( + name="test", + agent=TextAgent("agent", ""), + reducer=StateReducer.OVERWRITE, + ) + + state = GraphState(data={}) + new_state = node._default_output_mapper("", state) + + # Empty string is still stored + assert new_state.data["test"] == "" + + +# ============================================================================ +# GraphState.data_to_json and PydanticJSONEncoder tests +# ============================================================================ + + +def test_data_to_json_simple_values(): + """data_to_json serializes plain dict values to JSON string.""" + import json + + state = GraphState(data={"key": "value", "num": 42}) + result = state.data_to_json() + + parsed = json.loads(result) + assert parsed["key"] == "value" + assert parsed["num"] == 42 + + +def test_data_to_json_pydantic_model(): + """data_to_json converts Pydantic BaseModel values via model_dump.""" + import json + + from pydantic import BaseModel + + class Inner(BaseModel): + x: int + y: str + + state = GraphState(data={"model": Inner(x=1, y="hello")}) + result = state.data_to_json() + + parsed = json.loads(result) + assert parsed["model"] == {"x": 1, "y": "hello"} + + +def test_data_to_json_non_serializable_raises(): + """data_to_json raises TypeError for non-JSON-serializable, non-Pydantic objects.""" + import json + + class Unserializable: + pass + + state = GraphState(data={"bad": Unserializable()}) + with pytest.raises(TypeError): + state.data_to_json() diff --git a/tests/unittests/agents/test_graph_telemetry_config.py b/tests/unittests/agents/test_graph_telemetry_config.py new file mode 100644 index 0000000000..b73417b274 --- /dev/null +++ b/tests/unittests/agents/test_graph_telemetry_config.py @@ -0,0 +1,610 @@ +"""Tests for GraphAgent telemetry configuration.""" + +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph.graph_agent_config import TelemetryConfig +from google.adk.agents.graph.graph_node import GraphNode +import pytest + + +@pytest.fixture +def simple_graph(): + """Create a simple graph for testing.""" + + def simple_function(state, ctx): + return "test output" + + graph = GraphAgent(name="test_graph", description="Test graph") + graph.add_node(GraphNode(name="start", function=simple_function)) + graph.add_node(GraphNode(name="end", function=simple_function)) + graph.add_edge("start", "end") + graph.set_start("start") + graph.set_end("end") + return graph + + +def test_telemetry_config_creation(): + """Test creating TelemetryConfig with defaults.""" + config = TelemetryConfig() + + assert config.enabled is True + assert config.trace_nodes is True + assert config.trace_edges is True + assert config.trace_iterations is True + assert config.trace_parallel_groups is True + assert config.trace_callbacks is True + assert config.trace_interrupts is True + assert config.sampling_rate == 1.0 + assert config.additional_attributes is None + + +def test_telemetry_config_custom_values(): + """Test creating TelemetryConfig with custom values.""" + config = TelemetryConfig( + enabled=True, + trace_nodes=True, + trace_edges=False, + trace_iterations=True, + trace_parallel_groups=False, + trace_callbacks=True, + trace_interrupts=False, + sampling_rate=0.5, + additional_attributes={"environment": "test"}, + ) + + assert config.enabled is True + assert config.trace_nodes is True + assert config.trace_edges is False + assert config.trace_iterations is True + assert config.trace_parallel_groups is False + assert config.trace_callbacks is True + assert config.trace_interrupts is False + assert config.sampling_rate == 0.5 + assert config.additional_attributes == {"environment": "test"} + + +def test_telemetry_config_disabled(): + """Test TelemetryConfig with telemetry disabled.""" + config = TelemetryConfig(enabled=False) + + assert config.enabled is False + # Other settings should still have their defaults + assert config.trace_nodes is True + assert config.trace_edges is True + + +def test_telemetry_config_sampling_rate_validation(): + """Test TelemetryConfig sampling_rate validation.""" + # Valid sampling rates + config = TelemetryConfig(sampling_rate=0.0) + assert config.sampling_rate == 0.0 + + config = TelemetryConfig(sampling_rate=1.0) + assert config.sampling_rate == 1.0 + + config = TelemetryConfig(sampling_rate=0.5) + assert config.sampling_rate == 0.5 + + # Invalid sampling rates should raise validation error + with pytest.raises(Exception): # Pydantic validation error + TelemetryConfig(sampling_rate=-0.1) + + with pytest.raises(Exception): # Pydantic validation error + TelemetryConfig(sampling_rate=1.1) + + +def test_graph_agent_telemetry_config_none(simple_graph): + """Test GraphAgent with no telemetry config (defaults to enabled).""" + assert simple_graph.telemetry_config is None + assert simple_graph._is_telemetry_enabled() is True + assert simple_graph._should_trace_nodes() is True + assert simple_graph._should_trace_edges() is True + assert simple_graph._should_trace_iterations() is True + assert simple_graph._should_trace_parallel_groups() is True + assert simple_graph._should_trace_callbacks() is True + assert simple_graph._should_trace_interrupts() is True + + +def test_graph_agent_telemetry_config_enabled(): + """Test GraphAgent with telemetry config enabled.""" + config = TelemetryConfig(enabled=True) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + assert graph.telemetry_config is config + assert graph._is_telemetry_enabled() is True + assert graph._should_trace_nodes() is True + assert graph._should_trace_edges() is True + + +def test_graph_agent_telemetry_config_disabled(): + """Test GraphAgent with telemetry config disabled.""" + config = TelemetryConfig(enabled=False) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + assert graph.telemetry_config is config + assert graph._is_telemetry_enabled() is False + assert graph._should_trace_nodes() is False + assert graph._should_trace_edges() is False + assert graph._should_trace_iterations() is False + assert graph._should_trace_parallel_groups() is False + assert graph._should_trace_callbacks() is False + assert graph._should_trace_interrupts() is False + + +def test_graph_agent_telemetry_selective_tracing(): + """Test GraphAgent with selective tracing enabled.""" + config = TelemetryConfig( + enabled=True, + trace_nodes=True, + trace_edges=False, + trace_iterations=True, + trace_parallel_groups=False, + trace_callbacks=True, + trace_interrupts=False, + ) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + assert graph._is_telemetry_enabled() is True + assert graph._should_trace_nodes() is True + assert graph._should_trace_edges() is False + assert graph._should_trace_iterations() is True + assert graph._should_trace_parallel_groups() is False + assert graph._should_trace_callbacks() is True + assert graph._should_trace_interrupts() is False + + +def test_graph_agent_telemetry_only_nodes(): + """Test GraphAgent with only node tracing enabled.""" + config = TelemetryConfig( + enabled=True, + trace_nodes=True, + trace_edges=False, + trace_iterations=False, + trace_parallel_groups=False, + trace_callbacks=False, + trace_interrupts=False, + ) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + assert graph._should_trace_nodes() is True + assert graph._should_trace_edges() is False + assert graph._should_trace_iterations() is False + assert graph._should_trace_parallel_groups() is False + assert graph._should_trace_callbacks() is False + assert graph._should_trace_interrupts() is False + + +def test_graph_agent_telemetry_only_edges(): + """Test GraphAgent with only edge tracing enabled.""" + config = TelemetryConfig( + enabled=True, + trace_nodes=False, + trace_edges=True, + trace_iterations=False, + trace_parallel_groups=False, + trace_callbacks=False, + trace_interrupts=False, + ) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + assert graph._should_trace_nodes() is False + assert graph._should_trace_edges() is True + assert graph._should_trace_iterations() is False + + +def test_telemetry_config_additional_attributes(): + """Test TelemetryConfig with additional custom attributes.""" + config = TelemetryConfig( + additional_attributes={ + "environment": "production", + "version": "1.2.3", + "team": "ml-platform", + } + ) + + assert config.additional_attributes == { + "environment": "production", + "version": "1.2.3", + "team": "ml-platform", + } + + +def test_telemetry_config_model_serialization(): + """Test TelemetryConfig serialization/deserialization.""" + config = TelemetryConfig( + enabled=True, + trace_nodes=True, + trace_edges=False, + sampling_rate=0.75, + additional_attributes={"env": "test"}, + ) + + # Serialize to dict + config_dict = config.model_dump() + assert config_dict["enabled"] is True + assert config_dict["trace_nodes"] is True + assert config_dict["trace_edges"] is False + assert config_dict["sampling_rate"] == 0.75 + assert config_dict["additional_attributes"] == {"env": "test"} + + # Deserialize from dict + new_config = TelemetryConfig(**config_dict) + assert new_config.enabled is True + assert new_config.trace_nodes is True + assert new_config.trace_edges is False + assert new_config.sampling_rate == 0.75 + assert new_config.additional_attributes == {"env": "test"} + + +def test_telemetry_disabled_overrides_individual_settings(): + """Test that disabling telemetry overrides individual trace settings.""" + config = TelemetryConfig( + enabled=False, + trace_nodes=True, # Even though this is True + trace_edges=True, # Even though this is True + trace_iterations=True, # Even though this is True + ) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + # All tracing should be disabled because enabled=False + assert graph._is_telemetry_enabled() is False + assert graph._should_trace_nodes() is False + assert graph._should_trace_edges() is False + assert graph._should_trace_iterations() is False + + +def test_graph_agent_init_with_telemetry_config(): + """Test GraphAgent __init__ accepts telemetry_config.""" + + def simple_function(state, ctx): + return "output" + + config = TelemetryConfig( + enabled=True, trace_nodes=True, trace_edges=False, sampling_rate=0.5 + ) + + graph = GraphAgent( + name="test_graph", + description="Test graph with telemetry config", + telemetry_config=config, + ) + + assert graph.telemetry_config is config + assert graph.telemetry_config.enabled is True + assert graph.telemetry_config.trace_nodes is True + assert graph.telemetry_config.trace_edges is False + assert graph.telemetry_config.sampling_rate == 0.5 + + +def test_should_sample_with_no_config(): + """Test _should_sample() with no telemetry config (defaults to 100%).""" + graph = GraphAgent(name="test_graph") + + # No config means 100% sampling + assert graph._should_sample() is True + assert graph._should_sample() is True + assert graph._should_sample() is True + + +def test_should_sample_with_100_percent(): + """Test _should_sample() with 100% sampling rate.""" + config = TelemetryConfig(sampling_rate=1.0) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + # Should always return True + for _ in range(10): + assert graph._should_sample() is True + + +def test_should_sample_with_0_percent(): + """Test _should_sample() with 0% sampling rate.""" + config = TelemetryConfig(sampling_rate=0.0) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + # Should always return False + for _ in range(10): + assert graph._should_sample() is False + + +def test_should_sample_with_50_percent(monkeypatch): + """Test _should_sample() with 50% sampling rate.""" + import random + + config = TelemetryConfig(sampling_rate=0.5) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + # Mock random.random() to return controlled values + mock_values = [0.3, 0.7, 0.4, 0.8, 0.1, 0.9] + mock_values_iter = iter(mock_values) + + def mock_random(): + return next(mock_values_iter) + + monkeypatch.setattr(random, "random", mock_random) + + # 0.3 < 0.5 → sampled + assert graph._should_sample() is True + # 0.7 > 0.5 → not sampled + assert graph._should_sample() is False + # 0.4 < 0.5 → sampled + assert graph._should_sample() is True + # 0.8 > 0.5 → not sampled + assert graph._should_sample() is False + # 0.1 < 0.5 → sampled + assert graph._should_sample() is True + # 0.9 > 0.5 → not sampled + assert graph._should_sample() is False + + +def test_should_sample_with_25_percent(monkeypatch): + """Test _should_sample() with 25% sampling rate.""" + import random + + config = TelemetryConfig(sampling_rate=0.25) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + # Mock random.random() to return controlled values + mock_values = [0.1, 0.3, 0.2, 0.5, 0.15, 0.9] + mock_values_iter = iter(mock_values) + + def mock_random(): + return next(mock_values_iter) + + monkeypatch.setattr(random, "random", mock_random) + + # 0.1 < 0.25 → sampled + assert graph._should_sample() is True + # 0.3 > 0.25 → not sampled + assert graph._should_sample() is False + # 0.2 < 0.25 → sampled + assert graph._should_sample() is True + # 0.5 > 0.25 → not sampled + assert graph._should_sample() is False + # 0.15 < 0.25 → sampled + assert graph._should_sample() is True + # 0.9 > 0.25 → not sampled + assert graph._should_sample() is False + + +def test_get_telemetry_attributes_no_config(): + """Test _get_telemetry_attributes() with no telemetry config.""" + graph = GraphAgent(name="test_graph") + + base_attrs = {"graph.node.name": "test_node", "graph.node.type": "agent"} + result = graph._get_telemetry_attributes(base_attrs) + + # Should return base attributes unchanged + assert result == base_attrs + assert result["graph.node.name"] == "test_node" + assert result["graph.node.type"] == "agent" + + +def test_get_telemetry_attributes_no_additional(): + """Test _get_telemetry_attributes() with config but no additional_attributes.""" + config = TelemetryConfig(enabled=True, sampling_rate=0.5) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + base_attrs = {"graph.node.name": "test_node"} + result = graph._get_telemetry_attributes(base_attrs) + + # Should return base attributes unchanged + assert result == base_attrs + assert result["graph.node.name"] == "test_node" + + +def test_get_telemetry_attributes_with_additional(): + """Test _get_telemetry_attributes() merges additional_attributes.""" + config = TelemetryConfig( + additional_attributes={"environment": "production", "version": "1.2.3"} + ) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + base_attrs = {"graph.node.name": "test_node", "graph.node.type": "agent"} + result = graph._get_telemetry_attributes(base_attrs) + + # Should have both base and additional attributes + assert result["graph.node.name"] == "test_node" + assert result["graph.node.type"] == "agent" + assert result["environment"] == "production" + assert result["version"] == "1.2.3" + assert len(result) == 4 + + +def test_get_telemetry_attributes_base_takes_precedence(): + """Test _get_telemetry_attributes() - base attributes take precedence.""" + config = TelemetryConfig( + additional_attributes={ + "environment": "dev", + "graph.node.name": "should_be_overwritten", + } + ) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + base_attrs = {"graph.node.name": "actual_node", "graph.node.type": "function"} + result = graph._get_telemetry_attributes(base_attrs) + + # Base attributes should override additional_attributes + assert result["graph.node.name"] == "actual_node" + assert result["graph.node.type"] == "function" + assert result["environment"] == "dev" + + +def test_get_telemetry_attributes_empty_additional(): + """Test _get_telemetry_attributes() with empty additional_attributes dict.""" + config = TelemetryConfig(additional_attributes={}) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + base_attrs = {"graph.node.name": "test_node"} + result = graph._get_telemetry_attributes(base_attrs) + + # Should return base attributes unchanged + assert result == base_attrs + + +def test_get_telemetry_attributes_complex_values(): + """Test _get_telemetry_attributes() with complex attribute values.""" + config = TelemetryConfig( + additional_attributes={ + "environment": "staging", + "version": "2.0.1", + "team": "ml-platform", + "region": "us-west-2", + } + ) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + base_attrs = { + "graph.node.name": "complex_node", + "graph.node.type": "agent", + "graph.node.iteration": 5, + } + result = graph._get_telemetry_attributes(base_attrs) + + # Should have all attributes + assert result["graph.node.name"] == "complex_node" + assert result["graph.node.type"] == "agent" + assert result["graph.node.iteration"] == 5 + assert result["environment"] == "staging" + assert result["version"] == "2.0.1" + assert result["team"] == "ml-platform" + assert result["region"] == "us-west-2" + assert len(result) == 7 + + +def test_get_effective_telemetry_config_no_parent(): + """Test _get_effective_telemetry_config with no parent config.""" + from unittest import mock + + config = TelemetryConfig( + sampling_rate=0.5, additional_attributes={"env": "test"} + ) + graph = GraphAgent(name="test_graph", telemetry_config=config) + + # Mock context with no parent config + ctx = mock.Mock() + ctx.agent_states = {} + + effective = graph._get_effective_telemetry_config(ctx) + + # Should return own config + assert effective is config + assert effective.sampling_rate == 0.5 + assert effective.additional_attributes == {"env": "test"} + + +def test_get_effective_telemetry_config_parent_only(): + """Test _get_effective_telemetry_config with only parent config.""" + from unittest import mock + + # Child has no config + graph = GraphAgent(name="child_graph") + + # Mock context with parent config in agent_states + ctx = mock.Mock() + ctx.agent_states = { + "parent_graph": { + "telemetry_config_dict": { + "enabled": True, + "sampling_rate": 0.3, + "additional_attributes": {"parent": "true"}, + "trace_nodes": True, + "trace_edges": True, + "trace_iterations": True, + "trace_parallel_groups": True, + "trace_callbacks": True, + "trace_interrupts": True, + } + } + } + + effective = graph._get_effective_telemetry_config(ctx) + + # Should inherit parent config + assert effective is not None + assert effective.sampling_rate == 0.3 + assert effective.additional_attributes == {"parent": "true"} + + +def test_get_effective_telemetry_config_merge(): + """Test _get_effective_telemetry_config merges parent and own.""" + from unittest import mock + + # Child has own config + child_config = TelemetryConfig( + sampling_rate=0.8, additional_attributes={"child": "true", "env": "dev"} + ) + graph = GraphAgent(name="child_graph", telemetry_config=child_config) + + # Mock context with parent config in agent_states + ctx = mock.Mock() + ctx.agent_states = { + "parent_graph": { + "telemetry_config_dict": { + "enabled": True, + "sampling_rate": 0.3, + "additional_attributes": {"parent": "true", "version": "1.0"}, + "trace_nodes": True, + "trace_edges": True, + "trace_iterations": True, + "trace_parallel_groups": True, + "trace_callbacks": True, + "trace_interrupts": True, + } + } + } + + effective = graph._get_effective_telemetry_config(ctx) + + # Own config takes precedence + assert effective.sampling_rate == 0.8 + + # Additional attributes should be merged (own takes precedence) + assert effective.additional_attributes["child"] == "true" + assert effective.additional_attributes["parent"] == "true" + assert effective.additional_attributes["env"] == "dev" + assert effective.additional_attributes["version"] == "1.0" + + +def test_get_effective_telemetry_config_own_takes_precedence(): + """Test that own config values take precedence over parent.""" + from unittest import mock + + # Child config with specific values + child_config = TelemetryConfig( + enabled=True, + sampling_rate=1.0, + trace_nodes=False, # Override parent + additional_attributes={"env": "prod", "override": "child"}, + ) + graph = GraphAgent(name="child_graph", telemetry_config=child_config) + + # Mock context with parent config in agent_states + ctx = mock.Mock() + ctx.agent_states = { + "parent_graph": { + "telemetry_config_dict": { + "enabled": True, + "sampling_rate": 0.1, # Should be overridden + "trace_nodes": True, # Should be overridden + "trace_edges": True, + "trace_iterations": True, + "trace_parallel_groups": True, + "trace_callbacks": True, + "trace_interrupts": True, + "additional_attributes": {"env": "dev", "parent_only": "value"}, + } + } + } + + effective = graph._get_effective_telemetry_config(ctx) + + # Own values take precedence + assert effective.sampling_rate == 1.0 + assert effective.trace_nodes is False # Child overrode this + assert effective.trace_edges is True # Inherited from parent + + # Attributes merged, own takes precedence + assert effective.additional_attributes["env"] == "prod" # Child overrode + assert effective.additional_attributes["override"] == "child" # Child only + assert ( + effective.additional_attributes["parent_only"] == "value" + ) # Parent only diff --git a/tests/unittests/cli/test_agent_graph.py b/tests/unittests/cli/test_agent_graph.py new file mode 100644 index 0000000000..62c7a868f0 --- /dev/null +++ b/tests/unittests/cli/test_agent_graph.py @@ -0,0 +1,138 @@ +"""Tests for GraphAgent visualization in agent_graph.py. + +Asserts on graphviz.Digraph.source string — no rendering engine needed. +Only covers GraphAgent-specific cluster rendering (our code). +""" + +import graphviz +import pytest + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent, GraphNode, GraphState +from google.adk.cli.agent_graph import build_graph, get_agent_graph +from google.adk.events.event import Event +from google.genai import types + + +class SimpleTestAgent(BaseAgent): + """Minimal async agent for visualization tests (no LLM).""" + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text="stub")]), + ) + + +def _make_digraph(): + return graphviz.Digraph( + graph_attr={"rankdir": "LR", "bgcolor": "#333537"}, strict=True + ) + + +class TestGraphAgentVisualization: + """Test GraphAgent rendering in agent_graph build_graph.""" + + @pytest.mark.asyncio + async def test_graph_agent_linear_cluster(self): + """3 agent nodes with linear edges render as cluster with edges.""" + a1 = SimpleTestAgent(name="step1") + a2 = SimpleTestAgent(name="step2") + a3 = SimpleTestAgent(name="step3") + + g = GraphAgent(name="workflow") + g.add_node(GraphNode(name="n1", agent=a1)) + g.add_node(GraphNode(name="n2", agent=a2)) + g.add_node(GraphNode(name="n3", agent=a3)) + g.add_edge("n1", "n2") + g.add_edge("n2", "n3") + g.set_start("n1") + g.set_end("n3") + + dg = _make_digraph() + await build_graph(dg, g, highlight_pairs=None) + src = dg.source + + assert "cluster_" in src + assert "step1" in src + assert "step2" in src + assert "step3" in src + + @pytest.mark.asyncio + async def test_graph_agent_conditional_branch(self): + """1 source -> 2 targets: both edges present.""" + a1 = SimpleTestAgent(name="check") + a2 = SimpleTestAgent(name="pass_path") + a3 = SimpleTestAgent(name="fail_path") + + g = GraphAgent(name="branch") + g.add_node(GraphNode(name="n1", agent=a1)) + g.add_node(GraphNode(name="n2", agent=a2)) + g.add_node(GraphNode(name="n3", agent=a3)) + g.add_edge("n1", "n2", condition=lambda s: True) + g.add_edge("n1", "n3", condition=lambda s: False) + g.set_start("n1") + g.set_end("n2") + g.set_end("n3") + + dg = _make_digraph() + await build_graph(dg, g, highlight_pairs=None) + src = dg.source + + assert "check" in src + assert "pass_path" in src + assert "fail_path" in src + + @pytest.mark.asyncio + async def test_graph_agent_loop(self): + """Back-edge present in source.""" + a1 = SimpleTestAgent(name="reason") + a2 = SimpleTestAgent(name="observe") + + g = GraphAgent(name="react") + g.add_node(GraphNode(name="n1", agent=a1)) + g.add_node(GraphNode(name="n2", agent=a2)) + g.add_edge("n1", "n2") + g.add_edge("n2", "n1", condition=lambda s: True) + g.set_start("n1") + g.set_end("n2") + + dg = _make_digraph() + await build_graph(dg, g, highlight_pairs=None) + src = dg.source + + assert "reason" in src + assert "observe" in src + + @pytest.mark.asyncio + async def test_graph_agent_function_node(self): + """Function-only node rendered as box shape.""" + + async def my_func(state, ctx): + return "done" + + g = GraphAgent(name="wf") + g.add_node("fn_node", function=my_func) + g.set_start("fn_node") + g.set_end("fn_node") + + dg = _make_digraph() + await build_graph(dg, g, highlight_pairs=None) + src = dg.source + + assert "fn_node" in src + assert "box" in src + + @pytest.mark.asyncio + async def test_get_agent_graph_returns_digraph(self): + """get_agent_graph returns a graphviz.Digraph for GraphAgent.""" + a = SimpleTestAgent(name="s") + g = GraphAgent(name="wf") + g.add_node(GraphNode(name="n", agent=a)) + g.set_start("n") + g.set_end("n") + + result = await get_agent_graph(g, highlights_pairs=None) + assert isinstance(result, graphviz.Digraph) diff --git a/tests/unittests/telemetry/test_graph_tracing.py b/tests/unittests/telemetry/test_graph_tracing.py new file mode 100644 index 0000000000..c3fec301f0 --- /dev/null +++ b/tests/unittests/telemetry/test_graph_tracing.py @@ -0,0 +1,522 @@ +"""Tests for GraphAgent telemetry instrumentation.""" + +from unittest import mock + +from google.adk.telemetry import graph_tracing +import pytest + + +def test_telemetry_module_imports(): + """Test that all telemetry exports are available.""" + assert hasattr(graph_tracing, "tracer") + assert hasattr(graph_tracing, "meter") + assert hasattr(graph_tracing, "logger") + assert hasattr(graph_tracing, "otel_logger") + + # Metrics + assert hasattr(graph_tracing, "node_execution_counter") + assert hasattr(graph_tracing, "node_execution_latency") + assert hasattr(graph_tracing, "edge_evaluation_counter") + assert hasattr(graph_tracing, "edge_evaluation_latency") + assert hasattr(graph_tracing, "graph_iteration_counter") + assert hasattr(graph_tracing, "parallel_group_counter") + assert hasattr(graph_tracing, "parallel_group_latency") + assert hasattr(graph_tracing, "callback_execution_counter") + assert hasattr(graph_tracing, "callback_execution_latency") + assert hasattr(graph_tracing, "interrupt_check_counter") + + # Semantic conventions + assert hasattr(graph_tracing, "GRAPH_AGENT_NAME") + assert hasattr(graph_tracing, "GRAPH_NODE_NAME") + assert hasattr(graph_tracing, "GRAPH_NODE_TYPE") + assert hasattr(graph_tracing, "GRAPH_EDGE_SOURCE") + assert hasattr(graph_tracing, "GRAPH_EDGE_TARGET") + + # Recording functions + assert hasattr(graph_tracing, "record_node_execution") + assert hasattr(graph_tracing, "record_edge_evaluation") + assert hasattr(graph_tracing, "record_graph_iteration") + assert hasattr(graph_tracing, "record_parallel_group_execution") + assert hasattr(graph_tracing, "record_callback_execution") + assert hasattr(graph_tracing, "record_interrupt_check") + + +def test_record_node_execution_success(): + """Test recording successful node execution.""" + with ( + mock.patch.object( + graph_tracing.node_execution_counter, "add" + ) as mock_counter, + mock.patch.object( + graph_tracing.node_execution_latency, "record" + ) as mock_latency, + ): + graph_tracing.record_node_execution( + node_name="test_node", + node_type="agent", + agent_name="test_graph", + latency_ms=100.5, + success=True, + ) + + # Verify counter was called + assert mock_counter.call_count == 1 + call_args = mock_counter.call_args + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_NODE_NAME] + == "test_node" + ) + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_NODE_TYPE] == "agent" + ) + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_AGENT_NAME] + == "test_graph" + ) + assert call_args.kwargs["attributes"]["success"] is True + + # Verify latency was recorded + assert mock_latency.call_count == 1 + latency_call_args = mock_latency.call_args + assert latency_call_args.args[0] == 100.5 + + +def test_record_node_execution_failure(): + """Test recording failed node execution.""" + with ( + mock.patch.object( + graph_tracing.node_execution_counter, "add" + ) as mock_counter, + mock.patch.object( + graph_tracing.node_execution_latency, "record" + ) as mock_latency, + ): + graph_tracing.record_node_execution( + node_name="failing_node", + node_type="function", + agent_name="test_graph", + latency_ms=50.0, + success=False, + ) + + # Verify counter was called with success=False + assert mock_counter.call_count == 1 + call_args = mock_counter.call_args + assert call_args.kwargs["attributes"]["success"] is False + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_NODE_TYPE] + == "function" + ) + + # Verify latency was still recorded + assert mock_latency.call_count == 1 + + +def test_record_edge_evaluation(): + """Test recording edge condition evaluation.""" + with ( + mock.patch.object( + graph_tracing.edge_evaluation_counter, "add" + ) as mock_counter, + mock.patch.object( + graph_tracing.edge_evaluation_latency, "record" + ) as mock_latency, + ): + graph_tracing.record_edge_evaluation( + source_node="node_a", + target_node="node_b", + agent_name="test_graph", + condition_result=True, + latency_ms=5.2, + priority=2, + ) + + # Verify counter was called + assert mock_counter.call_count == 1 + call_args = mock_counter.call_args + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_EDGE_SOURCE] + == "node_a" + ) + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_EDGE_TARGET] + == "node_b" + ) + assert ( + call_args.kwargs["attributes"][ + graph_tracing.GRAPH_EDGE_CONDITION_RESULT + ] + == "True" + ) + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_EDGE_PRIORITY] == 2 + ) + + # Verify latency was recorded + assert mock_latency.call_count == 1 + latency_call_args = mock_latency.call_args + assert latency_call_args.args[0] == 5.2 + + +def test_record_edge_evaluation_false_condition(): + """Test recording edge evaluation with false condition result.""" + with mock.patch.object( + graph_tracing.edge_evaluation_counter, "add" + ) as mock_counter: + graph_tracing.record_edge_evaluation( + source_node="node_a", + target_node="node_c", + agent_name="test_graph", + condition_result=False, + latency_ms=3.1, + priority=1, + ) + + call_args = mock_counter.call_args + assert ( + call_args.kwargs["attributes"][ + graph_tracing.GRAPH_EDGE_CONDITION_RESULT + ] + == "False" + ) + + +def test_record_graph_iteration(): + """Test recording graph iteration metrics.""" + with mock.patch.object( + graph_tracing.graph_iteration_counter, "add" + ) as mock_counter: + graph_tracing.record_graph_iteration( + agent_name="test_graph", iteration=5, path_length=10 + ) + + # Verify counter was called + assert mock_counter.call_count == 1 + call_args = mock_counter.call_args + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_AGENT_NAME] + == "test_graph" + ) + assert call_args.kwargs["attributes"][graph_tracing.GRAPH_ITERATION] == 5 + assert call_args.kwargs["attributes"]["path_length"] == 10 + + +def test_record_parallel_group_execution(): + """Test recording parallel group execution metrics.""" + with ( + mock.patch.object( + graph_tracing.parallel_group_counter, "add" + ) as mock_counter, + mock.patch.object( + graph_tracing.parallel_group_latency, "record" + ) as mock_latency, + ): + graph_tracing.record_parallel_group_execution( + agent_name="test_graph", + node_count=3, + strategy="all", + latency_ms=250.5, + completed_count=3, + ) + + # Verify counter was called + assert mock_counter.call_count == 1 + call_args = mock_counter.call_args + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_PARALLEL_NODE_COUNT] + == 3 + ) + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_PARALLEL_STRATEGY] + == "all" + ) + assert call_args.kwargs["attributes"]["completed_count"] == 3 + + # Verify latency was recorded + assert mock_latency.call_count == 1 + latency_call_args = mock_latency.call_args + assert latency_call_args.args[0] == 250.5 + + +def test_record_parallel_group_partial_completion(): + """Test recording parallel group with partial completion.""" + with mock.patch.object( + graph_tracing.parallel_group_counter, "add" + ) as mock_counter: + graph_tracing.record_parallel_group_execution( + agent_name="test_graph", + node_count=5, + strategy="any", + latency_ms=100.0, + completed_count=2, + ) + + call_args = mock_counter.call_args + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_PARALLEL_NODE_COUNT] + == 5 + ) + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_PARALLEL_STRATEGY] + == "any" + ) + assert call_args.kwargs["attributes"]["completed_count"] == 2 + + +def test_record_callback_execution_before_node(): + """Test recording before_node callback execution.""" + with ( + mock.patch.object( + graph_tracing.callback_execution_counter, "add" + ) as mock_counter, + mock.patch.object( + graph_tracing.callback_execution_latency, "record" + ) as mock_latency, + ): + graph_tracing.record_callback_execution( + callback_type="before_node", + agent_name="test_graph", + latency_ms=10.5, + success=True, + ) + + # Verify counter was called + assert mock_counter.call_count == 1 + call_args = mock_counter.call_args + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_CALLBACK_TYPE] + == "before_node" + ) + assert call_args.kwargs["attributes"]["success"] is True + + # Verify latency was recorded + assert mock_latency.call_count == 1 + + +def test_record_callback_execution_after_node(): + """Test recording after_node callback execution.""" + with mock.patch.object( + graph_tracing.callback_execution_counter, "add" + ) as mock_counter: + graph_tracing.record_callback_execution( + callback_type="after_node", + agent_name="test_graph", + latency_ms=15.2, + success=True, + ) + + call_args = mock_counter.call_args + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_CALLBACK_TYPE] + == "after_node" + ) + + +def test_record_callback_execution_on_edge(): + """Test recording on_edge callback execution.""" + with mock.patch.object( + graph_tracing.callback_execution_counter, "add" + ) as mock_counter: + graph_tracing.record_callback_execution( + callback_type="on_edge", + agent_name="test_graph", + latency_ms=5.0, + success=True, + ) + + call_args = mock_counter.call_args + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_CALLBACK_TYPE] + == "on_edge" + ) + + +def test_record_callback_execution_failure(): + """Test recording failed callback execution.""" + with mock.patch.object( + graph_tracing.callback_execution_counter, "add" + ) as mock_counter: + graph_tracing.record_callback_execution( + callback_type="before_node", + agent_name="test_graph", + latency_ms=8.0, + success=False, + ) + + call_args = mock_counter.call_args + assert call_args.kwargs["attributes"]["success"] is False + + +def test_record_interrupt_check(): + """Test recording interrupt check metrics.""" + with mock.patch.object( + graph_tracing.interrupt_check_counter, "add" + ) as mock_counter: + graph_tracing.record_interrupt_check( + mode="before", agent_name="test_graph", session_id="session_123" + ) + + # Verify counter was called + assert mock_counter.call_count == 1 + call_args = mock_counter.call_args + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_INTERRUPT_MODE] + == "before" + ) + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_AGENT_NAME] + == "test_graph" + ) + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_SESSION_ID] + == "session_123" + ) + + +def test_record_interrupt_check_after_mode(): + """Test recording interrupt check in after mode.""" + with mock.patch.object( + graph_tracing.interrupt_check_counter, "add" + ) as mock_counter: + graph_tracing.record_interrupt_check( + mode="after", agent_name="test_graph", session_id="session_456" + ) + + call_args = mock_counter.call_args + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_INTERRUPT_MODE] + == "after" + ) + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_SESSION_ID] + == "session_456" + ) + + +def test_record_interrupt_check_both_mode(): + """Test recording interrupt check in both mode.""" + with mock.patch.object( + graph_tracing.interrupt_check_counter, "add" + ) as mock_counter: + graph_tracing.record_interrupt_check( + mode="both", agent_name="test_graph", session_id="session_789" + ) + + call_args = mock_counter.call_args + assert ( + call_args.kwargs["attributes"][graph_tracing.GRAPH_INTERRUPT_MODE] + == "both" + ) + + +def test_semantic_convention_values(): + """Test that semantic convention constants have correct values.""" + assert graph_tracing.GRAPH_AGENT_NAME == "graph.agent.name" + assert graph_tracing.GRAPH_NODE_NAME == "graph.node.name" + assert graph_tracing.GRAPH_NODE_TYPE == "graph.node.type" + assert graph_tracing.GRAPH_NODE_ITERATION == "graph.node.iteration" + assert graph_tracing.GRAPH_EDGE_SOURCE == "graph.edge.source" + assert graph_tracing.GRAPH_EDGE_TARGET == "graph.edge.target" + assert ( + graph_tracing.GRAPH_EDGE_CONDITION_RESULT == "graph.edge.condition.result" + ) + assert graph_tracing.GRAPH_EDGE_PRIORITY == "graph.edge.priority" + assert graph_tracing.GRAPH_ITERATION == "graph.iteration" + assert graph_tracing.GRAPH_PATH == "graph.path" + assert graph_tracing.GRAPH_PARALLEL_NODE_COUNT == "graph.parallel.node_count" + assert graph_tracing.GRAPH_PARALLEL_STRATEGY == "graph.parallel.strategy" + assert graph_tracing.GRAPH_PARALLEL_WAIT_N == "graph.parallel.wait_n" + assert graph_tracing.GRAPH_CALLBACK_TYPE == "graph.callback.type" + assert graph_tracing.GRAPH_INTERRUPT_MODE == "graph.interrupt.mode" + assert graph_tracing.GRAPH_SESSION_ID == "graph.session.id" + + +def test_multiple_node_executions(): + """Test recording multiple node executions with different types.""" + with mock.patch.object( + graph_tracing.node_execution_counter, "add" + ) as mock_counter: + # Record agent node + graph_tracing.record_node_execution( + node_name="agent_node", + node_type="agent", + agent_name="test_graph", + latency_ms=100.0, + success=True, + ) + + # Record function node + graph_tracing.record_node_execution( + node_name="function_node", + node_type="function", + agent_name="test_graph", + latency_ms=50.0, + success=True, + ) + + # Verify both were recorded + assert mock_counter.call_count == 2 + + # Verify first call + first_call_args = mock_counter.call_args_list[0] + assert ( + first_call_args.kwargs["attributes"][graph_tracing.GRAPH_NODE_NAME] + == "agent_node" + ) + assert ( + first_call_args.kwargs["attributes"][graph_tracing.GRAPH_NODE_TYPE] + == "agent" + ) + + # Verify second call + second_call_args = mock_counter.call_args_list[1] + assert ( + second_call_args.kwargs["attributes"][graph_tracing.GRAPH_NODE_NAME] + == "function_node" + ) + assert ( + second_call_args.kwargs["attributes"][graph_tracing.GRAPH_NODE_TYPE] + == "function" + ) + + +def test_edge_evaluations_with_different_priorities(): + """Test recording edge evaluations with different priorities.""" + with mock.patch.object( + graph_tracing.edge_evaluation_counter, "add" + ) as mock_counter: + # High priority edge + graph_tracing.record_edge_evaluation( + source_node="node_a", + target_node="node_b", + agent_name="test_graph", + condition_result=True, + latency_ms=5.0, + priority=10, + ) + + # Low priority edge + graph_tracing.record_edge_evaluation( + source_node="node_a", + target_node="node_c", + agent_name="test_graph", + condition_result=False, + latency_ms=3.0, + priority=1, + ) + + assert mock_counter.call_count == 2 + + # Verify priorities + assert ( + mock_counter.call_args_list[0].kwargs["attributes"][ + graph_tracing.GRAPH_EDGE_PRIORITY + ] + == 10 + ) + assert ( + mock_counter.call_args_list[1].kwargs["attributes"][ + graph_tracing.GRAPH_EDGE_PRIORITY + ] + == 1 + ) From a461d142f81e816e438e4840cc85d41b0de49e32 Mon Sep 17 00:00:00 2001 From: drahnreb <25883607+drahnreb@users.noreply.github.com> Date: Sat, 21 Feb 2026 19:29:01 +0100 Subject: [PATCH 2/2] feat: Add graph pattern nodes for dynamic dispatch and composition Add DynamicNode (runtime agent selection), NestedGraphNode (hierarchical workflow composition), and DynamicParallelGroup (variable-count concurrent execution). Extends CLI visualization with pattern-aware rendering (diamond, parallelogram, sub-cluster shapes). Includes pattern samples, node type reference, and design documentation. --- .../README.md | 37 + .../__init__.py | 0 .../graph_agent_pattern_dynamic_node/agent.py | 188 +++++ .../README.md | 42 + .../__init__.py | 0 .../graph_agent_pattern_nested_graph/agent.py | 201 +++++ .../README.md | 41 + .../__init__.py | 0 .../agent.py | 227 ++++++ .../graph_examples/run_all_examples.sh | 42 + src/google/adk/agents/graph/__init__.py | 9 + src/google/adk/agents/graph/graph_agent.py | 7 + src/google/adk/agents/graph/patterns.py | 327 ++++++++ src/google/adk/cli/agent_graph.py | 30 +- tests/unittests/agents/test_graph_agent.py | 96 ++- tests/unittests/agents/test_graph_patterns.py | 741 ++++++++++++++++++ tests/unittests/cli/test_agent_graph.py | 67 ++ 17 files changed, 2023 insertions(+), 32 deletions(-) create mode 100644 contributing/samples/graph_agent_pattern_dynamic_node/README.md create mode 100644 contributing/samples/graph_agent_pattern_dynamic_node/__init__.py create mode 100644 contributing/samples/graph_agent_pattern_dynamic_node/agent.py create mode 100644 contributing/samples/graph_agent_pattern_nested_graph/README.md create mode 100644 contributing/samples/graph_agent_pattern_nested_graph/__init__.py create mode 100644 contributing/samples/graph_agent_pattern_nested_graph/agent.py create mode 100644 contributing/samples/graph_agent_pattern_parallel_group/README.md create mode 100644 contributing/samples/graph_agent_pattern_parallel_group/__init__.py create mode 100644 contributing/samples/graph_agent_pattern_parallel_group/agent.py create mode 100755 contributing/samples/graph_examples/run_all_examples.sh create mode 100644 src/google/adk/agents/graph/patterns.py create mode 100644 tests/unittests/agents/test_graph_patterns.py diff --git a/contributing/samples/graph_agent_pattern_dynamic_node/README.md b/contributing/samples/graph_agent_pattern_dynamic_node/README.md new file mode 100644 index 0000000000..5ff6a6a5fd --- /dev/null +++ b/contributing/samples/graph_agent_pattern_dynamic_node/README.md @@ -0,0 +1,37 @@ +# GraphAgent Pattern — DynamicNode (Mixture of Experts) + +This example implements **runtime agent selection** using `DynamicNode`. A classifier labels the +incoming task as SIMPLE or COMPLEX, then `DynamicNode` routes to a cheap fast model or a thorough +capable model accordingly — a sparse mixture-of-experts dispatch optimizing cost vs. quality. + +## When to Use This Pattern + +- Cost optimisation: route easy tasks to cheaper models, hard tasks to capable models +- Capability dispatch: pick a specialist agent based on detected task domain +- Fallback chains: try a fast agent first, escalate to a powerful agent on failure + +## How to Run + +```bash +adk run contributing/samples/graph_agent_pattern_dynamic_node +``` + +## Graph Structure + +``` +classify ──▶ respond (DynamicNode) + ├── simple_agent (when classify output contains "SIMPLE") + └── detailed_agent (otherwise) +``` + +## Key Code Walkthrough + +- **`DynamicNode(name="respond", agent_selector=select_responder)`** — the selector callable + receives `GraphState` and returns the `BaseAgent` to invoke +- **`select_responder(state)`** — reads `state.data["classify"]` and returns the matching agent +- **`fallback_agent`** parameter — used when the selector returns `None` +- **Transparent to the graph** — downstream edges see `respond`'s output regardless of which + agent was chosen +- **No graph-level changes needed** — swap agents by changing `select_responder`, not the graph + topology + diff --git a/contributing/samples/graph_agent_pattern_dynamic_node/__init__.py b/contributing/samples/graph_agent_pattern_dynamic_node/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contributing/samples/graph_agent_pattern_dynamic_node/agent.py b/contributing/samples/graph_agent_pattern_dynamic_node/agent.py new file mode 100644 index 0000000000..2a01cca3f2 --- /dev/null +++ b/contributing/samples/graph_agent_pattern_dynamic_node/agent.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +"""DynamicNode Pattern: Runtime Agent Selection + +Motivation (Mixture-of-Experts) +-------------------------------- +Shazeer et al. (2017) "Outrageously Large Neural Networks: The Sparsely-Gated +Mixture-of-Experts Layer" showed that routing inputs to specialised experts +beats a single monolithic model while keeping per-token compute fixed. + +The same principle applies to agentic workflows: a *router* classifies the +complexity/type of each task, then a *gating function* selects the cheapest +adequate specialist—a fast flash-model for simple tasks, a slower pro-model +for hard tasks. + +Pattern: DynamicNode +-------------------- +DynamicNode is the first-class API for this pattern. The `agent_selector` +callable runs at runtime, reads the current GraphState, and returns the +appropriate BaseAgent. + +Compare to the function-node alternative +----------------------------------------- +Without DynamicNode you need a function node that manually dispatches: + + async def dispatch(state, ctx): + agent = complex_agent if "hard" in state.data else simple_agent + node_ctx = ctx.model_copy(update={...}) + output = "" + async for event in agent.run_async(node_ctx): + if event.content and event.content.parts: + output = event.content.parts[0].text or "" + return output + +DynamicNode gives you: + ✅ Metadata auto-tracking: which agent was selected (observability) + ✅ Built-in fallback_agent when selector returns None + ✅ Selection logic decoupled from execution boilerplate + +Architecture +------------ + classify ──► route (DynamicNode) ──► end + │ + ├─ selector returns simple_agent (flash, cheap) + └─ selector returns detailed_agent (pro, thorough) +""" + +import asyncio +import os + +from google.adk.agents import LlmAgent +from google.adk.agents.graph import DynamicNode +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +# --------------------------------------------------------------------------- +# Step 1: Classifier — assigns complexity label from the user's request +# --------------------------------------------------------------------------- +classifier = LlmAgent( + name="classifier", + model=_MODEL, + instruction=""" +You are a task complexity classifier. + +Read the user's request and reply with EXACTLY one word: + SIMPLE – if the task is a quick factual lookup or short question + COMPLEX – if the task requires multi-step reasoning, analysis, or code + +Reply with only the word, nothing else. +""", +) + +# --------------------------------------------------------------------------- +# Step 2: Specialists — cheap flash model vs thorough pro model +# --------------------------------------------------------------------------- +simple_agent = LlmAgent( + name="simple_responder", + model=_MODEL, + instruction=""" +You are a concise assistant. Answer the user's question briefly (1-3 sentences). +""", +) + +detailed_agent = LlmAgent( + name="detailed_responder", + model=_MODEL, + instruction=""" +You are a thorough analyst. Work through the problem step by step, show your +reasoning, and provide a complete, well-structured answer. +""", +) + + +# --------------------------------------------------------------------------- +# Step 3: Agent selector — called at runtime with current GraphState +# --------------------------------------------------------------------------- +def select_responder(state: GraphState) -> LlmAgent: + """Route to simple_agent for SIMPLE tasks, detailed_agent otherwise. + + The classifier stored its output in state.data["classify"] via the + default output_mapper (OVERWRITE reducer, key = node name). + """ + classification = state.data.get("classify", "").upper() + if "SIMPLE" in classification: + return simple_agent + return detailed_agent + + +# --------------------------------------------------------------------------- +# Build the graph +# --------------------------------------------------------------------------- +def build_graph() -> GraphAgent: + graph = GraphAgent( + name="dynamic_routing", + description="Routes each query to the cheapest adequate specialist", + ) + + # Node 1: classify complexity + graph.add_node("classify", agent=classifier) + + # Node 2: DynamicNode selects the right specialist at runtime + graph.add_node( + DynamicNode( + name="respond", + agent_selector=select_responder, + fallback_agent=simple_agent, # safety net if selector returns None + ) + ) + + graph.add_edge("classify", "respond") + graph.set_start("classify") + graph.set_end("respond") + return graph + + +# --------------------------------------------------------------------------- +# Runner helper +# --------------------------------------------------------------------------- +_graph = build_graph() + + +async def run(question: str) -> str: + graph = _graph + svc = InMemorySessionService() + runner = Runner( + app_name="dynamic_node_example", agent=graph, session_service=svc + ) + await svc.create_session( + app_name="dynamic_node_example", user_id="user", session_id="s1" + ) + final = "" + async for event in runner.run_async( + user_id="user", + session_id="s1", + new_message=types.Content(role="user", parts=[types.Part(text=question)]), + ): + if event.content and event.content.parts: + text = event.content.parts[0].text or "" + if text and not text.startswith("[GraphMetadata]"): + final = text + return final + + +# --------------------------------------------------------------------------- +# Demo +# --------------------------------------------------------------------------- + + +async def main(): + questions = [ + "What is the capital of France?", # SIMPLE → flash model + ( # COMPLEX → pro model + "Explain how transformer attention scales with sequence length " + "and what architectural changes help address this." + ), + ] + for q in questions: + print(f"\nQ: {q}") + answer = await run(q) + print(f"A: {answer}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_pattern_nested_graph/README.md b/contributing/samples/graph_agent_pattern_nested_graph/README.md new file mode 100644 index 0000000000..846807e2e3 --- /dev/null +++ b/contributing/samples/graph_agent_pattern_nested_graph/README.md @@ -0,0 +1,42 @@ +# GraphAgent Pattern — NestedGraphNode (Hierarchical Composition) + +This example demonstrates **hierarchical workflow decomposition** using `NestedGraphNode`. A +coordinator planner produces a focused query, a three-step research sub-graph (search → extract → +summarise) handles it as a single reusable unit, and a synthesiser produces the final answer. +The sub-graph is entirely encapsulated and independently testable. + +## When to Use This Pattern + +- Large workflows that benefit from breaking into independently developed sub-pipelines +- Reusable sub-workflows (the same research graph could be called from multiple parent graphs) +- Team boundaries: different teams own the outer orchestration and inner sub-workflows +- Recursive depth: sub-graphs can themselves contain `NestedGraphNode`s + +## How to Run + +```bash +adk run contributing/samples/graph_agent_pattern_nested_graph +``` + +## Graph Structure + +``` +Outer: plan ──▶ research (NestedGraphNode) ──▶ synthesise + +Inner (research sub-graph): + search ──▶ extract ──▶ summarise +``` + +## Key Code Walkthrough + +- **`NestedGraphNode(name="research", graph_agent=build_research_subgraph())`** — wraps an entire + `GraphAgent` as a single node in the parent graph +- **`inherit_session=True`** — the sub-graph shares the parent session's state, so outputs + written inside are visible to the parent's synthesiser +- **`build_research_subgraph()`** — factory function that constructs and returns the inner + `GraphAgent`; call it multiple times for independent instances +- **State bridging** — the sub-graph's final state is merged back; use `output_mapper` on the + `NestedGraphNode` to control which keys are exposed to the outer graph +- **Telemetry and checkpointing** — propagate automatically into the sub-graph when enabled on + the parent + diff --git a/contributing/samples/graph_agent_pattern_nested_graph/__init__.py b/contributing/samples/graph_agent_pattern_nested_graph/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contributing/samples/graph_agent_pattern_nested_graph/agent.py b/contributing/samples/graph_agent_pattern_nested_graph/agent.py new file mode 100644 index 0000000000..c82242e617 --- /dev/null +++ b/contributing/samples/graph_agent_pattern_nested_graph/agent.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +"""NestedGraphNode Pattern: Hierarchical Workflow Composition + +Motivation (Hierarchical Planning) +------------------------------------ +Hierarchical planning research (Sutton et al. 1999, "Between MDPs and +semi-MDPs") and modern LLM orchestration papers like ORION / LLM Compiler +(Kim et al. 2023) show that breaking complex tasks into nested sub-plans +improves both reasoning quality and task-level reuse. + +In agentic terms: a *coordinator* decomposes the top-level goal into +sub-problems, each solved by a *sub-workflow* (a full GraphAgent) that can +be developed, tested, and reused independently. + +Pattern: NestedGraphNode +------------------------ +NestedGraphNode runs an entire GraphAgent as a single step inside a parent +graph. The parent graph sees only the sub-workflow's final output — internal +steps are transparent. + +Compare to the function-node alternative +----------------------------------------- +Without NestedGraphNode you must manually plumb the nested graph: + + async def run_research_step(state, ctx): + sub_ctx = ctx.model_copy(update={...}) + result = "" + async for event in research_graph._run_async_impl(sub_ctx): + if event.content and event.content.parts: + result = event.content.parts[0].text or "" + return result + +NestedGraphNode gives you: + ✅ Session-inheritance control (inherit_session=True/False) + ✅ Automatic metadata: sub-graph iteration count + execution path + ✅ No manual context plumbing + +Architecture (two-level hierarchy) +------------------------------------ +Outer graph: + plan ──► [research_step (NestedGraphNode)] ──► synthesize + +Inner (research_step) sub-graph: + search ──► extract ──► summarise +""" + +import asyncio +import os + +from google.adk.agents import LlmAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import NestedGraphNode +from google.adk.agents.graph.callbacks import create_nested_observability_callback +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +# --------------------------------------------------------------------------- +# Inner sub-graph: a three-step research pipeline +# (search → extract key claims → summarise for the parent) +# --------------------------------------------------------------------------- +_searcher = LlmAgent( + name="searcher", + model=_MODEL, + instruction=""" +You are a research agent. Given a topic, produce 3-5 bullet-point facts +that are accurate and concise. Start each bullet with "•". +""", +) + +_extractor = LlmAgent( + name="extractor", + model=_MODEL, + instruction=""" +You receive bullet-point facts. Extract the 2 most important claims and +rewrite them as complete sentences. Label them CLAIM-1 and CLAIM-2. +""", +) + +_summariser = LlmAgent( + name="summariser", + model=_MODEL, + instruction=""" +You receive two claims. Write a single-paragraph research summary (≤4 +sentences) that a domain expert would find useful. +""", +) + + +def build_research_subgraph() -> GraphAgent: + """Returns a reusable three-step research sub-workflow.""" + g = GraphAgent( + name="research_pipeline", + before_node_callback=create_nested_observability_callback(), + ) + g.add_node("search", agent=_searcher) + g.add_node("extract", agent=_extractor) + g.add_node("summarise", agent=_summariser) + g.add_edge("search", "extract") + g.add_edge("extract", "summarise") + g.set_start("search") + g.set_end("summarise") + return g + + +# --------------------------------------------------------------------------- +# Outer graph: plan → nested research → final synthesis +# --------------------------------------------------------------------------- +_planner = LlmAgent( + name="planner", + model=_MODEL, + instruction=""" +You receive a broad research question. Restate it as a focused single-topic +query suitable for a research assistant (one sentence, no preamble). +""", +) + +_synthesiser = LlmAgent( + name="synthesiser", + model=_MODEL, + instruction=""" +You receive a research summary. Write a concise final answer (2-3 sentences) +to the original user question, citing the key findings. +""", +) + + +def build_graph() -> GraphAgent: + outer = GraphAgent( + name="hierarchical_research", + description="Coordinator → research sub-workflow → synthesis", + before_node_callback=create_nested_observability_callback(), + ) + + outer.add_node("plan", agent=_planner) + + # The entire inner research pipeline runs as a single node + outer.add_node( + NestedGraphNode( + name="research", + graph_agent=build_research_subgraph(), + inherit_session=True, # share parent session → state visible to outer + ) + ) + + outer.add_node("synthesise", agent=_synthesiser) + + outer.add_edge("plan", "research") + outer.add_edge("research", "synthesise") + outer.set_start("plan") + outer.set_end("synthesise") + return outer + + +# --------------------------------------------------------------------------- +# Runner helper +# --------------------------------------------------------------------------- +async def run(question: str) -> str: + graph = build_graph() + svc = InMemorySessionService() + runner = Runner( + app_name="nested_graph_example", agent=graph, session_service=svc + ) + await svc.create_session( + app_name="nested_graph_example", user_id="user", session_id="s1" + ) + final = "" + async for event in runner.run_async( + user_id="user", + session_id="s1", + new_message=types.Content(role="user", parts=[types.Part(text=question)]), + ): + if not event.content or not event.content.parts: + continue + text = event.content.parts[0].text or "" + if not text: + continue + if event.author == "observability": + # Show node execution trace (from create_nested_observability_callback) + print(f" → {text}") + elif not text.startswith("[GraphMetadata]"): + final = text + return final + + +# --------------------------------------------------------------------------- +# Demo +# --------------------------------------------------------------------------- + + +async def main(): + q = "What are the key developments in quantum computing hardware?" + print(f"Question: {q}\n") + answer = await run(q) + print(f"Answer:\n{answer}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_pattern_parallel_group/README.md b/contributing/samples/graph_agent_pattern_parallel_group/README.md new file mode 100644 index 0000000000..56417049e3 --- /dev/null +++ b/contributing/samples/graph_agent_pattern_parallel_group/README.md @@ -0,0 +1,41 @@ +# GraphAgent Pattern — DynamicParallelGroup (Tree of Thoughts) + +This example implements the **Tree of Thoughts** pattern using `DynamicParallelGroup`. N independent +reasoning paths are generated concurrently, then an evaluator scores them and a selector picks the +best. The number of parallel branches is determined at runtime from the graph state, and concurrency +is capped by `max_parallelism` to prevent resource exhaustion. + +## When to Use This Pattern + +- Diverge-and-converge workflows: generate many candidates in parallel, then select the best +- Tree of Thoughts / beam search style reasoning +- Ensemble approaches: run N agents independently and aggregate their outputs +- Any case where the number of parallel branches is data-dependent (unknown at graph-build time) + +## How to Run + +```bash +adk run contributing/samples/graph_agent_pattern_parallel_group +``` + +## Graph Structure + +``` +config (function) ──▶ generate (DynamicParallelGroup) ──▶ evaluate ──▶ select + ├── thought_agent_0 + ├── thought_agent_1 + └── thought_agent_N (N from state.data["num_thoughts"]) +``` + +## Key Code Walkthrough + +- **`DynamicParallelGroup(name="generate", agent_generator=generate_thought_agents)`** — the + generator callable receives `GraphState` at runtime and returns a list of `BaseAgent` instances +- **`max_parallelism=5`** — caps concurrent agent executions via an `asyncio.Semaphore`; prevents + overloading the model API with too many simultaneous requests +- **`aggregator=aggregate_thoughts`** — combines the N results into a single string (using + `=== Thought N ===` separators) before passing to the evaluator +- **`config` function node** — parses `[num_thoughts=N]` from the user message and writes + `state.data["num_thoughts"]`; shows how function nodes can pre-process inputs +- **`select_responder`** — final selector reads evaluator scores and returns the winning thought + diff --git a/contributing/samples/graph_agent_pattern_parallel_group/__init__.py b/contributing/samples/graph_agent_pattern_parallel_group/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contributing/samples/graph_agent_pattern_parallel_group/agent.py b/contributing/samples/graph_agent_pattern_parallel_group/agent.py new file mode 100644 index 0000000000..891a27d400 --- /dev/null +++ b/contributing/samples/graph_agent_pattern_parallel_group/agent.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +"""DynamicParallelGroup Pattern: Tree of Thoughts / Self-Consistency + +Motivation (Tree of Thoughts & Self-Consistency) +-------------------------------------------------- +Wang et al. (2022) "Self-Consistency Improves Chain of Thought Reasoning": +sampling multiple independent reasoning paths and taking the majority answer +outperforms single chain-of-thought on arithmetic and commonsense benchmarks. + +Yao et al. (2023) "Tree of Thoughts: Deliberate Problem Solving with LLMs": +exploring multiple partial-solution branches in parallel, then scoring and +selecting the best one, significantly improves performance on complex tasks. + +Both techniques require spawning N concurrent agents whose count is determined +at runtime — exactly what DynamicParallelGroup provides. + +Pattern: DynamicParallelGroup +------------------------------ +DynamicParallelGroup generates the agent list at runtime from a callable, runs +them concurrently (with optional max_parallelism throttle), then feeds all +results to an aggregator function. + +Compare to the static-parallelism alternative (ParallelNodeGroup) +------------------------------------------------------------------ +ParallelNodeGroup is compiled at graph-construction time: + + group = ParallelNodeGroup( + name="parallel", + nodes=[node1, node2, node3], # fixed list — must know N upfront + join_strategy=JoinStrategy.WAIT_ALL, + ) + +DynamicParallelGroup gives you: + ✅ Runtime-determined N (read from state.data — e.g. user param) + ✅ State-driven generation (e.g. one agent per item in a list) + ✅ Custom aggregation logic (majority vote, best-score selection, …) + ✅ Concurrency cap via max_parallelism (back-pressure / rate limiting) + +Architecture (Tree of Thoughts) +--------------------------------- + generate_thoughts (DynamicParallelGroup) + │ N independent "thought" agents run concurrently + ▼ + evaluate ← LlmAgent scores all thoughts + │ + ▼ + select ← LlmAgent picks the winning thought +""" + +import asyncio +import os + +from google.adk.agents import LlmAgent +from google.adk.agents.graph import DynamicParallelGroup +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.agents.graph.callbacks import create_nested_observability_callback +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + + +# --------------------------------------------------------------------------- +# Thought generator — each instance produces one independent solution path +# --------------------------------------------------------------------------- +def make_thought_agent(idx: int) -> LlmAgent: + return LlmAgent( + name=f"thought_{idx}", + model=_MODEL, + instruction=f""" +You are creative problem-solver #{idx + 1}. Given a problem, propose one +original, concrete solution approach. Be specific (≤3 sentences). +Do NOT repeat approaches from other agents; bring a fresh angle. +""", + ) + + +# --------------------------------------------------------------------------- +# Runtime generator: reads state.data["num_thoughts"] or defaults to 3 +# --------------------------------------------------------------------------- +def generate_thought_agents(state: GraphState): + n = int(state.data.get("num_thoughts", "3")) + return [make_thought_agent(i) for i in range(n)] + + +# --------------------------------------------------------------------------- +# Aggregator: concatenate all thoughts with separators for the evaluator +# --------------------------------------------------------------------------- +def aggregate_thoughts(results: list, state: GraphState) -> str: + lines = [] + for i, r in enumerate(results, 1): + lines.append(f"=== Thought {i} ===\n{r.strip()}") + return "\n\n".join(lines) + + +# --------------------------------------------------------------------------- +# Evaluator and selector +# --------------------------------------------------------------------------- +evaluator = LlmAgent( + name="evaluator", + model=_MODEL, + instruction=""" +You receive several numbered solution approaches. For EACH one, write a +single line: "Thought N: ". +""", +) + +selector = LlmAgent( + name="selector", + model=_MODEL, + instruction=""" +You receive evaluations of solution approaches. Choose the best-scoring one +and explain why it is the strongest approach (2-3 sentences). +""", +) + + +# --------------------------------------------------------------------------- +# Pre-processing: extract [num_thoughts=N] header from user input +# --------------------------------------------------------------------------- +def parse_config(state: GraphState, ctx) -> str: + """Extract num_thoughts from the input message and store in state. + + Input format: "[num_thoughts=N] " + Falls back to default 3 if the header is absent. + """ + import re + + raw = state.data.get("input", "") + m = re.match(r"\[num_thoughts=(\d+)\]\s*(.*)", raw, re.DOTALL) + if m: + state.data["num_thoughts"] = m.group(1) + return m.group(2).strip() + return raw + + +# --------------------------------------------------------------------------- +# Build graph +# --------------------------------------------------------------------------- +def build_graph() -> GraphAgent: + graph = GraphAgent( + name="tree_of_thoughts", + description="Parallel thought generation → evaluation → selection", + before_node_callback=create_nested_observability_callback(), + ) + + # Extract num_thoughts config and clean the input text + graph.add_node("config", function=parse_config) + + graph.add_node( + DynamicParallelGroup( + name="generate", + agent_generator=generate_thought_agents, + aggregator=aggregate_thoughts, + max_parallelism=5, # cap concurrent LLM calls + ) + ) + graph.add_node("evaluate", agent=evaluator) + graph.add_node("select", agent=selector) + + graph.add_edge("config", "generate") + graph.add_edge("generate", "evaluate") + graph.add_edge("evaluate", "select") + graph.set_start("config") + graph.set_end("select") + return graph + + +# --------------------------------------------------------------------------- +# Runner helper +# --------------------------------------------------------------------------- +async def run(problem: str, num_thoughts: int = 3) -> str: + # Encode num_thoughts in the message so state.data["input"] carries it, + # and also pre-seed state via a thin wrapper node if needed. + # The simplest approach: embed num_thoughts in a header the agent ignores. + full_input = f"[num_thoughts={num_thoughts}] {problem}" + + graph = build_graph() + svc = InMemorySessionService() + runner = Runner( + app_name="parallel_group_example", agent=graph, session_service=svc + ) + await svc.create_session( + app_name="parallel_group_example", user_id="user", session_id="s1" + ) + + final = "" + async for event in runner.run_async( + user_id="user", + session_id="s1", + new_message=types.Content( + role="user", parts=[types.Part(text=full_input)] + ), + ): + if not event.content or not event.content.parts: + continue + text = event.content.parts[0].text or "" + if not text: + continue + if event.author == "observability": + # Show node execution trace (from create_nested_observability_callback) + print(f" → {text}") + elif not text.startswith("[GraphMetadata]"): + final = text + return final + + +# --------------------------------------------------------------------------- +# Demo +# --------------------------------------------------------------------------- + + +async def main(): + problem = ( + "How can a small startup with limited budget compete with large " + "incumbents in the enterprise software market?" + ) + print(f"Problem: {problem}\n") + print("Generating 4 parallel thought paths...\n") + result = await run(problem, num_thoughts=4) + print(f"Selected best approach:\n{result}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/run_all_examples.sh b/contributing/samples/graph_examples/run_all_examples.sh new file mode 100755 index 0000000000..65272f5de9 --- /dev/null +++ b/contributing/samples/graph_examples/run_all_examples.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Run all GraphAgent examples + +set -e + +cd "$(dirname "$0")/../../.." +source venv/bin/activate + +echo "========================================" +echo "Running All GraphAgent Examples" +echo "========================================" +echo "" + +examples=( + "01_basic" + "02_conditional_routing" + "03_cyclic_execution" + "15_enhanced_routing" + "04_checkpointing" + "05_interrupts_basic" + "06_interrupts_reasoning" + "07_callbacks" + "08_rewind" + "09_parallel_wait_all" + "10_parallel_wait_any" + "11_parallel_wait_n" + "12_parallel_checkpointing" + "13_parallel_interrupts" + "14_parallel_rewind" +) + +for example in "${examples[@]}"; do + echo "----------------------------------------" + echo "Running: $example" + echo "----------------------------------------" + python -m "contributing.samples.graph_examples.${example}.agent" 2>&1 | grep -v "UserWarning" || true + echo "" +done + +echo "========================================" +echo "✅ All Examples Complete!" +echo "========================================" diff --git a/src/google/adk/agents/graph/__init__.py b/src/google/adk/agents/graph/__init__.py index 5ccc0f0c4d..8d9c3f79fc 100644 --- a/src/google/adk/agents/graph/__init__.py +++ b/src/google/adk/agents/graph/__init__.py @@ -14,6 +14,9 @@ - EdgeCallbackContext: Context for edge condition callbacks - NodeCallback: Type for node lifecycle callbacks - EdgeCallback: Type for edge condition callbacks +- DynamicNode: Runtime agent selection based on state +- NestedGraphNode: Hierarchical workflow composition (graph within graph) +- DynamicParallelGroup: Dynamic concurrent execution with variable agent count """ from __future__ import annotations @@ -43,6 +46,9 @@ from .graph_state import GraphState from .graph_state import PydanticJSONEncoder from .graph_state import StateReducer +from .patterns import DynamicNode +from .patterns import DynamicParallelGroup +from .patterns import NestedGraphNode # Sentinel constants for graph boundaries START = "__start__" @@ -74,6 +80,9 @@ "export_graph_with_execution", "export_execution_timeline", "rewind_to_node", + "DynamicNode", + "NestedGraphNode", + "DynamicParallelGroup", "START", "END", ] diff --git a/src/google/adk/agents/graph/graph_agent.py b/src/google/adk/agents/graph/graph_agent.py index 3ebde3012f..b47e7729c4 100644 --- a/src/google/adk/agents/graph/graph_agent.py +++ b/src/google/adk/agents/graph/graph_agent.py @@ -405,6 +405,13 @@ def _get_node_agent(self, node: "GraphNode") -> Optional[BaseAgent]: """ if node.agent is not None: return node.agent + from .patterns import DynamicNode + from .patterns import NestedGraphNode + + if isinstance(node, NestedGraphNode): + return node.graph_agent + if isinstance(node, DynamicNode): + return node.fallback_agent return None def _register_node_agents(self, node: "GraphNode") -> None: diff --git a/src/google/adk/agents/graph/patterns.py b/src/google/adk/agents/graph/patterns.py new file mode 100644 index 0000000000..7872f6226f --- /dev/null +++ b/src/google/adk/agents/graph/patterns.py @@ -0,0 +1,327 @@ +"""Advanced graph patterns for common agentic workflows. + +This module provides first-class APIs for advanced patterns: +- DynamicNode: Runtime agent selection based on state +- NestedGraphNode: Hierarchical workflow composition (graph within graph) +- DynamicParallelGroup: Dynamic concurrent execution with variable agent count +""" + +from __future__ import annotations + +import asyncio +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +import uuid + +from google.genai import types + +from google import genai + +from ...utils.feature_decorator import experimental +from ..base_agent import BaseAgent +from ..invocation_context import InvocationContext +from .graph_node import GraphNode +from .graph_state import GraphState + + +@experimental +class DynamicNode(GraphNode): + """Node with runtime agent selection based on state. + + Enables dynamic dispatch pattern where the agent to execute is selected + at runtime based on the current graph state. This is useful for: + - Task routing (route to specialized agents based on task type) + - Adaptive execution (select agent based on difficulty/complexity) + - Multi-agent orchestration (dispatch to different agents per request) + + Example: + ```python + def select_agent(state: GraphState) -> BaseAgent: + task_type = state.data.get("task_type", "simple") + return complex_agent if task_type == "complex" else simple_agent + + node = DynamicNode( + name="dispatcher", + agent_selector=select_agent, + fallback_agent=default_agent + ) + ``` + """ + + def __init__( + self, + name: str, + agent_selector: Callable[[GraphState], Optional[BaseAgent]], + fallback_agent: Optional[BaseAgent] = None, + **kwargs: Any, + ): + """Initialize dynamic node. + + Args: + name: Node name + agent_selector: Function that selects agent based on state + fallback_agent: Agent to use if selector returns None + **kwargs: Additional arguments passed to GraphNode (input_mapper, output_mapper, etc.) + """ + self.agent_selector = agent_selector + self.fallback_agent = fallback_agent + super().__init__( + name=name, agent=None, function=self._execute_dynamic, **kwargs + ) + + async def _execute_dynamic( + self, state: GraphState, ctx: InvocationContext + ) -> str: + """Execute selected agent based on state. + + Args: + state: Current graph state + ctx: Invocation context + + Returns: + Agent output as string + + Raises: + ValueError: If no agent selected and no fallback + """ + selected = self.agent_selector(state) or self.fallback_agent + if not selected: + raise ValueError( + f"No agent selected for {self.name} and no fallback provided" + ) + + # Observability: track selected agent for debugging + state.data[f"_debug_{self.name}_selected_agent"] = selected.name + + node_input = self.input_mapper(state) + node_ctx = ctx.model_copy( + update={ + "user_content": types.Content( + role="user", parts=[types.Part(text=node_input)] + ) + } + ) + + output = "" + async for event in selected.run_async(node_ctx): + if event.content and event.content.parts: + text = "".join(p.text for p in event.content.parts if p.text) + if text: + output += text + return output + + +@experimental +class NestedGraphNode(GraphNode): + """Node that executes a GraphAgent as sub-workflow. + + Enables hierarchical workflow composition where a GraphAgent is executed + as a node within a parent graph. This is useful for: + - Multi-step validation (validation graph within main workflow) + - Conditional sub-workflows (execute different graphs based on conditions) + - Workflow reuse (share common sub-workflows across graphs) + + The nested graph automatically inherits telemetry config from parent via + context propagation, ensuring consistent observability. + + Example: + ```python + # Create validation sub-workflow + validation_graph = GraphAgent(name="validation") + validation_graph.add_node(GraphNode(name="check1", agent=checker1)) + validation_graph.add_node(GraphNode(name="check2", agent=checker2)) + + # Embed in parent workflow + nested = NestedGraphNode( + name="validate", + graph_agent=validation_graph, + inherit_session=True + ) + parent_graph.add_node(nested) + ``` + """ + + def __init__( + self, + name: str, + graph_agent: "BaseAgent", # Avoid circular import, will be GraphAgent at runtime + inherit_session: bool = True, + **kwargs: Any, + ): + """Initialize nested graph node. + + Args: + name: Node name + graph_agent: GraphAgent to execute as nested workflow + inherit_session: If True, nested graph uses parent session (shares state/history) + If False, creates isolated session for nested execution + **kwargs: Additional arguments passed to GraphNode + """ + self.graph_agent = graph_agent + self.inherit_session = inherit_session + super().__init__( + name=name, agent=None, function=self._execute_nested, **kwargs + ) + + async def _execute_nested( + self, state: GraphState, ctx: InvocationContext + ) -> str: + """Execute nested graph workflow. + + Args: + state: Current graph state + ctx: Invocation context + + Returns: + Final output from nested graph + """ + if self.inherit_session: + # Use parent session - nested graph sees parent state + nested_ctx = ctx.model_copy( + update={ + "user_content": types.Content( + role="user", parts=[types.Part(text=self.input_mapper(state))] + ) + } + ) + else: + # Create isolated session - nested graph has clean state + # Note: This requires access to session_service from parent context + nested_session_id = f"nested_{uuid.uuid4().hex[:8]}" + await ctx.session_service.create_session( + app_name=self.graph_agent.name, + user_id=ctx.session.user_id, + session_id=nested_session_id, + ) + nested_session = await ctx.session_service.get_session( + app_name=self.graph_agent.name, + user_id=ctx.session.user_id, + session_id=nested_session_id, + ) + nested_ctx = ctx.model_copy( + update={ + "session": nested_session, + "user_content": types.Content( + role="user", + parts=[types.Part(text=self.input_mapper(state))], + ), + } + ) + + # Execute nested graph + # Filter empty-text and [GraphMetadata] sentinel events; keep the last + # meaningful text output (agent nodes already emitted their content, but + # the inner GraphAgent's final response event has empty text for agent + # end-nodes, which would otherwise overwrite the real result). + final_output = "" + async for event in self.graph_agent.run_async(nested_ctx): + if event.content and event.content.parts: + text = "".join(p.text for p in event.content.parts if p.text) + if text and not text.startswith("[GraphMetadata]"): + final_output += text + + # Observability: track nested graph output (truncated) + state.data[f"_debug_{self.name}_output"] = final_output[:500] + + return final_output + + +@experimental +class DynamicParallelGroup(GraphNode): + """Node that executes multiple agents in parallel with dynamic concurrency. + + Enables dynamic parallel execution where the number of parallel agents + is determined at runtime based on state. This is useful for: + - Tree of Thoughts (generate N thoughts in parallel) + - Parallel search (search multiple sources concurrently) + - Batch processing (process variable-size batches) + + Example: + ```python + def gen_agents(state: GraphState) -> List[BaseAgent]: + num_thoughts = state.data.get("num_thoughts", 3) + return [thought_generator for _ in range(num_thoughts)] + + def aggregate(results: List[str], state: GraphState) -> str: + return "\\n---\\n".join(f"Thought {i}: {r}" for i, r in enumerate(results)) + + node = DynamicParallelGroup( + name="generate_thoughts", + agent_generator=gen_agents, + aggregator=aggregate, + max_parallelism=5 # Limit concurrent execution + ) + ``` + """ + + def __init__( + self, + name: str, + agent_generator: Callable[[GraphState], List[BaseAgent]], + aggregator: Callable[[List[str], GraphState], str], + max_parallelism: int = 5, + **kwargs: Any, + ): + """Initialize dynamic parallel group. + + Args: + name: Node name + agent_generator: Function that generates list of agents based on state + aggregator: Function that aggregates all agent outputs into single result + max_parallelism: Maximum number of agents to execute concurrently (default: 5) + **kwargs: Additional arguments passed to GraphNode + """ + self.agent_generator = agent_generator + self.aggregator = aggregator + self.max_parallelism = max_parallelism + super().__init__( + name=name, agent=None, function=self._execute_parallel, **kwargs + ) + + async def _execute_parallel( + self, state: GraphState, ctx: InvocationContext + ) -> str: + """Execute agents in parallel with concurrency limit. + + Args: + state: Current graph state + ctx: Invocation context + + Returns: + Aggregated output from all agents + """ + agents = self.agent_generator(state) + + # Observability: track parallel execution count + state.data[f"_debug_{self.name}_parallel_count"] = len(agents) + + if not agents: + return self.aggregator([], state) + + semaphore = asyncio.Semaphore(self.max_parallelism) + + async def run_agent(agent: BaseAgent) -> str: + """Execute single agent under the shared concurrency semaphore.""" + async with semaphore: + node_input = self.input_mapper(state) + node_ctx = ctx.model_copy( + update={ + "user_content": types.Content( + role="user", parts=[types.Part(text=node_input)] + ) + } + ) + + output = "" + async for event in agent.run_async(node_ctx): + if event.content and event.content.parts: + text = "".join(p.text for p in event.content.parts if p.text) + if text: + output += text + return output + + results = await asyncio.gather(*[run_agent(a) for a in agents]) + + return self.aggregator(results, state) diff --git a/src/google/adk/cli/agent_graph.py b/src/google/adk/cli/agent_graph.py index fb313f51a4..7f93c9dc94 100644 --- a/src/google/adk/cli/agent_graph.py +++ b/src/google/adk/cli/agent_graph.py @@ -21,6 +21,9 @@ from ..agents.base_agent import BaseAgent from ..agents.graph.graph_agent import GraphAgent +from ..agents.graph.patterns import DynamicNode +from ..agents.graph.patterns import DynamicParallelGroup +from ..agents.graph.patterns import NestedGraphNode from ..agents.llm_agent import LlmAgent from ..agents.loop_agent import LoopAgent from ..agents.parallel_agent import ParallelAgent @@ -43,8 +46,11 @@ def _graph_node_id(node) -> str: """Get the graphviz node ID for a GraphNode. For agent nodes, uses agent.name (matches build_graph's node naming). - For function nodes, uses node.name. + For NestedGraphNode, uses graph_agent's cluster name. + For function/pattern nodes, uses node.name. """ + if isinstance(node, NestedGraphNode): + return node.graph_agent.name + ' (Graph Agent)' if node.agent is not None: return node.agent.name return node.name @@ -207,7 +213,27 @@ async def build_cluster(child: graphviz.Digraph, agent: BaseAgent, name: str): elif isinstance(agent, GraphAgent): # Render all graph nodes inside the cluster for node_name, node in agent.nodes.items(): - if node.agent is not None: + if isinstance(node, NestedGraphNode): + await build_graph(child, node.graph_agent, highlight_pairs) + elif isinstance(node, DynamicNode): + child.node( + node_name, + node_name + ' (dynamic)', + shape='diamond', + style='rounded', + color=light_gray, + fontcolor=light_gray, + ) + elif isinstance(node, DynamicParallelGroup): + child.node( + node_name, + node_name + ' (parallel)', + shape='parallelogram', + style='rounded', + color=light_gray, + fontcolor=light_gray, + ) + elif node.agent is not None: # Agent node — render inside cluster await build_graph(child, node.agent, highlight_pairs) else: diff --git a/tests/unittests/agents/test_graph_agent.py b/tests/unittests/agents/test_graph_agent.py index 1ccf1070c0..61d096b3b0 100644 --- a/tests/unittests/agents/test_graph_agent.py +++ b/tests/unittests/agents/test_graph_agent.py @@ -33,6 +33,8 @@ from google.adk.agents.graph import rewind_to_node from google.adk.agents.graph import StateReducer from google.adk.agents.graph.graph_agent_state import GraphAgentState +from google.adk.agents.graph.patterns import DynamicNode +from google.adk.agents.graph.patterns import NestedGraphNode from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.run_config import RunConfig from google.adk.apps import ResumabilityConfig @@ -2222,9 +2224,7 @@ class OutputSchema(BaseModel): # _validate_node_configuration should log warning import logging - with patch( - "google.adk.agents.graph.graph_agent.logger" - ) as mock_logger: + with patch("google.adk.agents.graph.graph_agent.logger") as mock_logger: graph.add_node(node) mock_logger.warning.assert_called_once() @@ -2313,7 +2313,9 @@ def test_string_with_weight_only_creates_edge_condition(self): def test_invalid_target_type_raises(self): """Non-str non-EdgeCondition target raises TypeError.""" graph = self._make_graph() - with pytest.raises(TypeError, match="target_node must be str or EdgeCondition"): + with pytest.raises( + TypeError, match="target_node must be str or EdgeCondition" + ): graph.add_edge("src", 42) @@ -2335,9 +2337,7 @@ async def test_before_node_callback_returns_event(self): async def before_cb(ctx: NodeCallbackContext): return Event( author="before_cb", - content=types.Content( - parts=[types.Part(text="before_event")] - ), + content=types.Content(parts=[types.Part(text="before_event")]), ) graph = GraphAgent(name="g", before_node_callback=before_cb) @@ -2352,9 +2352,7 @@ async def before_cb(ctx: NodeCallbackContext): async for event in runner.run_async( user_id="u", session_id="s", - new_message=types.Content( - role="user", parts=[types.Part(text="go")] - ), + new_message=types.Content(role="user", parts=[types.Part(text="go")]), ): yielded_events.append(event) @@ -2362,7 +2360,9 @@ async def before_cb(ctx: NodeCallbackContext): before_texts = [ e.content.parts[0].text for e in yielded_events - if e.content and e.content.parts and e.content.parts[0].text == "before_event" + if e.content + and e.content.parts + and e.content.parts[0].text == "before_event" ] assert len(before_texts) == 1 @@ -2373,9 +2373,7 @@ async def test_after_node_callback_returns_event(self): async def after_cb(ctx: NodeCallbackContext): return Event( author="after_cb", - content=types.Content( - parts=[types.Part(text="after_event")] - ), + content=types.Content(parts=[types.Part(text="after_event")]), ) graph = GraphAgent(name="g", after_node_callback=after_cb) @@ -2391,16 +2389,16 @@ async def after_cb(ctx: NodeCallbackContext): async for event in runner.run_async( user_id="u", session_id="s", - new_message=types.Content( - role="user", parts=[types.Part(text="go")] - ), + new_message=types.Content(role="user", parts=[types.Part(text="go")]), ): yielded_events.append(event) after_texts = [ e.content.parts[0].text for e in yielded_events - if e.content and e.content.parts and e.content.parts[0].text == "after_event" + if e.content + and e.content.parts + and e.content.parts[0].text == "after_event" ] assert len(after_texts) == 1 @@ -2459,9 +2457,7 @@ def _cov_make_ctx( session_service=svc, invocation_id="inv-1", agent=agent, - user_content=types.Content( - role="user", parts=[types.Part(text="test")] - ), + user_content=types.Content(role="user", parts=[types.Part(text="test")]), ) if resumable: ctx.resumability_config = ResumabilityConfig(is_resumable=True) @@ -2488,6 +2484,32 @@ def _cov_linear_graph(name, agents, names): class TestGetNodeAgent: + + def test_nested_graph_node_returns_graph_agent(self): + inner = GraphAgent(name="inner") + inner_agent = SimpleTestAgent("inner_step", ["inner_ok"]) + inner.add_node(GraphNode(name="s", agent=inner_agent)) + inner.set_start("s") + inner.set_end("s") + nested = NestedGraphNode(name="nest", graph_agent=inner) + outer = GraphAgent(name="outer") + assert outer._get_node_agent(nested) is inner + + def test_dynamic_node_returns_fallback_agent(self): + fallback = SimpleTestAgent("fallback", ["fb_ok"]) + dyn = DynamicNode( + name="dyn", agent_selector=lambda s: None, fallback_agent=fallback + ) + outer = GraphAgent(name="outer") + assert outer._get_node_agent(dyn) is fallback + + def test_dynamic_node_no_fallback_returns_none(self): + dyn = DynamicNode( + name="dyn", agent_selector=lambda s: None, fallback_agent=None + ) + outer = GraphAgent(name="outer") + assert outer._get_node_agent(dyn) is None + def test_regular_node_returns_agent(self): agent = SimpleTestAgent("a", ["ok"]) node = GraphNode(name="n", agent=agent) @@ -2497,14 +2519,19 @@ def test_regular_node_returns_agent(self): @pytest.mark.asyncio class TestDomainDataFromSession: + async def test_session_state_populates_domain_data(self): agent_a = SimpleTestAgent("a", ["done"]) graph = _cov_linear_graph("g", [agent_a], ["nA"]) - ctx = _cov_make_ctx(graph, session_state={"my_key": "my_value", "another": 42}) + ctx = _cov_make_ctx( + graph, session_state={"my_key": "my_value", "another": 42} + ) events = await _cov_collect(graph, ctx) final = [ - e for e in events - if e.actions and e.actions.state_delta + e + for e in events + if e.actions + and e.actions.state_delta and "graph_data" in (e.actions.state_delta or {}) ] assert len(final) == 1 @@ -2526,8 +2553,10 @@ async def test_internal_keys_excluded_from_domain_data(self): ) events = await _cov_collect(graph, ctx) final = [ - e for e in events - if e.actions and e.actions.state_delta + e + for e in events + if e.actions + and e.actions.state_delta and "graph_data" in (e.actions.state_delta or {}) ] assert len(final) == 1 @@ -2541,6 +2570,7 @@ async def test_internal_keys_excluded_from_domain_data(self): @pytest.mark.asyncio @pytest.mark.asyncio class TestBeforeNodeCallbackException: + async def test_before_callback_failure_continues_execution(self): agent_a = SimpleTestAgent("a", ["a_out"]) graph = _cov_linear_graph("g", [agent_a], ["nA"]) @@ -2558,6 +2588,7 @@ async def failing_callback(ctx): @pytest.mark.asyncio @pytest.mark.asyncio class TestOutputMapperNoneFallback: + async def test_output_mapper_returning_none_uses_prev_state(self): agent_a = SimpleTestAgent("a", ["a_out"]) graph = GraphAgent(name="g") @@ -2576,16 +2607,21 @@ async def test_output_mapper_returning_none_uses_prev_state(self): events = await _cov_collect(graph, ctx) assert agent_a.call_count == 1 final = [ - e for e in events - if e.actions and e.actions.state_delta + e + for e in events + if e.actions + and e.actions.state_delta and "graph_data" in (e.actions.state_delta or {}) ] assert len(final) == 1 - assert final[0].actions.state_delta["graph_data"].get("custom_key") == "a_out" + assert ( + final[0].actions.state_delta["graph_data"].get("custom_key") == "a_out" + ) @pytest.mark.asyncio class TestAfterNodeCallbackException: + async def test_after_callback_failure_continues_execution(self): agent_a = SimpleTestAgent("a", ["a_out"]) agent_b = SimpleTestAgent("b", ["b_out"]) @@ -2603,6 +2639,7 @@ async def failing_callback(ctx): @pytest.mark.asyncio class TestNodeExecutionException: + async def test_node_exception_raises_and_records_metrics(self): failing = _CovFailingAgent("fail", "node_error") graph = _cov_linear_graph("g", [failing], ["nA"]) @@ -2625,4 +2662,3 @@ async def test_condition_eval_failure_logs_with_exc_info(self): assert result is False mock_logger.error.assert_called_once() assert mock_logger.error.call_args[1].get("exc_info") is True - diff --git a/tests/unittests/agents/test_graph_patterns.py b/tests/unittests/agents/test_graph_patterns.py new file mode 100644 index 0000000000..0b9129ce15 --- /dev/null +++ b/tests/unittests/agents/test_graph_patterns.py @@ -0,0 +1,741 @@ +"""Tests for GraphAgent first-class pattern APIs. + +Tests: +- DynamicNode: Runtime agent selection based on state +- NestedGraphNode: Hierarchical workflow composition (graph within graph) +- DynamicParallelGroup: Dynamic concurrent execution with variable agent count + +Uses Runner/GraphAgent for end-to-end integration testing per ADK conventions. +Manual InvocationContext construction is wrong because it requires internal +fields (invocation_id, agent) that Runner fills automatically. +""" + +import asyncio +from typing import List +import uuid + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import DynamicNode +from google.adk.agents.graph import DynamicParallelGroup +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphNode +from google.adk.agents.graph import GraphState +from google.adk.agents.graph import NestedGraphNode +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +import pytest + +# ============================================================================ +# Test Helpers +# ============================================================================ + + +class SimpleTestAgent(BaseAgent): + """Minimal agent that returns a fixed list of responses in order.""" + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + def __init__(self, name: str, responses: list, delay: float = 0.0): + super().__init__(name=name) + object.__setattr__(self, "_responses", responses) + object.__setattr__(self, "_call_count", 0) + object.__setattr__(self, "_delay", delay) + + async def _run_async_impl(self, ctx): + delay = object.__getattribute__(self, "_delay") + await asyncio.sleep(delay) + call_count = object.__getattribute__(self, "_call_count") + responses = object.__getattribute__(self, "_responses") + response = responses[min(call_count, len(responses) - 1)] + object.__setattr__(self, "_call_count", call_count + 1) + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=response)]), + ) + + @property + def call_count(self): + return object.__getattribute__(self, "_call_count") + + +def make_runner(graph): + svc = InMemorySessionService() + runner = Runner(app_name="test", agent=graph, session_service=svc) + return runner, svc + + +async def run_graph(runner, svc, message): + """Create a fresh session, run the graph, return the final non-metadata text.""" + sid = f"s_{uuid.uuid4().hex[:8]}" + await svc.create_session(app_name="test", user_id="u", session_id=sid) + final = "" + async for event in runner.run_async( + user_id="u", + session_id=sid, + new_message=types.Content(role="user", parts=[types.Part(text=message)]), + ): + if event.content and event.content.parts: + text = event.content.parts[0].text or "" + if text and not text.startswith("[GraphMetadata]"): + final = text + return final + + +async def run_graph_with_state(runner, svc, message): + """Run the graph and return (final_text, graph_data_dict).""" + sid = f"s_{uuid.uuid4().hex[:8]}" + await svc.create_session(app_name="test", user_id="u", session_id=sid) + final = "" + async for event in runner.run_async( + user_id="u", + session_id=sid, + new_message=types.Content(role="user", parts=[types.Part(text=message)]), + ): + if event.content and event.content.parts: + text = event.content.parts[0].text or "" + if text and not text.startswith("[GraphMetadata]"): + final = text + session = await svc.get_session(app_name="test", user_id="u", session_id=sid) + graph_data = session.state.get("graph_data", {}) if session else {} + return final, graph_data + + +# ============================================================================ +# DynamicNode +# ============================================================================ + + +class TestDynamicNode: + """DynamicNode selects which agent to run based on GraphState at runtime.""" + + @pytest.mark.asyncio + async def test_selects_agent_based_on_input(self): + """Selector reads state.data and returns appropriate agent.""" + simple = SimpleTestAgent("simple", ["SIMPLE"]) + complex_ = SimpleTestAgent("complex", ["COMPLEX"]) + + def selector(state): + return complex_ if "complex" in state.data.get("input", "") else simple + + graph = GraphAgent(name="g") + graph.add_node(DynamicNode(name="d", agent_selector=selector)) + graph.set_start("d") + graph.set_end("d") + + runner, svc = make_runner(graph) + assert await run_graph(runner, svc, "simple task") == "SIMPLE" + assert simple.call_count == 1 + assert complex_.call_count == 0 + + assert await run_graph(runner, svc, "complex task") == "COMPLEX" + assert complex_.call_count == 1 + + @pytest.mark.asyncio + async def test_fallback_when_selector_returns_none(self): + """When selector returns None, fallback_agent is used.""" + fallback = SimpleTestAgent("fallback", ["FALLBACK"]) + + graph = GraphAgent(name="g") + graph.add_node( + DynamicNode( + name="d", + agent_selector=lambda _: None, + fallback_agent=fallback, + ) + ) + graph.set_start("d") + graph.set_end("d") + + runner, svc = make_runner(graph) + assert await run_graph(runner, svc, "any") == "FALLBACK" + assert fallback.call_count == 1 + + @pytest.mark.asyncio + async def test_raises_when_no_agent_and_no_fallback(self): + """ValueError raised when selector returns None and no fallback set.""" + graph = GraphAgent(name="g") + graph.add_node(DynamicNode(name="d", agent_selector=lambda _: None)) + graph.set_start("d") + graph.set_end("d") + + runner, svc = make_runner(graph) + with pytest.raises(ValueError, match="No agent selected"): + await run_graph(runner, svc, "any") + + @pytest.mark.asyncio + async def test_different_agents_on_sequential_runs(self): + """Selector can pick different agents on each graph invocation.""" + a = SimpleTestAgent("a", ["A"]) + b = SimpleTestAgent("b", ["B"]) + n = {"count": 0} + + def selector(state): + n["count"] += 1 + return a if n["count"] % 2 == 1 else b + + graph = GraphAgent(name="g") + graph.add_node(DynamicNode(name="d", agent_selector=selector)) + graph.set_start("d") + graph.set_end("d") + + runner, svc = make_runner(graph) + assert await run_graph(runner, svc, "run1") == "A" + assert await run_graph(runner, svc, "run2") == "B" + assert a.call_count == 1 + assert b.call_count == 1 + + +# ============================================================================ +# NestedGraphNode +# ============================================================================ + + +class TestNestedGraphNode: + """NestedGraphNode executes a full GraphAgent as a single node step.""" + + @pytest.mark.asyncio + async def test_nested_graph_runs_all_steps(self): + """All nodes in the nested graph execute; all outputs are accumulated.""" + a1 = SimpleTestAgent("n1", ["STEP1"]) + a2 = SimpleTestAgent("n2", ["STEP2"]) + + inner = GraphAgent(name="inner") + inner.add_node(GraphNode(name="s1", agent=a1)) + inner.add_node(GraphNode(name="s2", agent=a2)) + inner.add_edge("s1", "s2") + inner.set_start("s1") + inner.set_end("s2") + + outer = GraphAgent(name="outer") + outer.add_node(NestedGraphNode(name="nested", graph_agent=inner)) + outer.set_start("nested") + outer.set_end("nested") + + runner, svc = make_runner(outer) + result = await run_graph(runner, svc, "go") + assert result == "STEP1STEP2" + assert a1.call_count == 1 + assert a2.call_count == 1 + + @pytest.mark.asyncio + async def test_inherit_session_true(self): + """With inherit_session=True the nested graph shares the parent session.""" + inner_agent = SimpleTestAgent("inner", ["INNER"]) + inner = GraphAgent(name="inner_graph") + inner.add_node(GraphNode(name="p", agent=inner_agent)) + inner.set_start("p") + inner.set_end("p") + + outer = GraphAgent(name="outer") + outer.add_node( + NestedGraphNode(name="nested", graph_agent=inner, inherit_session=True) + ) + outer.set_start("nested") + outer.set_end("nested") + + runner, svc = make_runner(outer) + assert await run_graph(runner, svc, "go") == "INNER" + + @pytest.mark.asyncio + async def test_inherit_session_false_isolated(self): + """With inherit_session=False the nested graph gets its own session.""" + inner_agent = SimpleTestAgent("inner", ["ISOLATED"]) + inner = GraphAgent(name="inner_graph") + inner.add_node(GraphNode(name="p", agent=inner_agent)) + inner.set_start("p") + inner.set_end("p") + + outer = GraphAgent(name="outer") + outer.add_node( + NestedGraphNode(name="nested", graph_agent=inner, inherit_session=False) + ) + outer.set_start("nested") + outer.set_end("nested") + + runner, svc = make_runner(outer) + assert await run_graph(runner, svc, "go") == "ISOLATED" + + @pytest.mark.asyncio + async def test_multi_node_inner_graph(self): + """Inner graph with 3 nodes; outer gets all accumulated output.""" + agents = [SimpleTestAgent(f"n{i}", [f"INNER{i}"]) for i in range(3)] + inner = GraphAgent(name="inner") + for i, a in enumerate(agents): + inner.add_node(GraphNode(name=f"n{i}", agent=a)) + inner.add_edge("n0", "n1") + inner.add_edge("n1", "n2") + inner.set_start("n0") + inner.set_end("n2") + + outer = GraphAgent(name="outer") + outer.add_node(NestedGraphNode(name="nested", graph_agent=inner)) + outer.set_start("nested") + outer.set_end("nested") + + runner, svc = make_runner(outer) + assert await run_graph(runner, svc, "go") == "INNER0INNER1INNER2" + for a in agents: + assert a.call_count == 1 + + +# ============================================================================ +# DynamicParallelGroup +# ============================================================================ + + +class TestDynamicParallelGroup: + """DynamicParallelGroup generates agent list at runtime and runs concurrently.""" + + @pytest.mark.asyncio + async def test_all_agents_run_and_aggregated(self): + """All generated agents execute; aggregator receives all results.""" + agents = [SimpleTestAgent(f"a{i}", [f"R{i}"]) for i in range(3)] + + graph = GraphAgent(name="g") + graph.add_node( + DynamicParallelGroup( + name="p", + agent_generator=lambda _: agents, + aggregator=lambda results, _: "|".join(results), + ) + ) + graph.set_start("p") + graph.set_end("p") + + runner, svc = make_runner(graph) + result = await run_graph(runner, svc, "go") + for i in range(3): + assert f"R{i}" in result + for a in agents: + assert a.call_count == 1 + + @pytest.mark.asyncio + async def test_empty_list_handled(self): + """Empty agent list: aggregator receives [] and returns gracefully.""" + graph = GraphAgent(name="g") + graph.add_node( + DynamicParallelGroup( + name="p", + agent_generator=lambda _: [], + aggregator=lambda results, _: f"count={len(results)}", + ) + ) + graph.set_start("p") + graph.set_end("p") + + runner, svc = make_runner(graph) + assert await run_graph(runner, svc, "go") == "count=0" + + @pytest.mark.asyncio + async def test_max_parallelism_all_complete(self): + """All agents complete even when max_parallelism < agent count.""" + agents = [SimpleTestAgent(f"a{i}", [f"R{i}"], delay=0.02) for i in range(6)] + + graph = GraphAgent(name="g") + graph.add_node( + DynamicParallelGroup( + name="p", + agent_generator=lambda _: agents, + aggregator=lambda results, _: f"count={len(results)}", + max_parallelism=2, + ) + ) + graph.set_start("p") + graph.set_end("p") + + runner, svc = make_runner(graph) + assert await run_graph(runner, svc, "go") == "count=6" + + @pytest.mark.asyncio + async def test_variable_count_from_state(self): + """Generator reads state.data.input to determine how many agents to spawn.""" + + def gen(state): + n = int(state.data.get("input", "3")) + return [SimpleTestAgent(f"a{i}", ["x"]) for i in range(n)] + + graph = GraphAgent(name="g") + graph.add_node( + DynamicParallelGroup( + name="p", + agent_generator=gen, + aggregator=lambda results, _: str(len(results)), + ) + ) + graph.set_start("p") + graph.set_end("p") + + runner, svc = make_runner(graph) + assert await run_graph(runner, svc, "5") == "5" + assert await run_graph(runner, svc, "2") == "2" + + @pytest.mark.asyncio + async def test_custom_aggregator(self): + """Aggregator fully controls result combination (e.g. sum).""" + agents = [SimpleTestAgent(f"a{i}", [str((i + 1) * 10)]) for i in range(3)] + + graph = GraphAgent(name="g") + graph.add_node( + DynamicParallelGroup( + name="p", + agent_generator=lambda _: agents, + aggregator=lambda results, _: str(sum(int(r) for r in results)), + ) + ) + graph.set_start("p") + graph.set_end("p") + + runner, svc = make_runner(graph) + assert await run_graph(runner, svc, "go") == "60" # 10+20+30 + + +# ============================================================================ +# Pattern Integration +# ============================================================================ + + +class TestPatternIntegration: + """Patterns compose naturally within a single GraphAgent.""" + + @pytest.mark.asyncio + async def test_dynamic_node_then_parallel_group(self): + """DynamicNode routes first, then DynamicParallelGroup runs downstream.""" + router_agent = SimpleTestAgent("router", ["routed"]) + parallel_agents = [SimpleTestAgent(f"p{i}", [f"P{i}"]) for i in range(3)] + + graph = GraphAgent(name="g") + graph.add_node( + DynamicNode( + name="router", + agent_selector=lambda _: router_agent, + ) + ) + graph.add_node( + DynamicParallelGroup( + name="parallel", + agent_generator=lambda _: parallel_agents, + aggregator=lambda results, _: "-".join(results), + ) + ) + graph.add_edge("router", "parallel") + graph.set_start("router") + graph.set_end("parallel") + + runner, svc = make_runner(graph) + result = await run_graph(runner, svc, "go") + for i in range(3): + assert f"P{i}" in result + + @pytest.mark.asyncio + async def test_nested_graph_with_dynamic_node_inside(self): + """NestedGraphNode wraps a sub-graph that itself uses DynamicNode.""" + inner_agent = SimpleTestAgent("inner_a", ["INNER_DYNAMIC"]) + + inner = GraphAgent(name="inner") + inner.add_node( + DynamicNode( + name="dyn", + agent_selector=lambda _: inner_agent, + ) + ) + inner.set_start("dyn") + inner.set_end("dyn") + + outer = GraphAgent(name="outer") + outer.add_node(NestedGraphNode(name="nested", graph_agent=inner)) + outer.set_start("nested") + outer.set_end("nested") + + runner, svc = make_runner(outer) + assert await run_graph(runner, svc, "go") == "INNER_DYNAMIC" + + +# ============================================================================ +# Observability (_debug_ keys in state.data) +# ============================================================================ + + +class TestPatternObservability: + """Pattern nodes write _debug_ keys to state.data for observability.""" + + @pytest.mark.asyncio + async def test_dynamic_node_records_selected_agent(self): + """DynamicNode records _debug_{name}_selected_agent in state.data.""" + agent_a = SimpleTestAgent("agent_a", ["A"]) + + graph = GraphAgent(name="g") + graph.add_node( + DynamicNode( + name="d", + agent_selector=lambda _: agent_a, + ) + ) + graph.set_start("d") + graph.set_end("d") + + runner, svc = make_runner(graph) + result, data = await run_graph_with_state(runner, svc, "go") + assert result == "A" + assert data.get("_debug_d_selected_agent") == "agent_a" + + @pytest.mark.asyncio + async def test_parallel_group_records_count(self): + """DynamicParallelGroup records _debug_{name}_parallel_count.""" + agents = [SimpleTestAgent(f"a{i}", [f"R{i}"]) for i in range(3)] + + graph = GraphAgent(name="g") + graph.add_node( + DynamicParallelGroup( + name="p", + agent_generator=lambda _: agents, + aggregator=lambda results, _: "|".join(results), + ) + ) + graph.set_start("p") + graph.set_end("p") + + runner, svc = make_runner(graph) + result, data = await run_graph_with_state(runner, svc, "go") + assert "R0" in result + assert data.get("_debug_p_parallel_count") == 3 + + @pytest.mark.asyncio + async def test_parallel_group_records_zero_count(self): + """DynamicParallelGroup records count=0 for empty agent list.""" + graph = GraphAgent(name="g") + graph.add_node( + DynamicParallelGroup( + name="p", + agent_generator=lambda _: [], + aggregator=lambda results, _: "empty", + ) + ) + graph.set_start("p") + graph.set_end("p") + + runner, svc = make_runner(graph) + result, data = await run_graph_with_state(runner, svc, "go") + assert result == "empty" + assert data.get("_debug_p_parallel_count") == 0 + + @pytest.mark.asyncio + async def test_nested_graph_records_output(self): + """NestedGraphNode records _debug_{name}_output in state.data.""" + inner_agent = SimpleTestAgent("inner_agent", ["NESTED_OUTPUT"]) + inner = GraphAgent(name="inner") + inner.add_node(GraphNode(name="s", agent=inner_agent)) + inner.set_start("s") + inner.set_end("s") + + outer = GraphAgent(name="outer") + outer.add_node(NestedGraphNode(name="nested", graph_agent=inner)) + outer.set_start("nested") + outer.set_end("nested") + + runner, svc = make_runner(outer) + result, data = await run_graph_with_state(runner, svc, "go") + assert result == "NESTED_OUTPUT" + assert data.get("_debug_nested_output") == "NESTED_OUTPUT" + + @pytest.mark.asyncio + async def test_nested_graph_output_truncated(self): + """NestedGraphNode truncates _debug_ output to 500 chars.""" + long_text = "x" * 1000 + inner_agent = SimpleTestAgent("inner_agent", [long_text]) + inner = GraphAgent(name="inner") + inner.add_node(GraphNode(name="s", agent=inner_agent)) + inner.set_start("s") + inner.set_end("s") + + outer = GraphAgent(name="outer") + outer.add_node(NestedGraphNode(name="nested", graph_agent=inner)) + outer.set_start("nested") + outer.set_end("nested") + + runner, svc = make_runner(outer) + _, data = await run_graph_with_state(runner, svc, "go") + assert len(data.get("_debug_nested_output", "")) == 500 + + +# ============================================================================ +# Sub-Agent Registration for Pattern Nodes +# ============================================================================ + + +class TestPatternSubAgentRegistration: + """Test GraphAgent.sub_agents registration for pattern node types.""" + + def test_nested_graph_node_registers_graph_agent(self): + """NestedGraphNode's graph_agent should appear in outer graph sub_agents.""" + inner_agent = SimpleTestAgent("inner_a", ["ok"]) + inner_graph = GraphAgent(name="inner_graph") + inner_graph.add_node(GraphNode(name="step", agent=inner_agent)) + inner_graph.set_start("step") + inner_graph.set_end("step") + + outer = GraphAgent(name="outer") + nested = NestedGraphNode(name="nested_step", graph_agent=inner_graph) + outer.add_node(nested) + + assert inner_graph in outer.sub_agents + assert inner_graph.parent_agent is outer + + def test_dynamic_node_registers_fallback_agent(self): + """DynamicNode's fallback_agent should appear in sub_agents when provided.""" + fallback = SimpleTestAgent("fallback", ["fb"]) + + outer = GraphAgent(name="g") + dyn = DynamicNode( + name="dispatcher", + agent_selector=lambda _: None, + fallback_agent=fallback, + ) + outer.add_node(dyn) + + assert fallback in outer.sub_agents + assert fallback.parent_agent is outer + + def test_dynamic_node_no_fallback_no_registration(self): + """DynamicNode without fallback should not add anything to sub_agents.""" + outer = GraphAgent(name="g") + dyn = DynamicNode(name="dispatcher", agent_selector=lambda _: None) + outer.add_node(dyn) + + assert len(outer.sub_agents) == 0 + + def test_dynamic_parallel_no_registration(self): + """DynamicParallelGroup should not register anything (runtime-only).""" + + def gen(state): + return [SimpleTestAgent(f"tmp_{i}", ["x"]) for i in range(2)] + + outer = GraphAgent(name="g") + dpg = DynamicParallelGroup( + name="par", + agent_generator=gen, + aggregator=lambda results, _: ",".join(results), + ) + outer.add_node(dpg) + + assert len(outer.sub_agents) == 0 + + def test_find_agent_through_nested_graph_node(self): + """outer.find_agent should traverse into NestedGraphNode's graph_agent.""" + inner_agent = SimpleTestAgent("deep_agent", ["ok"]) + inner_graph = GraphAgent(name="inner_graph") + inner_graph.add_node(GraphNode(name="step", agent=inner_agent)) + inner_graph.set_start("step") + inner_graph.set_end("step") + + outer = GraphAgent(name="outer") + nested = NestedGraphNode(name="nested_step", graph_agent=inner_graph) + outer.add_node(nested) + + # Find inner graph itself + assert outer.find_agent("inner_graph") is inner_graph + # Find deeply-nested agent + assert outer.find_agent("deep_agent") is inner_agent + + +# ============================================================================ +# Multi-event output accumulation tests +# ============================================================================ + + +class MultiEventAgent(BaseAgent): + """Agent that yields multiple content events (simulates streaming).""" + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + def __init__(self, name: str, texts: list): + super().__init__(name=name) + object.__setattr__(self, "_texts", texts) + + async def _run_async_impl(self, ctx): + for text in object.__getattribute__(self, "_texts"): + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=text)]), + ) + + +@pytest.mark.asyncio +class TestPatternOutputAccumulation: + """Pattern nodes must accumulate output from multi-event agents.""" + + async def test_dynamic_node_accumulates_multi_event_output(self): + """DynamicNode must concatenate all event texts, not keep only last.""" + agent = MultiEventAgent(name="streamer", texts=["Hello ", "World"]) + graph = GraphAgent(name="g") + graph.add_node(DynamicNode(name="dyn", agent_selector=lambda s: agent)) + graph.set_start("dyn") + graph.set_end("dyn") + + result = await run_graph(*make_runner(graph), "test") + assert "Hello " in result + assert "World" in result + + async def test_dynamic_parallel_group_accumulates_multi_event_output(self): + """DynamicParallelGroup must concatenate all event texts per agent.""" + agent = MultiEventAgent(name="s", texts=["part1", "part2"]) + graph = GraphAgent(name="g") + graph.add_node( + DynamicParallelGroup( + name="dpg", + agent_generator=lambda s: [agent], + aggregator=lambda results, s: " ".join(results), + ) + ) + graph.set_start("dpg") + graph.set_end("dpg") + + result = await run_graph(*make_runner(graph), "test") + assert "part1" in result + assert "part2" in result + + +class MultiPartEventAgent(BaseAgent): + """Agent that yields a single event with multiple parts.""" + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + def __init__(self, name: str, texts: list): + super().__init__(name=name) + object.__setattr__(self, "_texts", texts) + + async def _run_async_impl(self, ctx): + parts = [ + types.Part(text=t) for t in object.__getattribute__(self, "_texts") + ] + yield Event( + author=self.name, + content=types.Content(parts=parts), + ) + + +@pytest.mark.asyncio +class TestNestedGraphNodeMultiPart: + """NestedGraphNode must properly join multi-part event content.""" + + async def test_nested_graph_node_multi_part_output(self): + """Multi-part event parts are joined, not just parts[0].""" + agent = MultiPartEventAgent( + name="mp_agent", texts=["alpha", " beta", " gamma"] + ) + inner = GraphAgent(name="inner_mp") + inner.add_node(GraphNode(name="step", agent=agent)) + inner.set_start("step") + inner.set_end("step") + + outer = GraphAgent(name="outer_mp") + outer.add_node(NestedGraphNode(name="nested", graph_agent=inner)) + outer.set_start("nested") + outer.set_end("nested") + + result = await run_graph(*make_runner(outer), "go") + assert "alpha" in result + assert "beta" in result + assert "gamma" in result + assert result == "alpha beta gamma" diff --git a/tests/unittests/cli/test_agent_graph.py b/tests/unittests/cli/test_agent_graph.py index 62c7a868f0..6eef831739 100644 --- a/tests/unittests/cli/test_agent_graph.py +++ b/tests/unittests/cli/test_agent_graph.py @@ -9,6 +9,11 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.graph import GraphAgent, GraphNode, GraphState +from google.adk.agents.graph.patterns import ( + DynamicNode, + DynamicParallelGroup, + NestedGraphNode, +) from google.adk.cli.agent_graph import build_graph, get_agent_graph from google.adk.events.event import Event from google.genai import types @@ -106,6 +111,27 @@ async def test_graph_agent_loop(self): assert "reason" in src assert "observe" in src + @pytest.mark.asyncio + async def test_graph_agent_nested(self): + """NestedGraphNode renders inner graph as sub-cluster.""" + inner_a = SimpleTestAgent(name="inner_step") + inner_g = GraphAgent(name="inner") + inner_g.add_node(GraphNode(name="is", agent=inner_a)) + inner_g.set_start("is") + inner_g.set_end("is") + + outer = GraphAgent(name="outer") + outer.add_node(NestedGraphNode(name="nested", graph_agent=inner_g)) + outer.set_start("nested") + outer.set_end("nested") + + dg = _make_digraph() + await build_graph(dg, outer, highlight_pairs=None) + src = dg.source + + assert "inner" in src + assert "inner_step" in src + @pytest.mark.asyncio async def test_graph_agent_function_node(self): """Function-only node rendered as box shape.""" @@ -125,6 +151,47 @@ async def my_func(state, ctx): assert "fn_node" in src assert "box" in src + @pytest.mark.asyncio + async def test_graph_agent_dynamic_node(self): + """DynamicNode rendered as diamond shape.""" + g = GraphAgent(name="wf") + dyn = DynamicNode( + name="dispatcher", + agent_selector=lambda _: None, + ) + g.add_node(dyn) + g.set_start("dispatcher") + g.set_end("dispatcher") + + dg = _make_digraph() + await build_graph(dg, g, highlight_pairs=None) + src = dg.source + + assert "dispatcher" in src + assert "diamond" in src + assert "(dynamic)" in src + + @pytest.mark.asyncio + async def test_graph_agent_dynamic_parallel_group(self): + """DynamicParallelGroup rendered as parallelogram shape.""" + g = GraphAgent(name="wf") + dpg = DynamicParallelGroup( + name="fan_out", + agent_generator=lambda _: [], + aggregator=lambda r, _: "", + ) + g.add_node(dpg) + g.set_start("fan_out") + g.set_end("fan_out") + + dg = _make_digraph() + await build_graph(dg, g, highlight_pairs=None) + src = dg.source + + assert "fan_out" in src + assert "parallelogram" in src + assert "(parallel)" in src + @pytest.mark.asyncio async def test_get_agent_graph_returns_digraph(self): """get_agent_graph returns a graphviz.Digraph for GraphAgent."""