diff --git a/contributing/samples/graph_agent_advanced/README.md b/contributing/samples/graph_agent_advanced/README.md new file mode 100644 index 0000000000..4ce30fccc8 --- /dev/null +++ b/contributing/samples/graph_agent_advanced/README.md @@ -0,0 +1,406 @@ +# GraphAgent Advanced Example - All Features + +This example demonstrates **all features** of the GraphAgent interrupt and observability framework through a realistic research paper writing workflow. + +--- + +## Features Demonstrated + +### 1. **Checkpointing** βœ… +- Automatic checkpoint creation at each node +- Resume from any checkpoint +- Checkpoint listing and management +- State preservation across restarts + +**Code**: See `scenario_3_checkpointing_and_resume()` + +```python +# Enable checkpointing +graph = GraphAgent( + name="research_workflow", + checkpointing=True, + checkpoint_service=checkpoint_service, +) + +# List checkpoints +checkpoints = await checkpoint_service.list_checkpoints(session) + +# Restore from checkpoint +restored = await checkpoint_service.restore_checkpoint(session, checkpoint_id) +``` + +--- + +### 2. **LLM-based Interrupt Reasoning** βœ… +- InterruptReasoner analyzes interrupt messages +- Context-aware decisions (uses current node, state, execution path) +- Available actions: continue, rerun, pause, defer, skip +- Extensible via custom_actions + +**Code**: See `scenario_2_interrupt_with_reasoning()` + +```python +# Create LLM-based reasoner +reasoner = InterruptReasoner( + config=InterruptReasonerConfig( + model="gemini-2.0-flash-exp", + available_actions=["continue", "rerun", "pause", "defer", "skip"], + instruction="You are an interrupt reasoning agent...", + ) +) + +# Use in GraphAgent +graph = GraphAgent( + name="research_workflow", + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, + reasoner=reasoner, # LLM decides actions + ), +) + +# Send interrupt - LLM will analyze and decide +await interrupt_service.send_interrupt( + session_id=session.id, + text="The literature review missed key papers on neural architecture search", + action="defer", # Suggestion, but LLM may override +) +``` + +--- + +### 3. **Callback-based Observability** βœ… +- Custom before/after node callbacks +- Full access to state, iteration, invocation context +- Rich multi-content events (text, JSON, metadata) +- Developers control format (no hardcoded strings) + +**Code**: See `research_observability_callback()` + +```python +async def research_observability_callback(ctx: NodeCallbackContext) -> Optional[Event]: + """Custom observability with rich content.""" + parts = [ + types.Part(text=f"πŸ“ **Executing**: {ctx.node.name}"), + types.Part(text=f"Progress: {progress:.1f}%"), + types.Part(text=f"**State**:\n```json\n{json.dumps(ctx.state.data, indent=2)}\n```"), + ] + + return Event( + author="observability", + content=types.Content(parts=parts), + actions=EventActions( + state_delta={ + "observability_node": ctx.node.name, + "observability_progress": progress, + }, + ), + ) + +# Use in GraphAgent +graph = GraphAgent( + name="research_workflow", + before_node_callback=research_observability_callback, + after_node_callback=create_nested_observability_callback(), +) +``` + +--- + +### 4. **Flexible Interrupt Timings** βœ… +- **BEFORE**: Validate before node execution (pre-conditions) +- **AFTER**: Correct after node execution (retrospective feedback) +- **BOTH**: Both before and after +- Per-node configuration + +**Code**: See `scenario_5_all_interrupt_timings()` + +```python +# Interrupt AFTER node execution (default, retrospective) +graph = GraphAgent( + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, # Check after each node + nodes=None, # All nodes (or specify specific nodes) + ), +) + +# Interrupt BEFORE specific nodes (validation) +graph = GraphAgent( + interrupt_config=InterruptConfig( + mode=InterruptMode.BEFORE, # Check before execution + nodes=["peer_review"], # Only before peer_review + ), +) + +# Interrupt BOTH before and after +graph = GraphAgent( + interrupt_config=InterruptConfig( + mode=InterruptMode.BOTH, # Check before AND after + nodes=["write_paper", "peer_review"], + ), +) +``` + +--- + +### 5. **Immediate Cancellation (ESC-like)** βœ… +- Cancels **during** node execution (not just between nodes) +- State preservation on cancel (partial results, execution path) +- Resume capability after cancellation +- Clean session cleanup + +**Code**: See `scenario_4_immediate_cancellation()` + +```python +# Cancel immediately (ESC-like) +await interrupt_service.cancel_session(session.id) + +# Check preserved state +print(session.state.get("graph_cancelled")) # True +print(session.state.get("graph_cancelled_at_node")) # "write_paper" +print(session.state.get("graph_can_resume")) # True +print(session.state.get("graph_data")) # Partial domain data saved +``` + +**Cancellation Paths**: +1. **Between nodes**: Cancels at iteration start +2. **During node execution**: Cancels mid-execution (TRUE immediate) +3. **Task cancellation**: Handles `asyncio.CancelledError` + +All paths save: +- `graph_state`: Partial execution state +- `graph_cancelled_at_node`: Where cancellation occurred +- `graph_path`: Execution path so far +- `graph_partial_output`: Partial node output (if mid-execution) +- `graph_can_resume`: Resume capability flag + +--- + +### 6. **All Interrupt Actions** βœ… + +| Action | Description | Example Use Case | +|--------|-------------|------------------| +| `continue` | Proceed normally | "Looks good, continue" | +| `rerun` | Re-execute current node with guidance | "Rerun with more details" | +| `pause` | Pause execution (escalate=True) | "Wait for human approval" | +| `defer` | Save for later (add to todos) | "Good idea, but not urgent" | +| `skip` | Skip current node | "No need for peer review" | + +**Code**: See `scenario_2_interrupt_with_reasoning()` + +```python +# Defer action - saves to session.state["_interrupt_todos"] +await interrupt_service.send_interrupt( + session_id=session.id, + text="Add section on ethical implications", + action="defer", +) + +# Check deferred todos +todos = session.state.get("_interrupt_todos", []) +print(f"Deferred: {len(todos)} items") + +# Rerun action - adds guidance to state metadata +await interrupt_service.send_interrupt( + session_id=session.id, + text="Rerun with more focus on practical applications", + action="rerun", +) + +# Pause action - escalates to pause execution +await interrupt_service.send_interrupt( + session_id=session.id, + text="Pause for team review", + action="pause", +) +``` + +--- + +## Workflow Overview + +**Research Paper Writing Workflow**: + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Literature Review β”‚ ──> Review existing papers +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Generate Hypotheses β”‚ ──> Propose testable hypotheses +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Design Methodology β”‚ ──> Plan experimental methods +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Analyze Results β”‚ ──> Run simulated experiments +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Write Paper β”‚ ◄─┐ Write academic paper +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ + β”‚ β”‚ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ Peer Review β”‚ β”‚ Review quality +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ + β”‚ β”‚ + [needs revision?] β”€β”€β”€β”˜ (loop back if score < 7/10) + β”‚ + [accept or max revisions] + β”‚ + [END] +``` + +**Checkpoints created**: +- After each major node +- Stored in session.state +- Can resume from any checkpoint + +**Interrupt points**: +- AFTER each node (default) +- Can send interrupts at any time +- LLM reasons about best action + +--- + +## Running the Examples + +### Prerequisites + +```bash +# Install dependencies +pip install google-adk + +# Set Gemini API key +export GOOGLE_API_KEY="your-api-key" +``` + +### Run All Scenarios + +```bash +python -m contributing.samples.graph_agent_advanced.agent +``` + +### Run Individual Scenarios + +```python +# In Python REPL +from contributing.samples.graph_agent_advanced.agent import * + +# Scenario 1: Basic execution with observability +await scenario_1_basic_execution() + +# Scenario 2: Interrupt with LLM reasoning +await scenario_2_interrupt_with_reasoning() + +# Scenario 3: Checkpointing and resume +await scenario_3_checkpointing_and_resume() + +# Scenario 4: Immediate cancellation +await scenario_4_immediate_cancellation() + +# Scenario 5: All interrupt timings +await scenario_5_all_interrupt_timings() +``` + +--- + +## Expected Output + +### Scenario 1: Basic Execution +``` +================================================================================== +SCENARIO 1: Basic Execution with Observability +================================================================================== + +Running research workflow... + +[observability] πŸ“ **Executing**: literature_review +Progress: 5.0% (iteration 1) +**State**: +```json +{ + "topic": "Impact of AI on software development" +} +``` + +[literature_review] Based on recent literature, key papers include... + +[observability] πŸ“ **Executing**: generate_hypotheses +Progress: 10.0% (iteration 2) +... + +βœ… Workflow completed! +Final state keys: ['topic', 'literature_review', 'hypotheses', 'methodology', 'analysis_results', 'paper', 'peer_review'] +``` + +### Scenario 2: Interrupt with Reasoning +``` +πŸ”” Sending interrupt: 'The literature review missed key papers on neural architecture search' + +πŸ”Ά [interrupt_reasoner] Analyzed interrupt: The researcher is providing important feedback about missing references. Action: defer (save for later revision) + +πŸ“Š LLM Decision: defer - The feedback is valuable but doesn't require immediate action. Save for the revision phase. + +πŸ“ Deferred todos: 1 items + First todo: The literature review missed key papers on neural architecture search. Please include them. +``` + +### Scenario 3: Checkpointing +``` +⏸️ Pausing workflow after methodology design... + +πŸ“¦ Checkpoints created: 3 + - checkpoint_001: literature_review + - checkpoint_002: generate_hypotheses + - checkpoint_003: design_methodology + +▢️ Resuming from checkpoint: checkpoint_003 +βœ… Restored state keys: ['topic', 'literature_review', 'hypotheses', 'methodology'] +``` + +### Scenario 4: Immediate Cancellation +``` +πŸ›‘ Cancelling workflow immediately (ESC)... + +⚠️ Cancellation event received: ⚠️ Execution cancelled during node 'write_paper' + +πŸ“Š Session state after cancel: + - Cancelled: True + - Cancelled at node: write_paper + - Can resume: True + - Partial state saved: True + - Partial state keys: ['topic', 'literature_review', 'hypotheses', 'methodology', 'analysis_results'] +``` + +--- + +## Key Takeaways + +1. **Observability**: Developers control event format via callbacks (no hardcoded strings) +2. **Interrupt Reasoning**: LLM analyzes context and decides best action +3. **Flexible Timings**: BEFORE (validate), AFTER (correct), BOTH (comprehensive) +4. **Immediate Cancel**: TRUE immediate interrupt (cancels during execution, not just between nodes) +5. **State Preservation**: All cancellation paths save partial state for resume +6. **Extensible Actions**: continue, rerun, pause, defer, skip (+ custom actions) + +--- + +## Next Steps + +**Try modifying the example**: +1. Add your own custom callback with different observability +2. Create custom interrupt actions via `InterruptReasonerConfig.custom_actions` +3. Experiment with different interrupt timings (BEFORE/AFTER/BOTH) +4. Test resume from checkpoint after cancellation +5. Build your own workflow with conditional routing + +**Explore the codebase**: +- `src/google/adk/agents/graph/graph_agent.py` - Core orchestration +- `src/google/adk/agents/graph/interrupt_reasoner.py` - LLM reasoning +- `src/google/adk/agents/graph/callbacks.py` - Callback infrastructure +- `src/google/adk/agents/graph/interrupt_service.py` - Interrupt management +- `src/google/adk/checkpoints/checkpoint_service.py` - Checkpoint management + +**Questions?** See the design docs or run the tests to understand the implementation. diff --git a/contributing/samples/graph_agent_advanced/__init__.py b/contributing/samples/graph_agent_advanced/__init__.py new file mode 100644 index 0000000000..03953f46c8 --- /dev/null +++ b/contributing/samples/graph_agent_advanced/__init__.py @@ -0,0 +1,10 @@ +"""Advanced GraphAgent example showcasing all features. + +This example demonstrates: +- Checkpointing with resume capability +- LLM-based interrupt reasoning +- Callback-based observability +- Flexible interrupt timings (BEFORE/AFTER/BOTH) +- Immediate cancellation with state preservation +- All interrupt actions (continue, pause, defer, rerun, skip) +""" diff --git a/contributing/samples/graph_agent_advanced/agent.py b/contributing/samples/graph_agent_advanced/agent.py new file mode 100644 index 0000000000..9421885e08 --- /dev/null +++ b/contributing/samples/graph_agent_advanced/agent.py @@ -0,0 +1,689 @@ +"""Advanced GraphAgent example with all features. + +This example demonstrates a research paper writing workflow with: +- Checkpointing (save/resume) +- LLM-based interrupt reasoning +- Custom observability callbacks +- Flexible interrupt timings +- Immediate cancellation +- All interrupt actions + +Run: + python -m contributing.samples.graph_agent_advanced.agent +""" + +import asyncio +import json +import os +from typing import Optional + +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.agents.graph import InterruptConfig +from google.adk.agents.graph import InterruptMode +from google.adk.agents.graph.callbacks import create_nested_observability_callback +from google.adk.agents.graph.callbacks import NodeCallbackContext +from google.adk.agents.graph.interrupt_reasoner import InterruptReasoner +from google.adk.agents.graph.interrupt_reasoner import InterruptReasonerConfig +from google.adk.agents.graph.interrupt_service import InterruptService +from google.adk.agents.llm_agent import LlmAgent +from google.adk.checkpoints.checkpoint_service import CheckpointService +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types + +from google import genai + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +# ============================================================================== +# Custom Observability Callback (Rich, Multi-Content Events) +# ============================================================================== + + +async def research_observability_callback( + ctx: NodeCallbackContext, +) -> Optional[Event]: + """Custom observability for research workflow. + + Emits rich events with: + - Node execution info + - Current state snapshot + - Progress indicator + - Execution metadata + """ + # Calculate progress + total_nodes = len(ctx.invocation_context.agent.nodes) + current_iteration = ctx.iteration + progress = ( + current_iteration / ctx.invocation_context.agent.max_iterations + ) * 100 + + # Build rich content + parts = [ + # Header with emoji and node name + types.Part(text=f"πŸ“ **Executing**: {ctx.node.name}"), + # Progress indicator + types.Part( + text=f"Progress: {progress:.1f}% (iteration {current_iteration})" + ), + # Current state (formatted as JSON) + # Automatic Pydantic serialization + types.Part(text=f"**State**:\n{ctx.state.data_to_json()}"), + # Execution metadata + types.Part( + text=( + "**Metadata**:" + f" Agent={ctx.node.agent.name if ctx.node.agent else 'function'}," + f" Total nodes={total_nodes}" + ) + ), + ] + + return Event( + author="observability", + content=types.Content(parts=parts), + actions=EventActions( + escalate=False, + state_delta={ + "observability_node": ctx.node.name, + "observability_iteration": ctx.iteration, + "observability_progress": progress, + "observability_timestamp": asyncio.get_event_loop().time(), + }, + ), + ) + + +# ============================================================================== +# Research Workflow Nodes (Agents) +# ============================================================================== + + +def _create_research_agents(): + """Create fresh agent instances (avoids single-parent conflicts).""" + return ( + LlmAgent( + name="literature_review", + model=_MODEL, + instruction=( + "You are a research literature reviewer. Search for and summarize" + " key papers related to the given topic. Output a JSON list of" + " papers with title, authors, and key findings." + ), + ), + LlmAgent( + name="hypothesis_generator", + model=_MODEL, + instruction=( + "You are a research hypothesis generator. Based on the literature" + " review, propose 3-5 testable hypotheses. Output a JSON list of" + " hypotheses with rationale." + ), + ), + LlmAgent( + name="methodology_designer", + model=_MODEL, + instruction=( + "You are a research methodology expert. Design experimental" + " methods to test the hypotheses. Output a JSON dict with" + " methodology sections: participants, materials, procedure." + ), + ), + LlmAgent( + name="paper_writer", + model=_MODEL, + instruction=( + "You are an academic paper writer. Write a research paper with" + " sections: Abstract, Introduction, Methods, Results, Discussion." + " Use the literature review, hypotheses, methodology, and results" + " from the state. Output well-structured academic prose." + ), + ), + LlmAgent( + name="peer_reviewer", + model=_MODEL, + instruction=( + "You are a peer reviewer. Review the paper for: clarity," + " scientific rigor, statistical validity, and writing quality." + " Provide a review with scores (1-10) and specific suggestions" + " for improvement. Output JSON with scores and comments." + ), + ), + ) + + +# 4. Results Analyzer Agent (Simulated) +def analyze_results(state: GraphState, ctx) -> GraphState: + """Simulate data analysis (in real scenario, this would run experiments). + + Args: + state: Current graph state + ctx: Invocation context (provides session, session_service, user_content, etc.) + + Returns: + Updated GraphState with analysis results + """ + # Simulate analysis results + results = { + "hypothesis_1": { + "supported": True, + "p_value": 0.023, + "effect_size": 0.42, + }, + "hypothesis_2": { + "supported": False, + "p_value": 0.156, + "effect_size": 0.18, + }, + "hypothesis_3": { + "supported": True, + "p_value": 0.001, + "effect_size": 0.67, + }, + } + # Update state with results + state.data["analysis_results"] = results + return state + + +# ============================================================================== +# Conditional Routing Functions +# ============================================================================== + + +def needs_revision(state: GraphState) -> bool: + """Check if paper needs revision based on peer review.""" + review = state.data.get("peer_review", {}) + # LLM agents store output as JSON string; parse if needed + if isinstance(review, str): + try: + review = json.loads(review) + except (json.JSONDecodeError, TypeError): + return False + avg_score = sum(review.get("scores", {}).values()) / max( + len(review.get("scores", {})), 1 + ) + return avg_score < 7.0 # Needs revision if average score < 7/10 + + +def revision_count_ok(state: GraphState) -> bool: + """Check if we've exceeded max revisions.""" + return state.data.get("revision_count", 0) < 3 + + +# ============================================================================== +# Build Research Workflow Graph +# ============================================================================== + + +def build_research_workflow( + session_service: InMemorySessionService, + checkpoint_service: CheckpointService, + interrupt_service: InterruptService, +) -> GraphAgent: + """Build advanced research workflow with all features enabled.""" + + # Create LLM-based interrupt reasoner + interrupt_reasoner = InterruptReasoner( + config=InterruptReasonerConfig( + model=_MODEL, + available_actions=["continue", "rerun", "pause", "defer", "skip"], + instruction=( + "You are an interrupt reasoning agent for a research paper" + " writing workflow. Analyze interrupt messages from researchers" + " and decide the best action. Consider: Is the feedback about" + " quality? Should we rerun with guidance? Should we pause for" + " human review? Should we defer the feedback for later?" + ), + ) + ) + + # Create GraphAgent with all features enabled + graph = GraphAgent( + name="research_workflow", + description=( + "Advanced research paper writing workflow with interrupt &" + " observability" + ), + max_iterations=20, + checkpointing=True, + interrupt_service=interrupt_service, + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, # Check for interrupts after each node + nodes=None, # All nodes (can specify specific nodes like ["peer_reviewer"]) + reasoner=interrupt_reasoner, # Use LLM to reason about interrupts + ), + # Custom observability callback + before_node_callback=research_observability_callback, + # Nested observability (shows hierarchy) + after_node_callback=create_nested_observability_callback(), + ) + + # Create fresh agents (avoids single-parent conflicts across scenarios) + ( + literature_agent, + hypothesis_agent, + methodology_agent, + paper_writer_agent, + peer_reviewer_agent, + ) = _create_research_agents() + + # Add nodes + graph.add_node("literature_review", agent=literature_agent) + graph.add_node("generate_hypotheses", agent=hypothesis_agent) + graph.add_node("design_methodology", agent=methodology_agent) + graph.add_node("analyze_results", function=analyze_results) + graph.add_node("write_paper", agent=paper_writer_agent) + graph.add_node("peer_review", agent=peer_reviewer_agent) + + # Add edges (sequential workflow with revision loop) + graph.add_edge("literature_review", "generate_hypotheses") + graph.add_edge("generate_hypotheses", "design_methodology") + graph.add_edge("design_methodology", "analyze_results") + graph.add_edge("analyze_results", "write_paper") + graph.add_edge("write_paper", "peer_review") + + # Conditional routing: if review is poor, revise + graph.add_edge( + "peer_review", + "write_paper", # Loop back to rewrite + condition=lambda s: needs_revision(s) and revision_count_ok(s), + ) + + # If review is good or max revisions reached, finish at peer_review + # No edge needed - peer_review naturally becomes an end node when no edge matches + + # Set start and end + graph.set_start("literature_review") + graph.set_end("peer_review") # End at peer_review when revision not needed + + return graph + + +# ============================================================================== +# Example Usage Scenarios +# ============================================================================== + + +async def scenario_1_basic_execution(): + """Scenario 1: Basic execution with observability.""" + print("\n" + "=" * 80) + print("SCENARIO 1: Basic Execution with Observability") + print("=" * 80 + "\n") + + session_service = InMemorySessionService() + checkpoint_service = CheckpointService(session_service=session_service) + interrupt_service = InterruptService() + + graph = build_research_workflow( + session_service, checkpoint_service, interrupt_service + ) + + # Create session + session = await session_service.create_session( + app_name="research_workflow", user_id="researcher_1" + ) + + # Create runner + runner = Runner( + app_name="research_workflow", + agent=graph, + session_service=session_service, + auto_create_session=False, # Session already created + ) + + # Run workflow + print("Running research workflow...\n") + async for event in runner.run_async( + user_id="researcher_1", + session_id=session.id, + new_message=types.Content( + parts=[ + types.Part( + text="Research topic: Impact of AI on software development" + ) + ] + ), + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f"[{event.author}] {part.text[:200]}...") + + print("\nβœ… Workflow completed!") + print(f"Final state keys: {list(session.state.get('graph_data', {}).keys())}") + + +async def scenario_2_interrupt_with_reasoning(): + """Scenario 2: Send interrupt and let LLM reason about it.""" + print("\n" + "=" * 80) + print("SCENARIO 2: Interrupt with LLM Reasoning") + print("=" * 80 + "\n") + + session_service = InMemorySessionService() + checkpoint_service = CheckpointService(session_service=session_service) + interrupt_service = InterruptService() + + graph = build_research_workflow( + session_service, checkpoint_service, interrupt_service + ) + + session = await session_service.create_session( + app_name="research_workflow", user_id="researcher_2" + ) + + # Send interrupt after 2 seconds (simulating human feedback) + async def send_interrupt_after_delay(): + await asyncio.sleep(2) + print( + "\nπŸ”” Sending interrupt: 'The literature review missed key papers on" + " neural architecture search'" + ) + await interrupt_service.send_interrupt( + session_id=session.id, + text=( + "The literature review missed key papers on neural architecture" + " search. Please include them." + ), + action="defer", # Suggest defer, but LLM will decide + metadata={"feedback_type": "missing_references"}, + ) + + # Run both concurrently + interrupt_task = asyncio.create_task(send_interrupt_after_delay()) + + # Create runner + runner = Runner( + app_name="research_workflow", + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + async for event in runner.run_async( + user_id="researcher_2", + session_id=session.id, + new_message=types.Content( + parts=[types.Part(text="Research topic: Neural Architecture Search")] + ), + ): + if event.content and event.content.parts: + for part in event.content.parts: + if "interrupt" in part.text.lower() or "defer" in part.text.lower(): + print(f"\nπŸ”Ά [{event.author}] {part.text}") + + await interrupt_task + + # Check interrupt decision + decision = session.state.get("_last_interrupt_decision", {}) + print( + f"\nπŸ“Š LLM Decision: {decision.get('action')} -" + f" {decision.get('reasoning')}" + ) + + # Check deferred todos + todos = session.state.get("_interrupt_todos", []) + print(f"πŸ“ Deferred todos: {len(todos)} items") + if todos: + print(f" First todo: {todos[0].get('message', '')[:100]}...") + + +async def scenario_3_checkpointing_and_resume(): + """Scenario 3: Create checkpoints and resume from them.""" + print("\n" + "=" * 80) + print("SCENARIO 3: Checkpointing and Resume") + print("=" * 80 + "\n") + + session_service = InMemorySessionService() + checkpoint_service = CheckpointService(session_service=session_service) + interrupt_service = InterruptService() + + graph = build_research_workflow( + session_service, checkpoint_service, interrupt_service + ) + + session = await session_service.create_session( + app_name="research_workflow", user_id="researcher_3" + ) + + # Run workflow and pause after "design_methodology" + async def pause_after_methodology(): + await asyncio.sleep(3) + print("\n⏸️ Pausing workflow after methodology design...") + await interrupt_service.send_interrupt( + session_id=session.id, + text="Pause for team review", + action="pause", + ) + + pause_task = asyncio.create_task(pause_after_methodology()) + + # Create runner + runner = Runner( + app_name="research_workflow", + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + print("Running workflow (will pause)...\n") + async for event in runner.run_async( + user_id="researcher_3", + session_id=session.id, + new_message=types.Content( + parts=[types.Part(text="Research topic: Quantum Machine Learning")] + ), + ): + pass + + await pause_task + + # List checkpoints + checkpoints = await checkpoint_service.list_checkpoints(session) + print(f"\nπŸ“¦ Checkpoints created: {len(checkpoints)}") + for cp in checkpoints: + print( + f" - {cp.checkpoint_id}: {cp.metadata.get('graph_node', 'unknown')}" + ) + + # Resume from last checkpoint + if checkpoints: + last_checkpoint = checkpoints[-1] + print(f"\n▢️ Resuming from checkpoint: {last_checkpoint.checkpoint_id}") + + restored_state = await checkpoint_service.restore_checkpoint( + session, last_checkpoint.checkpoint_id + ) + print(f"βœ… Restored state keys: {list(restored_state.keys())}") + + +async def scenario_4_immediate_cancellation(): + """Scenario 4: Immediate cancellation (ESC-like) with state preservation.""" + print("\n" + "=" * 80) + print("SCENARIO 4: Immediate Cancellation with State Preservation") + print("=" * 80 + "\n") + + session_service = InMemorySessionService() + checkpoint_service = CheckpointService(session_service=session_service) + interrupt_service = InterruptService() + + graph = build_research_workflow( + session_service, checkpoint_service, interrupt_service + ) + + session = await session_service.create_session( + app_name="research_workflow", user_id="researcher_4" + ) + + # Cancel immediately after 1.5 seconds + async def cancel_immediately(): + await asyncio.sleep(1.5) + print("\nπŸ›‘ Cancelling workflow immediately (ESC)...") + await interrupt_service.cancel_session(session.id) + + cancel_task = asyncio.create_task(cancel_immediately()) + + # Create runner + runner = Runner( + app_name="research_workflow", + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + print("Running workflow (will cancel immediately)...\n") + async for event in runner.run_async( + user_id="researcher_4", + session_id=session.id, + new_message=types.Content( + parts=[types.Part(text="Research topic: Large Language Models")] + ), + ): + if "cancelled" in str(event.content).lower(): + print(f"\n⚠️ Cancellation event received: {event.content.parts[0].text}") + + await cancel_task + + # Check preserved state + print(f"\nπŸ“Š Session state after cancel:") + print(f" - Cancelled: {session.state.get('graph_cancelled', False)}") + print( + " - Cancelled at node:" + f" {session.state.get('graph_cancelled_at_node', 'unknown')}" + ) + print(f" - Can resume: {session.state.get('graph_can_resume', False)}") + print(f" - Partial state saved: {bool(session.state.get('graph_data'))}") + + # Show partial domain data + if session.state.get("graph_data"): + partial_data = session.state["graph_data"] + print(f" - Partial data keys: {list(partial_data.keys())}") + + +async def scenario_5_all_interrupt_timings(): + """Scenario 5: Demonstrate all interrupt timings (BEFORE/AFTER/BOTH).""" + print("\n" + "=" * 80) + print("SCENARIO 5: All Interrupt Timings") + print("=" * 80 + "\n") + + session_service = InMemorySessionService() + checkpoint_service = CheckpointService(session_service=session_service) + interrupt_service = InterruptService() + + # Create local agent instances for this scenario + _, _, _, paper_writer_agent, peer_reviewer_agent = _create_research_agents() + + # Test BEFORE mode + print("Testing InterruptMode.BEFORE (validate before execution)...\n") + + reasoner = InterruptReasoner( + config=InterruptReasonerConfig( + model=_MODEL, + available_actions=["continue", "skip", "pause"], + ) + ) + + graph_before = GraphAgent( + name="research_workflow_before", + max_iterations=5, + interrupt_service=interrupt_service, + interrupt_config=InterruptConfig( + mode=InterruptMode.BEFORE, # Interrupt BEFORE node execution + nodes=["peer_review"], # Only before peer_review + reasoner=reasoner, + ), + ) + + # Add simple nodes + graph_before.add_node("write", agent=paper_writer_agent) + graph_before.add_node("peer_review", agent=peer_reviewer_agent) + graph_before.add_edge("write", "peer_review") + graph_before.add_edge("peer_review", "END") + graph_before.set_start("write") + graph_before.set_end("END") + + session_before = await session_service.create_session( + app_name="before_test", user_id="researcher_5" + ) + + # Send interrupt before peer_review + async def interrupt_before(): + await asyncio.sleep(1) + print("πŸ”” Sending BEFORE interrupt: 'Skip peer review, paper is perfect'\n") + await interrupt_service.send_interrupt( + session_id=session_before.id, + text="Skip peer review, this paper is already perfect", + action="skip", + ) + + interrupt_task = asyncio.create_task(interrupt_before()) + + # Create runner + runner_before = Runner( + app_name="before_test", + agent=graph_before, + session_service=session_service, + auto_create_session=False, + ) + + async for event in runner_before.run_async( + user_id="researcher_5", + session_id=session_before.id, + new_message=types.Content( + parts=[types.Part(text="Write a paper on AI safety")] + ), + ): + pass + + await interrupt_task + + print( + "βœ… BEFORE mode test complete. Peer review skipped:" + f" {not bool(session_before.state.get('peer_review'))}\n" + ) + + +# ============================================================================== +# Main Entry Point +# ============================================================================== + + +async def main(run_all: bool = False): + """Run example scenarios. + + Args: + run_all: If True, runs all 5 scenarios (requires multiple LLM calls, + may be slow). If False, runs only scenario 1 for quick validation. + Set RUN_ALL_SCENARIOS=1 env var or pass --all flag to run all. + """ + run_all = run_all or os.getenv("RUN_ALL_SCENARIOS", "").strip() in ( + "1", + "true", + ) + + print("\n" + "=" * 60) + print("GraphAgent Advanced Examples") + print("=" * 60) + + await scenario_1_basic_execution() + + if run_all: + await scenario_2_interrupt_with_reasoning() + await scenario_3_checkpointing_and_resume() + await scenario_4_immediate_cancellation() + await scenario_5_all_interrupt_timings() + print("\nAll 5 scenarios completed.") + else: + print( + "\nScenario 1 completed. Set RUN_ALL_SCENARIOS=1 to run all 5" + " scenarios." + ) + + +if __name__ == "__main__": + import sys + + asyncio.run(main(run_all="--all" in sys.argv)) diff --git a/contributing/samples/graph_agent_advanced/root_agent.yaml b/contributing/samples/graph_agent_advanced/root_agent.yaml new file mode 100644 index 0000000000..f71d0f74eb --- /dev/null +++ b/contributing/samples/graph_agent_advanced/root_agent.yaml @@ -0,0 +1,88 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/google/adk-python/refs/heads/main/src/google/adk/agents/config_schemas/AgentConfig.json + +agent_class: GraphAgent +name: research_workflow +description: Research paper writing workflow with advanced features + +# Graph configuration +start_node: research +end_nodes: + - publish + +max_iterations: 20 +checkpointing: true + +# Interrupt configuration (optional) +interrupt_config: + mode: after # before|after|both (None = disabled) + # Optional: Configure custom interrupt service + # interrupt_service: + # name: google.adk.agents.graph.interrupt_service.InterruptService + +# Node definitions +nodes: + - name: research + sub_agents: + - code: agents.researcher + + - name: outline + sub_agents: + - code: agents.outliner + + - name: draft + sub_agents: + - code: agents.writer + + - name: review + sub_agents: + - code: agents.reviewer + + - name: revise + sub_agents: + - code: agents.reviser + + - name: publish + sub_agents: + - code: agents.publisher + +# Edge definitions with conditional routing and priorities +edges: + # Research -> Outline (always) + - source_node: research + target_node: outline + priority: 1 + + # Outline -> Draft (always) + - source_node: outline + target_node: draft + priority: 1 + + # Draft -> Review (always) + - source_node: draft + target_node: review + priority: 1 + + # Review -> Publish (if approved) + - source_node: review + target_node: publish + condition: "data.get('approved', False) is True" + priority: 2 + + # Review -> Revise (if not approved) + - source_node: review + target_node: revise + condition: "data.get('approved', False) is False" + priority: 1 + + # Revise -> Review (retry review after revision) + - source_node: revise + target_node: review + priority: 1 + +# Parallel execution groups (optional) +parallel_groups: [] + +# Callbacks (optional) +# before_node_callback_ref: module.path.to.before_node_callback +# after_node_callback_ref: module.path.to.after_node_callback +# on_edge_condition_callback_ref: module.path.to.on_edge_condition_callback diff --git a/contributing/samples/graph_agent_agent_driven_checkpoint/README.md b/contributing/samples/graph_agent_agent_driven_checkpoint/README.md new file mode 100644 index 0000000000..445cfbc09b --- /dev/null +++ b/contributing/samples/graph_agent_agent_driven_checkpoint/README.md @@ -0,0 +1,52 @@ +# graph_agent_agent_driven_checkpoint + +Demonstrates the `checkpoint_request_key` pattern: an LLM agent proposes +checkpoints via a boolean flag in its structured output schema, and +`GraphCheckpointCallback` creates the checkpoint only when the flag is set. + +## Pattern + +```python +class AnalysisOutput(BaseModel): + finding: str + risk_level: str + checkpoint_requested: bool = False # LLM sets True for high-risk + +checkpoint_callback = GraphCheckpointCallback( + checkpoint_service, + checkpoint_after=False, # no automatic checkpoints + checkpoint_request_key="analyzer.checkpoint_requested", # agent-proposed +) +``` + +`checkpoint_request_key` accepts a dotted path `"."`. +After the named node completes, `after_node` reads the field from state and +creates a checkpoint only when it is truthy. The flag resets automatically +because `StateReducer.OVERWRITE` replaces the entire output on the next run. + +## Why this vs `checkpoint_after=True`? + +| | `checkpoint_after=True` | `checkpoint_request_key` | +|-|------------------------|--------------------------| +| Trigger | Every node | LLM reasoning only | +| Overhead | Always | Only for high-risk tasks | +| Control | Infrastructure | Agent | + +## Flow + +``` +analyzer β†’ processor β†’ reporter β†’ END + ↓ (high-risk only) +[checkpoint created] +``` + +## Run + +```bash +GOOGLE_API_KEY= python -m contributing.samples.graph_agent_agent_driven_checkpoint.agent +# or Vertex AI: +GOOGLE_CLOUD_PROJECT= GOOGLE_CLOUD_LOCATION= GOOGLE_GENAI_USE_VERTEXAI=true \ + python -m contributing.samples.graph_agent_agent_driven_checkpoint.agent +``` + +Expected: only the `HIGH RISK` scenario produces agent-requested checkpoints. diff --git a/contributing/samples/graph_agent_agent_driven_checkpoint/agent.py b/contributing/samples/graph_agent_agent_driven_checkpoint/agent.py new file mode 100644 index 0000000000..b3bd84113c --- /dev/null +++ b/contributing/samples/graph_agent_agent_driven_checkpoint/agent.py @@ -0,0 +1,242 @@ +"""GraphAgent agent-driven checkpoint: LLM proposes checkpoints via state flag. + +Demonstrates the checkpoint_request_key pattern where an LLM agent decides +at runtime whether a checkpoint is warranted, rather than checkpointing +every node unconditionally. + +Pattern: +- LlmAgent output_schema includes checkpoint_requested: bool = False +- GraphCheckpointCallback(checkpoint_request_key="analyzer.checkpoint_requested") + reads the flag after the "analyzer" node finishes +- Checkpoint created only when LLM sets the flag (e.g., for high-risk findings) +- Flag clears automatically: StateReducer.OVERWRITE replaces output each run + +Flow: + analyzer β†’ processor β†’ reporter β†’ END + (may set checkpoint_requested=True based on task risk) + +Why this pattern vs checkpointing=True? +- checkpointing=True: checkpoint after EVERY node unconditionally +- checkpoint_request_key: LLM reasons about criticality, checkpoints selectively + +Run (requires Vertex AI or GOOGLE_API_KEY env var): + python -m contributing.samples.graph_agent_agent_driven_checkpoint.agent +""" + +import asyncio +import json +import os + +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import StateReducer +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.graph.checkpoint_callback import GraphCheckpointCallback +from google.adk.checkpoints import CheckpointService +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +from pydantic import BaseModel + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +# --------------------------------------------------------------------------- +# Output Schemas +# --------------------------------------------------------------------------- + + +class AnalysisOutput(BaseModel): + """Structured analysis output that includes a checkpoint proposal.""" + + finding: str # What was found + risk_level: str # "low" | "medium" | "high" + justification: str # Why this risk level + checkpoint_requested: bool = False # LLM sets True for high-risk findings + + +# --------------------------------------------------------------------------- +# Agents +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# Graph builder +# --------------------------------------------------------------------------- + + +def _create_agents(): + """Create fresh agent instances (avoids single-parent conflicts).""" + _analyzer = LlmAgent( + name="analyzer", + model=_MODEL, + instruction=( + "You are a risk analyzer. Analyze the input task and return" + " structured JSON. Set checkpoint_requested=true ONLY when risk_level" + " is 'high' (irreversible or destructive operations). For" + " low/medium risk, set checkpoint_requested=false to avoid" + ' unnecessary overhead. Return {"finding": "...", "risk_level":' + ' "low|medium|high", "justification": "...", "checkpoint_requested":' + " true|false}." + ), + output_schema=AnalysisOutput, + ) + _processor = LlmAgent( + name="processor", + model=_MODEL, + instruction=( + "You are an action executor. Based on the analyzer's finding," + " describe what action was taken. Be concise (1 sentence)." + ), + output_key="processor_result", + ) + _reporter = LlmAgent( + name="reporter", + model=_MODEL, + instruction=( + "You are a reporter. Summarize the analysis and action taken in one" + " sentence for an audit log." + ), + output_key="report", + ) + return _analyzer, _processor, _reporter + + +def build_agent_checkpoint_graph( + session_service: InMemorySessionService, +) -> GraphAgent: + """Build graph with agent-proposed checkpointing.""" + checkpoint_service = CheckpointService(session_service=session_service) + + # Only create checkpoints when the LLM explicitly requests one. + # checkpoint_after=False disables automatic checkpoints. + # checkpoint_request_key reads analyzer.checkpoint_requested from state. + checkpoint_callback = GraphCheckpointCallback( + checkpoint_service, + checkpoint_before=False, + checkpoint_after=False, # no automatic checkpoints + checkpoint_request_key="analyzer.checkpoint_requested", + ) + + graph = GraphAgent( + name="agent_checkpoint_workflow", + description="Workflow where the LLM decides when checkpoints are needed", + max_iterations=10, + after_node_callback=checkpoint_callback.after_node, + ) + + analyzer, processor, reporter = _create_agents() + graph.add_node("analyzer", agent=analyzer, reducer=StateReducer.OVERWRITE) + graph.add_node("processor", agent=processor, reducer=StateReducer.OVERWRITE) + graph.add_node("reporter", agent=reporter, reducer=StateReducer.OVERWRITE) + + graph.set_start("analyzer") + graph.add_edge("analyzer", "processor") + graph.add_edge("processor", "reporter") + graph.set_end("reporter") + + return graph + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def run_scenario( + task: str, + session_service: InMemorySessionService, + graph: GraphAgent, + scenario_id: str, +) -> None: + """Run a single scenario and print results.""" + session = await session_service.create_session( + app_name="agent_checkpoint_workflow", + user_id="user1", + session_id=scenario_id, + ) + + runner = Runner( + app_name="agent_checkpoint_workflow", + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + assessment: dict = {} + report_text: str = "" + + async for event in runner.run_async( + user_id="user1", + session_id=scenario_id, + new_message=types.Content(parts=[types.Part(text=task)]), + ): + if not event.content or not event.content.parts: + continue + text = event.content.parts[0].text or "" + if not text: + continue + + if event.author == "analyzer": + try: + assessment = json.loads(text) + except (json.JSONDecodeError, TypeError): + assessment = {} + risk = assessment.get("risk_level", "?").upper() + requested = assessment.get("checkpoint_requested", False) + flag = " [CHECKPOINT REQUESTED]" if requested else "" + print(f" Analyzer: risk={risk}{flag}") + print(f" {assessment.get('justification', '')[:100]}") + elif event.author == "reporter": + report_text += text + print(f" Report: {text[:120]}") + + # Re-fetch to get updated checkpoint count + fresh = await session_service.get_session( + app_name="agent_checkpoint_workflow", + user_id="user1", + session_id=scenario_id, + ) + if fresh is None: + print( + f"WARNING: session_service.get_session returned None, using stale copy" + ) + fresh = session + checkpoints = fresh.state.get("_checkpoint_index", {}) + requested_cps = {k for k in checkpoints if k.endswith("-requested")} + print( + f" Checkpoints: {len(checkpoints)} total," + f" {len(requested_cps)} agent-requested" + ) + + +async def main() -> None: + print("=== Agent-Driven Checkpoint: LLM Decides When to Checkpoint ===\n") + print( + "checkpoint_request_key='analyzer.checkpoint_requested' β€” only the LLM" + " can trigger a checkpoint\n" + ) + + scenarios = [ + ("low", "Read the README file to understand the project structure"), + ( + "medium", + "Update the application config to change the log level to DEBUG", + ), + ( + "high", + "Delete all rows from the users table in the production database", + ), + ] + + for label, task in scenarios: + session_service = InMemorySessionService() + graph = build_agent_checkpoint_graph(session_service) + print(f"[{label.upper()} RISK] {task}") + await run_scenario(task, session_service, graph, f"scenario-{label}") + print() + + print( + "Summary: only high-risk tasks should have agent-requested checkpoints." + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_agent_driven_topology/README.md b/contributing/samples/graph_agent_agent_driven_topology/README.md new file mode 100644 index 0000000000..b666261d2f --- /dev/null +++ b/contributing/samples/graph_agent_agent_driven_topology/README.md @@ -0,0 +1,59 @@ +# graph_agent_agent_driven_topology + +Demonstrates the **function-node-with-closure** pattern: a function node that +holds a reference to the `GraphAgent` calls `graph.add_node()` / `graph.add_edge()` +from *inside* the runner, injecting the topology entirely within the execution loop. + +Contrasts with `graph_agent_dynamic_topology` where the outer event loop modifies +the graph between emitted events. + +## Pattern + +```python +def make_topology_applier(graph: GraphAgent): + async def topology_applier(state: GraphState, ctx: InvocationContext) -> str: + design = json.loads(state.data.get("planner", "{}")) + prev = "topology_applier" + for step_name in design.get("steps", []): + node_name = f"step_{step_name}" + if node_name not in graph.nodes: + graph.add_node(node_name, agent=make_step_agent(step_name)) + graph.add_edge(prev, node_name) + prev = node_name + graph.set_end(prev) + return f"Injected nodes: {design.get('steps')}" + return topology_applier + +graph.add_node("topology_applier", function=make_topology_applier(graph)) +``` + +## Why this is safe (cooperative asyncio) + +`add_node` / `add_edge` perform plain dict writes with no locks. Python's +cooperative scheduler suspends the graph loop while the function node runs +synchronously, so no other coroutine can interleave. + +## Comparison + +| | `dynamic_topology` | `agent_driven_topology` | +|-|-------------------|------------------------| +| Where mutation happens | Outer event loop | Inside runner (function node) | +| Graph reference | Passed in from outside | Closure over `graph` | +| Mediation | Event loop pauses between events | None β€” runs in one async frame | + +## Flow + +``` +planner ──▢ topology_applier (fn, closes over graph) + ──▢ step_validate ──▢ step_transform ──▢ ... ──▢ END + (nodes added at runtime by topology_applier) +``` + +## Run + +```bash +GOOGLE_API_KEY= python -m contributing.samples.graph_agent_agent_driven_topology.agent +# or Vertex AI: +GOOGLE_CLOUD_PROJECT= GOOGLE_CLOUD_LOCATION= GOOGLE_GENAI_USE_VERTEXAI=true \ + python -m contributing.samples.graph_agent_agent_driven_topology.agent +``` diff --git a/contributing/samples/graph_agent_agent_driven_topology/agent.py b/contributing/samples/graph_agent_agent_driven_topology/agent.py new file mode 100644 index 0000000000..466fd7142e --- /dev/null +++ b/contributing/samples/graph_agent_agent_driven_topology/agent.py @@ -0,0 +1,277 @@ +"""GraphAgent agent-driven topology: function node injects nodes inside the runner. + +Demonstrates the function-node-with-closure pattern for runtime topology +injection, contrasting with graph_agent_dynamic_topology which modifies the +graph from the OUTSIDE (event loop). + +Key difference vs graph_agent_dynamic_topology: +- dynamic_topology: outer event loop calls graph.add_node() between events +- THIS sample: a function node inside the runner calls graph.add_node(), + so topology injection happens entirely within the runner's execution. + +How it works: +1. LlmAgent (planner) proposes steps via output_schema β†’ stored in state +2. Function node "topology_applier" closes over the graph reference +3. topology_applier reads state.data["planner"], calls graph.add_node() and + graph.add_edge() to connect itself β†’ step_1 β†’ step_2 β†’ ... β†’ step_N +4. Graph evaluates the newly added edges on its next routing decision and + executes the injected step nodes + +Why this is safe (cooperative asyncio): +- graph.add_node/add_edge have no locks; they do plain dict writes +- The function node runs synchronously within the async event loop +- No other coroutine can interleave while the function node executes + +Flow: + planner ──▢ topology_applier (fn, closes over graph) + ──▢ step_validate ──▢ step_transform ──▢ ... ──▢ END + (added at runtime by topology_applier) + +Run (requires Vertex AI or GOOGLE_API_KEY env var): + python -m contributing.samples.graph_agent_agent_driven_topology.agent +""" + +import asyncio +import json +import os +from typing import Callable + +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.graph.checkpoint_callback import GraphCheckpointCallback +from google.adk.checkpoints import CheckpointService +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +from pydantic import BaseModel + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +# --------------------------------------------------------------------------- +# Output Schemas +# --------------------------------------------------------------------------- + + +class PipelineDesign(BaseModel): + """Structured pipeline design from planner agent.""" + + steps: list[str] # Ordered step names discovered at runtime + description: str # Pipeline overview + + +# --------------------------------------------------------------------------- +# Planner agent +# --------------------------------------------------------------------------- + +planner = LlmAgent( + name="planner", + model=_MODEL, + instruction=( + "You are a data pipeline designer. Given a data processing task," + " identify 2-4 distinct processing steps required. Each step name" + " must be a single lowercase word (e.g. validate, transform, enrich," + " aggregate, export). Return" + ' {"steps": ["step1", "step2", ...], "description": "overview"}.' + ), + output_schema=PipelineDesign, + # output_key auto-defaults to "planner" +) + + +# --------------------------------------------------------------------------- +# Step agent factory +# --------------------------------------------------------------------------- + + +def make_step_agent(step_name: str) -> LlmAgent: + """Create a specialized agent for a specific pipeline step.""" + return LlmAgent( + name=f"{step_name}_agent", + model=_MODEL, + instruction=( + f"You perform the '{step_name}' step of a data pipeline." + f" Describe what the '{step_name}' step does and confirm it" + " completed successfully. Be concise (1-2 sentences)." + ), + output_key=f"step_{step_name}_result", + ) + + +# --------------------------------------------------------------------------- +# Function node factory: closes over graph reference +# --------------------------------------------------------------------------- + + +def make_topology_applier(graph: GraphAgent) -> Callable: + """Return an async function node that injects pipeline steps into the graph. + + The returned function closes over `graph` so it can call add_node/add_edge + from inside the runner β€” no event-loop mediation required. + + Args: + graph: The GraphAgent to extend at runtime. + + Returns: + An async function compatible with GraphNode(function=...). + """ + + async def topology_applier(state: GraphState, ctx: InvocationContext) -> str: + """Read planner output from state and inject nodes into the graph.""" + raw = state.data.get("planner", "{}") + if isinstance(raw, str): + try: + design = json.loads(raw) + except (json.JSONDecodeError, TypeError): + design = {} + else: + design = raw if isinstance(raw, dict) else {} + + steps = design.get("steps", []) + description = design.get("description", "") + + if not steps: + return "No steps discovered; pipeline ends here." + + prev = "topology_applier" + injected = [] + + for step_name in steps: + node_name = f"step_{step_name}" + if node_name not in graph.nodes: + graph.add_node(node_name, agent=make_step_agent(step_name)) + graph.add_edge(prev, node_name) + injected.append(node_name) + print(f" [INJECT] {prev} β†’ {node_name}") + else: + print(f" [SKIP] {node_name} already exists") + prev = node_name + + # Mark last step as the terminal node + graph.set_end(prev) + print(f" [END] {prev}") + + return ( + f"Injected {len(injected)} nodes for pipeline: {description}." + f" Steps: {steps}" + ) + + return topology_applier + + +# --------------------------------------------------------------------------- +# Graph builder +# --------------------------------------------------------------------------- + + +def build_agent_topology_graph( + session_service: InMemorySessionService, +) -> GraphAgent: + """Build graph with agent-driven topology injection via function node.""" + checkpoint_service = CheckpointService(session_service=session_service) + checkpoint_callback = GraphCheckpointCallback( + checkpoint_service, + checkpoint_before=False, + checkpoint_after=True, + # checkpoint_nodes=None β†’ checkpoint ALL nodes including injected ones + ) + + graph = GraphAgent( + name="agent_topology_pipeline", + description=( + "Adaptive pipeline where a function node injects steps at runtime" + ), + max_iterations=20, + after_node_callback=checkpoint_callback.after_node, + ) + + # Static nodes: planner + topology_applier only + graph.add_node("planner", agent=planner) + graph.add_node( + "topology_applier", + function=make_topology_applier(graph), # closes over graph + ) + + graph.set_start("planner") + graph.add_edge("planner", "topology_applier") + # topology_applier adds its own edges + set_end at runtime + + return graph + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def main() -> None: + print( + "=== Agent-Driven Topology: Function Node Injects Steps Inside Runner" + " ===\n" + ) + print( + "Contrast with graph_agent_dynamic_topology where the event loop\n" + "modifies the graph from OUTSIDE the runner.\n" + ) + + session_service = InMemorySessionService() + graph = build_agent_topology_graph(session_service) + session_id = "agent-topology-1" + + session = await session_service.create_session( + app_name="agent_topology_pipeline", + user_id="user1", + session_id=session_id, + ) + + runner = Runner( + app_name="agent_topology_pipeline", + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + task = ( + "Process a customer dataset: validate schema, enrich with geo data," + " aggregate by region, and export to Parquet" + ) + print(f"Task: {task}\n") + + step_count = 0 + + async for event in runner.run_async( + user_id="user1", + session_id=session_id, + new_message=types.Content(parts=[types.Part(text=task)]), + ): + if not event.content or not event.content.parts: + continue + text = event.content.parts[0].text or "" + author = event.author + + if author == "topology_applier": + print(f"[topology_applier] {text[:200]}\n") + elif author.endswith("_agent"): + step_count += 1 + print(f"[Step {step_count}] {author}: {text[:150]}") + + # Re-fetch session for checkpoint count + fresh = await session_service.get_session( + app_name="agent_topology_pipeline", + user_id="user1", + session_id=session_id, + ) + if fresh is None: + print( + f"WARNING: session_service.get_session returned None, using stale copy" + ) + fresh = session + checkpoints = fresh.state.get("_checkpoint_index", {}) + print(f"\nPipeline complete. Steps executed: {step_count}") + print(f"Checkpoints created: {len(checkpoints)}") + print(f"Graph nodes after execution: {list(graph.nodes.keys())}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_basic/README.md b/contributing/samples/graph_agent_basic/README.md new file mode 100644 index 0000000000..7b60001fc6 --- /dev/null +++ b/contributing/samples/graph_agent_basic/README.md @@ -0,0 +1,36 @@ +# GraphAgent Basic Example β€” Conditional Routing + +This example demonstrates a data validation pipeline using **conditional routing** based on runtime +state. The validator checks input quality and branches to either a processor (success path) or an +error handler (failure path), showing how GraphAgent enables state-dependent decision logic that +sequential or parallel agent composition alone cannot achieve. + +## When to Use This Pattern + +- Any workflow requiring "if X then A, else B" branching on agent output +- Input validation before expensive downstream processing +- Quality-gate patterns where the next step depends on a score or classification + +## How to Run + +```bash +adk run contributing/samples/graph_agent_basic +``` + +## Graph Structure + +``` +validate ──(valid=True)──▢ process + ──(valid=False)─▢ error +``` + +## Key Code Walkthrough + +- **`GraphNode(name="validate", agent=validator_agent)`** β€” wraps an `LlmAgent` as a graph node +- **`add_edge("validate", "process", condition=lambda s: s.data["valid"] == True)`** β€” conditional + edge that only fires when the validation flag is set +- **Two end nodes** (`process` and `error`) β€” GraphAgent can have multiple terminal nodes +- **State propagation** β€” each node's output is written to `state.data[node_name]` and read by + downstream condition functions +- **No cycles** β€” this is a simple directed acyclic graph; for loops see `graph_agent_dynamic_queue` + diff --git a/contributing/samples/graph_agent_basic/__init__.py b/contributing/samples/graph_agent_basic/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contributing/samples/graph_agent_basic/agent.py b/contributing/samples/graph_agent_basic/agent.py new file mode 100644 index 0000000000..f7b72bc83c --- /dev/null +++ b/contributing/samples/graph_agent_basic/agent.py @@ -0,0 +1,149 @@ +"""Basic GraphAgent example demonstrating conditional routing. + +This example shows how GraphAgent enables conditional workflow routing +based on runtime state, which cannot be achieved with SequentialAgent +or ParallelAgent composition. + +Use case: Data validation pipeline with retry logic. +- If validation passes -> process data +- If validation fails -> retry validation +- After max retries -> route to error handler +""" + +import asyncio +import os + +from google.adk.agents import GraphAgent +from google.adk.agents import LlmAgent +from google.adk.agents.graph import GraphState +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types +from pydantic import BaseModel + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + + +# --- Validation Result Schema --- +class ValidationResult(BaseModel): + """Validation result structure.""" + + valid: bool + error: str | None = None + + +# --- Validator Agent --- +validator = LlmAgent( + name="validator", + model=_MODEL, + instruction=""" + You validate input data quality. + Check if the input contains valid JSON. + Return {"valid": true} if valid, {"valid": false, "error": "reason"} if invalid. + """, + output_schema=ValidationResult, # Ensures structured JSON output + # output_key auto-defaults to "validator" (agent name) +) + +# --- Processor Agent --- +processor = LlmAgent( + name="processor", + model=_MODEL, + instruction=""" + You process validated data. + Transform the input JSON and return processed results. + """, +) + +# --- Error Handler Agent --- +error_handler = LlmAgent( + name="error_handler", + model=_MODEL, + instruction=""" + You handle validation errors. + Provide helpful error messages and suggestions for fixing invalid data. + """, +) + + +# --- Edge Condition Functions --- +def is_valid_json(state: GraphState) -> bool: + """Check if JSON is valid from structured output.""" + result = state.get_parsed("validator", ValidationResult) + return result.valid if result else False + + +# --- Create GraphAgent with Conditional Routing --- +def build_validation_graph() -> GraphAgent: + """Build the validation pipeline graph.""" + g = GraphAgent(name="validation_pipeline") + + # Add nodes + g.add_node("validate", agent=validator) + g.add_node("process", agent=processor) + g.add_node("error", agent=error_handler) + + # Add conditional edges + # If validation passes (state.data["validator"]["valid"] == True) -> process + g.add_edge( + "validate", + "process", + condition=is_valid_json, + ) + + # If validation fails (state.data["validator"]["valid"] == False) -> error handler + g.add_edge( + "validate", + "error", + condition=lambda state: not is_valid_json(state), + ) + + # Define workflow + g.set_start("validate") + g.set_end("process") # Success path ends at process + g.set_end("error") # Error path ends at error handler + + return g + + +# --- Run the workflow --- + + +async def main(): + graph = build_validation_graph() + runner = Runner( + app_name="validation_pipeline", + agent=graph, + session_service=InMemorySessionService(), + auto_create_session=True, + ) + + # Example: Valid input + print("=== Testing with valid JSON ===") + async for event in runner.run_async( + user_id="user_1", + session_id="session_1", + new_message=types.Content( + role="user", + parts=[types.Part(text='{"name": "John", "age": 30}')], + ), + ): + if event.content and event.content.parts: + print(f"{event.author}: {event.content.parts[0].text}") + + # Example: Invalid input + print("\n=== Testing with invalid JSON ===") + async for event in runner.run_async( + user_id="user_1", + session_id="session_2", + new_message=types.Content( + role="user", + parts=[types.Part(text='{"name": "Invalid data')], + ), + ): + if event.content and event.content.parts: + print(f"{event.author}: {event.content.parts[0].text}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_basic/root_agent.yaml b/contributing/samples/graph_agent_basic/root_agent.yaml new file mode 100644 index 0000000000..d27a79ab1a --- /dev/null +++ b/contributing/samples/graph_agent_basic/root_agent.yaml @@ -0,0 +1,42 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/google/adk-python/refs/heads/main/src/google/adk/agents/config_schemas/AgentConfig.json + +agent_class: GraphAgent +name: validation_pipeline +description: Data validation pipeline with conditional routing + +# Define start and end nodes +start_node: validate +end_nodes: + - process + - error + +# Maximum iterations for cyclic graphs (default: 20) +max_iterations: 10 + +# Node definitions +nodes: + - name: validate + sub_agents: + - code: agents.validator + + - name: process + sub_agents: + - code: agents.processor + + - name: error + sub_agents: + - code: agents.error_handler + +# Edge definitions with conditional routing +edges: + # If validation passes -> process + - source_node: validate + target_node: process + condition: "data.get('valid', False) is True" + priority: 1 + + # If validation fails -> error handler + - source_node: validate + target_node: error + condition: "data.get('valid', False) is False" + priority: 1 diff --git a/contributing/samples/graph_agent_dynamic_queue/README.md b/contributing/samples/graph_agent_dynamic_queue/README.md new file mode 100644 index 0000000000..c5269443ec --- /dev/null +++ b/contributing/samples/graph_agent_dynamic_queue/README.md @@ -0,0 +1,53 @@ +# GraphAgent Dynamic Task Queue Example + +This example demonstrates the **Dynamic Task Queue** pattern for GraphAgent, enabling AI Co-Scientist and similar workflows where tasks are generated and processed dynamically at runtime. + +## Pattern Overview + +The dynamic task queue pattern uses a function node with runtime agent dispatch: +- **Task Queue**: Maintained in GraphState, grows/shrinks dynamically +- **Agent Dispatch**: Different agents selected based on task type +- **Dynamic Task Generation**: Agents generate new tasks from their outputs +- **State-Based Loop**: Continues until queue is empty + +## What This Example Shows + +1. **Mock Agents**: Three agents (generation, review, experiment) for demonstration +2. **Task Parsing**: Extract TODO items from agent outputs to create new tasks +3. **Dynamic Dispatch**: Select agent based on task type at runtime +4. **Queue Management**: Process tasks until queue is empty + +## Architecture Support + +This pattern enables **95%+ architecture support** for: +- AI Co-Scientist (dynamic hypothesis generation and testing) +- Research paper writing (dynamic outline β†’ research β†’ writing loops) +- Multi-agent task orchestration + +## Running the Example + +```bash +cd /path/to/adk-python +source venv/bin/activate +python contributing/samples/graph_agent_dynamic_queue/agent.py +``` + +The example will: +1. Start with 2 initial tasks (generate hypothesis 1 and 2) +2. Process each task with appropriate agent +3. Parse agent outputs for new tasks (TODO: review X, TODO: experiment Y) +4. Add new tasks to queue dynamically +5. Continue until queue is empty + +## Adapting This Pattern + +Replace the mock agents with real agents: +```python +from your_agents import GenerationAgent, ReviewAgent, ExperimentAgent + +generation_agent = GenerationAgent(name="generation") +review_agent = ReviewAgent(name="review") +experiment_agent = ExperimentAgent(name="experiment") +``` + +Customize task parsing logic in `parse_new_tasks_from_result()` to match your agent outputs. diff --git a/contributing/samples/graph_agent_dynamic_queue/agent.py b/contributing/samples/graph_agent_dynamic_queue/agent.py new file mode 100644 index 0000000000..a476726ee3 --- /dev/null +++ b/contributing/samples/graph_agent_dynamic_queue/agent.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 +"""Dynamic Task Queue Pattern Example + +Demonstrates how to implement AI Co-Scientist pattern using GraphAgent +with a function node that dynamically dispatches to different agents. + +This example shows: +1. Dynamic task queue management +2. Runtime agent dispatch based on task type +3. Dynamic task generation from agent outputs +4. State-based loop control +""" + +import asyncio +import re +from typing import Any +from typing import Dict +from typing import List + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types + + +# Mock agents for demonstration (replace with real agents) +class MockGenerationAgent(BaseAgent): + """Mock agent that generates hypotheses.""" + + async def _run_async_impl(self, ctx: InvocationContext): + input_text = "" + if ctx.user_content and ctx.user_content.parts: + input_text = ctx.user_content.parts[0].text or "" + + # Simulate hypothesis generation + hypothesis = f"Hypothesis: {input_text} leads to interesting results" + + # Generate follow-up tasks + output = f"""{hypothesis} + +TODO: review this hypothesis +TODO: experiment with variation A""" + + yield Event( + author=self.name, + content=types.Content( + role="assistant", parts=[types.Part(text=output)] + ), + ) + + +class MockReviewAgent(BaseAgent): + """Mock agent that reviews hypotheses.""" + + async def _run_async_impl(self, ctx: InvocationContext): + input_text = "" + if ctx.user_content and ctx.user_content.parts: + input_text = ctx.user_content.parts[0].text or "" + + # Simulate review + review = f"Review: {input_text} - APPROVED with score 8/10" + + yield Event( + author=self.name, + content=types.Content( + role="assistant", parts=[types.Part(text=review)] + ), + ) + + +class MockExperimentAgent(BaseAgent): + """Mock agent that runs experiments.""" + + async def _run_async_impl(self, ctx: InvocationContext): + input_text = "" + if ctx.user_content and ctx.user_content.parts: + input_text = ctx.user_content.parts[0].text or "" + + # Simulate experiment + result = f"Experiment: {input_text} - SUCCESS (confidence: 0.92)" + + yield Event( + author=self.name, + content=types.Content( + role="assistant", parts=[types.Part(text=result)] + ), + ) + + +# Initialize worker agents +generation_agent = MockGenerationAgent(name="generation_agent") +review_agent = MockReviewAgent(name="review_agent") +experiment_agent = MockExperimentAgent(name="experiment_agent") + + +def parse_new_tasks_from_result(result: str) -> List[Dict[str, str]]: + """Extract TODO tasks from agent output. + + Looks for lines like: + - TODO: review X + - TODO: experiment with Y + + Returns: + List of task dicts: [{"type": "review", "data": "X"}, ...] + """ + tasks = [] + todo_pattern = r"TODO:\s*(review|experiment)\s+(.+)" + + for match in re.finditer(todo_pattern, result, re.IGNORECASE): + task_type = match.group(1).lower() + task_data = match.group(2).strip() + tasks.append({"type": task_type, "data": task_data}) + + return tasks + + +async def dynamic_task_dispatcher( + state: GraphState, ctx: InvocationContext +) -> Dict[str, Any]: + """Dispatch to agents based on dynamic task queue. + + This function: + 1. Reads task queue from state + 2. Pops next task + 3. Dispatches to appropriate agent + 4. Updates queue with any new tasks generated + 5. Returns updated state + """ + task_queue = state.data.get("task_queue", []) + + if not task_queue: + print("βœ… Task queue empty - all tasks complete!") + return {"all_complete": True, "tasks_remaining": 0} + + # Pop next task + next_task = task_queue.pop(0) + task_type = next_task["type"] + task_data = next_task["data"] + + print(f"\nπŸ”„ Processing task: [{task_type}] {task_data}") + + # Dynamic agent dispatch based on task type + if task_type == "generate": + agent = generation_agent + elif task_type == "review": + agent = review_agent + elif task_type == "experiment": + agent = experiment_agent + else: + raise ValueError(f"Unknown task type: {task_type}") + + # Create context for agent with task data + agent_ctx = ctx.model_copy( + update={ + "user_content": types.Content( + role="user", parts=[types.Part(text=task_data)] + ) + } + ) + + # Execute agent and collect result + result = "" + async for event in agent.run_async(agent_ctx): + if event.content and event.content.parts: + result += event.content.parts[0].text or "" + + print(f" Result: {result[:100]}...") + + # Parse result for new tasks (dynamic task generation!) + new_tasks = parse_new_tasks_from_result(result) + if new_tasks: + print(f" Generated {len(new_tasks)} new tasks: {new_tasks}") + task_queue.extend(new_tasks) + + # Update state + state.data["task_queue"] = task_queue + state.data["last_result"] = result + completed = state.data.setdefault("completed_tasks", []) + completed.append({"type": task_type, "data": task_data, "result": result}) + + print(f" Tasks remaining: {len(task_queue)}") + + return {"tasks_remaining": len(task_queue), "completed_count": len(completed)} + + +def build_dynamic_task_queue_graph() -> GraphAgent: + """Build GraphAgent with dynamic task queue pattern. + + The graph has a single node that loops, processing tasks from a queue + that can grow dynamically based on agent outputs. + """ + graph = GraphAgent( + name="ai_co_scientist", + max_iterations=20, # Prevent infinite loops + description="Dynamic task queue with agent dispatch", + ) + + # Single dispatcher node that processes queue + graph.add_node("task_dispatcher", function=dynamic_task_dispatcher) + + # Loop back to dispatcher while tasks remain. + # Check task_queue directly (updated via output_mapper return value). + # The return dict {"tasks_remaining": N} is stored under state.data["task_dispatcher"] + # by the output mapper, so state.data.get("tasks_remaining") would always be 0. + graph.add_edge( + "task_dispatcher", + "task_dispatcher", + condition=lambda state: len(state.data.get("task_queue", [])) > 0, + ) + + graph.set_start("task_dispatcher") + graph.set_end("task_dispatcher") # Terminal when no edge condition matches + + return graph + + +async def main(): + """Run dynamic task queue example.""" + print("=" * 70) + print("Dynamic Task Queue Pattern - AI Co-Scientist Example") + print("=" * 70) + + # Build graph + graph = build_dynamic_task_queue_graph() + + # Create session service + session_service = InMemorySessionService() + + # Initialize with starting tasks + initial_state = GraphState( + data={ + "task_queue": [ + {"type": "generate", "data": "quantum computing approach"}, + {"type": "generate", "data": "machine learning approach"}, + ], + "completed_tasks": [], + } + ) + + # Seed domain data via create_session so it survives the deepcopy: + # InMemorySessionService always deepcopies on get/create, so setting + # session.state after create_session() would only mutate the returned copy. + session = await session_service.create_session( + app_name="dynamic_queue_demo", + user_id="demo_user", + state=initial_state.data, + ) + + # Create runner + runner = Runner( + app_name="dynamic_queue_demo", + agent=graph, + session_service=session_service, + auto_create_session=False, # Session already created + ) + + # Run graph - dispatcher will process queue dynamically + print("\nπŸ“‹ Initial task queue:") + for task in initial_state.data["task_queue"]: + print(f" - [{task['type']}] {task['data']}") + + print("\n" + "=" * 70) + print("Starting execution...") + print("=" * 70) + + async for event in runner.run_async( + user_id="demo_user", + session_id=session.id, + new_message=types.Content( + role="user", parts=[types.Part(text="Start task queue processing")] + ), + ): + if event.content and event.content.parts: + text = event.content.parts[0].text or "" + if text and "final_output" in text.lower(): + print(f"\nπŸ“Š {text}") + + # Print final statistics (re-fetch β€” create_session returned a deepcopy) + fresh_session = await session_service.get_session( + app_name="dynamic_queue_demo", user_id="demo_user", session_id=session.id + ) + if fresh_session is None: + print("\n⚠️ Could not retrieve final session state to print statistics.") + return + final_session = fresh_session + final_data = final_session.state.get("graph_data", {}) + final_state = GraphState(data=final_data) if final_data else GraphState() + + print("\n" + "=" * 70) + print("Execution Complete!") + print("=" * 70) + print( + "Total tasks completed:" + f" {len(final_state.data.get('completed_tasks', []))}" + ) + print(f"\nCompleted tasks:") + for i, task in enumerate(final_state.data.get("completed_tasks", []), 1): + print(f"{i}. [{task['type']}] {task['data']}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_dynamic_topology/README.md b/contributing/samples/graph_agent_dynamic_topology/README.md new file mode 100644 index 0000000000..10e4c0183b --- /dev/null +++ b/contributing/samples/graph_agent_dynamic_topology/README.md @@ -0,0 +1,34 @@ +# GraphAgent Dynamic Topology β€” Runtime Node and Edge Discovery + +This example demonstrates a pipeline where an LLM planner first decides which +processing steps are needed, then those steps are added as new nodes and edges +to the graph before execution continues β€” the graph shape is unknown at build time. + +## When to Use This Pattern + +- Adaptive pipelines where required steps depend on the input or task type +- ETL workflows where an LLM determines the transformation chain at runtime +- Any scenario requiring a variable number of agents chosen dynamically + +## How to Run + +```bash +GOOGLE_API_KEY=your_key python -m contributing.samples.graph_agent_dynamic_topology.agent +``` + +## Graph Structure + +``` +planner ──(discovers steps)──▢ step_1 ──▢ step_2 ──▢ ... ──▢ step_N ──▢ END + β”‚ (nodes and edges added at runtime) +checkpoint checkpoint checkpoint checkpoint +``` + +## Key Code Walkthrough + +- **`graph.add_node()` / `graph.add_edge()` post-construction** β€” called inside `extend_graph_with_steps()` after the planner fires; GraphAgent supports topology mutations mid-run +- **`output_schema=PipelineDesign`** β€” planner returns a structured `steps: list[str]` that drives topology extension +- **`make_step_agent(step_name)`** β€” factory creates a specialized `LlmAgent` per discovered step +- **`graph.set_end(prev_node)`** β€” end node is reassigned to the last dynamically added step +- **`GraphCheckpointCallback(checkpoint_nodes=None)`** β€” checkpoints every node (planner + all steps) since the full set is not known ahead of time + diff --git a/contributing/samples/graph_agent_dynamic_topology/__init__.py b/contributing/samples/graph_agent_dynamic_topology/__init__.py new file mode 100644 index 0000000000..574f0423ae --- /dev/null +++ b/contributing/samples/graph_agent_dynamic_topology/__init__.py @@ -0,0 +1 @@ +"""GraphAgent dynamic topology sample.""" diff --git a/contributing/samples/graph_agent_dynamic_topology/agent.py b/contributing/samples/graph_agent_dynamic_topology/agent.py new file mode 100644 index 0000000000..ba9a18295b --- /dev/null +++ b/contributing/samples/graph_agent_dynamic_topology/agent.py @@ -0,0 +1,262 @@ +"""GraphAgent dynamic topology: add nodes and edges at runtime. + +Demonstrates adding entirely new nodes and edges AFTER graph construction +based on runtime decisions, with checkpointing of each processing step. + +Difference from DynamicNode (patterns.py): +- DynamicNode: selects WHICH agent runs at a FIXED node position +- DynamicTopology: adds NEW nodes/edges to the graph at runtime + +Use case: Adaptive pipeline that discovers required processing steps at +runtime (e.g., an LLM decides ETL steps needed for a given dataset). + +Flow: + planner ──(discovers steps)──→ [step_1, step_2, ...step_N] + | | + checkpoint checkpoint (each step) + | + END + +Why GraphAgent (not SequentialAgent)? +- SequentialAgent: fixed sequence of nodes defined at construction time +- GraphAgent: nodes and edges can be added at runtime before re-running + +Run (requires GOOGLE_API_KEY env var): + python -m contributing.samples.graph_agent_dynamic_topology.agent +""" + +import asyncio +import json +import os + +from google.adk.agents.graph import GraphAgent +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.graph.checkpoint_callback import GraphCheckpointCallback +from google.adk.checkpoints import CheckpointService +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +from pydantic import BaseModel + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +# --------------------------------------------------------------------------- +# Output Schemas +# --------------------------------------------------------------------------- + + +class PipelineDesign(BaseModel): + """Structured pipeline design from planner agent.""" + + steps: list[str] # Ordered list of processing step names + description: str # Brief description of the overall pipeline + + +# --------------------------------------------------------------------------- +# Planner agent (designs the pipeline steps at runtime) +# --------------------------------------------------------------------------- + +planner = LlmAgent( + name="planner", + model=_MODEL, + instruction=( + "You are a data pipeline designer. Given a data processing task," + " identify 2-4 distinct processing steps required (e.g., 'validate'," + " 'transform', 'enrich', 'aggregate', 'export'). Return {\"steps\":" + ' ["step1", "step2", ...], "description": "pipeline overview"}.' + ), + output_schema=PipelineDesign, + # output_key auto-defaults to "planner" (agent name) +) + + +# --------------------------------------------------------------------------- +# Step agent factory +# --------------------------------------------------------------------------- + + +def make_step_agent(step_name: str) -> LlmAgent: + """Create a specialized agent for a specific pipeline step.""" + safe_name = step_name.replace(" ", "_").replace("-", "_") + return LlmAgent( + name=f"{safe_name}_agent", + model=_MODEL, + instruction=( + f"You are a data processing agent performing the '{step_name}' step." + f" Describe what the '{step_name}' step does to the input data and" + " confirm it was completed successfully. Be concise (1-2 sentences)." + ), + output_key=f"step_{step_name}_result", + ) + + +# --------------------------------------------------------------------------- +# Graph builder with dynamic topology +# --------------------------------------------------------------------------- + + +def build_base_graph( + session_service: InMemorySessionService, +) -> GraphAgent: + """Build the initial graph with only the planner node.""" + checkpoint_service = CheckpointService(session_service=session_service) + # Checkpoint after planner (topology decision) and after each dynamic step + # We pass all=True initially; will filter in callback based on step names + checkpoint_callback = GraphCheckpointCallback( + checkpoint_service, + checkpoint_before=False, + checkpoint_after=True, + # checkpoint_nodes=None means: checkpoint ALL nodes (planner + all steps) + ) + + graph = GraphAgent( + name="dynamic_pipeline", + description="Adaptive data pipeline with runtime-discovered steps", + max_iterations=20, + after_node_callback=checkpoint_callback.after_node, + ) + + graph.add_node("planner", agent=planner) + graph.set_start("planner") + # No edges or end node yet - will be added dynamically + + return graph + + +async def extend_graph_with_steps( + graph: GraphAgent, + steps: list[str], +) -> None: + """Add discovered pipeline steps as nodes and edges to the graph. + + This modifies the graph BEFORE the runner continues execution. + + Args: + graph: The GraphAgent to extend + steps: Ordered list of step names discovered by planner + """ + if not steps: + return + + # Remove existing end nodes (planner was a dead end without steps) + # Set the first dynamic step as the next node after planner + prev_node = "planner" + + for i, step_name in enumerate(steps): + safe_step = step_name.replace(" ", "_").replace("-", "_") + node_name = f"step_{safe_step}" + + # Skip if node already exists (idempotent) + if node_name in graph.nodes: + print(f" [TOPOLOGY] Node '{node_name}' already exists, skipping") + prev_node = node_name + continue + + # Create and add agent for this step + step_agent = make_step_agent(step_name) + graph.add_node(node_name, agent=step_agent) + print(f" [TOPOLOGY] Added node '{node_name}'") + + # Connect to previous node + graph.add_edge(prev_node, node_name) + print(f" [TOPOLOGY] Added edge '{prev_node}' β†’ '{node_name}'") + + prev_node = node_name + + # Set the last step as the end node + graph.set_end(prev_node) + print(f" [TOPOLOGY] Set '{prev_node}' as end node") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def main() -> None: + print("=== Dynamic Topology: Runtime Pipeline Discovery ===\n") + + session_service = InMemorySessionService() + graph = build_base_graph(session_service) + session_id = "dynamic-pipeline-1" + + session = await session_service.create_session( + app_name="dynamic_pipeline", user_id="user1", session_id=session_id + ) + + runner = Runner( + app_name="dynamic_pipeline", + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + task = ( + "Process a CSV dataset: validate schema, transform types, enrich with" + " external API, and export to JSON" + ) + print(f"Task: {task}\n") + + topology_extended = False + step_count = 0 + + async for event in runner.run_async( + user_id="user1", + session_id=session_id, + new_message=types.Content(parts=[types.Part(text=task)]), + ): + if not event.content or not event.content.parts: + continue + text = event.content.parts[0].text or "" + author = event.author + + if author == "planner" and not topology_extended: + # Planner emits structured JSON as event text (output_schema serialises to string) + try: + design = json.loads(text) + except (json.JSONDecodeError, TypeError): + design = {} + steps = design.get("steps", []) + description = design.get("description", "") + + print(f"Planner designed pipeline: {description}") + print(f"Steps discovered: {steps}\n") + + # Dynamically add discovered steps to the graph + await extend_graph_with_steps(graph, steps) + topology_extended = True + print() + + elif author.startswith("step_") or author.endswith("_agent"): + step_count += 1 + print(f"[Step {step_count}] {author}: {text[:150]}") + + elif author == "checkpoint_service": + pass # Silent checkpoint acknowledgment + + # Show final state + final_state = session.state.get("graph_data", {}) + print(f"\nPipeline completed. Steps executed: {step_count}") + + # Re-fetch session: InMemorySessionService returns deepcopies so the local + # `session` reference is stale. + fresh_session = await session_service.get_session( + app_name="dynamic_pipeline", user_id="user1", session_id=session_id + ) + if fresh_session is None: + print( + f"WARNING: session_service.create_session returned None, using stale" + f" copy" + ) + fresh_session = session + checkpoints = fresh_session.state.get("_checkpoint_index", {}) + print(f"Checkpoints created: {len(checkpoints)}") + print("\nStep results:") + for key, value in final_state.items(): + if key.startswith("step_") and key.endswith("_result"): + step_name = key[5:-7] # strip "step_" prefix and "_result" suffix + print(f" {step_name}: {str(value)[:100]}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_hitl/README.md b/contributing/samples/graph_agent_hitl/README.md new file mode 100644 index 0000000000..891d2db259 --- /dev/null +++ b/contributing/samples/graph_agent_hitl/README.md @@ -0,0 +1,37 @@ +# GraphAgent Human-In-The-Loop (HITL) β€” Risk-Gated Approval + +This example demonstrates a risk-aware workflow where an agent assesses action +risk at runtime and conditionally pauses for human approval before proceeding, +using `InterruptService`, `InterruptReasoner`, and `GraphCheckpointCallback`. + +## When to Use This Pattern + +- Any workflow requiring human sign-off before irreversible actions +- Automated pipelines that must escalate high-risk decisions to a reviewer +- Audit trails where state must be preserved across human interaction windows + +## How to Run + +```bash +GOOGLE_API_KEY=your_key python -m contributing.samples.graph_agent_hitl.agent +``` + +## Graph Structure + +``` +analyze ──(always)──▢ execute ──▢ END + β”‚ β–² + β”‚ [interrupt_config fires if risk == "high"] + β”‚ β”‚ + └──── InterruptReasoner processes human feedback β”€β”€β”€β”€β”€β”€β”€β”˜ + "continue" β†’ proceed | "pause" β†’ stop +``` + +## Key Code Walkthrough + +- **`InterruptConfig(mode=BEFORE, nodes=["execute"])`** β€” pauses before `execute` when risk is high +- **`InterruptReasoner`** β€” LLM that interprets human feedback and returns `continue` or `pause` +- **`InterruptService.send_message()`** β€” delivers human approval into the running graph +- **`GraphCheckpointCallback(checkpoint_nodes={"analyze","execute"})`** β€” saves state only at critical nodes, not every node +- **`output_schema=RiskAssessment`** β€” structured output lets condition logic read `risk_level` from state + diff --git a/contributing/samples/graph_agent_hitl/__init__.py b/contributing/samples/graph_agent_hitl/__init__.py new file mode 100644 index 0000000000..f889f7c92e --- /dev/null +++ b/contributing/samples/graph_agent_hitl/__init__.py @@ -0,0 +1 @@ +"""GraphAgent HITL (Human-In-The-Loop) sample.""" diff --git a/contributing/samples/graph_agent_hitl/agent.py b/contributing/samples/graph_agent_hitl/agent.py new file mode 100644 index 0000000000..1fb80c7e48 --- /dev/null +++ b/contributing/samples/graph_agent_hitl/agent.py @@ -0,0 +1,300 @@ +"""GraphAgent Human-In-The-Loop (HITL) with automatic checkpointing. + +Demonstrates: +- Agent-driven risk assessment (agent decides when human approval is needed) +- InterruptService for runtime human messages +- InterruptReasoner for LLM-based approval decisions +- GraphCheckpointCallback for selective node-level checkpointing +- State preservation across human interactions + +Flow: + analyze β†’ (checkpoint) β†’ execute β†’ (checkpoint) β†’ END + ↓ + [interrupt_config fires if risk_level == "high"] + ↓ + InterruptReasoner processes human feedback + ↓ + "continue" β†’ execute; "pause" β†’ stop + +Why GraphAgent (not SequentialAgent)? +- SequentialAgent: cannot inspect agent output to conditionally interrupt +- GraphAgent: interrupt_config + interrupt_service read state β†’ route or pause + +Why GraphCheckpointCallback over checkpointing=True? +- checkpointing=True: checkpoints EVERY node +- GraphCheckpointCallback(checkpoint_nodes=...): checkpoints ONLY critical nodes + +Run (requires GOOGLE_API_KEY env var): + python -m contributing.samples.graph_agent_hitl.agent +""" + +import asyncio +import json +import os + +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.agents.graph import InterruptConfig +from google.adk.agents.graph import InterruptMode +from google.adk.agents.graph import InterruptReasoner +from google.adk.agents.graph import InterruptReasonerConfig +from google.adk.agents.graph import InterruptService +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.graph.checkpoint_callback import GraphCheckpointCallback +from google.adk.checkpoints import CheckpointService +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +from pydantic import BaseModel + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +# --------------------------------------------------------------------------- +# Output Schemas +# --------------------------------------------------------------------------- + + +class RiskAssessment(BaseModel): + """Structured risk assessment from analyzer agent.""" + + action: str # Description of the action to take + risk_level: str # "low" | "medium" | "high" + justification: str # Why this risk level was assigned + + +# --------------------------------------------------------------------------- +# Agents +# --------------------------------------------------------------------------- + + +def _create_agents(): + """Create fresh agent instances (avoids single-parent conflicts).""" + _analyzer = LlmAgent( + name="analyzer", + model=_MODEL, + instruction=( + "You are a risk assessment agent. Analyze the requested action and" + " determine its risk level." + ' Return {"action": "description of action", "risk_level":' + ' "low|medium|high", "justification": "reason for risk level"}.' + " High risk: irreversible actions (delete, overwrite, deploy to" + " production). Low risk: read-only or reversible actions." + ), + output_schema=RiskAssessment, + ) + _executor = LlmAgent( + name="executor", + model=_MODEL, + instruction=( + "You are an action executor. Confirm the action has been executed" + " successfully and describe what was done." + ), + output_key="execution_result", + ) + return _analyzer, _executor + + +# --------------------------------------------------------------------------- +# Interrupt reasoner +# --------------------------------------------------------------------------- + +# LLM-based reasoner: processes human feedback and decides next action +# "continue" β†’ proceed with execution +# "pause" β†’ stop execution (escalate=True) +approval_reasoner = InterruptReasoner( + InterruptReasonerConfig( + model=_MODEL, + available_actions=["continue", "pause"], + instruction=( + "You process human approval decisions for a risk-aware workflow." + " If the human approves the action, return 'continue'." + " If the human rejects or requests more review, return 'pause'." + ), + ) +) + + +# --------------------------------------------------------------------------- +# Graph +# --------------------------------------------------------------------------- + + +def build_hitl_graph( + session_service: InMemorySessionService, +) -> tuple[GraphAgent, InterruptService]: + """Build HITL graph with checkpointing and interrupt support.""" + interrupt_service = InterruptService() + + checkpoint_service = CheckpointService(session_service=session_service) + # Only checkpoint the two critical nodes (not every node) + checkpoint_callback = GraphCheckpointCallback( + checkpoint_service, + checkpoint_before=False, # Only after + checkpoint_after=True, + checkpoint_nodes={"analyze", "execute"}, + ) + + graph = GraphAgent( + name="hitl_workflow", + description=( + "Risk-aware workflow with human approval for high-risk actions" + ), + max_iterations=10, + interrupt_service=interrupt_service, + interrupt_config=InterruptConfig( + # Interrupt BEFORE execute for high-risk actions + mode=InterruptMode.BEFORE, + nodes=["execute"], # Only check before "execute" node + reasoner=approval_reasoner, + ), + after_node_callback=checkpoint_callback.after_node, + ) + + analyzer, executor = _create_agents() + graph.add_node("analyze", agent=analyzer) + graph.add_node( + "execute", + agent=executor, + input_mapper=lambda s: ( + f"Execute this action: {s.data.get('analyzer', {}).get('action', '')}" + ), + ) + + graph.set_start("analyze") + graph.add_edge("analyze", "execute") + graph.set_end("execute") + + return graph, interrupt_service + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def run_with_approval( + request: str, + approval_message: str, + session_id: str, +) -> None: + """Run HITL workflow with simulated human approval. + + In production: the approval_message would come from a real human + (e.g., via API endpoint, Slack bot, or review UI). + + Args: + request: The action to analyze and potentially execute + approval_message: Human's approval feedback (simulated here) + session_id: Unique session identifier + """ + session_service = InMemorySessionService() + graph, interrupt_service = build_hitl_graph(session_service) + + session = await session_service.create_session( + app_name="hitl_workflow", user_id="user1", session_id=session_id + ) + + runner = Runner( + app_name="hitl_workflow", + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + # Register session with interrupt service BEFORE running + interrupt_service.register_session(session_id) + + # Pre-queue human approval message (simulates human reviewing and approving) + # In production: send this message from a separate process/API after + # the graph escalates and notifies a human reviewer. + await interrupt_service.send_message( + session_id, + approval_message, + action="continue", # "continue" = approve, "pause" = reject + ) + + print(f"\nRequest: {request}") + print(f"Human approval pre-queued: '{approval_message}'") + print("-" * 50) + + last_assessment: dict = {} + last_execution: str = "" + + async for event in runner.run_async( + user_id="user1", + session_id=session_id, + new_message=types.Content(parts=[types.Part(text=request)]), + ): + if not event.content or not event.content.parts: + continue + text = event.content.parts[0].text or "" + if not text: + continue + + author = event.author + if author == "analyzer": + # output_schema serialises structured output as JSON text + try: + last_assessment = json.loads(text) + except (json.JSONDecodeError, TypeError): + last_assessment = {} + risk = last_assessment.get("risk_level", "unknown").upper() + action = last_assessment.get("action", "") + print(f"[Analyzer] Risk: {risk} | Action: {action}") + elif author == "executor": + last_execution = text + print(f"[Executor] {text[:200]}") + elif author == "interrupt_reasoner": + print(f"[Interrupt Reasoner] Decision: {text[:100]}") + elif author == "checkpoint_service": + print(f"[Checkpoint] Saved state at current node") + + # Show final state (tracked from events to avoid stale session reference) + print(f"\nRisk level: {last_assessment.get('risk_level', 'N/A')}") + print( + f"Result: {last_execution[:200] if last_execution else '(not executed)'}" + ) + + # Re-fetch session: InMemorySessionService returns deepcopies so the local + # `session` reference is stale. The runner's internal copy holds checkpoint state. + fresh_session = await session_service.get_session( + app_name="hitl_workflow", user_id="user1", session_id=session_id + ) + if fresh_session is None: + print( + f"WARNING: session_service.create_session returned None, using stale" + f" copy" + ) + fresh_session = session + checkpoints = fresh_session.state.get("_checkpoint_index", {}) + print(f"Checkpoints created: {len(checkpoints)}") + + # Cleanup + interrupt_service.unregister_session(session_id) + + +async def main() -> None: + print( + "=== HITL Workflow: Agent-Driven Risk Assessment + Human Approval ===\n" + ) + + # Low risk: auto-approved (agent decides low risk, interrupt fires but + # approval is already in queue) + await run_with_approval( + request="Read the contents of the config.yaml file", + approval_message="Approved - read-only operation is safe", + session_id="hitl-low-risk", + ) + + print("\n" + "=" * 60 + "\n") + + # High risk: human must review (pre-queued approval simulates human review) + await run_with_approval( + request="Delete all records from the production database", + approval_message="Approved with caution - backup confirmed", + session_id="hitl-high-risk", + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_hitl_orchestrated/README.md b/contributing/samples/graph_agent_hitl_orchestrated/README.md new file mode 100644 index 0000000000..05c468a776 --- /dev/null +++ b/contributing/samples/graph_agent_hitl_orchestrated/README.md @@ -0,0 +1,51 @@ +# Composable HITL Orchestrated Pipeline + +Demonstrates how to compose HITL review loops as reusable `NestedGraphNode` building blocks in a larger orchestrated pipeline. + +## Graph Structure + +**Outer graph** (`document_pipeline`): +``` +[classify] --> [process] --> [aggregate] +``` + +**Inner graph** (`review_stage`, wrapped as `NestedGraphNode`): +``` +[execute] --> [review_gate] --> approved? --> [done] + | + v rejected + [revise] --> [review_gate] (loop) +``` + +## Key Concepts + +- **Reusable HITL block**: The inner review graph is built once and wrapped in `NestedGraphNode` +- **Clean abstraction**: Outer orchestrator doesn't know about inner HITL details +- **Independent review cycles**: Each inner graph has its own interrupt timing +- **Observability**: `_debug_process_output` tracks nested graph output (via Part B observability) +- **Rule-based classification**: `classify` node determines stages without LLM + +## Running + +```bash +# Without LLM (deterministic fallback): +python -m contributing.samples.graph_agent_hitl_orchestrated.agent + +# With LLM: +export GOOGLE_API_KEY="your-key" +python -m contributing.samples.graph_agent_hitl_orchestrated.agent +``` + +## How It Works + +1. `classify` reads input document, determines processing stages (extract/summarize/translate) +2. `process` (NestedGraphNode) runs the inner review graph with HITL loop +3. Inner `execute` performs the stage task, `review_gate` pauses for human approval +4. If rejected: inner `revise` incorporates feedback, loops back +5. If approved: inner `done` returns output to outer graph +6. `aggregate` combines results into final output + +## Differences from `graph_agent_hitl_review` + +- `graph_agent_hitl_review`: Standalone HITL review loop +- `graph_agent_hitl_orchestrated`: Wraps the review pattern as a NestedGraphNode in a larger pipeline, showing composability diff --git a/contributing/samples/graph_agent_hitl_orchestrated/__init__.py b/contributing/samples/graph_agent_hitl_orchestrated/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contributing/samples/graph_agent_hitl_orchestrated/agent.py b/contributing/samples/graph_agent_hitl_orchestrated/agent.py new file mode 100644 index 0000000000..364156d097 --- /dev/null +++ b/contributing/samples/graph_agent_hitl_orchestrated/agent.py @@ -0,0 +1,400 @@ +"""Composable HITL Orchestrated Pipeline β€” Multi-stage document processing. + +Demonstrates how to compose HITL review loops as reusable NestedGraphNode +building blocks within a larger orchestrated pipeline. + +Scenario: Multi-stage document pipeline +1. Orchestrator receives a document processing request +2. Classifier determines which processing stages are needed +3. Each stage is a self-contained GraphAgent with its own HITL review loop +4. Stages execute in sequence; each can loop internally for human review +5. Final stage aggregates results + +Key concepts: +- HITL review as a **reusable building block** (NestedGraphNode) +- Inner graphs have independent interrupt timing and review cycles +- Outer orchestrator doesn't know about inner HITL details β€” clean abstraction +- Models real enterprise workflows: compliance review, multi-department approval + +Outer graph: + [classify] --> [process] --> [aggregate] + +Inner graph (per stage): + [execute] --> [review_gate] --> approved? --> [done] + | + v rejected + [revise] --> [review_gate] (loop) + +Run: + python -m contributing.samples.graph_agent_hitl_orchestrated.agent +""" + +import asyncio +import os + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphNode +from google.adk.agents.graph import GraphState +from google.adk.agents.graph import InterruptConfig +from google.adk.agents.graph import InterruptMode +from google.adk.agents.graph import InterruptService +from google.adk.agents.graph import NestedGraphNode +from google.adk.agents.llm_agent import LlmAgent +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types + + +# --------------------------------------------------------------------------- +# LLM availability +# --------------------------------------------------------------------------- + + +def _has_llm() -> bool: + return bool( + os.environ.get("GOOGLE_API_KEY") + or os.environ.get("GOOGLE_GENAI_USE_VERTEXAI") + ) + + +# --------------------------------------------------------------------------- +# Fallback agents (deterministic, no LLM) +# --------------------------------------------------------------------------- + + +class StageAgent(BaseAgent): + """Deterministic agent that performs a named stage operation.""" + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + def __init__(self, name: str, stage_type: str): + super().__init__(name=name) + object.__setattr__(self, "_stage_type", stage_type) + + async def _run_async_impl(self, ctx): + stage_type = object.__getattribute__(self, "_stage_type") + graph_data = ctx.session.state.get("graph_data", {}) + doc = graph_data.get("document", "") + + if stage_type == "extract": + result = f"Extracted key entities from: {doc[:100]}" + elif stage_type == "summarize": + result = f"Summary of document: {doc[:80]}..." + elif stage_type == "translate": + result = f"Translated content: [{doc[:60]}] (translated)" + else: + result = f"Processed ({stage_type}): {doc[:80]}" + + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=result)]), + ) + + +class ReviseStageAgent(BaseAgent): + """Deterministic revise agent for stage content.""" + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + async def _run_async_impl(self, ctx): + graph_data = ctx.session.state.get("graph_data", {}) + previous = graph_data.get("stage_content", "") + feedback = graph_data.get("stage_feedback", "") + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"{previous}\n[Revised: {feedback}]")] + ), + ) + + +# --------------------------------------------------------------------------- +# Inner review graph builder (reusable per stage) +# --------------------------------------------------------------------------- + + +def build_review_stage_graph( + stage_name: str, + stage_type: str, + interrupt_service: InterruptService | None = None, +) -> GraphAgent: + """Build a self-contained review graph for a single processing stage. + + Args: + stage_name: Name for this stage graph (e.g., "extract_review") + stage_type: Type of processing ("extract", "summarize", "translate") + interrupt_service: Optional interrupt service for HITL + + Returns: + GraphAgent with execute -> review_gate -> done/revise flow + """ + execute_agent = StageAgent( + name=f"{stage_type}_executor", stage_type=stage_type + ) + revise_agent = ReviseStageAgent(name=f"{stage_type}_reviser") + + interrupt_config = None + if interrupt_service: + interrupt_config = InterruptConfig( + mode=InterruptMode.BEFORE, + nodes=["stage_review"], + ) + + graph = GraphAgent( + name=stage_name, + description=f"Review stage for {stage_type}", + max_iterations=6, + interrupt_service=interrupt_service, + interrupt_config=interrupt_config, + ) + + # Execute stage + graph.add_node( + "execute", + agent=execute_agent, + output_mapper=lambda output, s: ( + s.data.update({"stage_content": output}), + s, + )[1], + ) + + # Review gate β€” checks interrupt queue for approval + async def review_gate_fn(state: GraphState, ctx) -> str: + svc = getattr(ctx, "_interrupt_service", None) + sid = ctx.session.id if ctx.session else None + if svc and sid: + msg = await svc.check_interrupt(sid) + if msg: + approved = msg.action == "approve" + state.data["stage_approved"] = approved + state.data["stage_feedback"] = msg.text + return ( + f"[{stage_type} review]" + f" {'APPROVED' if approved else 'REVISE: ' + msg.text}" + ) + # Auto-approve if no interrupt + state.data["stage_approved"] = True + state.data["stage_feedback"] = "" + return f"[{stage_type} review] Auto-approved" + + graph.add_node(GraphNode(name="stage_review", function=review_gate_fn)) + + # Revise node + graph.add_node( + "revise", + agent=revise_agent, + input_mapper=lambda s: ( + f"Revise: {s.data.get('stage_content', '')}\n" + f"Feedback: {s.data.get('stage_feedback', '')}" + ), + output_mapper=lambda output, s: ( + s.data.update({"stage_content": output}), + s, + )[1], + ) + + # Done node β€” outputs final stage content + async def done_fn(state: GraphState, ctx) -> str: + return state.data.get("stage_content", "") + + graph.add_node(GraphNode(name="done", function=done_fn)) + + # Edges + graph.set_start("execute") + graph.add_edge("execute", "stage_review") + graph.add_edge( + "stage_review", + "done", + condition=lambda s: s.data.get("stage_approved") is True, + ) + graph.add_edge( + "stage_review", + "revise", + condition=lambda s: s.data.get("stage_approved") is False, + ) + graph.add_edge("revise", "stage_review") + graph.set_end("done") + + return graph + + +# --------------------------------------------------------------------------- +# Outer orchestrator graph +# --------------------------------------------------------------------------- + + +def build_orchestrator( + interrupt_service: InterruptService | None = None, +) -> GraphAgent: + """Build the outer orchestration graph. + + Structure: + [classify] --> [process] --> [aggregate] + + The `process` node is a NestedGraphNode wrapping a review-stage graph. + + Args: + interrupt_service: Optional interrupt service for inner HITL + + Returns: + Configured outer GraphAgent + """ + + # Classify: determine required stages from document + async def classify_fn(state: GraphState, ctx) -> str: + doc = state.data.get("input", "") + state.data["document"] = doc + + # Simple rule-based classifier + stages = [] + if any(kw in doc.lower() for kw in ["data", "entity", "extract"]): + stages.append("extract") + if any(kw in doc.lower() for kw in ["long", "summary", "summarize"]): + stages.append("summarize") + if any(kw in doc.lower() for kw in ["translate", "language", "i18n"]): + stages.append("translate") + if not stages: + stages = ["extract", "summarize"] # Default + + state.data["stages"] = stages + return f"Classified: stages={stages}" + + # Build inner review graph for the primary stage + # (In a production version, you'd dynamically select based on stages[0]) + inner_graph = build_review_stage_graph( + stage_name="review_stage", + stage_type="extract", + interrupt_service=interrupt_service, + ) + + # Aggregate: combine results + async def aggregate_fn(state: GraphState, ctx) -> str: + stages = state.data.get("stages", []) + stage_content = state.data.get("stage_content", "") + document = state.data.get("document", "") + + result = ( + "Pipeline complete.\n" + f"Document: {document[:100]}\n" + f"Stages run: {stages}\n" + f"Output: {stage_content[:200]}" + ) + state.data["pipeline_result"] = result + return result + + # Build outer graph + graph = GraphAgent( + name="document_pipeline", + description="Multi-stage document processing with HITL review", + max_iterations=20, + ) + + graph.add_node(GraphNode(name="classify", function=classify_fn)) + graph.add_node( + NestedGraphNode( + name="process", + graph_agent=inner_graph, + inherit_session=True, + input_mapper=lambda s: s.data.get("document", ""), + ) + ) + graph.add_node(GraphNode(name="aggregate", function=aggregate_fn)) + + graph.set_start("classify") + graph.add_edge("classify", "process") + graph.add_edge("process", "aggregate") + graph.set_end("aggregate") + + return graph + + +# --------------------------------------------------------------------------- +# Main β€” simulates multi-stage human review +# --------------------------------------------------------------------------- + + +async def main() -> None: + """Run orchestrated HITL pipeline with simulated human interaction.""" + print("=" * 60) + print("Composable HITL Orchestrated Pipeline") + print("=" * 60) + + session_service = InMemorySessionService() + interrupt_service = InterruptService() + session_id = "hitl-orchestrated-demo" + + graph = build_orchestrator(interrupt_service=interrupt_service) + + await session_service.create_session( + app_name="document_pipeline", + user_id="reviewer", + session_id=session_id, + ) + + runner = Runner( + app_name="document_pipeline", + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + interrupt_service.register_session(session_id) + + # Simulate: reject extract stage once, then approve + await interrupt_service.send_message( + session_id, + "Missing key entity 'revenue' β€” please include financial data", + action="revise", + ) + await interrupt_service.send_message( + session_id, + "Extraction looks complete now", + action="approve", + ) + + document = ( + "Q4 2025 Financial Report: Revenue grew 23% YoY. " + "Key entities include quarterly data, revenue metrics, " + "and department-level breakdowns. Extract all financial entities." + ) + + print(f"\nDocument: {document[:80]}...") + print("-" * 40) + + async for event in runner.run_async( + user_id="reviewer", + session_id=session_id, + new_message=types.Content(parts=[types.Part(text=document)]), + ): + if not event.content or not event.content.parts: + continue + text = event.content.parts[0].text or "" + if not text or text.startswith("[GraphMetadata]"): + continue + author = event.author or "system" + print(f"[{author}] {text[:300]}") + + # Show final state + session = await session_service.get_session( + app_name="document_pipeline", + user_id="reviewer", + session_id=session_id, + ) + if session: + graph_data = session.state.get("graph_data", {}) + print( + f"\nPipeline result: {graph_data.get('pipeline_result', 'N/A')[:200]}" + ) + print(f"Stages: {graph_data.get('stages', [])}") + debug_output = graph_data.get("_debug_process_output", "") + if debug_output: + print(f"Debug (nested output): {debug_output[:100]}") + + interrupt_service.unregister_session(session_id) + print("\nDone.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_hitl_review/README.md b/contributing/samples/graph_agent_hitl_review/README.md new file mode 100644 index 0000000000..a2c9b94f3d --- /dev/null +++ b/contributing/samples/graph_agent_hitl_review/README.md @@ -0,0 +1,49 @@ +# HITL Content Review Workflow + +Demonstrates a complete human-in-the-loop review workflow using GraphAgent. + +## Graph Structure + +``` +[draft] --> [review_gate] --> approved? --> [publish] + | + v rejected + [revise] --> [review_gate] (loop) +``` + +## Key Concepts + +- **InterruptService** pauses execution at the `review_gate` node for human input +- **Conditional routing** routes to `publish` (approved) or `revise` (rejected) based on `state.data["approved"]` +- **Fallback mode**: runs without LLM when no API key is configured (deterministic string templates) +- **Review loop**: `revise -> review_gate` loop with `max_iterations=10` safety limit + +## Running + +```bash +# Without LLM (deterministic fallback): +python -m contributing.samples.graph_agent_hitl_review.agent + +# With LLM: +export GOOGLE_API_KEY="your-key" +python -m contributing.samples.graph_agent_hitl_review.agent + +# With Vertex AI: +export GOOGLE_GENAI_USE_VERTEXAI=1 +python -m contributing.samples.graph_agent_hitl_review.agent +``` + +## How It Works + +1. `draft` node generates initial content (LLM or template) +2. `review_gate` node pauses via InterruptService, waits for human message +3. Human sends approval (`action="approve"`) or revision request (`action="revise"`) +4. If rejected: `revise` node incorporates feedback, loops back to `review_gate` +5. If approved: `publish` node finalizes content + +In the demo, human messages are pre-queued to simulate the review interaction. + +## Differences from `graph_agent_hitl` + +- `graph_agent_hitl`: Demonstrates interrupt *mechanics* (risk assessment, InterruptReasoner) +- `graph_agent_hitl_review`: Demonstrates a *workflow pattern* where human approval is a required graph step with conditional routing diff --git a/contributing/samples/graph_agent_hitl_review/__init__.py b/contributing/samples/graph_agent_hitl_review/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contributing/samples/graph_agent_hitl_review/agent.py b/contributing/samples/graph_agent_hitl_review/agent.py new file mode 100644 index 0000000000..6056faaf1b --- /dev/null +++ b/contributing/samples/graph_agent_hitl_review/agent.py @@ -0,0 +1,345 @@ +"""GraphAgent HITL Review Workflow β€” Content generation with human approval loop. + +Demonstrates a complete review workflow pattern where human approval is a +first-class part of the graph execution flow (not an ad-hoc interruption): + +1. LLM agent drafts content from a topic +2. Human reviews and approves or requests revision +3. If revision needed: loop back with feedback +4. If approved: proceed to publish + +Graph structure: + [draft] --> [review_gate] --> approved? --> [publish] + | + v rejected + [revise] --> [review_gate] (loop) + +Key concepts: +- InterruptService for structured pause/resume at deterministic points +- Conditional routing (approved -> publish, rejected -> revise) +- Execution path tracking shows review iteration history +- Fallback mode: runs without LLM (string templates) when no API key is set + +Run: + # With LLM (set GOOGLE_API_KEY or GOOGLE_GENAI_USE_VERTEXAI=1): + python -m contributing.samples.graph_agent_hitl_review.agent + + # Without LLM (deterministic fallback, no API key needed): + python -m contributing.samples.graph_agent_hitl_review.agent +""" + +import asyncio +import os + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphNode +from google.adk.agents.graph import GraphState +from google.adk.agents.graph import InterruptConfig +from google.adk.agents.graph import InterruptMode +from google.adk.agents.graph import InterruptService +from google.adk.agents.llm_agent import LlmAgent +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + + +# --------------------------------------------------------------------------- +# LLM availability check +# --------------------------------------------------------------------------- + + +def _has_llm() -> bool: + """Check if LLM backend is configured via environment variables.""" + return bool( + os.environ.get("GOOGLE_API_KEY") + or os.environ.get("GOOGLE_GENAI_USE_VERTEXAI") + ) + + +# --------------------------------------------------------------------------- +# Fallback agents (no LLM required) +# --------------------------------------------------------------------------- + + +class DraftAgent(BaseAgent): + """Deterministic draft agent β€” produces content from a template.""" + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + async def _run_async_impl(self, ctx): + topic = ctx.session.state.get("graph_data", {}).get( + "input", "unknown topic" + ) + content = ( + f"Draft: This is an article about '{topic}'.\n" + "It covers the key concepts, benefits, and practical applications " + f"of {topic} in modern software engineering." + ) + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=content)]), + ) + + +class ReviseAgent(BaseAgent): + """Deterministic revise agent β€” appends feedback-based revision.""" + + model_config = {"extra": "allow", "arbitrary_types_allowed": True} + + async def _run_async_impl(self, ctx): + graph_data = ctx.session.state.get("graph_data", {}) + previous = graph_data.get("content", "") + feedback = graph_data.get("review_feedback", "no feedback") + revision = ( + f"{previous}\n\n" + f"[Revised based on feedback: {feedback}]\n" + "Additional details and corrections have been incorporated." + ) + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=revision)]), + ) + + +# --------------------------------------------------------------------------- +# Node functions +# --------------------------------------------------------------------------- + + +async def review_gate_fn(state: GraphState, ctx) -> str: + """Pause for human review and record approval decision. + + This is the core HITL mechanism: the function pauses execution via + the interrupt service, waits for a human message, and records the + decision in state.data for conditional routing. + """ + interrupt_service = getattr(ctx, "_interrupt_service", None) + session_id = ctx.session.id if ctx.session else None + + if interrupt_service and session_id: + # Check for pre-queued interrupt message + message = await interrupt_service.check_interrupt(session_id) + if message: + approved = message.action == "approve" + state.data["approved"] = approved + state.data["review_feedback"] = message.text + decision = ( + "APPROVED" if approved else f"REVISION REQUESTED: {message.text}" + ) + return f"[Review] {decision}" + + # No interrupt service or no message β€” auto-approve (for testing) + state.data["approved"] = True + state.data["review_feedback"] = "" + return "[Review] Auto-approved (no interrupt service)" + + +async def publish_fn(state: GraphState, ctx) -> str: + """Finalize and publish the approved content.""" + content = state.data.get("content", "(no content)") + state.data["published"] = True + return f"[Published] {content[:200]}" + + +# --------------------------------------------------------------------------- +# Graph builder +# --------------------------------------------------------------------------- + + +def build_review_graph( + interrupt_service: InterruptService | None = None, + use_llm: bool = False, +) -> GraphAgent: + """Build the content review workflow graph. + + Args: + interrupt_service: Optional interrupt service for HITL. If None, + review_gate auto-approves. + use_llm: If True and LLM is available, use LlmAgent for draft/revise. + Otherwise use deterministic fallback agents. + + Returns: + Configured GraphAgent with draft -> review -> publish flow. + """ + # Select agents based on LLM availability + if use_llm and _has_llm(): + draft_agent = LlmAgent( + name="drafter", + model=_MODEL, + instruction=( + "You are a content writer. Write a short article (2-3 paragraphs) " + "about the topic provided in the user input. Be informative and " + "engaging." + ), + output_key="content", + ) + revise_agent = LlmAgent( + name="reviser", + model=_MODEL, + instruction=( + "You are a content editor. Revise the following draft based on " + "the reviewer's feedback.\n\n" + "Current draft: {content}\n" + "Feedback: {review_feedback}\n\n" + "Produce an improved version." + ), + output_key="content", + ) + else: + draft_agent = DraftAgent(name="drafter") + revise_agent = ReviseAgent(name="reviser") + + interrupt_config = None + if interrupt_service: + interrupt_config = InterruptConfig( + mode=InterruptMode.BEFORE, + nodes=["review_gate"], + ) + + graph = GraphAgent( + name="content_review", + description="Content generation with human review loop", + max_iterations=10, + interrupt_service=interrupt_service, + interrupt_config=interrupt_config, + ) + + # Nodes + graph.add_node( + "draft", + agent=draft_agent, + output_mapper=lambda output, s: (s.data.update({"content": output}), s)[ + 1 + ], + ) + graph.add_node( + GraphNode( + name="review_gate", + function=review_gate_fn, + ) + ) + graph.add_node( + "revise", + agent=revise_agent, + input_mapper=lambda s: ( + "Revise this draft based on feedback.\n" + f"Draft: {s.data.get('content', '')}\n" + f"Feedback: {s.data.get('review_feedback', '')}" + ), + output_mapper=lambda output, s: (s.data.update({"content": output}), s)[ + 1 + ], + ) + graph.add_node( + GraphNode( + name="publish", + function=publish_fn, + ) + ) + + # Edges + graph.set_start("draft") + graph.add_edge("draft", "review_gate") + graph.add_edge( + "review_gate", + "publish", + condition=lambda s: s.data.get("approved") is True, + ) + graph.add_edge( + "review_gate", + "revise", + condition=lambda s: s.data.get("approved") is False, + ) + graph.add_edge("revise", "review_gate") + graph.set_end("publish") + + return graph + + +# --------------------------------------------------------------------------- +# Main β€” simulates human reviewer +# --------------------------------------------------------------------------- + + +async def main() -> None: + """Run content review workflow with simulated human interaction.""" + print("=" * 60) + print("HITL Content Review Workflow") + print("=" * 60) + + session_service = InMemorySessionService() + interrupt_service = InterruptService() + session_id = "hitl-review-demo" + + graph = build_review_graph( + interrupt_service=interrupt_service, + use_llm=_has_llm(), + ) + + await session_service.create_session( + app_name="content_review", + user_id="reviewer", + session_id=session_id, + ) + + runner = Runner( + app_name="content_review", + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + interrupt_service.register_session(session_id) + + # Simulate human: first reject with feedback, then approve on revision + # Pre-queue messages (in production these come from a UI/API) + await interrupt_service.send_message( + session_id, + "Needs more concrete examples and a conclusion paragraph", + action="revise", + ) + await interrupt_service.send_message( + session_id, + "Looks good now", + action="approve", + ) + + print("\nTopic: 'Graph-based AI Agent Workflows'") + print("-" * 40) + + async for event in runner.run_async( + user_id="reviewer", + session_id=session_id, + new_message=types.Content( + parts=[types.Part(text="Graph-based AI Agent Workflows")] + ), + ): + if not event.content or not event.content.parts: + continue + text = event.content.parts[0].text or "" + if not text or text.startswith("[GraphMetadata]"): + continue + author = event.author or "system" + print(f"[{author}] {text[:300]}") + + # Show final state + session = await session_service.get_session( + app_name="content_review", + user_id="reviewer", + session_id=session_id, + ) + if session: + graph_data = session.state.get("graph_data", {}) + print(f"\nPublished: {graph_data.get('published', False)}") + print(f"Review iterations: {graph_data.get('review_feedback', 'N/A')}") + + interrupt_service.unregister_session(session_id) + print("\nDone.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_multi_agent/README.md b/contributing/samples/graph_agent_multi_agent/README.md new file mode 100644 index 0000000000..107121cf2c --- /dev/null +++ b/contributing/samples/graph_agent_multi_agent/README.md @@ -0,0 +1,98 @@ +# GraphAgent Multi-Agent Research Workflow + +Demonstrates a **multi-agent coordination pattern** using GraphAgent: +parallel research branches, sequential coordination, and a quality-review loop. + +## Graph Structure + +``` +coordinator + ↓ +[researcher_a β•‘ researcher_b] ← ParallelNodeGroup (WAIT_ALL) + ↓ +merger + ↓ +critic ──REVISE──→ merger + β”‚ + APPROVED + ↓ + END +``` + +## What Each Agent Does + +| Agent | Role | Output key | +|-------|------|-----------| +| coordinator | Splits topic into two subtopics | `subtopics` | +| researcher_a | Investigates subtopic A concurrently | `research_a` | +| researcher_b | Investigates subtopic B concurrently | `research_b` | +| merger | Synthesises findings into one report | `merged_report` | +| critic | Peer-reviews; routes to merger (REVISE) or ends (APPROVED) | `review` | + +## When to Use + +- Tasks that decompose into independent parallel workstreams +- Workflows needing a quality-review loop after synthesis +- Any pattern mixing sequential coordination, parallelism, and conditional loops + +## Comparison with Other Workflow Agents + +| Capability | SequentialAgent | ParallelAgent | **GraphAgent** | +|------------|----------------|---------------|----------------| +| Run researcher_a and researcher_b concurrently | βœ— | βœ… | βœ… | +| Pre-coordination step before parallel work | βœ— | βœ— | βœ… | +| Post-merge step after parallel work | βœ— | βœ— | βœ… | +| Conditional quality loop (critic β†’ merger) | βœ— | βœ— | βœ… | +| Inspect state to decide routing | βœ— | βœ— | βœ… | + +**ParallelAgent** can fan out but cannot add a coordinator before or a +critic-loop after β€” it has no concept of entry/exit coordination nodes. +**SequentialAgent** executes in a fixed order and cannot parallelise the +two researchers. + +## Key Code + +```python +# Register parallel group: researcher_a and researcher_b run concurrently +graph.add_parallel_group( + "researchers", + ParallelNodeGroup( + nodes=["researcher_a", "researcher_b"], + join_strategy=JoinStrategy.WAIT_ALL, + ), +) + +# Edges fan-out from coordinator to both researchers +graph.add_edge("coordinator", "researcher_a") +graph.add_edge("coordinator", "researcher_b") + +# Both researchers converge at merger +graph.add_edge("researcher_a", "merger") +graph.add_edge("researcher_b", "merger") + +# Conditional quality loop +graph.add_edge("critic", "merger", condition=lambda s: s.data.get("review","").startswith("REVISE")) +graph.set_end("critic") +``` + +## State Isolation + +During parallel execution each branch (`researcher_a`, `researcher_b`) receives +an **isolated copy** of the shared state. Both write to independent output keys +(`research_a`, `research_b`), so there are no race conditions. After both +complete, states are merged automatically before `merger` runs. + +## How to Run + +```bash +cd /path/to/adk-python +source venv/bin/activate +export GOOGLE_API_KEY= +python -m contributing.samples.graph_agent_multi_agent.agent +``` + +## Related Examples + +- `contributing/samples/graph_examples/09_parallel_wait_all` β€” parallel basics +- `contributing/samples/graph_examples/14_parallel_rewind` β€” parallel + rewind +- `contributing/samples/graph_agent_advanced` β€” full research workflow with interrupts diff --git a/contributing/samples/graph_agent_multi_agent/__init__.py b/contributing/samples/graph_agent_multi_agent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contributing/samples/graph_agent_multi_agent/agent.py b/contributing/samples/graph_agent_multi_agent/agent.py new file mode 100644 index 0000000000..3c2f4a86a5 --- /dev/null +++ b/contributing/samples/graph_agent_multi_agent/agent.py @@ -0,0 +1,262 @@ +"""GraphAgent multi-agent research workflow example. + +Demonstrates a coordinator β†’ parallel researchers β†’ merger β†’ critic loop: + + coordinator + ↓ + [researcher_a || researcher_b] (ParallelNodeGroup, WAIT_ALL) + ↓ + merger + ↓ + critic ──REVISE──→ merger + β”‚ + APPROVED + ↓ + END + +Why GraphAgent (not ParallelAgent/SequentialAgent)? +- SequentialAgent: cannot run researcher_a and researcher_b concurrently. +- ParallelAgent: parallelises but cannot add coordinator before or critic+loop after. +- GraphAgent: combines sequential coordination, true parallel research, AND a + conditional quality-review loop in one declarative graph. + +Run (requires GOOGLE_API_KEY env var): + python -m contributing.samples.graph_agent_multi_agent.agent +""" + +import asyncio +import os + +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.agents.graph import JoinStrategy +from google.adk.agents.graph import ParallelNodeGroup +from google.adk.agents.graph import StateReducer +from google.adk.agents.llm_agent import LlmAgent +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +from pydantic import BaseModel + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +# --------------------------------------------------------------------------- +# Output Schemas +# --------------------------------------------------------------------------- + + +class ReviewResult(BaseModel): + """Structured review output from critic agent.""" + + decision: str # "approve" or "revise" + feedback: str # Review comments + + +# --------------------------------------------------------------------------- +# Agents +# --------------------------------------------------------------------------- + +coordinator = LlmAgent( + name="coordinator", + model=_MODEL, + instruction=( + "You are a research coordinator. Given a research topic, split it into" + " exactly two independent subtopics for parallel investigation. Output" + " each subtopic on its own line prefixed with 'SUBTOPIC A:' and" + " 'SUBTOPIC B:'." + ), + output_key="subtopics", +) + +researcher_a = LlmAgent( + name="researcher_a", + model=_MODEL, + instruction=( + "You are a researcher specialising in the first subtopic. " + "Write a concise research summary (3-5 sentences) with key findings." + ), + output_key="research_a", +) + +researcher_b = LlmAgent( + name="researcher_b", + model=_MODEL, + instruction=( + "You are a researcher specialising in the second subtopic. " + "Write a concise research summary (3-5 sentences) with key findings." + ), + output_key="research_b", +) + +merger = LlmAgent( + name="merger", + model=_MODEL, + instruction=( + "You are a synthesis expert. Merge the two research summaries into a " + "single coherent report. Highlight complementary insights." + ), + output_key="merged_report", +) + +critic = LlmAgent( + name="critic", + model=_MODEL, + instruction=( + "You are a peer reviewer. Evaluate the merged report for clarity, " + "completeness, and accuracy. " + 'Return {"decision": "approve", "feedback": "..."} if good, ' + 'or {"decision": "revise", "feedback": "explanation..."} if needs work.' + ), + output_schema=ReviewResult, # Structured output + # output_key auto-defaults to "critic" (agent name) +) + + +# --------------------------------------------------------------------------- +# Routing predicates +# --------------------------------------------------------------------------- + + +def _needs_revision(state: GraphState) -> bool: + """Check if critic requested revision using structured output.""" + review = state.get_parsed("critic", ReviewResult) + return review.decision.lower() == "revise" if review else False + + +# --------------------------------------------------------------------------- +# Graph +# --------------------------------------------------------------------------- + + +def build_multi_agent_graph() -> GraphAgent: + graph = GraphAgent( + name="research_graph", + description=( + "Multi-agent research with parallel execution and quality loop" + ), + max_iterations=20, + ) + + graph.add_node("coordinator", agent=coordinator) + + graph.add_node( + "researcher_a", + agent=researcher_a, + # Both researchers see the same coordinator output + input_mapper=lambda s: s.data.get("subtopics", ""), + reducer=StateReducer.OVERWRITE, + ) + graph.add_node( + "researcher_b", + agent=researcher_b, + input_mapper=lambda s: s.data.get("subtopics", ""), + reducer=StateReducer.OVERWRITE, + ) + + graph.add_node( + "merger", + agent=merger, + input_mapper=lambda s: ( + f"Research A:\n{s.data.get('research_a', '')}\n\n" + f"Research B:\n{s.data.get('research_b', '')}" + ), + reducer=StateReducer.OVERWRITE, + ) + graph.add_node("critic", agent=critic) + + # Register parallel group so branches execute concurrently + graph.add_parallel_group( + "researchers", + ParallelNodeGroup( + nodes=["researcher_a", "researcher_b"], + join_strategy=JoinStrategy.WAIT_ALL, + ), + ) + + graph.set_start("coordinator") + graph.add_edge("coordinator", "researcher_a") + graph.add_edge("coordinator", "researcher_b") + graph.add_edge("researcher_a", "merger") + graph.add_edge("researcher_b", "merger") + graph.add_edge("merger", "critic") + + # Quality loop: revise if not approved + graph.add_edge("critic", "merger", condition=_needs_revision) + + graph.set_end("critic") + + return graph + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def main() -> None: + session_service = InMemorySessionService() + graph = build_multi_agent_graph() + + session = await session_service.create_session( + app_name="research_graph", user_id="user1" + ) + + topic = "The impact of large language models on software engineering" + print(f"Research topic: {topic}\n") + + # Use Runner instead of manual invocation context + runner = Runner( + app_name="research_graph", + agent=graph, + session_service=session_service, + auto_create_session=False, # Session already created above + ) + + revision_count = 0 + async for event in runner.run_async( + user_id="user1", + session_id=session.id, + new_message=types.Content(parts=[types.Part(text=topic)]), + ): + if not event.content or not event.content.parts: + continue + author = event.author + text = event.content.parts[0].text or "" + if author == "coordinator": + print("Coordinator assigned subtopics.") + elif author in ("researcher_a", "researcher_b"): + print(f" [{author}] research complete ({len(text)} chars)") + elif author == "merger": + revision_count += 1 + print(f"Merger produced report (revision {revision_count}).") + elif author == "critic": + # Parse critic output from the event text (JSON string) + try: + review = ReviewResult.model_validate_json(text.strip()) + decision = review.decision.upper() + except Exception: + decision = "UNKNOWN (parse error)" + print(f"Critic verdict: {decision}") + + # Re-fetch fresh session state (create_session returns a deepcopy) + fresh_session = await session_service.get_session( + app_name="research_graph", user_id="user1", session_id=session.id + ) + if fresh_session is None: + print( + "WARNING: session_service.get_session returned None, using stale copy" + ) + fresh_session = session + final_data = fresh_session.state.get("graph_data", {}) + final_state = GraphState(data=final_data) + + print("\nFinal merged report:") + print(final_state.get_str("merged_report", "(none)")[:500]) + print("\nFinal review:") + review = final_state.get_parsed("critic", ReviewResult) + print(f"Decision: {review.decision if review else 'none'}") + print(f"Feedback: {review.feedback[:200] if review else 'none'}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_parallel_features/README.md b/contributing/samples/graph_agent_parallel_features/README.md new file mode 100644 index 0000000000..6dda740fec --- /dev/null +++ b/contributing/samples/graph_agent_parallel_features/README.md @@ -0,0 +1,519 @@ +## GraphAgent Parallel Execution & Rewind Features + +Comprehensive demonstration of advanced GraphAgent features including parallel node execution, rewind integration, and edge cases. + +--- + +## Features Demonstrated + +### 1. **Parallel Node Execution** βœ… + +Execute independent nodes concurrently for better throughput. + +**Join Strategies:** +- `WAIT_ALL`: Wait for all nodes to complete (default) +- `WAIT_ANY`: Proceed when first node completes (race condition) +- `WAIT_N`: Wait for N nodes to complete (e.g., 2 out of 3) + +**Error Policies:** +- `FAIL_FAST`: Cancel all on first error +- `CONTINUE`: Continue others on error +- `COLLECT`: Collect all errors + +**Code Example:** +```python +# Create parallel group +graph.add_parallel_group( + "fetch_group", + ParallelNodeGroup( + nodes=["fetch_users", "fetch_products", "fetch_orders"], + join_strategy=JoinStrategy.WAIT_ALL, + error_policy=ErrorPolicy.FAIL_FAST, + ), +) + +# All three nodes execute concurrently +graph.add_edge("validate", "fetch_users") +graph.add_edge("validate", "fetch_products") +graph.add_edge("validate", "fetch_orders") +``` + +--- + +### 2. **Rewind Integration with Parallel Workflows** βœ… + +Rewind to any node, even those that trigger parallel execution. + +**How It Works:** +- GraphAgent tracks invocation IDs per node +- `rewind_to_node()` restores session state to before node execution +- Re-execution from rewind point re-runs parallel groups + +**Code Example:** +```python +# Execute workflow +async for event in runner.run_async(...): + pass + +# Rewind to a node that triggers parallel execution +await graph.rewind_to_node( + session_service, + app_name="my_app", + user_id="user1", + session_id="session1", + node_name="fetch_users", # Part of parallel group + invocation_index=-1, # Last invocation +) + +# Re-execute - parallel group runs again +async for event in runner.run_async(...): + pass +``` + +--- + +### 3. **Checkpointing with Parallel Execution** βœ… + +Checkpoints capture state after parallel branches complete. + +**Architecture:** +- Parallel branches have isolated state during execution +- After all branches complete, state is merged +- Checkpoint created after merge includes all results + +**Code Example:** +```python +# Enable checkpointing +graph = GraphAgent( + name="workflow", + checkpointing=True, +) + +# Checkpoints created automatically after each node +# Including after parallel groups complete +``` + +--- + +### 4. **Interrupts During Parallel Execution** βœ… + +Interrupts can cancel all parallel branches immediately. + +**Behavior:** +- InterruptService can mark session as cancelled +- GraphAgent checks for cancellation between events +- All parallel branches stop immediately +- Partial state preserved for potential resume + +**Code Example:** +```python +# During parallel execution, send interrupt +await interrupt_service.send_interrupt( + session_id=session.id, + text="User requested abort", + action="continue", # Or "pause", "rerun", etc. +) + +# GraphAgent detects cancellation, stops all branches +# State saved: {graph_cancelled: true, ...} +``` + +--- + +## Architectural Considerations + +### State Isolation in Parallel Branches + +**Problem:** How to prevent race conditions when multiple nodes modify state? + +**Solution:** Each parallel branch gets an **isolated copy** of the state. + +```python +# Parallel execution pseudocode +for node_name in parallel_group.nodes: + # Create isolated state copy + branch_state = GraphState( + data=state.data.copy() + ) + + # Execute node with isolated state + execute_node(node, branch_state, ctx) + +# After all complete, merge results +merged_state = merge(branch_states) +``` + +**Benefits:** +- No race conditions +- Deterministic behavior +- Branches can't interfere with each other + +--- + +### Rewind with Merged States + +**Question:** Can rewind work if parallel branches have been merged? + +**Answer:** YES! Here's how: + +1. **During Execution:** + - Parallel branches have isolated state + - Each branch emits events independently + - Results merged after all complete + +2. **Invocation Tracking:** + - GraphAgent tracks invocation IDs per node + - Parallel nodes each get their own invocation ID + - These IDs persist in session state + +3. **Rewind Process:** + - `rewind_to_node()` identifies the invocation ID + - Uses `Runner.rewind_async()` to restore session state + - State reverted to BEFORE node execution + - Re-execution re-runs parallel group from scratch + +**Example:** +```python +# Execution creates invocations: +# { +# "validate": ["inv_1"], +# "fetch_users": ["inv_2"], +# "fetch_products": ["inv_3"], +# "aggregate": ["inv_4"] +# } + +# Rewind to fetch_users (inv_2) +await graph.rewind_to_node(..., node_name="fetch_users", invocation_index=-1) + +# State restored to BEFORE inv_2 +# Re-execution will: +# 1. Run fetch_users (new invocation) +# 2. Run fetch_products in parallel (new invocation) +# 3. Run aggregate (new invocation) +``` + +--- + +### Session State Communication + +**Question:** What if session state is not communicated between parallel branches? + +**Answer:** This is BY DESIGN for safety! + +**Rationale:** +- Parallel branches should be **independent** +- Shared mutable state leads to race conditions +- Isolation ensures deterministic results + +**Communication Patterns:** + +1. **Before Parallel Execution:** + ```python + # All branches start with same initial state + state = GraphState(data={"shared_input": "value"}) + ``` + +2. **After Parallel Execution:** + ```python + # Merge results using StateReducer + merged_state = reducer(branch_results) + ``` + +3. **If Communication Needed:** + - Use separate coordination node + - Don't put nodes in parallel group + - Use sequential edges with conditions + +--- + +### Interrupts During Parallel Execution + +**Question:** What happens if we interrupt during parallel execution? + +**Answer:** Clean cancellation with state preservation. + +**Flow:** +1. User sends interrupt (or timeout triggers) +2. `InterruptService` marks session as inactive +3. GraphAgent checks `is_active()` between events +4. All parallel branches detect cancellation +5. Tasks cancelled via `task.cancel()` +6. Partial state saved to session: + ```python + { + "graph_cancelled": True, + "graph_cancelled_at_node": "fetch_users", + "graph_iteration": 2, + "graph_data": {...}, # Partial domain data + "graph_can_resume": True + } + ``` + +**Resume After Interrupt:** +```python +# State preserved, can resume from checkpoint +# Or restart from beginning +# Or rewind to specific point +``` + +--- + +## Scenarios + +### Scenario 1: Parallel Execution (WAIT_ALL) +Fetch data from 3 sources concurrently, wait for all. + +**Workflow:** +``` +validate β†’ (fetch_users || fetch_products || fetch_orders) β†’ aggregate +``` + +### Scenario 2: Parallel Execution (WAIT_ANY) +Race 3 data sources, proceed with first to complete. + +**Workflow:** +``` +validate β†’ (fetch_cache || fetch_db || fetch_api) β†’ transform +``` + +### Scenario 3: Parallel Execution (WAIT_N) +Run 3 ML models, proceed when 2 out of 3 complete. + +**Workflow:** +``` +validate β†’ (model1 || model2 || model3) β†’ aggregate +``` + +### Scenario 4: Rewind with Parallel Execution +Execute workflow, rewind to parallel group, re-execute. + +**Demonstrates:** +- Invocation tracking across parallel nodes +- Rewind restores state before parallel execution +- Re-execution runs parallel group again + +### Scenario 5: Checkpointing with Parallel Execution +Enable checkpointing, execute workflow with parallel nodes. + +**Demonstrates:** +- Checkpoints created after each node +- Parallel group checkpoint captures merged state +- Resume from checkpoint works correctly + +### Scenario 6: Interrupts During Parallel Execution +Show interrupt behavior and considerations. + +**Demonstrates:** +- How interrupts cancel all parallel branches +- State preservation on cancellation +- Architecture for resume capability + +### Scenario 7: State Isolation in Parallel Branches +Parallel branches modify same state key, show isolation. + +**Demonstrates:** +- Each branch has isolated state +- No race conditions +- Deterministic results + +--- + +## Running the Examples + +```bash +# Run all scenarios +python -m contributing.samples.graph_agent_parallel_features.agent + +# Or run from the adk-python directory +cd /path/to/adk-python +source venv/bin/activate +python -m contributing.samples.graph_agent_parallel_features.agent +``` + +--- + +## Expected Output + +``` +╔══════════════════════════════════════════════════════════════════════════════╗ +β•‘ β•‘ +β•‘ GraphAgent Parallel Execution & Rewind Features - Comprehensive Demo β•‘ +β•‘ β•‘ +β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β• + +================================================================================ +SCENARIO 1: Parallel Execution with WAIT_ALL +================================================================================ + +πŸ“Š Executing workflow with parallel fetch operations... + Strategy: WAIT_ALL (wait for all 3 fetches to complete) + + βœ… Data validation passed + βœ… Fetched 3 records from users_db + βœ… Fetched 3 records from products_db + βœ… Fetched 3 records from orders_db + βœ… Aggregated results from all sources + +βœ… Scenario 1 complete: 5 events emitted + Note: All 3 fetch operations ran concurrently! + +... (more scenarios) ... + +================================================================================ +βœ… ALL SCENARIOS COMPLETE +================================================================================ + +Key Takeaways: +1. Parallel execution works with WAIT_ALL, WAIT_ANY, WAIT_N strategies +2. Rewind integration works - can rewind to nodes that trigger parallel groups +3. Checkpointing captures state after parallel branches complete +4. Interrupts can cancel parallel execution (state preserved) +5. Parallel branches have isolated state (no race conditions) + +Architectural Answers: +- Q: Can rewind work with parallel execution? + A: YES! Rewind restores to before node execution, re-runs parallel group +- Q: What about session state communication? + A: Branches are isolated during execution, merged after completion +- Q: What if we interrupt during parallel execution? + A: All branches cancelled, partial state saved for resume +``` + +--- + +## Architecture Diagrams + +### Parallel Execution Flow + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ validate β”‚ +β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ β”‚ β”‚ + β–Ό β–Ό β–Ό +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ fetch_users β”‚ β”‚fetch_productsβ”‚ β”‚ fetch_orders β”‚ +β”‚ (isolated) β”‚ β”‚ (isolated) β”‚ β”‚ (isolated) β”‚ +β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ β”‚ β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β–Ό + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ aggregate β”‚ + β”‚ (merged state)β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +### Rewind Flow + +``` +1. Initial Execution: + validate β†’ parallel_group β†’ aggregate + +2. Rewind to parallel_group: + [Session State] ← Restore ← Before parallel_group + +3. Re-execution: + parallel_group β†’ aggregate + (New invocations created) +``` + +### State Isolation + +``` +Main State: +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ data: {input: "value"} β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”‚ copy() + β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β–Ό β–Ό β–Ό β–Ό + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ Branch 1 β”‚ β”‚ Branch 2 β”‚ β”‚ Branch 3 β”‚ β”‚ ... β”‚ + β”‚ (isolated)β”‚ β”‚ (isolated)β”‚ β”‚ (isolated)β”‚ β”‚ β”‚ + β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜ + β”‚ β”‚ β”‚ β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β–Ό + Merge Results + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ Merged State β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +--- + +## Performance Considerations + +### Parallel vs Sequential + +**Sequential Execution:** +``` +Total time = sum(all node execution times) +Example: 100ms + 150ms + 200ms = 450ms +``` + +**Parallel Execution (WAIT_ALL):** +``` +Total time = max(all node execution times) +Example: max(100ms, 150ms, 200ms) = 200ms +Speedup: 2.25x faster +``` + +**Parallel Execution (WAIT_ANY):** +``` +Total time = min(all node execution times) +Example: min(100ms, 150ms, 200ms) = 100ms +Speedup: 4.5x faster (but only uses first result) +``` + +--- + +## Best Practices + +1. **Use Parallel Execution When:** + - Nodes are independent (no data dependencies) + - Operations are I/O bound (API calls, DB queries) + - Order doesn't matter + +2. **Avoid Parallel Execution When:** + - Nodes have sequential dependencies + - Order matters for correctness + - Shared mutable resources (use locks) + +3. **Join Strategy Selection:** + - `WAIT_ALL`: When you need all results + - `WAIT_ANY`: When any result is acceptable (cache/DB/API fallback) + - `WAIT_N`: When you need quorum (ML ensemble, consensus) + +4. **Error Handling:** + - `FAIL_FAST`: When any failure invalidates the entire operation + - `CONTINUE`: When partial results are acceptable + - `COLLECT`: When you need to analyze all failures + +5. **State Management:** + - Don't rely on shared mutable state in parallel branches + - Use isolated state copies (automatic) + - Merge results after completion + +--- + +## Related Examples + +- `graph_agent_basic` - Basic GraphAgent workflow +- `graph_agent_advanced` - Interrupts, checkpointing, callbacks +- `graph_agent_builder` - Graph construction patterns + +--- + +## References + +- GraphAgent: `src/google/adk/agents/graph/graph_agent.py` +- Parallel Execution: `src/google/adk/agents/graph/parallel.py` +- Rewind Integration: `src/google/adk/agents/graph/graph_agent.py:rewind_to_node()` +- Tests: `tests/unittests/agents/test_graph_parallel.py` diff --git a/contributing/samples/graph_agent_parallel_features/__init__.py b/contributing/samples/graph_agent_parallel_features/__init__.py new file mode 100644 index 0000000000..7b8c6ccbe0 --- /dev/null +++ b/contributing/samples/graph_agent_parallel_features/__init__.py @@ -0,0 +1,9 @@ +"""GraphAgent Parallel Execution Features Example. + +Demonstrates: +- Parallel node execution (WAIT_ALL, WAIT_ANY, WAIT_N) +- Rewind integration with parallel workflows +- Interrupts during parallel execution +- Checkpointing with parallel branches +- Edge cases and architectural considerations +""" diff --git a/contributing/samples/graph_agent_parallel_features/agent.py b/contributing/samples/graph_agent_parallel_features/agent.py new file mode 100644 index 0000000000..f243dfc4e1 --- /dev/null +++ b/contributing/samples/graph_agent_parallel_features/agent.py @@ -0,0 +1,820 @@ +"""GraphAgent Parallel Execution & Rewind Features - Comprehensive Examples. + +This example demonstrates advanced GraphAgent features including: +1. Parallel node execution with different join strategies +2. Rewind integration with parallel workflows +3. Interrupts during parallel execution +4. Checkpointing with parallel branches +5. Edge cases and architectural considerations + +Run with: python -m contributing.samples.graph_agent_parallel_features.agent +""" + +import asyncio +import json +from typing import Optional + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import EdgeCondition +from google.adk.agents.graph import END +from google.adk.agents.graph import ErrorPolicy +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.agents.graph import InterruptConfig +from google.adk.agents.graph import InterruptMode +from google.adk.agents.graph import JoinStrategy +from google.adk.agents.graph import ParallelNodeGroup +from google.adk.agents.graph import rewind_to_node +from google.adk.agents.graph import START +from google.adk.agents.graph import StateReducer +from google.adk.checkpoints import CheckpointService +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from google import genai + +# ============================================================================ +# Test Agents for Demonstrations +# ============================================================================ + + +class DataFetchAgent(BaseAgent): + """Simulates fetching data from an API.""" + + def __init__( + self, name: str, data_source: str, delay_ms: int = 100, **kwargs + ): + super().__init__(name=name, **kwargs) + self._data_source = data_source + self._delay_ms = delay_ms + + async def _run_async_impl(self, ctx): + # Simulate API call delay + await asyncio.sleep(self._delay_ms / 1000.0) + + data = { + "source": self._data_source, + "records": [f"{self._data_source}_record_{i}" for i in range(3)], + "timestamp": "2026-02-08T12:00:00Z", + } + + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=( + f"βœ… Fetched {len(data['records'])} records from" + f" {self._data_source}" + ) + ) + ] + ), + ) + + +class ValidationAgent(BaseAgent): + """Validates fetched data.""" + + def __init__(self, name: str, **kwargs): + super().__init__(name=name, **kwargs) + + async def _run_async_impl(self, ctx): + # Simulate validation logic + await asyncio.sleep(0.05) + + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text="βœ… Data validation passed")] + ), + ) + + +class TransformAgent(BaseAgent): + """Transforms data.""" + + def __init__(self, name: str, transformation: str, **kwargs): + super().__init__(name=name, **kwargs) + self._transformation = transformation + + async def _run_async_impl(self, ctx): + await asyncio.sleep(0.05) + + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=f"βœ… Applied transformation: {self._transformation}" + ) + ] + ), + ) + + +class AggregationAgent(BaseAgent): + """Aggregates results from multiple sources.""" + + def __init__(self, name: str, **kwargs): + super().__init__(name=name, **kwargs) + + async def _run_async_impl(self, ctx): + await asyncio.sleep(0.05) + + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text="βœ… Aggregated results from all sources")] + ), + ) + + +# ============================================================================ +# Scenario 1: Basic Parallel Execution (WAIT_ALL) +# ============================================================================ + + +async def scenario_1_parallel_wait_all(): + """Demonstrate parallel execution with WAIT_ALL strategy. + + Workflow: + validate -> (fetch_users || fetch_products || fetch_orders) -> aggregate + + All three fetch operations run concurrently and we wait for all to complete. + """ + print("\n" + "=" * 80) + print("SCENARIO 1: Parallel Execution with WAIT_ALL") + print("=" * 80) + + # Create agents + validate = ValidationAgent(name="validate") + fetch_users = DataFetchAgent( + name="fetch_users", data_source="users_db", delay_ms=150 + ) + fetch_products = DataFetchAgent( + name="fetch_products", data_source="products_db", delay_ms=100 + ) + fetch_orders = DataFetchAgent( + name="fetch_orders", data_source="orders_db", delay_ms=200 + ) + aggregate = AggregationAgent(name="aggregate") + + # Build graph + graph = GraphAgent(name="parallel_workflow") + graph.add_node("validate", agent=validate) + graph.add_node("fetch_users", agent=fetch_users) + graph.add_node("fetch_products", agent=fetch_products) + graph.add_node("fetch_orders", agent=fetch_orders) + graph.add_node("aggregate", agent=aggregate) + + # Add parallel group for fetch operations + graph.add_parallel_group( + "fetch_group", + ParallelNodeGroup( + nodes=["fetch_users", "fetch_products", "fetch_orders"], + join_strategy=JoinStrategy.WAIT_ALL, # Wait for all to complete + error_policy=ErrorPolicy.FAIL_FAST, # Cancel all if one fails + ), + ) + + # Setup edges + graph.add_edge("validate", "fetch_users") + graph.add_edge("validate", "fetch_products") + graph.add_edge("validate", "fetch_orders") + graph.add_edge("fetch_users", "aggregate") + graph.add_edge("fetch_products", "aggregate") + graph.add_edge("fetch_orders", "aggregate") + + graph.set_start("validate") + graph.set_end("aggregate") + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="parallel_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("\nπŸ“Š Executing workflow with parallel fetch operations...") + print(" Strategy: WAIT_ALL (wait for all 3 fetches to complete)\n") + + events = [] + new_message = types.Content(parts=[types.Part(text="Start data pipeline")]) + async for event in runner.run_async( + user_id="demo_user", session_id="scenario1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + events.append(event) + + print(f"\nβœ… Scenario 1 complete: {len(events)} events emitted") + print(" Note: All 3 fetch operations ran concurrently!") + + +# ============================================================================ +# Scenario 2: Parallel Execution with WAIT_ANY (Race Condition) +# ============================================================================ + + +async def scenario_2_parallel_wait_any(): + """Demonstrate parallel execution with WAIT_ANY strategy. + + Workflow: + validate -> (fetch_cache || fetch_db || fetch_api) -> transform + + Three data sources race, and we proceed with whichever returns first. + """ + print("\n" + "=" * 80) + print("SCENARIO 2: Parallel Execution with WAIT_ANY (Race)") + print("=" * 80) + + # Create agents with different speeds + validate = ValidationAgent(name="validate") + fetch_cache = DataFetchAgent( + name="fetch_cache", data_source="cache", delay_ms=50 + ) + fetch_db = DataFetchAgent( + name="fetch_db", data_source="database", delay_ms=150 + ) + fetch_api = DataFetchAgent( + name="fetch_api", data_source="external_api", delay_ms=300 + ) + transform = TransformAgent(name="transform", transformation="normalize") + + # Build graph + graph = GraphAgent(name="race_workflow") + graph.add_node("validate", agent=validate) + graph.add_node("fetch_cache", agent=fetch_cache) + graph.add_node("fetch_db", agent=fetch_db) + graph.add_node("fetch_api", agent=fetch_api) + graph.add_node("transform", agent=transform) + + # Add parallel group with WAIT_ANY + graph.add_parallel_group( + "race_group", + ParallelNodeGroup( + nodes=["fetch_cache", "fetch_db", "fetch_api"], + join_strategy=JoinStrategy.WAIT_ANY, # First to complete wins + ), + ) + + # Setup edges + graph.add_edge("validate", "fetch_cache") + graph.add_edge("validate", "fetch_db") + graph.add_edge("validate", "fetch_api") + graph.add_edge("fetch_cache", "transform") + graph.add_edge("fetch_db", "transform") + graph.add_edge("fetch_api", "transform") + + graph.set_start("validate") + graph.set_end("transform") + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="race_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("\n🏁 Executing workflow with racing data sources...") + print(" Strategy: WAIT_ANY (first to complete wins)\n") + + new_message = types.Content(parts=[types.Part(text="Start race")]) + async for event in runner.run_async( + user_id="demo_user", session_id="scenario2", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + print("\nβœ… Scenario 2 complete") + print(" Note: Cache was fastest! Other fetches were cancelled.") + + +# ============================================================================ +# Scenario 3: Parallel Execution with WAIT_N +# ============================================================================ + + +async def scenario_3_parallel_wait_n(): + """Demonstrate parallel execution with WAIT_N strategy. + + Workflow: + validate -> (ml_model_1 || ml_model_2 || ml_model_3) -> aggregate + + Three ML models run in parallel, we proceed when 2 out of 3 complete. + """ + print("\n" + "=" * 80) + print("SCENARIO 3: Parallel Execution with WAIT_N (2 out of 3)") + print("=" * 80) + + # Create agents with different speeds + validate = ValidationAgent(name="validate") + model1 = TransformAgent(name="model1", transformation="bert_inference") + model2 = TransformAgent(name="model2", transformation="gpt_inference") + model3 = TransformAgent(name="model3", transformation="t5_inference") + aggregate = AggregationAgent(name="aggregate") + + # Simulate different model speeds + model1._delay = 100 # Fast + model2._delay = 150 # Medium + model3._delay = 300 # Slow + + # Build graph + graph = GraphAgent(name="ml_ensemble_workflow") + graph.add_node("validate", agent=validate) + graph.add_node("model1", agent=model1) + graph.add_node("model2", agent=model2) + graph.add_node("model3", agent=model3) + graph.add_node("aggregate", agent=aggregate) + + # Add parallel group with WAIT_N (2 out of 3) + graph.add_parallel_group( + "ml_ensemble", + ParallelNodeGroup( + nodes=["model1", "model2", "model3"], + join_strategy=JoinStrategy.WAIT_N, + wait_n=2, # Wait for 2 out of 3 + ), + ) + + # Setup edges + graph.add_edge("validate", "model1") + graph.add_edge("validate", "model2") + graph.add_edge("validate", "model3") + graph.add_edge("model1", "aggregate") + graph.add_edge("model2", "aggregate") + graph.add_edge("model3", "aggregate") + + graph.set_start("validate") + graph.set_end("aggregate") + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="ml_ensemble_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("\nπŸ€– Executing ML ensemble workflow...") + print(" Strategy: WAIT_N (proceed when 2 out of 3 models complete)\n") + + new_message = types.Content(parts=[types.Part(text="Start inference")]) + async for event in runner.run_async( + user_id="demo_user", session_id="scenario3", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + print("\nβœ… Scenario 3 complete") + print(" Note: Proceeded after 2 models finished (model3 was cancelled)") + + +# ============================================================================ +# Scenario 4: Rewind Integration with Parallel Execution +# ============================================================================ + + +async def scenario_4_rewind_with_parallel(): + """Demonstrate rewind integration with parallel workflows. + + Tests rewinding to a node that triggers parallel execution. + + Architectural Note: + - Rewind restores session state to before a specific invocation + - Each parallel branch has isolated state during execution + - Rewinding to the start of a parallel group will re-execute all branches + """ + print("\n" + "=" * 80) + print("SCENARIO 4: Rewind Integration with Parallel Execution") + print("=" * 80) + + # Create agents + validate = ValidationAgent(name="validate") + fetch_users = DataFetchAgent( + name="fetch_users", data_source="users", delay_ms=50 + ) + fetch_products = DataFetchAgent( + name="fetch_products", data_source="products", delay_ms=50 + ) + aggregate = AggregationAgent(name="aggregate") + + # Build graph + graph = GraphAgent(name="rewind_parallel_workflow") + graph.add_node("validate", agent=validate) + graph.add_node("fetch_users", agent=fetch_users) + graph.add_node("fetch_products", agent=fetch_products) + graph.add_node("aggregate", agent=aggregate) + + # Add parallel group + graph.add_parallel_group( + "fetch_group", + ParallelNodeGroup(nodes=["fetch_users", "fetch_products"]), + ) + + # Setup edges + graph.add_edge("validate", "fetch_users") + graph.add_edge("validate", "fetch_products") + graph.add_edge("fetch_users", "aggregate") + graph.add_edge("fetch_products", "aggregate") + + graph.set_start("validate") + graph.set_end("aggregate") + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="rewind_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("\nπŸ“ First execution...") + new_message = types.Content(parts=[types.Part(text="Start pipeline")]) + async for event in runner.run_async( + user_id="demo_user", session_id="scenario4", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + # Get session to inspect invocations + session = await session_service.get_session( + app_name="rewind_demo", user_id="demo_user", session_id="scenario4" + ) + node_invocations = session.state.get("node_invocations", {}) + + print(f"\nπŸ“Š Invocation tracking:") + for node_name, invocations in node_invocations.items(): + print(f" - {node_name}: {len(invocations)} invocation(s)") + + # Rewind to validate node (will re-execute parallel group) + print("\nβͺ Rewinding to 'fetch_users' node...") + await rewind_to_node( + graph, + session_service, + app_name="rewind_demo", + user_id="demo_user", + session_id="scenario4", + node_name="fetch_users", + invocation_index=-1, # Last invocation + ) + + print(" βœ… Rewind successful! Session state restored.") + + # Re-execute from rewind point + print("\nπŸ“ Re-execution after rewind...") + async for event in runner.run_async( + user_id="demo_user", session_id="scenario4", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + print("\nβœ… Scenario 4 complete") + print( + " Note: Rewind works with parallel groups - all branches re-executed!" + ) + + +# ============================================================================ +# Scenario 5: Checkpointing with Parallel Execution +# ============================================================================ + + +async def scenario_5_checkpointing_with_parallel(): + """Demonstrate checkpointing with parallel workflows. + + Architectural Note: + - Checkpoints capture session state at specific points + - Parallel branches have isolated state during execution + - After parallel group completes, state is merged back to main session + - Checkpoint created after merge includes all parallel results + """ + print("\n" + "=" * 80) + print("SCENARIO 5: Checkpointing with Parallel Execution") + print("=" * 80) + + # Create agents + validate = ValidationAgent(name="validate") + fetch_users = DataFetchAgent( + name="fetch_users", data_source="users", delay_ms=50 + ) + fetch_products = DataFetchAgent( + name="fetch_products", data_source="products", delay_ms=50 + ) + aggregate = AggregationAgent(name="aggregate") + + # Setup checkpoint service + session_service = InMemorySessionService() + checkpoint_service = CheckpointService(session_service) + + # Build graph with checkpointing enabled + graph = GraphAgent(name="checkpoint_parallel_workflow", checkpointing=True) + graph.add_node("validate", agent=validate) + graph.add_node("fetch_users", agent=fetch_users) + graph.add_node("fetch_products", agent=fetch_products) + graph.add_node("aggregate", agent=aggregate) + + # Add parallel group + graph.add_parallel_group( + "fetch_group", + ParallelNodeGroup(nodes=["fetch_users", "fetch_products"]), + ) + + # Setup edges + graph.add_edge("validate", "fetch_users") + graph.add_edge("validate", "fetch_products") + graph.add_edge("fetch_users", "aggregate") + graph.add_edge("fetch_products", "aggregate") + + graph.set_start("validate") + graph.set_end("aggregate") + + # Execute with checkpointing + runner = Runner( + app_name="checkpoint_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("\nπŸ“ Executing workflow with checkpointing enabled...") + print(" Checkpoints created at each node (including parallel branches)\n") + + new_message = types.Content(parts=[types.Part(text="Start pipeline")]) + async for event in runner.run_async( + user_id="demo_user", session_id="scenario5", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + # Get session + session = await session_service.get_session( + app_name="checkpoint_demo", user_id="demo_user", session_id="scenario5" + ) + + # Check checkpoint data in session state + checkpoint_data = session.state.get("graph_checkpoint", {}) + print(f"\nπŸ“Š Checkpoint data:") + print(f" - Last checkpoint at node: {checkpoint_data.get('node', 'N/A')}") + print(f" - Iteration: {checkpoint_data.get('iteration', 'N/A')}") + + print("\nβœ… Scenario 5 complete") + print(" Note: Checkpoints capture state after parallel branches complete!") + + +# ============================================================================ +# Scenario 6: Edge Case - Interrupts During Parallel Execution +# ============================================================================ + + +async def scenario_6_interrupts_during_parallel(): + """Demonstrate interrupt behavior during parallel execution. + + Architectural Consideration: + - Interrupts can be sent while parallel nodes are executing + - InterruptService can cancel all parallel branches immediately + - State is preserved up to the point of cancellation + - Useful for user-initiated abort or timeout scenarios + + This scenario shows what happens, but doesn't actually send interrupts + during execution (would require manual intervention). + """ + print("\n" + "=" * 80) + print("SCENARIO 6: Interrupt Considerations with Parallel Execution") + print("=" * 80) + + print("\nπŸ“ Architectural Notes:") + print(" 1. Interrupts CAN be sent during parallel execution") + print(" 2. GraphAgent checks for cancellation between events") + print(" 3. Immediate cancellation (ESC-like) stops all parallel branches") + print(" 4. Partial state is saved for potential resume") + print("\n Example interrupt flow:") + print(" - User sends interrupt while parallel nodes execute") + print(" - InterruptService marks session as cancelled") + print(" - GraphAgent detects cancellation, stops all branches") + print(" - State saved: {graph_cancelled: true, cancelled_at_node: ...}") + + # Create a simple parallel workflow + fetch1 = DataFetchAgent(name="fetch1", data_source="source1", delay_ms=500) + fetch2 = DataFetchAgent(name="fetch2", data_source="source2", delay_ms=500) + aggregate = AggregationAgent(name="aggregate") + + graph = GraphAgent(name="interrupt_aware_workflow") + graph.add_node("fetch1", agent=fetch1) + graph.add_node("fetch2", agent=fetch2) + graph.add_node("aggregate", agent=aggregate) + + graph.add_parallel_group( + "fetch_group", + ParallelNodeGroup(nodes=["fetch1", "fetch2"]), + ) + + graph.add_edge("fetch1", "aggregate") + graph.add_edge("fetch2", "aggregate") + graph.set_start("fetch1") + graph.set_end("aggregate") + + # Execute (without actually sending interrupts) + session_service = InMemorySessionService() + runner = Runner( + app_name="interrupt_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("\nπŸ“ Executing workflow (no interrupt sent in this demo)...\n") + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="demo_user", session_id="scenario6", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + print("\nβœ… Scenario 6 complete") + print( + " Note: In production, you could send interrupts via InterruptService" + ) + + +# ============================================================================ +# Scenario 7: Edge Case - State Isolation in Parallel Branches +# ============================================================================ + + +async def scenario_7_state_isolation(): + """Demonstrate state isolation in parallel branches. + + Architectural Detail: + - Each parallel branch gets an ISOLATED copy of the state + - Changes in one branch don't affect others during execution + - After all branches complete, results can be merged + - This prevents race conditions and ensures deterministic behavior + """ + print("\n" + "=" * 80) + print("SCENARIO 7: State Isolation in Parallel Branches") + print("=" * 80) + + class StateModifyingAgent(BaseAgent): + """Agent that modifies state.""" + + def __init__(self, name: str, key: str, value: str, **kwargs): + super().__init__(name=name, **kwargs) + self._key = key + self._value = value + + async def _run_async_impl(self, ctx): + # Modify session state + ctx.session.state[self._key] = self._value + + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"Set {self._key}={self._value}")] + ), + ) + + # Create agents that modify state + branch1 = StateModifyingAgent(name="branch1", key="counter", value="100") + branch2 = StateModifyingAgent(name="branch2", key="counter", value="200") + check = ValidationAgent(name="check") + + # Build graph + graph = GraphAgent(name="state_isolation_workflow") + graph.add_node("branch1", agent=branch1) + graph.add_node("branch2", agent=branch2) + graph.add_node("check", agent=check) + + graph.add_parallel_group( + "parallel_modifiers", + ParallelNodeGroup(nodes=["branch1", "branch2"]), + ) + + graph.add_edge("branch1", "check") + graph.add_edge("branch2", "check") + graph.set_start("branch1") + graph.set_end("check") + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="isolation_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("\nπŸ“ Executing workflow with parallel state modifications...") + print( + " Both branches try to set 'counter' to different values (isolated)\n" + ) + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="demo_user", session_id="scenario7", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + # Check final state + session = await session_service.get_session( + app_name="isolation_demo", user_id="demo_user", session_id="scenario7" + ) + + print(f"\nπŸ“Š Final session state:") + print(f" - counter value: {session.state.get('counter', 'NOT SET')}") + + print("\nβœ… Scenario 7 complete") + print(" Note: Parallel branches have isolated state - no race conditions!") + + +# ============================================================================ +# Main Demo Runner +# ============================================================================ + + +async def main(): + """Run all scenarios.""" + print("\n") + print("β•”" + "═" * 78 + "β•—") + print("β•‘" + " " * 78 + "β•‘") + print( + "β•‘" + + " GraphAgent Parallel Execution & Rewind Features - Comprehensive Demo" + .center(78) + + "β•‘" + ) + print("β•‘" + " " * 78 + "β•‘") + print("β•š" + "═" * 78 + "╝") + + try: + # Run all scenarios + await scenario_1_parallel_wait_all() + await scenario_2_parallel_wait_any() + await scenario_3_parallel_wait_n() + await scenario_4_rewind_with_parallel() + await scenario_5_checkpointing_with_parallel() + await scenario_6_interrupts_during_parallel() + await scenario_7_state_isolation() + + print("\n" + "=" * 80) + print("βœ… ALL SCENARIOS COMPLETE") + print("=" * 80) + print("\nKey Takeaways:") + print( + "1. Parallel execution works with WAIT_ALL, WAIT_ANY, WAIT_N strategies" + ) + print( + "2. Rewind integration works - can rewind to nodes that trigger" + " parallel groups" + ) + print("3. Checkpointing captures state after parallel branches complete") + print("4. Interrupts can cancel parallel execution (state preserved)") + print("5. Parallel branches have isolated state (no race conditions)") + print("\nArchitectural Answers:") + print("- Q: Can rewind work with parallel execution?") + print( + " A: YES! Rewind restores to before node execution, re-runs parallel" + " group" + ) + print("- Q: What about session state communication?") + print( + " A: Branches are isolated during execution, merged after completion" + ) + print("- Q: What if we interrupt during parallel execution?") + print(" A: All branches cancelled, partial state saved for resume") + + except Exception as e: + print(f"\n❌ Error running scenarios: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_pattern_dynamic_node/README.md b/contributing/samples/graph_agent_pattern_dynamic_node/README.md new file mode 100644 index 0000000000..5ff6a6a5fd --- /dev/null +++ b/contributing/samples/graph_agent_pattern_dynamic_node/README.md @@ -0,0 +1,37 @@ +# GraphAgent Pattern β€” DynamicNode (Mixture of Experts) + +This example implements **runtime agent selection** using `DynamicNode`. A classifier labels the +incoming task as SIMPLE or COMPLEX, then `DynamicNode` routes to a cheap fast model or a thorough +capable model accordingly β€” a sparse mixture-of-experts dispatch optimizing cost vs. quality. + +## When to Use This Pattern + +- Cost optimisation: route easy tasks to cheaper models, hard tasks to capable models +- Capability dispatch: pick a specialist agent based on detected task domain +- Fallback chains: try a fast agent first, escalate to a powerful agent on failure + +## How to Run + +```bash +adk run contributing/samples/graph_agent_pattern_dynamic_node +``` + +## Graph Structure + +``` +classify ──▢ respond (DynamicNode) + β”œβ”€β”€ simple_agent (when classify output contains "SIMPLE") + └── detailed_agent (otherwise) +``` + +## Key Code Walkthrough + +- **`DynamicNode(name="respond", agent_selector=select_responder)`** β€” the selector callable + receives `GraphState` and returns the `BaseAgent` to invoke +- **`select_responder(state)`** β€” reads `state.data["classify"]` and returns the matching agent +- **`fallback_agent`** parameter β€” used when the selector returns `None` +- **Transparent to the graph** β€” downstream edges see `respond`'s output regardless of which + agent was chosen +- **No graph-level changes needed** β€” swap agents by changing `select_responder`, not the graph + topology + diff --git a/contributing/samples/graph_agent_pattern_dynamic_node/__init__.py b/contributing/samples/graph_agent_pattern_dynamic_node/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contributing/samples/graph_agent_pattern_dynamic_node/agent.py b/contributing/samples/graph_agent_pattern_dynamic_node/agent.py new file mode 100644 index 0000000000..cddbee8a8f --- /dev/null +++ b/contributing/samples/graph_agent_pattern_dynamic_node/agent.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +"""DynamicNode Pattern: Runtime Agent Selection + +Motivation (Mixture-of-Experts) +-------------------------------- +Shazeer et al. (2017) "Outrageously Large Neural Networks: The Sparsely-Gated +Mixture-of-Experts Layer" showed that routing inputs to specialised experts +beats a single monolithic model while keeping per-token compute fixed. + +The same principle applies to agentic workflows: a *router* classifies the +complexity/type of each task, then a *gating function* selects the cheapest +adequate specialistβ€”a fast flash-model for simple tasks, a slower pro-model +for hard tasks. + +Pattern: DynamicNode +-------------------- +DynamicNode is the first-class API for this pattern. The `agent_selector` +callable runs at runtime, reads the current GraphState, and returns the +appropriate BaseAgent. + +Compare to the function-node alternative +----------------------------------------- +Without DynamicNode you need a function node that manually dispatches: + + async def dispatch(state, ctx): + agent = complex_agent if "hard" in state.data else simple_agent + node_ctx = ctx.model_copy(update={...}) + output = "" + async for event in agent.run_async(node_ctx): + if event.content and event.content.parts: + output = event.content.parts[0].text or "" + return output + +DynamicNode gives you: + βœ… Metadata auto-tracking: which agent was selected (observability) + βœ… Built-in fallback_agent when selector returns None + βœ… Selection logic decoupled from execution boilerplate + +Architecture +------------ + classify ──► route (DynamicNode) ──► end + β”‚ + β”œβ”€ selector returns simple_agent (flash, cheap) + └─ selector returns detailed_agent (pro, thorough) +""" + +import asyncio +import os + +from google.adk.agents import LlmAgent +from google.adk.agents.graph import DynamicNode +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +# --------------------------------------------------------------------------- +# Step 1: Classifier β€” assigns complexity label from the user's request +# --------------------------------------------------------------------------- +classifier = LlmAgent( + name="classifier", + model=_MODEL, + instruction=""" +You are a task complexity classifier. + +Read the user's request and reply with EXACTLY one word: + SIMPLE – if the task is a quick factual lookup or short question + COMPLEX – if the task requires multi-step reasoning, analysis, or code + +Reply with only the word, nothing else. +""", +) + +# --------------------------------------------------------------------------- +# Step 2: Specialists β€” cheap flash model vs thorough pro model +# --------------------------------------------------------------------------- +simple_agent = LlmAgent( + name="simple_responder", + model=_MODEL, + instruction=""" +You are a concise assistant. Answer the user's question briefly (1-3 sentences). +""", +) + +detailed_agent = LlmAgent( + name="detailed_responder", + model=_MODEL, + instruction=""" +You are a thorough analyst. Work through the problem step by step, show your +reasoning, and provide a complete, well-structured answer. +""", +) + + +# --------------------------------------------------------------------------- +# Step 3: Agent selector β€” called at runtime with current GraphState +# --------------------------------------------------------------------------- +def select_responder(state: GraphState) -> LlmAgent: + """Route to simple_agent for SIMPLE tasks, detailed_agent otherwise. + + The classifier stored its output in state.data["classify"] via the + default output_mapper (OVERWRITE reducer, key = node name). + """ + classification = state.data.get("classify", "").upper() + if "SIMPLE" in classification: + return simple_agent + return detailed_agent + + +# --------------------------------------------------------------------------- +# Build the graph +# --------------------------------------------------------------------------- +def build_graph() -> GraphAgent: + graph = GraphAgent( + name="dynamic_routing", + description="Routes each query to the cheapest adequate specialist", + ) + + # Node 1: classify complexity + graph.add_node("classify", agent=classifier) + + # Node 2: DynamicNode selects the right specialist at runtime + graph.add_node( + DynamicNode( + name="respond", + agent_selector=select_responder, + fallback_agent=simple_agent, # safety net if selector returns None + ) + ) + + graph.add_edge("classify", "respond") + graph.set_start("classify") + graph.set_end("respond") + return graph + + +# --------------------------------------------------------------------------- +# Runner helper +# --------------------------------------------------------------------------- +_graph = build_graph() + + +async def run(question: str) -> str: + graph = _graph + svc = InMemorySessionService() + runner = Runner( + app_name="dynamic_node_example", agent=graph, session_service=svc + ) + await svc.create_session( + app_name="dynamic_node_example", user_id="user", session_id="s1" + ) + final = "" + async for event in runner.run_async( + user_id="user", + session_id="s1", + new_message=types.Content(role="user", parts=[types.Part(text=question)]), + ): + if event.content and event.content.parts: + text = event.content.parts[0].text or "" + if text and not text.startswith("[GraphMetadata]"): + final += text + return final + + +# --------------------------------------------------------------------------- +# Demo +# --------------------------------------------------------------------------- + + +async def main(): + questions = [ + "What is the capital of France?", # SIMPLE β†’ flash model + ( # COMPLEX β†’ pro model + "Explain how transformer attention scales with sequence length " + "and what architectural changes help address this." + ), + ] + for q in questions: + print(f"\nQ: {q}") + answer = await run(q) + print(f"A: {answer}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_pattern_nested_graph/README.md b/contributing/samples/graph_agent_pattern_nested_graph/README.md new file mode 100644 index 0000000000..846807e2e3 --- /dev/null +++ b/contributing/samples/graph_agent_pattern_nested_graph/README.md @@ -0,0 +1,42 @@ +# GraphAgent Pattern β€” NestedGraphNode (Hierarchical Composition) + +This example demonstrates **hierarchical workflow decomposition** using `NestedGraphNode`. A +coordinator planner produces a focused query, a three-step research sub-graph (search β†’ extract β†’ +summarise) handles it as a single reusable unit, and a synthesiser produces the final answer. +The sub-graph is entirely encapsulated and independently testable. + +## When to Use This Pattern + +- Large workflows that benefit from breaking into independently developed sub-pipelines +- Reusable sub-workflows (the same research graph could be called from multiple parent graphs) +- Team boundaries: different teams own the outer orchestration and inner sub-workflows +- Recursive depth: sub-graphs can themselves contain `NestedGraphNode`s + +## How to Run + +```bash +adk run contributing/samples/graph_agent_pattern_nested_graph +``` + +## Graph Structure + +``` +Outer: plan ──▢ research (NestedGraphNode) ──▢ synthesise + +Inner (research sub-graph): + search ──▢ extract ──▢ summarise +``` + +## Key Code Walkthrough + +- **`NestedGraphNode(name="research", graph_agent=build_research_subgraph())`** β€” wraps an entire + `GraphAgent` as a single node in the parent graph +- **`inherit_session=True`** β€” the sub-graph shares the parent session's state, so outputs + written inside are visible to the parent's synthesiser +- **`build_research_subgraph()`** β€” factory function that constructs and returns the inner + `GraphAgent`; call it multiple times for independent instances +- **State bridging** β€” the sub-graph's final state is merged back; use `output_mapper` on the + `NestedGraphNode` to control which keys are exposed to the outer graph +- **Telemetry and checkpointing** β€” propagate automatically into the sub-graph when enabled on + the parent + diff --git a/contributing/samples/graph_agent_pattern_nested_graph/__init__.py b/contributing/samples/graph_agent_pattern_nested_graph/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contributing/samples/graph_agent_pattern_nested_graph/agent.py b/contributing/samples/graph_agent_pattern_nested_graph/agent.py new file mode 100644 index 0000000000..d569b2aa9c --- /dev/null +++ b/contributing/samples/graph_agent_pattern_nested_graph/agent.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +"""NestedGraphNode Pattern: Hierarchical Workflow Composition + +Motivation (Hierarchical Planning) +------------------------------------ +Hierarchical planning research (Sutton et al. 1999, "Between MDPs and +semi-MDPs") and modern LLM orchestration papers like ORION / LLM Compiler +(Kim et al. 2023) show that breaking complex tasks into nested sub-plans +improves both reasoning quality and task-level reuse. + +In agentic terms: a *coordinator* decomposes the top-level goal into +sub-problems, each solved by a *sub-workflow* (a full GraphAgent) that can +be developed, tested, and reused independently. + +Pattern: NestedGraphNode +------------------------ +NestedGraphNode runs an entire GraphAgent as a single step inside a parent +graph. The parent graph sees only the sub-workflow's final output β€” internal +steps are transparent. + +Compare to the function-node alternative +----------------------------------------- +Without NestedGraphNode you must manually plumb the nested graph: + + async def run_research_step(state, ctx): + sub_ctx = ctx.model_copy(update={...}) + result = "" + async for event in research_graph._run_async_impl(sub_ctx): + if event.content and event.content.parts: + result = event.content.parts[0].text or "" + return result + +NestedGraphNode gives you: + βœ… Session-inheritance control (inherit_session=True/False) + βœ… Automatic metadata: sub-graph iteration count + execution path + βœ… No manual context plumbing + +Architecture (two-level hierarchy) +------------------------------------ +Outer graph: + plan ──► [research_step (NestedGraphNode)] ──► synthesize + +Inner (research_step) sub-graph: + search ──► extract ──► summarise +""" + +import asyncio +import os + +from google.adk.agents import LlmAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import NestedGraphNode +from google.adk.agents.graph.callbacks import create_nested_observability_callback +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +# --------------------------------------------------------------------------- +# Inner sub-graph: a three-step research pipeline +# (search β†’ extract key claims β†’ summarise for the parent) +# --------------------------------------------------------------------------- +_searcher = LlmAgent( + name="searcher", + model=_MODEL, + instruction=""" +You are a research agent. Given a topic, produce 3-5 bullet-point facts +that are accurate and concise. Start each bullet with "β€’". +""", +) + +_extractor = LlmAgent( + name="extractor", + model=_MODEL, + instruction=""" +You receive bullet-point facts. Extract the 2 most important claims and +rewrite them as complete sentences. Label them CLAIM-1 and CLAIM-2. +""", +) + +_summariser = LlmAgent( + name="summariser", + model=_MODEL, + instruction=""" +You receive two claims. Write a single-paragraph research summary (≀4 +sentences) that a domain expert would find useful. +""", +) + + +def build_research_subgraph() -> GraphAgent: + """Returns a reusable three-step research sub-workflow.""" + g = GraphAgent( + name="research_pipeline", + before_node_callback=create_nested_observability_callback(), + ) + g.add_node("search", agent=_searcher) + g.add_node("extract", agent=_extractor) + g.add_node("summarise", agent=_summariser) + g.add_edge("search", "extract") + g.add_edge("extract", "summarise") + g.set_start("search") + g.set_end("summarise") + return g + + +# --------------------------------------------------------------------------- +# Outer graph: plan β†’ nested research β†’ final synthesis +# --------------------------------------------------------------------------- +_planner = LlmAgent( + name="planner", + model=_MODEL, + instruction=""" +You receive a broad research question. Restate it as a focused single-topic +query suitable for a research assistant (one sentence, no preamble). +""", +) + +_synthesiser = LlmAgent( + name="synthesiser", + model=_MODEL, + instruction=""" +You receive a research summary. Write a concise final answer (2-3 sentences) +to the original user question, citing the key findings. +""", +) + + +def build_graph() -> GraphAgent: + outer = GraphAgent( + name="hierarchical_research", + description="Coordinator β†’ research sub-workflow β†’ synthesis", + before_node_callback=create_nested_observability_callback(), + ) + + outer.add_node("plan", agent=_planner) + + # The entire inner research pipeline runs as a single node + outer.add_node( + NestedGraphNode( + name="research", + graph_agent=build_research_subgraph(), + inherit_session=True, # share parent session β†’ state visible to outer + ) + ) + + outer.add_node("synthesise", agent=_synthesiser) + + outer.add_edge("plan", "research") + outer.add_edge("research", "synthesise") + outer.set_start("plan") + outer.set_end("synthesise") + return outer + + +# --------------------------------------------------------------------------- +# Runner helper +# --------------------------------------------------------------------------- +async def run(question: str) -> str: + graph = build_graph() + svc = InMemorySessionService() + runner = Runner( + app_name="nested_graph_example", agent=graph, session_service=svc + ) + await svc.create_session( + app_name="nested_graph_example", user_id="user", session_id="s1" + ) + final = "" + async for event in runner.run_async( + user_id="user", + session_id="s1", + new_message=types.Content(role="user", parts=[types.Part(text=question)]), + ): + if not event.content or not event.content.parts: + continue + text = event.content.parts[0].text or "" + if not text: + continue + if event.author == "observability": + # Show node execution trace (from create_nested_observability_callback) + print(f" β†’ {text}") + elif not text.startswith("[GraphMetadata]"): + final += text + return final + + +# --------------------------------------------------------------------------- +# Demo +# --------------------------------------------------------------------------- + + +async def main(): + q = "What are the key developments in quantum computing hardware?" + print(f"Question: {q}\n") + answer = await run(q) + print(f"Answer:\n{answer}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_pattern_parallel_group/README.md b/contributing/samples/graph_agent_pattern_parallel_group/README.md new file mode 100644 index 0000000000..56417049e3 --- /dev/null +++ b/contributing/samples/graph_agent_pattern_parallel_group/README.md @@ -0,0 +1,41 @@ +# GraphAgent Pattern β€” DynamicParallelGroup (Tree of Thoughts) + +This example implements the **Tree of Thoughts** pattern using `DynamicParallelGroup`. N independent +reasoning paths are generated concurrently, then an evaluator scores them and a selector picks the +best. The number of parallel branches is determined at runtime from the graph state, and concurrency +is capped by `max_parallelism` to prevent resource exhaustion. + +## When to Use This Pattern + +- Diverge-and-converge workflows: generate many candidates in parallel, then select the best +- Tree of Thoughts / beam search style reasoning +- Ensemble approaches: run N agents independently and aggregate their outputs +- Any case where the number of parallel branches is data-dependent (unknown at graph-build time) + +## How to Run + +```bash +adk run contributing/samples/graph_agent_pattern_parallel_group +``` + +## Graph Structure + +``` +config (function) ──▢ generate (DynamicParallelGroup) ──▢ evaluate ──▢ select + β”œβ”€β”€ thought_agent_0 + β”œβ”€β”€ thought_agent_1 + └── thought_agent_N (N from state.data["num_thoughts"]) +``` + +## Key Code Walkthrough + +- **`DynamicParallelGroup(name="generate", agent_generator=generate_thought_agents)`** β€” the + generator callable receives `GraphState` at runtime and returns a list of `BaseAgent` instances +- **`max_parallelism=5`** β€” caps concurrent agent executions via an `asyncio.Semaphore`; prevents + overloading the model API with too many simultaneous requests +- **`aggregator=aggregate_thoughts`** β€” combines the N results into a single string (using + `=== Thought N ===` separators) before passing to the evaluator +- **`config` function node** β€” parses `[num_thoughts=N]` from the user message and writes + `state.data["num_thoughts"]`; shows how function nodes can pre-process inputs +- **`select_responder`** β€” final selector reads evaluator scores and returns the winning thought + diff --git a/contributing/samples/graph_agent_pattern_parallel_group/__init__.py b/contributing/samples/graph_agent_pattern_parallel_group/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contributing/samples/graph_agent_pattern_parallel_group/agent.py b/contributing/samples/graph_agent_pattern_parallel_group/agent.py new file mode 100644 index 0000000000..901dee72b7 --- /dev/null +++ b/contributing/samples/graph_agent_pattern_parallel_group/agent.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +"""DynamicParallelGroup Pattern: Tree of Thoughts / Self-Consistency + +Motivation (Tree of Thoughts & Self-Consistency) +-------------------------------------------------- +Wang et al. (2022) "Self-Consistency Improves Chain of Thought Reasoning": +sampling multiple independent reasoning paths and taking the majority answer +outperforms single chain-of-thought on arithmetic and commonsense benchmarks. + +Yao et al. (2023) "Tree of Thoughts: Deliberate Problem Solving with LLMs": +exploring multiple partial-solution branches in parallel, then scoring and +selecting the best one, significantly improves performance on complex tasks. + +Both techniques require spawning N concurrent agents whose count is determined +at runtime β€” exactly what DynamicParallelGroup provides. + +Pattern: DynamicParallelGroup +------------------------------ +DynamicParallelGroup generates the agent list at runtime from a callable, runs +them concurrently (with optional max_parallelism throttle), then feeds all +results to an aggregator function. + +Compare to the static-parallelism alternative (ParallelNodeGroup) +------------------------------------------------------------------ +ParallelNodeGroup is compiled at graph-construction time: + + group = ParallelNodeGroup( + name="parallel", + nodes=[node1, node2, node3], # fixed list β€” must know N upfront + join_strategy=JoinStrategy.WAIT_ALL, + ) + +DynamicParallelGroup gives you: + βœ… Runtime-determined N (read from state.data β€” e.g. user param) + βœ… State-driven generation (e.g. one agent per item in a list) + βœ… Custom aggregation logic (majority vote, best-score selection, …) + βœ… Concurrency cap via max_parallelism (back-pressure / rate limiting) + +Architecture (Tree of Thoughts) +--------------------------------- + generate_thoughts (DynamicParallelGroup) + β”‚ N independent "thought" agents run concurrently + β–Ό + evaluate ← LlmAgent scores all thoughts + β”‚ + β–Ό + select ← LlmAgent picks the winning thought +""" + +import asyncio +import os + +from google.adk.agents import LlmAgent +from google.adk.agents.graph import DynamicParallelGroup +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.agents.graph.callbacks import create_nested_observability_callback +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + + +# --------------------------------------------------------------------------- +# Thought generator β€” each instance produces one independent solution path +# --------------------------------------------------------------------------- +def make_thought_agent(idx: int) -> LlmAgent: + return LlmAgent( + name=f"thought_{idx}", + model=_MODEL, + instruction=f""" +You are creative problem-solver #{idx + 1}. Given a problem, propose one +original, concrete solution approach. Be specific (≀3 sentences). +Do NOT repeat approaches from other agents; bring a fresh angle. +""", + ) + + +# --------------------------------------------------------------------------- +# Runtime generator: reads state.data["num_thoughts"] or defaults to 3 +# --------------------------------------------------------------------------- +def generate_thought_agents(state: GraphState): + n = int(state.data.get("num_thoughts", "3")) + return [make_thought_agent(i) for i in range(n)] + + +# --------------------------------------------------------------------------- +# Aggregator: concatenate all thoughts with separators for the evaluator +# --------------------------------------------------------------------------- +def aggregate_thoughts(results: list, state: GraphState) -> str: + lines = [] + for i, r in enumerate(results, 1): + lines.append(f"=== Thought {i} ===\n{r.strip()}") + return "\n\n".join(lines) + + +# --------------------------------------------------------------------------- +# Evaluator and selector +# --------------------------------------------------------------------------- +evaluator = LlmAgent( + name="evaluator", + model=_MODEL, + instruction=""" +You receive several numbered solution approaches. For EACH one, write a +single line: "Thought N: – ". +""", +) + +selector = LlmAgent( + name="selector", + model=_MODEL, + instruction=""" +You receive evaluations of solution approaches. Choose the best-scoring one +and explain why it is the strongest approach (2-3 sentences). +""", +) + + +# --------------------------------------------------------------------------- +# Pre-processing: extract [num_thoughts=N] header from user input +# --------------------------------------------------------------------------- +def parse_config(state: GraphState, ctx) -> str: + """Extract num_thoughts from the input message and store in state. + + Input format: "[num_thoughts=N] " + Falls back to default 3 if the header is absent. + """ + import re + + raw = state.data.get("input", "") + m = re.match(r"\[num_thoughts=(\d+)\]\s*(.*)", raw, re.DOTALL) + if m: + state.data["num_thoughts"] = m.group(1) + return m.group(2).strip() + return raw + + +# --------------------------------------------------------------------------- +# Build graph +# --------------------------------------------------------------------------- +def build_graph() -> GraphAgent: + graph = GraphAgent( + name="tree_of_thoughts", + description="Parallel thought generation β†’ evaluation β†’ selection", + before_node_callback=create_nested_observability_callback(), + ) + + # Extract num_thoughts config and clean the input text + graph.add_node("config", function=parse_config) + + graph.add_node( + DynamicParallelGroup( + name="generate", + agent_generator=generate_thought_agents, + aggregator=aggregate_thoughts, + max_parallelism=5, # cap concurrent LLM calls + ) + ) + graph.add_node("evaluate", agent=evaluator) + graph.add_node("select", agent=selector) + + graph.add_edge("config", "generate") + graph.add_edge("generate", "evaluate") + graph.add_edge("evaluate", "select") + graph.set_start("config") + graph.set_end("select") + return graph + + +# --------------------------------------------------------------------------- +# Runner helper +# --------------------------------------------------------------------------- +async def run(problem: str, num_thoughts: int = 3) -> str: + # Encode num_thoughts in the message so state.data["input"] carries it, + # and also pre-seed state via a thin wrapper node if needed. + # The simplest approach: embed num_thoughts in a header the agent ignores. + full_input = f"[num_thoughts={num_thoughts}] {problem}" + + graph = build_graph() + svc = InMemorySessionService() + runner = Runner( + app_name="parallel_group_example", agent=graph, session_service=svc + ) + await svc.create_session( + app_name="parallel_group_example", user_id="user", session_id="s1" + ) + + final = "" + async for event in runner.run_async( + user_id="user", + session_id="s1", + new_message=types.Content( + role="user", parts=[types.Part(text=full_input)] + ), + ): + if not event.content or not event.content.parts: + continue + text = event.content.parts[0].text or "" + if not text: + continue + if event.author == "observability": + # Show node execution trace (from create_nested_observability_callback) + print(f" β†’ {text}") + elif not text.startswith("[GraphMetadata]"): + final += text + return final + + +# --------------------------------------------------------------------------- +# Demo +# --------------------------------------------------------------------------- + + +async def main(): + problem = ( + "How can a small startup with limited budget compete with large " + "incumbents in the enterprise software market?" + ) + print(f"Problem: {problem}\n") + print("Generating 4 parallel thought paths...\n") + result = await run(problem, num_thoughts=4) + print(f"Selected best approach:\n{result}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_react_pattern/README.md b/contributing/samples/graph_agent_react_pattern/README.md new file mode 100644 index 0000000000..2f8e3be75f --- /dev/null +++ b/contributing/samples/graph_agent_react_pattern/README.md @@ -0,0 +1,67 @@ +# GraphAgent ReAct Pattern + +Demonstrates the **ReAct (Reasoning + Acting)** loop using GraphAgent. + +## Pattern + +``` +reason β†’ act β†’ observe + ↑ | CONTINUE + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + | COMPLETE + ↓ + END +``` + +Each iteration: +1. **reason** β€” analyse task + previous observation, decide next action +2. **act** β€” execute the chosen action, produce a result +3. **observe** β€” evaluate result; output `COMPLETE:` or `CONTINUE:` + +GraphAgent routes `observe β†’ reason` (continue) or exits (complete) based on +the `observation` state key β€” a conditional edge only GraphAgent can express. + +## When to Use + +- Multi-step problem solving where the number of iterations is unknown +- Tool-augmented agents (search β†’ reason β†’ act β†’ observe loop) +- Any workflow where routing depends on the *content* of an intermediate output + +## Comparison with Other Workflow Agents + +| Capability | SequentialAgent | LoopAgent | **GraphAgent** | +|------------|----------------|-----------|----------------| +| Execute nodes in order | βœ… | βœ… | βœ… | +| Loop (repeat execution) | βœ— | βœ… | βœ… | +| Route based on state content | βœ— | βœ— (escalate only) | βœ… | +| Conditional exit mid-loop | βœ— | via escalate | βœ… | +| Route to *different* next node | βœ— | βœ— | βœ… | + +**LoopAgent** can repeat, but its only conditional exit is via `escalate` β€” it +cannot inspect the `observation` field and route back to `reason` vs. exit. +**SequentialAgent** cannot loop at all. + +## Key Code + +```python +# Conditional back-edge: loop if not complete +graph.add_edge("observe", "reason", condition=_should_continue) + +# No forward edge needed: set_end() exits when observe has no matching edge +graph.set_end("observe") +``` + +## How to Run + +```bash +cd /path/to/adk-python +source venv/bin/activate +export GOOGLE_API_KEY= +python -m contributing.samples.graph_agent_react_pattern.agent +``` + +## Related Examples + +- `contributing/samples/graph_examples/02_conditional_routing` β€” basic conditional edges +- `contributing/samples/graph_examples/03_cyclic_execution` β€” cyclic loop without LLM +- `contributing/samples/graph_agent_advanced` β€” full research workflow diff --git a/contributing/samples/graph_agent_react_pattern/__init__.py b/contributing/samples/graph_agent_react_pattern/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contributing/samples/graph_agent_react_pattern/agent.py b/contributing/samples/graph_agent_react_pattern/agent.py new file mode 100644 index 0000000000..384cc729eb --- /dev/null +++ b/contributing/samples/graph_agent_react_pattern/agent.py @@ -0,0 +1,215 @@ +"""GraphAgent ReAct Pattern example. + +Demonstrates the Reasoning + Acting (ReAct) loop using GraphAgent: + reason β†’ act β†’ observe + observe loops back to reason if "CONTINUE" + observe ends if "COMPLETE" + +Why GraphAgent (not LoopAgent/SequentialAgent)? +- SequentialAgent: cannot loop; fixed linear path +- LoopAgent: loops unconditionally or escalates; cannot inspect observation + content to decide direction (reason vs. exit) +- GraphAgent: conditional edges read state β†’ route to any node or exit + +Run (requires GOOGLE_API_KEY env var): + python -m contributing.samples.graph_agent_react_pattern.agent +""" + +import asyncio +import os + +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.agents.graph import StateReducer +from google.adk.agents.llm_agent import LlmAgent +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +from pydantic import BaseModel +from pydantic import ValidationError + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +# --------------------------------------------------------------------------- +# Output Schemas +# --------------------------------------------------------------------------- + + +class ObservationResult(BaseModel): + """Structured observation output from observer agent.""" + + status: str # "continue" or "complete" + reasoning: str # Why continue or why complete + + +# --------------------------------------------------------------------------- +# Agents +# --------------------------------------------------------------------------- + +reasoner = LlmAgent( + name="reasoner", + model=_MODEL, + instruction=( + "You are a reasoning agent. Analyse the task and any previous " + "observations, then decide what action to take next. " + "Write your reasoning in 1-3 sentences." + ), + output_key="reasoning", +) + +actor = LlmAgent( + name="actor", + model=_MODEL, + instruction=( + "You are an action agent. Based on the reasoning provided, " + "answer the question or perform the requested analysis using your " + "knowledge. Do NOT write code or tool calls β€” just provide the " + "factual answer or calculation result directly." + ), + output_key="action_result", +) + +observer = LlmAgent( + name="observer", + model=_MODEL, + instruction=( + "You are an observation agent. Evaluate whether the action result " + "fully answers the original task. " + 'Return {"status": "complete", "reasoning": "..."} if task is done, ' + 'or {"status": "continue", "reasoning": "what is missing..."} if not.' + ), + output_schema=ObservationResult, # Structured output + # output_key auto-defaults to "observer" (agent name) +) + + +# --------------------------------------------------------------------------- +# Routing predicates +# --------------------------------------------------------------------------- + + +def _should_continue(state: GraphState) -> bool: + """Check if ReAct loop should continue using structured output.""" + obs = state.get_parsed("observer", ObservationResult) + return obs.status.lower() == "continue" if obs else False + + + +# --------------------------------------------------------------------------- +# Graph +# --------------------------------------------------------------------------- + + +def build_react_graph() -> GraphAgent: + graph = GraphAgent( + name="react_agent", + description="ReAct pattern: Reasoning + Acting loop", + max_iterations=10, + ) + + graph.add_node( + "reason", + agent=reasoner, + input_mapper=lambda s: ( + f"Task: {s.data.get('task', '')}\n" + f"Previous observation: {s.data.get('observation', 'none')}" + ), + reducer=StateReducer.OVERWRITE, + ) + graph.add_node( + "act", + agent=actor, + input_mapper=lambda s: s.data.get("reasoning", ""), + reducer=StateReducer.OVERWRITE, + ) + graph.add_node( + "observe", + agent=observer, + input_mapper=lambda s: ( + f"Task: {s.data.get('task', '')}\n" + f"Action result: {s.data.get('action_result', '')}" + ), + reducer=StateReducer.OVERWRITE, + ) + + graph.set_start("reason") + graph.add_edge("reason", "act") + graph.add_edge("act", "observe") + + # Loop back if not yet complete + graph.add_edge("observe", "reason", condition=_should_continue) + + # Exit when complete + graph.set_end("observe") + + return graph + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def main() -> None: + session_service = InMemorySessionService() + graph = build_react_graph() + + session = await session_service.create_session( + app_name="react_agent", user_id="user1" + ) + + task = "What are the key features of the Google Agent Development Kit (ADK)?" + print(f"Task: {task}\n") + + # Seed the task into initial state (BEFORE calling runner) + session.state["task"] = task + + # Use Runner instead of manual invocation context + runner = Runner( + app_name="react_agent", + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + iteration = 0 + async for event in runner.run_async( + user_id="user1", + session_id=session.id, + new_message=types.Content(parts=[types.Part(text=task)]), + ): + if event.content and event.content.parts: + author = event.author + text = event.content.parts[0].text or "" + if author == "observer": + iteration += 1 + # Parse from event text (JSON string from output_schema) + try: + obs = ObservationResult.model_validate_json(text.strip()) + status = obs.status.upper() + except ValidationError: + status = "UNKNOWN (parse error)" + print(f"[iteration {iteration}] Observer: {status}") + elif author in ("reasoner", "actor"): + print(f" [{author}]: {text[:120]}...") + + # Re-fetch fresh session state (create_session returns a deepcopy) + fresh_session = await session_service.get_session( + app_name="react_agent", user_id="user1", session_id=session.id + ) + if fresh_session is None: + print( + "WARNING: session_service.get_session returned None, using stale copy" + ) + fresh_session = session + final_data = fresh_session.state.get("graph_data", {}) + final_state = GraphState(data=final_data) + + print("\nFinal observation:") + obs = final_state.get_parsed("observer", ObservationResult) + print(f"Status: {obs.status if obs else 'none'}") + print(f"Reasoning: {obs.reasoning if obs else 'none'}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_agent_todo_queue/README.md b/contributing/samples/graph_agent_todo_queue/README.md new file mode 100644 index 0000000000..78b532d1b9 --- /dev/null +++ b/contributing/samples/graph_agent_todo_queue/README.md @@ -0,0 +1,36 @@ +# GraphAgent TODO Queue β€” Conditional Loop with Type-Based Routing + +This example demonstrates queue-based orchestration where items are fetched one +at a time, classified, routed to a type-specific processor, and then looped back +to fetch the next item β€” with a checkpoint after each completion so the queue is +resume-safe after an interruption. + +## When to Use This Pattern + +- Processing a heterogeneous queue where each item type needs a different handler +- Long-running batch jobs that must survive process restarts mid-queue +- Any loop that requires branching inside each iteration + +## How to Run + +```bash +GOOGLE_API_KEY=your_key python -m contributing.samples.graph_agent_todo_queue.agent +``` + +## Graph Structure + +``` +fetcher ──▢ classifier ──(data)──────▢ processor_data ──┐ + ──(notification)β–Ά processor_notification ─── (has_more=True) ──▢ fetcher + ──(cleanup)────▢ processor_cleanup β”€β”€β”€β”€β”€β”€β”˜ + (has_more=False) ──▢ END +``` + +## Key Code Walkthrough + +- **`add_edge("classifier", "processor_data", condition=_is_data_task)`** β€” three conditional branches route each item to the correct processor based on `todo_type` in state +- **Loop edges** β€” each processor connects back to `fetcher` when `has_more=True`; GraphAgent handles cycles that LoopAgent cannot route conditionally +- **`graph.set_end()` on all three processors** β€” any processor can be the terminal node when the queue drains +- **`GraphCheckpointCallback(checkpoint_nodes={"processor_data","processor_notification","processor_cleanup"})`** β€” checkpoints only after a full item completes, not mid-classification +- **`StateReducer.OVERWRITE`** β€” each loop iteration overwrites `current_todo` and `last_processed` rather than accumulating all iterations + diff --git a/contributing/samples/graph_agent_todo_queue/__init__.py b/contributing/samples/graph_agent_todo_queue/__init__.py new file mode 100644 index 0000000000..15aa1a6086 --- /dev/null +++ b/contributing/samples/graph_agent_todo_queue/__init__.py @@ -0,0 +1 @@ +"""GraphAgent TODO queue orchestrator sample.""" diff --git a/contributing/samples/graph_agent_todo_queue/agent.py b/contributing/samples/graph_agent_todo_queue/agent.py new file mode 100644 index 0000000000..0d6932339d --- /dev/null +++ b/contributing/samples/graph_agent_todo_queue/agent.py @@ -0,0 +1,325 @@ +"""GraphAgent TODO queue orchestrator with checkpointing. + +Demonstrates queue-based orchestration where items are processed one at a +time with a checkpoint after each completion. If execution is interrupted +(e.g., process crash), the queue resumes from the last checkpoint. + +Features: +- Process a queue of TODO items sequentially +- Checkpoint after each item completion (resume-safe) +- Dynamic routing based on TODO type (data/notification/cleanup) +- Loop control: continues until queue is empty +- Selective checkpointing (only after processors, not fetcher/classifier) + +Flow: + fetcher β†’ classifier β†’ [processor_data | processor_notification | processor_cleanup] + ↑ ↓ + └──────── (has_more=True) β”€β”€β”€β”€β”€β”˜ + ↓ + (has_more=False) β†’ END + +Why GraphAgent (not LoopAgent)? +- LoopAgent: unconditional loop, cannot route to different processors +- GraphAgent: conditional routing + state-driven loop control + +Run (requires GOOGLE_API_KEY env var): + python -m contributing.samples.graph_agent_todo_queue.agent +""" + +import asyncio +import json +import os + +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.agents.graph import StateReducer +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.graph.checkpoint_callback import GraphCheckpointCallback +from google.adk.checkpoints import CheckpointService +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +from pydantic import BaseModel + +_MODEL = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +# --------------------------------------------------------------------------- +# Output Schemas +# --------------------------------------------------------------------------- + + +class TodoClassification(BaseModel): + """Structured TODO classification from classifier agent.""" + + todo_id: str # ID of the current TODO item + todo_type: str # "data_processing" | "notification" | "cleanup" + priority: int # 1 (highest) to 5 (lowest) + has_more: bool # Are there more items remaining in the queue? + + +# --------------------------------------------------------------------------- +# Agents +# --------------------------------------------------------------------------- + +fetcher = LlmAgent( + name="fetcher", + model=_MODEL, + instruction=( + "You are a TODO queue manager. Given the current queue state, fetch" + " the next unprocessed TODO item. Describe the item briefly." + " If the queue is empty, say 'Queue is empty'." + ), + output_key="current_todo", +) + +classifier = LlmAgent( + name="classifier", + model=_MODEL, + instruction=( + "You are a TODO classifier. Classify the current TODO item." + ' Return {"todo_id": "item-N", "todo_type":' + ' "data_processing|notification|cleanup", "priority": 1-5,' + ' "has_more": true/false}.' + " has_more=true if there are more unprocessed items after this one." + ), + output_schema=TodoClassification, + # output_key auto-defaults to "classifier" (agent name) +) + +processor_data = LlmAgent( + name="processor_data", + model=_MODEL, + instruction=( + "You process data_processing TODO items. Describe what data" + " transformation was performed. Be concise (1 sentence)." + ), + output_key="last_processed", +) + +processor_notification = LlmAgent( + name="processor_notification", + model=_MODEL, + instruction=( + "You process notification TODO items. Describe what notification" + " was sent. Be concise (1 sentence)." + ), + output_key="last_processed", +) + +processor_cleanup = LlmAgent( + name="processor_cleanup", + model=_MODEL, + instruction=( + "You process cleanup TODO items. Describe what was cleaned up." + " Be concise (1 sentence)." + ), + output_key="last_processed", +) + + +# --------------------------------------------------------------------------- +# Routing predicates +# --------------------------------------------------------------------------- + + +def _get_classifier(state: GraphState) -> dict: + """Parse classifier output (may be a JSON string or a dict).""" + val = state.data.get("classifier", {}) + if isinstance(val, str): + try: + return json.loads(val) + except (json.JSONDecodeError, TypeError): + return {} + return val if isinstance(val, dict) else {} + + +def _is_data_task(state: GraphState) -> bool: + return _get_classifier(state).get("todo_type") == "data_processing" + + +def _is_notification_task(state: GraphState) -> bool: + return _get_classifier(state).get("todo_type") == "notification" + + +def _is_cleanup_task(state: GraphState) -> bool: + return _get_classifier(state).get("todo_type") == "cleanup" + + +def _has_more_items(state: GraphState) -> bool: + return _get_classifier(state).get("has_more", False) is True + + +# --------------------------------------------------------------------------- +# Graph +# --------------------------------------------------------------------------- + + +def build_todo_queue_graph( + session_service: InMemorySessionService, +) -> GraphAgent: + """Build TODO queue orchestrator with selective checkpointing.""" + checkpoint_service = CheckpointService(session_service=session_service) + + # Only checkpoint after each processor completion (not fetcher/classifier) + # This ensures we can resume from the last COMPLETED item, not mid-classification + checkpoint_callback = GraphCheckpointCallback( + checkpoint_service, + checkpoint_before=False, + checkpoint_after=True, + checkpoint_nodes={ + "processor_data", + "processor_notification", + "processor_cleanup", + }, + ) + + graph = GraphAgent( + name="todo_queue", + description="Queue-based TODO processing with resume-safe checkpointing", + max_iterations=50, # Process up to 50 TODO items + after_node_callback=checkpoint_callback.after_node, + ) + + # Build graph structure + graph.add_node("fetcher", agent=fetcher, reducer=StateReducer.OVERWRITE) + graph.add_node("classifier", agent=classifier, reducer=StateReducer.OVERWRITE) + graph.add_node( + "processor_data", agent=processor_data, reducer=StateReducer.OVERWRITE + ) + graph.add_node( + "processor_notification", + agent=processor_notification, + reducer=StateReducer.OVERWRITE, + ) + graph.add_node( + "processor_cleanup", + agent=processor_cleanup, + reducer=StateReducer.OVERWRITE, + ) + + graph.set_start("fetcher") + graph.add_edge("fetcher", "classifier") + + # Route to appropriate processor based on TODO type + graph.add_edge("classifier", "processor_data", condition=_is_data_task) + graph.add_edge( + "classifier", "processor_notification", condition=_is_notification_task + ) + graph.add_edge("classifier", "processor_cleanup", condition=_is_cleanup_task) + + # Loop back to fetch next item if queue not empty + graph.add_edge("processor_data", "fetcher", condition=_has_more_items) + graph.add_edge("processor_notification", "fetcher", condition=_has_more_items) + graph.add_edge("processor_cleanup", "fetcher", condition=_has_more_items) + + # End at any processor when queue is empty + graph.set_end("processor_data") + graph.set_end("processor_notification") + graph.set_end("processor_cleanup") + + return graph + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def main() -> None: + print("=== TODO Queue Orchestrator with Checkpointing ===\n") + + session_service = InMemorySessionService() + graph = build_todo_queue_graph(session_service) + + session = await session_service.create_session( + app_name="todo_queue", user_id="user1" + ) + + # Define the TODO queue as initial state + todo_queue = [ + { + "id": "todo-1", + "type": "data_processing", + "task": "Transform user CSV export to JSON format", + }, + { + "id": "todo-2", + "type": "notification", + "task": "Send weekly summary email to team", + }, + { + "id": "todo-3", + "type": "cleanup", + "task": "Delete temporary files from /tmp/exports", + }, + { + "id": "todo-4", + "type": "data_processing", + "task": "Aggregate daily metrics into monthly report", + }, + { + "id": "todo-5", + "type": "notification", + "task": "Notify ops team of deployment completion", + }, + ] + + queue_json = json.dumps(todo_queue, indent=2) + print(f"Queue contains {len(todo_queue)} items:") + for item in todo_queue: + print(f" [{item['id']}] ({item['type']}) {item['task']}") + print() + + runner = Runner( + app_name="todo_queue", + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + processed_count = 0 + + async for event in runner.run_async( + user_id="user1", + session_id=session.id, + new_message=types.Content( + parts=[types.Part(text=f"Process this TODO queue:\n{queue_json}")] + ), + ): + if not event.content or not event.content.parts: + continue + text = event.content.parts[0].text or "" + author = event.author + + if author in ( + "processor_data", + "processor_notification", + "processor_cleanup", + ): + processed_count += 1 + proc_type = author.replace("processor_", "") + print(f"[{proc_type.upper()}] {text[:120]}") + + print(f"\nQueue processing complete. Items processed: {processed_count}") + + # Re-fetch session: InMemorySessionService returns deepcopies, so local + # `session` is stale. The runner's internal copy has the checkpoint updates. + fresh_session = await session_service.get_session( + app_name="todo_queue", user_id="user1", session_id=session.id + ) + if fresh_session is None: + print( + f"WARNING: session_service.create_session returned None, using stale" + f" copy" + ) + fresh_session = session + checkpoints = fresh_session.state.get("_checkpoint_index", {}) + print(f"Checkpoints saved: {len(checkpoints)}") + print( + "Note: If interrupted, resume by restoring latest checkpoint and" + " re-running." + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/01_basic/__init__.py b/contributing/samples/graph_examples/01_basic/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/01_basic/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/01_basic/agent.py b/contributing/samples/graph_examples/01_basic/agent.py new file mode 100644 index 0000000000..ebda8a0379 --- /dev/null +++ b/contributing/samples/graph_examples/01_basic/agent.py @@ -0,0 +1,138 @@ +"""Example 1: Basic GraphAgent Workflow + +Demonstrates: +- Creating a simple directed graph +- Adding nodes (agents) +- Adding edges (transitions) +- Setting start and end nodes +- Executing the workflow + +Run modes: +- Default: python -m contributing.samples.graph_examples.01_basic.agent +- LLM: python -m contributing.samples.graph_examples.01_basic.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.01_basic.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class SimpleAgent(BaseAgent): + """A simple agent that outputs a message.""" + + def __init__(self, name: str, message: str, **kwargs): + super().__init__(name=name, **kwargs) + self._message = message + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=self._message)]), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(): + """Create agents based on USE_LLM mode. + + Returns: + tuple: (validate, process, complete) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + validate = create_llm_agent( + name="validate", + instruction=( + "You are a validation agent. Respond with 'βœ… Validation passed' to" + " confirm the workflow started successfully." + ), + ) + process = create_llm_agent( + name="process", + instruction=( + "You are a processing agent. Respond with 'βš™οΈ Processing data' to" + " indicate you're processing the workflow." + ), + ) + complete = create_llm_agent( + name="complete", + instruction=( + "You are a completion agent. Respond with 'βœ… Workflow complete' to" + " signal successful workflow completion." + ), + ) + + return validate, process, complete + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + validate = SimpleAgent(name="validate", message="βœ… Validation passed") + process = SimpleAgent(name="process", message="βš™οΈ Processing data") + complete = SimpleAgent(name="complete", message="βœ… Workflow complete") + + return validate, process, complete + + +async def main(): + print("\n" + "=" * 60) + print("Example 1: Basic GraphAgent Workflow") + print("=" * 60 + "\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + validate, process, complete = create_agents() + + # Build graph using convenience API (fluent pattern) + graph = ( + GraphAgent(name="basic_workflow") + .add_node("validate", agent=validate) + .add_node("process", agent=process) + .add_node("complete", agent=complete) + .add_edge("validate", "process") + .add_edge("process", "complete") + .set_start("validate") + .set_end("complete") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="basic_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("πŸš€ Executing workflow: validate β†’ process β†’ complete\n") + + new_message = types.Content(parts=[types.Part(text="Start workflow")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + print("\nβœ… Example complete!\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/02_conditional_routing/__init__.py b/contributing/samples/graph_examples/02_conditional_routing/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/02_conditional_routing/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/02_conditional_routing/agent.py b/contributing/samples/graph_examples/02_conditional_routing/agent.py new file mode 100644 index 0000000000..1c0fd5394d --- /dev/null +++ b/contributing/samples/graph_examples/02_conditional_routing/agent.py @@ -0,0 +1,201 @@ +"""Example 2: Conditional Routing + +Demonstrates: +- Conditional edges based on state +- Multiple routing paths +- State-based decision making + +Run modes: +- Default: python -m contributing.samples.graph_examples.02_conditional_routing.agent +- LLM: python -m contributing.samples.graph_examples.02_conditional_routing.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.02_conditional_routing.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class ValidatorAgent(BaseAgent): + """Validates input and sets quality score.""" + + def __init__(self, name: str, score: int, **kwargs): + super().__init__(name=name, **kwargs) + self._score = score + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=f"βœ… Validation complete (score: {self._score})" + ) + ] + ), + ) + + +class ProcessAgent(BaseAgent): + """Process based on quality.""" + + def __init__(self, name: str, quality: str, **kwargs): + super().__init__(name=name, **kwargs) + self._quality = quality + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"βš™οΈ {self._quality} quality processing")] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(test_score: int): + """Create agents based on USE_LLM mode. + + Args: + test_score: The score to use for validation + + Returns: + tuple: (validate, high_quality, medium_quality, low_quality) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + validate = create_llm_agent( + name="validate", + instruction=( + "You are a validation agent. Respond with 'Validation complete" + f" (score: {test_score})' exactly." + ), + ) + high_quality = create_llm_agent( + name="high_quality", + instruction=( + "You are a high quality processor. Respond with 'HIGH quality" + " processing' exactly." + ), + ) + medium_quality = create_llm_agent( + name="medium_quality", + instruction=( + "You are a medium quality processor. Respond with 'MEDIUM quality" + " processing' exactly." + ), + ) + low_quality = create_llm_agent( + name="low_quality", + instruction=( + "You are a low quality processor. Respond with 'LOW quality" + " processing' exactly." + ), + ) + + return validate, high_quality, medium_quality, low_quality + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + validate = ValidatorAgent(name="validate", score=test_score) + high_quality = ProcessAgent(name="high_quality", quality="HIGH") + medium_quality = ProcessAgent(name="medium_quality", quality="MEDIUM") + low_quality = ProcessAgent(name="low_quality", quality="LOW") + + return validate, high_quality, medium_quality, low_quality + + +async def main(): + print("\n" + "=" * 60) + print("Example 2: Conditional Routing") + print("=" * 60 + "\n") + + # Test with different scores + for test_score in [95, 75, 45]: + print(f"🎯 Testing with score: {test_score}") + + # Create agents (deterministic or LLM based on USE_LLM flag) + validate, high_quality, medium_quality, low_quality = create_agents( + test_score + ) + + # Build graph with conditional routing + graph = ( + GraphAgent(name="conditional_workflow") + .add_node( + "validate", + agent=validate, + output_mapper=lambda output, state: GraphState( + data={**state.data, "score": test_score}, + ), + ) + .add_node("high_quality", agent=high_quality) + .add_node("medium_quality", agent=medium_quality) + .add_node("low_quality", agent=low_quality) + # Conditional edges based on score + .add_edge( + "validate", + "high_quality", + condition=lambda s: s.data.get("score", 0) >= 80, + ) + .add_edge( + "validate", + "medium_quality", + condition=lambda s: 50 <= s.data.get("score", 0) < 80, + ) + .add_edge( + "validate", + "low_quality", + condition=lambda s: s.data.get("score", 0) < 50, + ) + .set_start("validate") + .set_end("high_quality") + .set_end("medium_quality") + .set_end("low_quality") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="routing_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", + session_id=f"session_{test_score}", + new_message=new_message, + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + print() + + print("βœ… Example complete!\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/03_cyclic_execution/__init__.py b/contributing/samples/graph_examples/03_cyclic_execution/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/03_cyclic_execution/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/03_cyclic_execution/agent.py b/contributing/samples/graph_examples/03_cyclic_execution/agent.py new file mode 100644 index 0000000000..f5e5c1e608 --- /dev/null +++ b/contributing/samples/graph_examples/03_cyclic_execution/agent.py @@ -0,0 +1,306 @@ +"""Example 3: Cyclic Graph Execution + +Demonstrates: +- Cyclic graphs with conditional back-edges (A -> B -> C -> A loop) +- Two routing patterns depending on execution mode: + - Default mode: state_delta writes (ADK-standard for deterministic routing) + - LLM mode: LlmAgent with include_contents='none' + dynamic instructions + for clean context per iteration (Ralph Loop pattern), with + output_mappers writing structured state for edge conditions +- Edge conditions reading from GraphState.data +- max_iterations guard to prevent infinite loops + +Key design choice (LLM mode): + Cyclic nodes (step, check) use include_contents='none' to prevent + session history accumulation across iterations. Without this, the LLM + sees all previous loop outputs and gets biased toward repeating earlier + responses (context rot). This is the Ralph Loop pattern applied within + ADK: each iteration gets clean context, state lives in session.state + (synced from GraphState.data), not in conversation history. + + Dynamic instructions (callables) read the current counter from + session.state, providing fresh context per iteration without any + conversation history leakage. + +Run modes: +- Default: python -m contributing.samples.graph_examples.03_cyclic_execution.agent +- LLM: python -m contributing.samples.graph_examples.03_cyclic_execution.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.03_cyclic_execution.agent +""" + +import asyncio +import json +import re + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph.graph_state import GraphState +from google.adk.events.event import Event +from google.adk.events.event import EventActions +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +MAX_CYCLES = 3 + + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class StartAgent(BaseAgent): + """Initializes the counter.""" + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text="Workflow started")]), + actions=EventActions(state_delta={"counter": 0}), + ) + + +class StepAgent(BaseAgent): + """Increments the counter and persists it via state_delta.""" + + async def _run_async_impl(self, ctx): + counter = ctx.session.state.get("counter", 0) + 1 + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"Step executed (counter={counter})")] + ), + actions=EventActions(state_delta={"counter": counter}), + ) + + +class CheckAgent(BaseAgent): + """Reads counter and writes routing signal via state_delta.""" + + async def _run_async_impl(self, ctx): + counter = ctx.session.state.get("counter", 0) + status = "CONTINUE" if counter < MAX_CYCLES else "DONE" + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part(text=f"Check: counter={counter}, status={status}") + ] + ), + actions=EventActions(state_delta={"status": status}), + ) + + +class EndAgent(BaseAgent): + """Signals workflow completion.""" + + async def _run_async_impl(self, ctx): + counter = ctx.session.state.get("counter", 0) + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part(text=f"Workflow complete after {counter} cycle(s)") + ] + ), + ) + + +# =========================== +# LLM Dynamic Instructions +# =========================== +# Callables that read current state per iteration β€” clean context each time. +# Used with include_contents='none' (Ralph Loop pattern). + + +def step_instruction(ctx): + """Dynamic instruction: reads counter from session.state each iteration.""" + counter = ctx.state.get("counter", 0) + return ( + f"The current counter value is {counter}. " + f"Increment it by 1 and respond with ONLY the new number, " + f"nothing else. Just the number." + ) + + +def check_instruction(ctx): + """Dynamic instruction: reads counter from session.state each iteration.""" + counter = ctx.state.get("counter", 0) + return ( + f"The current counter value is {counter} and the threshold is" + f" {MAX_CYCLES}. If {counter} < {MAX_CYCLES}, respond with exactly:" + f" CONTINUE. If {counter} >= {MAX_CYCLES}, respond with exactly:" + f" DONE. One word only." + ) + + +# =========================== +# LLM Output Mappers +# =========================== +# Parse LLM text output into structured state keys for edge conditions. + + +def start_output_mapper(output: str, state: GraphState) -> GraphState: + """Initialize counter=0 in state (LLM agent can't write state_delta).""" + new_state = GraphState(data=state.data.copy()) + new_state.data["counter"] = 0 + new_state.data["start"] = output.strip() + return new_state + + +def step_output_mapper(output: str, state: GraphState) -> GraphState: + """Parse LLM number output -> counter in state.""" + new_state = GraphState(data=state.data.copy()) + text = str(output).strip() + match = re.search(r"\d+", text) + if match: + counter = int(match.group()) + else: + counter = state.data.get("counter", 0) + 1 + new_state.data["counter"] = counter + new_state.data["step"] = f"Step executed (counter={counter})" + return new_state + + +def check_output_mapper(output: str, state: GraphState) -> GraphState: + """Parse LLM CONTINUE/DONE -> status routing signal in state.""" + new_state = GraphState(data=state.data.copy()) + text = str(output).strip().upper() + status = "DONE" if "DONE" in text else "CONTINUE" + counter = state.data.get("counter", 0) + new_state.data["status"] = status + new_state.data["check"] = f"Check: counter={counter}, status={status}" + return new_state + + +# =========================== +# Graph Construction +# =========================== + + +async def main(): + print("\n" + "=" * 60) + print("Example 3: Cyclic Graph Execution") + print("=" * 60 + "\n") + + llm_mode = use_llm_mode() + + # Build graph + graph = GraphAgent(name="cyclic_workflow", max_iterations=10) + + if llm_mode: + print("πŸ€– Using LLM agents (gemini-2.5-flash)\n") + + start = create_llm_agent( + name="start", + instruction="Respond with exactly: 'Workflow started'", + ) + # Cyclic nodes: include_contents='none' prevents session history + # accumulation. Dynamic instructions read current counter from + # session.state each iteration (Ralph Loop pattern). + step = create_llm_agent( + name="step", + instruction=step_instruction, + include_contents="none", + ) + check = create_llm_agent( + name="check", + instruction=check_instruction, + include_contents="none", + ) + end = create_llm_agent( + name="end_node", + instruction=( + "A cyclic workflow just completed. Summarize: the workflow ran" + f" for {MAX_CYCLES} cycles. Respond in one sentence." + ), + ) + + graph.add_node("start", agent=start, output_mapper=start_output_mapper) + graph.add_node("step", agent=step, output_mapper=step_output_mapper) + graph.add_node("check", agent=check, output_mapper=check_output_mapper) + graph.add_node("end_node", agent=end) + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + start = StartAgent(name="start") + step = StepAgent(name="step") + check = CheckAgent(name="check") + end = EndAgent(name="end_node") + + graph.add_node("start", agent=start) + graph.add_node("step", agent=step) + graph.add_node("check", agent=check) + graph.add_node("end_node", agent=end) + + # Edges are identical in both modes β€” they read from state.data + ( + graph.add_edge("start", "step") + .add_edge("step", "check") + .add_edge( + "check", + "step", + condition=lambda s: s.data.get("status") == "CONTINUE", + ) + .add_edge( + "check", + "end_node", + condition=lambda s: s.data.get("status") == "DONE", + ) + .set_start("start") + .set_end("end_node") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="cyclic_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print( + f"Executing cyclic workflow (max_cycles={MAX_CYCLES}, max_iterations=10)" + ) + print("Graph: start -> step -> check -> (loop back or exit)\n") + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text and "#metadata" not in event.author: + print(f" [{event.author}] {part.text}") + + session = await session_service.get_session( + app_name="cyclic_demo", user_id="user1", session_id="session1" + ) + final_counter = session.state.get("counter") + if final_counter is None: + graph_data_raw = session.state.get("graph_data") + if graph_data_raw: + try: + data = ( + json.loads(graph_data_raw) + if isinstance(graph_data_raw, str) + else graph_data_raw + ) + final_counter = data.get("counter", 0) + except (json.JSONDecodeError, TypeError): + final_counter = 0 + + if final_counter is None: + final_counter = 0 + print(f"\n Final counter value: {final_counter}") + print(f" Completed {final_counter} cycle(s) before exiting loop") + + print("\nExample complete!\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/04_checkpointing/__init__.py b/contributing/samples/graph_examples/04_checkpointing/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/04_checkpointing/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/04_checkpointing/agent.py b/contributing/samples/graph_examples/04_checkpointing/agent.py new file mode 100644 index 0000000000..8678324be8 --- /dev/null +++ b/contributing/samples/graph_examples/04_checkpointing/agent.py @@ -0,0 +1,149 @@ +"""Example 4: Checkpointing & Resume + +Demonstrates: +- Automatic checkpointing at each node +- Listing checkpoints +- State persistence +- Checkpoint metadata + +Run modes: +- Default: python -m contributing.samples.graph_examples.04_checkpointing.agent +- LLM: python -m contributing.samples.graph_examples.04_checkpointing.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.04_checkpointing.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.checkpoints import CheckpointService +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class StepAgent(BaseAgent): + """Agent that represents a workflow step.""" + + def __init__(self, name: str, step_num: int, **kwargs): + super().__init__(name=name, **kwargs) + self._step_num = step_num + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"βœ… Completed step {self._step_num}")] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(): + """Create agents based on USE_LLM mode. + + Returns: + tuple: (step1, step2, step3) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + step1 = create_llm_agent( + name="step1", + instruction="Respond with 'Completed step 1' exactly.", + ) + step2 = create_llm_agent( + name="step2", + instruction="Respond with 'Completed step 2' exactly.", + ) + step3 = create_llm_agent( + name="step3", + instruction="Respond with 'Completed step 3' exactly.", + ) + + return step1, step2, step3 + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + step1 = StepAgent(name="step1", step_num=1) + step2 = StepAgent(name="step2", step_num=2) + step3 = StepAgent(name="step3", step_num=3) + + return step1, step2, step3 + + +async def main(): + print("\n" + "=" * 60) + print("Example 4: Checkpointing & Resume") + print("=" * 60 + "\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + step1, step2, step3 = create_agents() + + # Setup checkpoint service + session_service = InMemorySessionService() + checkpoint_service = CheckpointService(session_service) + + # Build graph with checkpointing enabled + graph = ( + GraphAgent(name="checkpoint_workflow", checkpointing=True) + .add_node("step1", agent=step1) + .add_node("step2", agent=step2) + .add_node("step3", agent=step3) + .add_edge("step1", "step2") + .add_edge("step2", "step3") + .set_start("step1") + .set_end("step3") + ) + + # Execute with checkpointing + runner = Runner( + app_name="checkpoint_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("πŸš€ Executing workflow with checkpointing enabled...\n") + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + # Get session and check checkpoint data + session = await session_service.get_session( + app_name="checkpoint_demo", user_id="user1", session_id="session1" + ) + + checkpoint_data = session.state.get("graph_checkpoint", {}) + print(f"\nπŸ“Š Checkpoint Information:") + print(f" Last checkpoint at: {checkpoint_data.get('node', 'N/A')}") + print(f" Iteration: {checkpoint_data.get('iteration', 'N/A')}") + + # Show execution path + path = session.state.get("graph_path", []) + print(f" Execution path: {' β†’ '.join(path)}") + + print("\nβœ… Example complete!") + print(" Note: Checkpoints created at each node for state persistence\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/05_interrupts_basic/__init__.py b/contributing/samples/graph_examples/05_interrupts_basic/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/05_interrupts_basic/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/05_interrupts_basic/agent.py b/contributing/samples/graph_examples/05_interrupts_basic/agent.py new file mode 100644 index 0000000000..da98cf167d --- /dev/null +++ b/contributing/samples/graph_examples/05_interrupts_basic/agent.py @@ -0,0 +1,633 @@ +"""Example 5: All Interrupt Actions β€” Concurrent Injection + +Demonstrates all 8 interrupt action types via asyncio.create_task injection: + + 1. continue β€” interrupt logged, workflow proceeds normally + 2. rerun β€” current node re-executes before continuing + 3. pause + resume β€” execution blocks until resume() is called + 4. skip β€” BEFORE-mode: node execution is skipped entirely + 5. go_back β€” rewinds N steps in the execution path + 6. defer β€” interrupt saved as a todo in session.state; continues + 7. update_state β€” injects key/value pairs into GraphState.data + 8. change_condition β€” stores overrides in agent_state.conditions + +Timing pattern (scenarios 1–3, 5): + β€’ SlowNode runs for 2 seconds (2 Γ— 1s sub-steps) + β€’ Interrupt injected at t=0.8s via asyncio.create_task + β€’ AFTER-interrupt check fires when node completes at t=2s + β†’ The interrupt was queued mid-execution but processed at the AFTER checkpoint. + +Note: Between sub-steps the GraphAgent only checks `is_active()` (cancellation), +NOT the message queue. The message queue is consumed once at the AFTER checkpoint. + +Run modes: +- Default: python -m contributing.samples.graph_examples.05_interrupts_basic.agent +- LLM: python -m contributing.samples.graph_examples.05_interrupts_basic.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.05_interrupts_basic.agent +""" + +import asyncio +import time + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import InterruptConfig +from google.adk.agents.graph import InterruptMode +from google.adk.agents.graph.interrupt_service import InterruptService +from google.adk.events.event import Event +from google.adk.events.event import EventActions +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +APP_NAME = "interrupt_demo" +USER_ID = "user1" + + +# --------------------------------------------------------------------------- +# Deterministic Agents (BaseAgent) +# --------------------------------------------------------------------------- + + +class SlowNode(BaseAgent): + """2 sub-steps Γ— 1s each = 2s total. Shows interrupt queued mid-execution.""" + + def __init__(self, name: str, label: str = "", **kwargs): + super().__init__(name=name, **kwargs) + self._label = label or name + + async def _run_async_impl(self, ctx): + run_count = ctx.session.state.get(f"{self.name}_runs", 0) + 1 + for step in range(1, 3): + await asyncio.sleep(1.0) + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=( + f"[{self._label}] sub-step {step}/2" + + (f" (run #{run_count})" if run_count > 1 else "") + ) + ) + ] + ), + actions=EventActions(state_delta={f"{self.name}_runs": run_count}), + ) + + +class QuickNode(BaseAgent): + """Instant node β€” used for surrounding workflow nodes.""" + + def __init__(self, name: str, label: str = "", **kwargs): + super().__init__(name=name, **kwargs) + self._label = label or name + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=f"[{self._label}] done")]), + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_T0: float = 0.0 + + +def _ts() -> str: + return f"t={time.time() - _T0:.1f}s" + + +async def _run( + runner: Runner, + session_id: str, + *, + user_id: str = USER_ID, +) -> None: + msg = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id=user_id, session_id=session_id, new_message=msg + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text and "#metadata" not in event.author: + prefix = " πŸ›‘" if "INTERRUPT" in part.text else " " + print(f" {prefix} [{_ts()}] [{event.author}] {part.text}") + + +async def _make_session( + session_service: InMemorySessionService, + interrupt_service: InterruptService, + session_id: str, +) -> None: + await session_service.create_session( + app_name=APP_NAME, user_id=USER_ID, session_id=session_id + ) + interrupt_service.register_session(session_id) + + +# --------------------------------------------------------------------------- +# Scenario 1: continue +# --------------------------------------------------------------------------- + + +async def scenario_continue() -> None: + print("\n" + "-" * 55) + print("Scenario 1: continue β€” interrupt logged, execution proceeds") + print("-" * 55) + print(" Interrupt injected at t=0.8s (node runs until t=2s)") + print(" AFTER check fires at t=2s β€” message consumed, continue\n") + + sid = "s1_continue" + interrupt_service = InterruptService() + session_service = InMemorySessionService() + await _make_session(session_service, interrupt_service, sid) + + graph = ( + GraphAgent( + name="g_continue", + interrupt_service=interrupt_service, + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, nodes=["draft"] + ), + ) + .add_node("draft", agent=SlowNode(name="draft", label="draft")) + .add_node("finalize", agent=QuickNode(name="finalize", label="finalize")) + .add_edge("draft", "finalize") + .set_start("draft") + .set_end("finalize") + ) + + runner = Runner( + app_name=APP_NAME, + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + async def _inject(): + await asyncio.sleep(0.8) + print(f" >>> [{_ts()}] injecting 'continue' interrupt") + await interrupt_service.send_message( + sid, "Looks good β€” continue", action="continue" + ) + + global _T0 + _T0 = time.time() + task = asyncio.create_task(_inject()) + await _run(runner, sid) + await task + print(" Result: workflow completed normally after interrupt logged\n") + + +# --------------------------------------------------------------------------- +# Scenario 2: rerun +# --------------------------------------------------------------------------- + + +async def scenario_rerun() -> None: + print("-" * 55) + print("Scenario 2: rerun β€” current node re-executes") + print("-" * 55) + print(" Interrupt injected at t=0.8s; AFTER check at t=2s β†’ rerun") + print(" Draft runs again (run #2), then finalize\n") + + sid = "s2_rerun" + interrupt_service = InterruptService() + session_service = InMemorySessionService() + await _make_session(session_service, interrupt_service, sid) + + graph = ( + GraphAgent( + name="g_rerun", + interrupt_service=interrupt_service, + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, nodes=["draft"] + ), + ) + .add_node("draft", agent=SlowNode(name="draft", label="draft")) + .add_node("finalize", agent=QuickNode(name="finalize", label="finalize")) + .add_edge("draft", "finalize") + .set_start("draft") + .set_end("finalize") + ) + + runner = Runner( + app_name=APP_NAME, + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + async def _inject(): + await asyncio.sleep(0.8) + print(f" >>> [{_ts()}] injecting 'rerun' interrupt") + await interrupt_service.send_message( + sid, + "Revise β€” rerun", + action="rerun", + metadata={"guidance": "Add more detail"}, + ) + + global _T0 + _T0 = time.time() + task = asyncio.create_task(_inject()) + await _run(runner, sid) + await task + + session = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id=sid + ) + runs = session.state.get("draft_runs", 0) + print(f" Result: draft_runs={runs} (rerun worked)\n") + + +# --------------------------------------------------------------------------- +# Scenario 3: pause + resume +# --------------------------------------------------------------------------- + + +async def scenario_pause_resume() -> None: + print("-" * 55) + print("Scenario 3: pause + resume β€” execution blocks for human review") + print("-" * 55) + print(" Node completes at t=2s; 'pause' queued at t=0.8s β†’ blocks") + print(" resume() called at t=3.5s β†’ continues to finalize\n") + + sid = "s3_pause" + interrupt_service = InterruptService() + session_service = InMemorySessionService() + await _make_session(session_service, interrupt_service, sid) + + graph = ( + GraphAgent( + name="g_pause", + interrupt_service=interrupt_service, + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, nodes=["draft"] + ), + ) + .add_node("draft", agent=SlowNode(name="draft", label="draft")) + .add_node("finalize", agent=QuickNode(name="finalize", label="finalize")) + .add_edge("draft", "finalize") + .set_start("draft") + .set_end("finalize") + ) + + runner = Runner( + app_name=APP_NAME, + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + async def _inject(): + await asyncio.sleep(0.8) + print( + f" >>> [{_ts()}] injecting 'pause' interrupt (queued; processed at" + " t=2s AFTER check)" + ) + await interrupt_service.send_message( + sid, "Hold β€” human review", action="pause" + ) + # Simulate human review delay β€” resume after 1.5s of "paused" state + await asyncio.sleep(1.5) + print(f" >>> [{_ts()}] calling resume() β€” human review complete") + await interrupt_service.resume(sid) + + global _T0 + _T0 = time.time() + task = asyncio.create_task(_inject()) + await _run(runner, sid) + await task + print(" Result: workflow paused for human review then resumed\n") + + +# --------------------------------------------------------------------------- +# Scenario 4: skip (BEFORE mode) +# --------------------------------------------------------------------------- + + +async def scenario_skip() -> None: + print("-" * 55) + print("Scenario 4: skip β€” node execution skipped entirely (BEFORE mode)") + print("-" * 55) + print(" 'review' node skipped via BEFORE interrupt (pre-queued)") + print(" Graph: draft -> review -> finalize (review skipped)\n") + + sid = "s4_skip" + interrupt_service = InterruptService() + session_service = InMemorySessionService() + await _make_session(session_service, interrupt_service, sid) + + graph = ( + GraphAgent( + name="g_skip", + interrupt_service=interrupt_service, + # BEFORE mode β€” interrupt check happens BEFORE 'review' runs + interrupt_config=InterruptConfig( + mode=InterruptMode.BEFORE, nodes=["review"] + ), + ) + .add_node("draft", agent=QuickNode(name="draft", label="draft")) + .add_node("review", agent=QuickNode(name="review", label="review")) + .add_node("finalize", agent=QuickNode(name="finalize", label="finalize")) + .add_edge("draft", "review") + .add_edge("review", "finalize") + .set_start("draft") + .set_end("finalize") + ) + + # Pre-queue the skip interrupt (BEFORE check β†’ message must be in queue before node starts) + await interrupt_service.send_message( + sid, "Skip review β€” auto-approved", action="skip" + ) + + runner = Runner( + app_name=APP_NAME, + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + global _T0 + _T0 = time.time() + await _run(runner, sid) + print(" Result: 'review' node was skipped, finalize ran directly\n") + + +# --------------------------------------------------------------------------- +# Scenario 5: go_back +# --------------------------------------------------------------------------- + + +async def scenario_go_back() -> None: + print("-" * 55) + print("Scenario 5: go_back β€” rewind N steps in execution path") + print("-" * 55) + print(" AFTER 'review': go_back 1 step β†’ reruns 'draft' then 'review'") + print(" Interrupt injected at t=0.8s while review node runs (2s)\n") + + sid = "s5_goback" + interrupt_service = InterruptService() + session_service = InMemorySessionService() + await _make_session(session_service, interrupt_service, sid) + + graph = ( + GraphAgent( + name="g_goback", + interrupt_service=interrupt_service, + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, nodes=["review"] + ), + ) + .add_node("draft", agent=QuickNode(name="draft", label="draft")) + .add_node("review", agent=SlowNode(name="review", label="review")) + .add_node("finalize", agent=QuickNode(name="finalize", label="finalize")) + .add_edge("draft", "review") + .add_edge("review", "finalize") + .set_start("draft") + .set_end("finalize") + ) + + runner = Runner( + app_name=APP_NAME, + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + _injected = False + + async def _inject(): + nonlocal _injected + await asyncio.sleep(0.8) + print(f" >>> [{_ts()}] injecting 'go_back' (steps=1) while review runs") + await interrupt_service.send_message( + sid, + "Needs rework β€” go back to draft", + action="go_back", + metadata={"steps": 1}, + ) + _injected = True + + global _T0 + _T0 = time.time() + task = asyncio.create_task(_inject()) + await _run(runner, sid) + await task + print( + " Result: execution rewound to 'draft', then re-ran" + " draftβ†’reviewβ†’finalize\n" + ) + + +# --------------------------------------------------------------------------- +# Scenario 6: defer +# --------------------------------------------------------------------------- + + +async def scenario_defer() -> None: + print("-" * 55) + print("Scenario 6: defer β€” interrupt saved as todo, execution continues") + print("-" * 55) + print(" 'defer' adds the message to session.state['_interrupt_todos']\n") + + sid = "s6_defer" + interrupt_service = InterruptService() + session_service = InMemorySessionService() + await _make_session(session_service, interrupt_service, sid) + + graph = ( + GraphAgent( + name="g_defer", + interrupt_service=interrupt_service, + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, nodes=["draft"] + ), + ) + .add_node("draft", agent=QuickNode(name="draft", label="draft")) + .add_node("finalize", agent=QuickNode(name="finalize", label="finalize")) + .add_edge("draft", "finalize") + .set_start("draft") + .set_end("finalize") + ) + + # Pre-queue defer interrupt + await interrupt_service.send_message( + sid, + "Non-urgent: add metadata later", + action="defer", + metadata={"message": "Add citation metadata", "priority": "low"}, + ) + + runner = Runner( + app_name=APP_NAME, + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + global _T0 + _T0 = time.time() + await _run(runner, sid) + + session = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id=sid + ) + todos = session.state.get("_interrupt_todos", []) + print(f" Result: todos in session.state = {[t['message'] for t in todos]}\n") + + +# --------------------------------------------------------------------------- +# Scenario 7: update_state +# --------------------------------------------------------------------------- + + +async def scenario_update_state() -> None: + print("-" * 55) + print("Scenario 7: update_state β€” inject key/value into GraphState.data") + print("-" * 55) + print(" Injects priority='high' into GraphState.data via interrupt\n") + + sid = "s7_update" + interrupt_service = InterruptService() + session_service = InMemorySessionService() + await _make_session(session_service, interrupt_service, sid) + + graph = ( + GraphAgent( + name="g_update", + interrupt_service=interrupt_service, + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, nodes=["draft"] + ), + ) + .add_node("draft", agent=QuickNode(name="draft", label="draft")) + .add_node("finalize", agent=QuickNode(name="finalize", label="finalize")) + .add_edge("draft", "finalize") + .set_start("draft") + .set_end("finalize") + ) + + # Pre-queue update_state interrupt: injects priority into GraphState.data + await interrupt_service.send_message( + sid, + "Boost priority", + action="update_state", + metadata={"priority": "high", "escalated": True}, + ) + + runner = Runner( + app_name=APP_NAME, + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + global _T0 + _T0 = time.time() + await _run(runner, sid) + print( + " Result: priority='high' and escalated=True injected into" + " GraphState.data\n" + ) + print( + " (Edge conditions on subsequent nodes can use s.data.get('priority'))\n" + ) + + +# --------------------------------------------------------------------------- +# Scenario 8: change_condition +# --------------------------------------------------------------------------- + + +async def scenario_change_condition() -> None: + print("-" * 55) + print("Scenario 8: change_condition β€” store condition overrides in metadata") + print("-" * 55) + print(" Stores named condition overrides in agent_state.conditions\n") + + sid = "s8_cond" + interrupt_service = InterruptService() + session_service = InMemorySessionService() + await _make_session(session_service, interrupt_service, sid) + + graph = ( + GraphAgent( + name="g_cond", + interrupt_service=interrupt_service, + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, nodes=["draft"] + ), + ) + .add_node("draft", agent=QuickNode(name="draft", label="draft")) + .add_node("finalize", agent=QuickNode(name="finalize", label="finalize")) + .add_edge("draft", "finalize") + .set_start("draft") + .set_end("finalize") + ) + + # Pre-queue change_condition: stores override flags in agent_state.conditions + await interrupt_service.send_message( + sid, + "Override routing conditions", + action="change_condition", + metadata={"allow_fast_track": True, "require_approval": False}, + ) + + runner = Runner( + app_name=APP_NAME, + agent=graph, + session_service=session_service, + auto_create_session=False, + ) + + global _T0 + _T0 = time.time() + await _run(runner, sid) + print(" Result: condition overrides stored in agent_state.conditions") + print(" Edge conditions can read: data.get('_conditions', {})\n") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def main(): + print("\n" + "=" * 55) + print("Example 5: All Interrupt Actions β€” Concurrent Injection") + print("=" * 55) + print() + print("Interrupt timing model:") + print(" β€’ AFTER mode: interrupt checked ONCE after all node events complete") + print(" β€’ BEFORE mode: interrupt checked ONCE before node starts") + print( + " β€’ Mid-execution injection: message queued during node, consumed at" + " checkpoint" + ) + print() + + await scenario_continue() + await scenario_rerun() + await scenario_pause_resume() + await scenario_skip() + await scenario_go_back() + await scenario_defer() + await scenario_update_state() + await scenario_change_condition() + + print("=" * 55) + print("All 8 interrupt actions demonstrated.") + print("=" * 55 + "\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/06_interrupts_reasoning/__init__.py b/contributing/samples/graph_examples/06_interrupts_reasoning/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/06_interrupts_reasoning/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/06_interrupts_reasoning/agent.py b/contributing/samples/graph_examples/06_interrupts_reasoning/agent.py new file mode 100644 index 0000000000..f08fbe82d6 --- /dev/null +++ b/contributing/samples/graph_examples/06_interrupts_reasoning/agent.py @@ -0,0 +1,195 @@ +"""Example 6: Interrupt with Condition-Based Action Selection + +Demonstrates: +- Condition-based interrupt routing without an LLM +- Using InterruptService.pause() and send_message() to inject decisions +- Checking queued messages with list_queued_messages() +- Resuming or cancelling based on message content +- Note: InterruptReasoner requires an LLM β€” this example shows deterministic routing + +Run modes: +- Default: python -m contributing.samples.graph_examples.06_interrupts_reasoning.agent +- LLM: python -m contributing.samples.graph_examples.06_interrupts_reasoning.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.06_interrupts_reasoning.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import InterruptConfig +from google.adk.agents.graph import InterruptMode +from google.adk.agents.graph.interrupt_service import InterruptService +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class DraftAgent(BaseAgent): + """Generates a draft for review.""" + + def __init__(self, name: str, content: str, **kwargs): + super().__init__(name=name, **kwargs) + self._content = content + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"Draft created: {self._content}")] + ), + ) + + +class ReviewAgent(BaseAgent): + """Processes the approved draft.""" + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text="Draft approved β€” publishing review")] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(content: str): + """Create agents based on USE_LLM mode. + + Args: + content: Draft content to generate + + Returns: + tuple: (draft, review) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + draft = create_llm_agent( + name="draft", + instruction=f"Respond with 'Draft created: {content}' exactly.", + ) + review = create_llm_agent( + name="review", + instruction=( + "Respond with 'Draft approved β€” publishing review' exactly." + ), + ) + + return draft, review + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + draft = DraftAgent(name="draft", content=content) + review = ReviewAgent(name="review") + + return draft, review + + +async def run_scenario(scenario_name: str, decision_message: str) -> None: + """Run a single interrupt scenario with a given decision message.""" + print(f"\n Scenario: {scenario_name}") + print(f" Decision message: '{decision_message}'") + + interrupt_service = InterruptService() + + # Create agents (deterministic or LLM based on USE_LLM flag) + draft, review = create_agents("First draft of the document") + + graph = ( + GraphAgent( + name="interrupt_routing_workflow", + interrupt_service=interrupt_service, + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, + nodes=["draft"], + ), + ) + .add_node("draft", agent=draft) + .add_node("review", agent=review) + .add_edge("draft", "review") + .set_start("draft") + .set_end("review") + ) + + session_service = InMemorySessionService() + runner = Runner( + app_name="interrupt_routing_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + session_id = f"session_{scenario_name.lower().replace(' ', '_')}" + + # Register session so we can interact with it before the graph does + interrupt_service.register_session(session_id) + + # Queue the decision message (simulates external human input) + await interrupt_service.pause(session_id) + await interrupt_service.send_message( + session_id, decision_message, action="route" + ) + + # Peek at the queued messages to determine action + queued = interrupt_service.list_queued_messages(session_id) + if queued: + msg_text = queued[0].text + if "APPROVE" in msg_text: + print(" Condition met: APPROVE found β€” resuming workflow") + await interrupt_service.resume(session_id) + else: + print(" Condition not met: APPROVE absent β€” cancelling workflow") + await interrupt_service.cancel(session_id) + + events_received = [] + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", session_id=session_id, new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + events_received.append(f"[{event.author}] {part.text}") + + for line in events_received: + print(f" {line}") + + interrupted = not any("review" in e.lower() for e in events_received) + print(f" Review reached: {not interrupted}") + + +async def main(): + print("\n" + "=" * 60) + print("Example 6: Interrupt with Condition-Based Action Selection") + print("=" * 60 + "\n") + print("Note: InterruptReasoner requires an LLM. This example uses") + print(" deterministic condition-based routing instead.\n") + print("Graph: draft -> [INTERRUPT] -> review") + + # Scenario 1: approve message -> resume -> review node runs + await run_scenario("Approve", "APPROVE: content looks good") + + # Scenario 2: reject message -> cancel -> review node skipped + await run_scenario("Reject", "REJECT: needs revision") + + print("\nExample complete!\n") + print(" In production, replace condition check with InterruptReasoner") + print(" (requires LLM) for natural language action selection.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/07_callbacks/__init__.py b/contributing/samples/graph_examples/07_callbacks/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/07_callbacks/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/07_callbacks/agent.py b/contributing/samples/graph_examples/07_callbacks/agent.py new file mode 100644 index 0000000000..a54325379b --- /dev/null +++ b/contributing/samples/graph_examples/07_callbacks/agent.py @@ -0,0 +1,198 @@ +"""Example 7: Node Callbacks (before_node_callback / after_node_callback) + +Demonstrates: +- Registering before_node_callback and after_node_callback on a GraphAgent +- Measuring per-node execution time with time.perf_counter() +- Callbacks store start times in ctx.metadata and compute elapsed on exit +- Timing results are printed in async def main() after the run completes + +Run modes: +- Default: python -m contributing.samples.graph_examples.07_callbacks.agent +- LLM: python -m contributing.samples.graph_examples.07_callbacks.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.07_callbacks.agent +""" + +import asyncio +import time + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph.callbacks import NodeCallbackContext +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# Shared dict to accumulate timing results from callbacks +_timings: dict[str, float] = {} + + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class FetchAgent(BaseAgent): + """Simulates a data fetch step.""" + + async def _run_async_impl(self, ctx): + await asyncio.sleep(0.02) + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text="Data fetched from source")] + ), + ) + + +class ProcessAgent(BaseAgent): + """Simulates a data processing step.""" + + async def _run_async_impl(self, ctx): + await asyncio.sleep(0.05) + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text="Data processed and transformed")] + ), + ) + + +class SaveAgent(BaseAgent): + """Simulates a data persistence step.""" + + async def _run_async_impl(self, ctx): + await asyncio.sleep(0.01) + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text="Data saved to storage")]), + ) + + +async def before_cb(ctx: NodeCallbackContext) -> None: + """Record start time in shared timings dict keyed by node name.""" + _timings[f"_start_{ctx.node.name}"] = time.perf_counter() + return None + + +async def after_cb(ctx: NodeCallbackContext) -> None: + """Compute elapsed time and store in shared timings dict.""" + start_key = f"_start_{ctx.node.name}" + start = _timings.get(start_key) + if start is not None: + elapsed_ms = (time.perf_counter() - start) * 1000.0 + _timings[ctx.node.name] = elapsed_ms + return None + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(): + """Create agents based on USE_LLM mode. + + Returns: + tuple: (fetch, process, save) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + fetch = create_llm_agent( + name="fetch", + instruction=( + "Respond with 'Data fetched from source' exactly. Respond quickly" + " without delays." + ), + ) + process = create_llm_agent( + name="process", + instruction=( + "Respond with 'Data processed and transformed' exactly. Respond" + " quickly without delays." + ), + ) + save = create_llm_agent( + name="save", + instruction=( + "Respond with 'Data saved to storage' exactly. Respond quickly" + " without delays." + ), + ) + + return fetch, process, save + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + fetch = FetchAgent(name="fetch") + process = ProcessAgent(name="process") + save = SaveAgent(name="save") + + return fetch, process, save + + +async def main(): + print("\n" + "=" * 60) + print("Example 7: Node Callbacks") + print("=" * 60 + "\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + fetch, process, save = create_agents() + + # Build graph with before/after callbacks + graph = ( + GraphAgent( + name="callback_workflow", + before_node_callback=before_cb, + after_node_callback=after_cb, + ) + .add_node("fetch", agent=fetch) + .add_node("process", agent=process) + .add_node("save", agent=save) + .add_edge("fetch", "process") + .add_edge("process", "save") + .set_start("fetch") + .set_end("save") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="callback_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("Executing workflow: fetch -> process -> save") + print("Callbacks will record timing for each node\n") + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" [{event.author}] {part.text}") + + # Print timing results collected by callbacks + print("\n Node execution times (measured by callbacks):") + for node_name in ["fetch", "process", "save"]: + elapsed_ms = _timings.get(node_name) + if elapsed_ms is not None: + print(f" [{node_name}] {elapsed_ms:.1f}ms") + else: + print(f" [{node_name}] timing not recorded") + + print("\nExample complete!\n") + print(" before_node_callback: stores perf_counter start per node") + print(" after_node_callback: computes elapsed ms and stores result") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/08_rewind/__init__.py b/contributing/samples/graph_examples/08_rewind/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/08_rewind/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/08_rewind/agent.py b/contributing/samples/graph_examples/08_rewind/agent.py new file mode 100644 index 0000000000..e2f0206bcc --- /dev/null +++ b/contributing/samples/graph_examples/08_rewind/agent.py @@ -0,0 +1,181 @@ +"""Example 8: Rewind Integration + +Demonstrates: +- Invocation tracking per node +- Rewinding to specific node execution +- Re-execution after rewind +- State restoration + +Run modes: +- Default: python -m contributing.samples.graph_examples.08_rewind.agent +- LLM: python -m contributing.samples.graph_examples.08_rewind.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.08_rewind.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import rewind_to_node +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class CounterAgent(BaseAgent): + """Agent that tracks execution count.""" + + def __init__(self, name: str, **kwargs): + super().__init__(name=name, **kwargs) + self._count = 0 + + async def _run_async_impl(self, ctx): + self._count += 1 + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=f"βœ… {self.name} executed (count: {self._count})" + ) + ] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(): + """Create agents based on USE_LLM mode. + + Returns: + tuple: (step1, step2, step3) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + step1 = create_llm_agent( + name="step1", + instruction=( + "Respond with 'step1 executed (count: X)' where X is the execution" + " count. Track this in your context." + ), + ) + step2 = create_llm_agent( + name="step2", + instruction=( + "Respond with 'step2 executed (count: X)' where X is the execution" + " count. Track this in your context." + ), + ) + step3 = create_llm_agent( + name="step3", + instruction=( + "Respond with 'step3 executed (count: X)' where X is the execution" + " count. Track this in your context." + ), + ) + + return step1, step2, step3 + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + step1 = CounterAgent(name="step1") + step2 = CounterAgent(name="step2") + step3 = CounterAgent(name="step3") + + return step1, step2, step3 + + +async def main(): + print("\n" + "=" * 60) + print("Example 8: Rewind Integration") + print("=" * 60 + "\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + step1, step2, step3 = create_agents() + + # Build graph + graph = ( + GraphAgent(name="rewind_workflow") + .add_node("step1", agent=step1) + .add_node("step2", agent=step2) + .add_node("step3", agent=step3) + .add_edge("step1", "step2") + .add_edge("step2", "step3") + .set_start("step1") + .set_end("step3") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="rewind_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("πŸš€ First execution...\n") + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + # Check invocations + session = await session_service.get_session( + app_name="rewind_demo", user_id="user1", session_id="session1" + ) + node_invocations = session.state.get("node_invocations", {}) + + print(f"\nπŸ“Š Invocation Tracking:") + for node_name, invocations in node_invocations.items(): + print(f" {node_name}: {len(invocations)} invocation(s)") + + # Rewind to step2 + print(f"\nβͺ Rewinding to 'step2'...") + await rewind_to_node( + graph, + session_service, + app_name="rewind_demo", + user_id="user1", + session_id="session1", + node_name="step2", + invocation_index=-1, # Last invocation + ) + + print(" βœ… Rewind successful! State restored to before step2") + + # Re-execute from rewind point + print("\nπŸš€ Re-execution after rewind...\n") + + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + print("\nβœ… Example complete!") + print(" Note: step1 count stays at 1, step2 & step3 executed again\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/09_parallel_wait_all/__init__.py b/contributing/samples/graph_examples/09_parallel_wait_all/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/09_parallel_wait_all/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/09_parallel_wait_all/agent.py b/contributing/samples/graph_examples/09_parallel_wait_all/agent.py new file mode 100644 index 0000000000..f4c4e49b39 --- /dev/null +++ b/contributing/samples/graph_examples/09_parallel_wait_all/agent.py @@ -0,0 +1,177 @@ +"""Example 9: Parallel Execution - WAIT_ALL + +Demonstrates: +- Concurrent node execution +- WAIT_ALL join strategy +- State isolation in parallel branches +- Event streaming from parallel nodes + +Run modes: +- Default: python -m contributing.samples.graph_examples.09_parallel_wait_all.agent +- LLM: python -m contributing.samples.graph_examples.09_parallel_wait_all.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.09_parallel_wait_all.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import JoinStrategy +from google.adk.agents.graph import ParallelNodeGroup +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class FetchAgent(BaseAgent): + """Simulates fetching data from a source.""" + + def __init__(self, name: str, source: str, delay_ms: int, **kwargs): + super().__init__(name=name, **kwargs) + self._source = source + self._delay_ms = delay_ms + + async def _run_async_impl(self, ctx): + # Simulate async I/O + await asyncio.sleep(self._delay_ms / 1000.0) + + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=( + f"βœ… Fetched data from {self._source}" + f" ({self._delay_ms}ms)" + ) + ) + ] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(): + """Create agents based on USE_LLM mode. + + Returns: + tuple: (fetch_users, fetch_products, fetch_orders) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + fetch_users = create_llm_agent( + name="fetch_users", + instruction=( + "Respond with 'Fetched data from users_db (150ms)' exactly. Respond" + " quickly without delays." + ), + ) + fetch_products = create_llm_agent( + name="fetch_products", + instruction=( + "Respond with 'Fetched data from products_db (100ms)' exactly." + " Respond quickly without delays." + ), + ) + fetch_orders = create_llm_agent( + name="fetch_orders", + instruction=( + "Respond with 'Fetched data from orders_db (200ms)' exactly." + " Respond quickly without delays." + ), + ) + + return fetch_users, fetch_products, fetch_orders + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + fetch_users = FetchAgent( + name="fetch_users", source="users_db", delay_ms=150 + ) + fetch_products = FetchAgent( + name="fetch_products", source="products_db", delay_ms=100 + ) + fetch_orders = FetchAgent( + name="fetch_orders", source="orders_db", delay_ms=200 + ) + + return fetch_users, fetch_products, fetch_orders + + +async def main(): + print("\n" + "=" * 60) + print("Example 9: Parallel Execution - WAIT_ALL") + print("=" * 60 + "\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + fetch_users, fetch_products, fetch_orders = create_agents() + + # Build graph + graph = ( + GraphAgent(name="parallel_workflow") + .add_node("fetch_users", agent=fetch_users) + .add_node("fetch_products", agent=fetch_products) + .add_node("fetch_orders", agent=fetch_orders) + # Add parallel group with WAIT_ALL strategy + .add_parallel_group( + "fetch_all", + ParallelNodeGroup( + nodes=["fetch_users", "fetch_products", "fetch_orders"], + join_strategy=JoinStrategy.WAIT_ALL, # Wait for ALL to complete + ), + ) + # Set start (any node in parallel group triggers all) + .set_start("fetch_users") + .set_end("fetch_users") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="parallel_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("πŸš€ Executing parallel workflow...") + print(" Strategy: WAIT_ALL (wait for all 3 fetches)") + print(" Expected: ~200ms (max latency, not sum)\n") + + import time + + start_time = time.time() + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + elapsed = int((time.time() - start_time) * 1000) + print(f" [{elapsed:3d}ms] {part.text}") + + total_time = int((time.time() - start_time) * 1000) + + print(f"\nβœ… Example complete in {total_time}ms!") + print(f" Sequential would take: 450ms (150+100+200)") + print(f" Parallel took: ~200ms (max of 3)") + print(f" Speedup: ~2.25x\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/10_parallel_wait_any/__init__.py b/contributing/samples/graph_examples/10_parallel_wait_any/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/10_parallel_wait_any/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/10_parallel_wait_any/agent.py b/contributing/samples/graph_examples/10_parallel_wait_any/agent.py new file mode 100644 index 0000000000..68305daec4 --- /dev/null +++ b/contributing/samples/graph_examples/10_parallel_wait_any/agent.py @@ -0,0 +1,178 @@ +"""Example 10: Parallel Execution - WAIT_ANY (Race) + +Demonstrates: +- Racing multiple data sources +- WAIT_ANY join strategy +- First-to-complete wins +- Automatic cancellation of slower nodes + +Run modes: +- Default: python -m contributing.samples.graph_examples.10_parallel_wait_any.agent +- LLM: python -m contributing.samples.graph_examples.10_parallel_wait_any.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.10_parallel_wait_any.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import JoinStrategy +from google.adk.agents.graph import ParallelNodeGroup +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class DataSourceAgent(BaseAgent): + """Simulates fetching from different data sources.""" + + def __init__(self, name: str, source_type: str, latency_ms: int, **kwargs): + super().__init__(name=name, **kwargs) + self._source_type = source_type + self._latency_ms = latency_ms + + async def _run_async_impl(self, ctx): + await asyncio.sleep(self._latency_ms / 1000.0) + + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=( + f"βœ… Data from {self._source_type}" + f" ({self._latency_ms}ms)" + ) + ) + ] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(): + """Create agents based on USE_LLM mode. + + Returns: + tuple: (from_cache, from_database, from_api) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + from_cache = create_llm_agent( + name="from_cache", + instruction=( + "Respond with 'Data from CACHE (50ms)' exactly. Respond quickly" + " without delays." + ), + ) + from_database = create_llm_agent( + name="from_database", + instruction=( + "Respond with 'Data from DATABASE (150ms)' exactly. Respond quickly" + " without delays." + ), + ) + from_api = create_llm_agent( + name="from_api", + instruction=( + "Respond with 'Data from API (300ms)' exactly. Respond quickly" + " without delays." + ), + ) + + return from_cache, from_database, from_api + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + from_cache = DataSourceAgent( + name="from_cache", source_type="CACHE", latency_ms=50 + ) + from_database = DataSourceAgent( + name="from_database", source_type="DATABASE", latency_ms=150 + ) + from_api = DataSourceAgent( + name="from_api", source_type="API", latency_ms=300 + ) + + return from_cache, from_database, from_api + + +async def main(): + print("\n" + "=" * 60) + print("Example 10: Parallel Execution - WAIT_ANY (Race)") + print("=" * 60 + "\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + from_cache, from_database, from_api = create_agents() + + # Build graph + graph = ( + GraphAgent(name="race_workflow") + .add_node("from_cache", agent=from_cache) + .add_node("from_database", agent=from_database) + .add_node("from_api", agent=from_api) + # Add parallel group with WAIT_ANY strategy (race!) + .add_parallel_group( + "data_race", + ParallelNodeGroup( + nodes=["from_cache", "from_database", "from_api"], + join_strategy=JoinStrategy.WAIT_ANY, # First to finish wins! + ), + ) + .set_start("from_cache") + .set_end("from_cache") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="race_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("🏁 Starting data source race...") + print(" Competitors:") + print(" - Cache: 50ms") + print(" - Database: 150ms") + print(" - API: 300ms") + print(" Strategy: WAIT_ANY (first to complete)\n") + + import time + + start_time = time.time() + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + elapsed = int((time.time() - start_time) * 1000) + print(f" [{elapsed:3d}ms] {part.text}") + + total_time = int((time.time() - start_time) * 1000) + + print(f"\nβœ… Race complete in ~{total_time}ms!") + print(" Winner: Cache (fastest source)") + print(" Slower sources: Cancelled automatically") + print(" Use case: Cache-DB-API fallback strategy\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/11_parallel_wait_n/__init__.py b/contributing/samples/graph_examples/11_parallel_wait_n/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/11_parallel_wait_n/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/11_parallel_wait_n/agent.py b/contributing/samples/graph_examples/11_parallel_wait_n/agent.py new file mode 100644 index 0000000000..32a307aa82 --- /dev/null +++ b/contributing/samples/graph_examples/11_parallel_wait_n/agent.py @@ -0,0 +1,228 @@ +"""Example 11: Parallel Execution - WAIT_N (continue after N of M complete) + +Demonstrates: +- WAIT_N join strategy: proceed after N out of M branches complete +- Timing: faster than WAIT_ALL (doesn't wait for slowest branch) +- Three branches with different latencies: fast (10ms), medium (30ms), slow (100ms) +- WAIT_N=2 means the workflow continues as soon as any 2 branches finish + +Run modes: +- Default: python -m contributing.samples.graph_examples.11_parallel_wait_n.agent +- LLM: python -m contributing.samples.graph_examples.11_parallel_wait_n.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.11_parallel_wait_n.agent +""" + +import asyncio +import time + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import JoinStrategy +from google.adk.agents.graph import ParallelNodeGroup +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +WAIT_N = 2 + + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class LatencyAgent(BaseAgent): + """Simulates an agent with a configurable latency.""" + + def __init__(self, name: str, label: str, delay_ms: int, **kwargs): + super().__init__(name=name, **kwargs) + self._label = label + self._delay_ms = delay_ms + + async def _run_async_impl(self, ctx): + await asyncio.sleep(self._delay_ms / 1000.0) + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=( + f"Branch '{self._label}' completed ({self._delay_ms}ms)" + ) + ) + ] + ), + ) + + +class SetupAgent(BaseAgent): + """Initialises the workflow.""" + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part(text="Setup complete β€” launching parallel branches") + ] + ), + ) + + +class MergeAgent(BaseAgent): + """Aggregates results from completed branches.""" + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=( + "Merge complete: collected results from " + f"{WAIT_N}/3 branches (WAIT_N strategy)" + ) + ) + ] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(): + """Create agents based on USE_LLM mode. + + Returns: + tuple: (setup, fast, medium, slow, merge) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + setup = create_llm_agent( + name="setup", + instruction=( + "Respond with 'Setup complete β€” launching parallel branches'" + " exactly." + ), + ) + fast = create_llm_agent( + name="fast", + instruction=( + "Respond with \"Branch 'fast' completed (10ms)\" exactly. Respond" + " quickly without delays." + ), + ) + medium = create_llm_agent( + name="medium", + instruction=( + "Respond with \"Branch 'medium' completed (30ms)\" exactly. Respond" + " quickly without delays." + ), + ) + slow = create_llm_agent( + name="slow", + instruction=( + "Respond with \"Branch 'slow' completed (100ms)\" exactly. Respond" + " quickly without delays." + ), + ) + merge = create_llm_agent( + name="merge", + instruction=( + f"Respond with 'Merge complete: collected results from {WAIT_N}/3" + " branches (WAIT_N strategy)' exactly." + ), + ) + + return setup, fast, medium, slow, merge + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + setup = SetupAgent(name="setup") + fast = LatencyAgent(name="fast", label="fast", delay_ms=10) + medium = LatencyAgent(name="medium", label="medium", delay_ms=30) + slow = LatencyAgent(name="slow", label="slow", delay_ms=100) + merge = MergeAgent(name="merge") + + return setup, fast, medium, slow, merge + + +async def main(): + print("\n" + "=" * 60) + print("Example 11: Parallel Execution - WAIT_N") + print("=" * 60 + "\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + setup, fast, medium, slow, merge = create_agents() + + # Build graph + graph = ( + GraphAgent(name="wait_n_workflow") + .add_node("setup", agent=setup) + .add_node("fast", agent=fast) + .add_node("medium", agent=medium) + .add_node("slow", agent=slow) + .add_node("merge", agent=merge) + # Parallel group with WAIT_N strategy + .add_parallel_group( + "branches", + ParallelNodeGroup( + nodes=["fast", "medium", "slow"], + join_strategy=JoinStrategy.WAIT_N, + wait_n=WAIT_N, + ), + ) + # Linear edges: setup -> (parallel group) -> merge + .add_edge("setup", "fast") + .add_edge("fast", "merge") + .set_start("setup") + .set_end("merge") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="wait_n_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print(f"Executing WAIT_N workflow (N={WAIT_N} of 3 branches)") + print(" Branch latencies: fast=10ms, medium=30ms, slow=100ms") + print( + f" Expected: ~30ms (waits for {WAIT_N} fastest, ignores slow=100ms)\n" + ) + + start_time = time.time() + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + elapsed = int((time.time() - start_time) * 1000) + print(f" [{elapsed:3d}ms] {part.text}") + + total_time = int((time.time() - start_time) * 1000) + + print(f"\n Total time: ~{total_time}ms") + print(f" WAIT_ALL would take: ~100ms (slowest branch)") + print(f" WAIT_N={WAIT_N} took: ~30ms (2nd fastest branch)") + print(f" Speedup vs WAIT_ALL: ~{100 // max(total_time, 1)}x") + + print("\nExample complete!\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/12_parallel_checkpointing/__init__.py b/contributing/samples/graph_examples/12_parallel_checkpointing/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/12_parallel_checkpointing/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/12_parallel_checkpointing/agent.py b/contributing/samples/graph_examples/12_parallel_checkpointing/agent.py new file mode 100644 index 0000000000..effc31657b --- /dev/null +++ b/contributing/samples/graph_examples/12_parallel_checkpointing/agent.py @@ -0,0 +1,222 @@ +"""Example 12: Parallel Execution with Checkpointing + +Demonstrates: +- Enabling checkpointing alongside parallel branch execution +- Two parallel workers each writing independent results to session state +- Using CheckpointService to inspect checkpoint index after execution +- Checkpoint data stored under "_checkpoint_index" in session state + +Run modes: +- Default: python -m contributing.samples.graph_examples.12_parallel_checkpointing.agent +- LLM: python -m contributing.samples.graph_examples.12_parallel_checkpointing.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.12_parallel_checkpointing.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import JoinStrategy +from google.adk.agents.graph import ParallelNodeGroup +from google.adk.checkpoints import CheckpointService +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class WorkerAgent(BaseAgent): + """Worker that writes a result value to session state.""" + + def __init__(self, name: str, result_key: str, result_value: str, **kwargs): + super().__init__(name=name, **kwargs) + self._result_key = result_key + self._result_value = result_value + + async def _run_async_impl(self, ctx): + ctx.session.state[self._result_key] = self._result_value + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=( + f"Worker '{self.name}':" + f" {self._result_key}={self._result_value!r}" + ) + ) + ] + ), + ) + + +class CollectAgent(BaseAgent): + """Collects and reports results from both workers.""" + + async def _run_async_impl(self, ctx): + result_a = ctx.session.state.get("result_a", "N/A") + result_b = ctx.session.state.get("result_b", "N/A") + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=( + f"Collected: result_a={result_a!r}," + f" result_b={result_b!r}" + ) + ) + ] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(): + """Create agents based on USE_LLM mode. + + Returns: + tuple: (worker_a, worker_b, collect) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + worker_a = create_llm_agent( + name="worker_a", + instruction=( + "Respond with \"Worker 'worker_a': result_a='processed_by_a'\"" + " exactly." + ), + ) + worker_b = create_llm_agent( + name="worker_b", + instruction=( + "Respond with \"Worker 'worker_b': result_b='processed_by_b'\"" + " exactly." + ), + ) + collect = create_llm_agent( + name="collect", + instruction=( + "Respond with \"Collected: result_a='processed_by_a'," + " result_b='processed_by_b'\" exactly." + ), + ) + + return worker_a, worker_b, collect + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + worker_a = WorkerAgent( + name="worker_a", result_key="result_a", result_value="processed_by_a" + ) + worker_b = WorkerAgent( + name="worker_b", result_key="result_b", result_value="processed_by_b" + ) + collect = CollectAgent(name="collect") + + return worker_a, worker_b, collect + + +async def main(): + print("\n" + "=" * 60) + print("Example 12: Parallel Execution with Checkpointing") + print("=" * 60 + "\n") + + # Setup services + session_service = InMemorySessionService() + checkpoint_service = CheckpointService(session_service=session_service) + + # Create agents (deterministic or LLM based on USE_LLM flag) + worker_a, worker_b, collect = create_agents() + + # Build graph with checkpointing enabled + graph = ( + GraphAgent(name="parallel_checkpoint_workflow", checkpointing=True) + .add_node("worker_a", agent=worker_a) + .add_node("worker_b", agent=worker_b) + .add_node("collect", agent=collect) + # Parallel group: both workers run concurrently with WAIT_ALL + .add_parallel_group( + "workers", + ParallelNodeGroup( + nodes=["worker_a", "worker_b"], + join_strategy=JoinStrategy.WAIT_ALL, + ), + ) + # Edges: worker_a (entry for parallel group) -> collect + .add_edge("worker_a", "collect") + .set_start("worker_a") + .set_end("collect") + ) + + # Execute + runner = Runner( + app_name="parallel_checkpoint_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("Executing parallel workflow with checkpointing=True") + print(" Graph: [worker_a || worker_b] -> collect\n") + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" [{event.author}] {part.text}") + + # Inspect checkpoint data from session state + session = await session_service.get_session( + app_name="parallel_checkpoint_demo", + user_id="user1", + session_id="session1", + ) + + checkpoint_index = session.state.get("_checkpoint_index", {}) + print(f"\n Checkpoint index entries: {len(checkpoint_index)}") + + if checkpoint_index: + for cp_id, cp_info in checkpoint_index.items(): + agent = cp_info.get("agent", "unknown") + print(f" Checkpoint '{cp_id[:24]}...': agent={agent}") + else: + print(" (No checkpoints recorded β€” checkpointing requires") + print(" CheckpointCallback for automatic per-node checkpoints)") + + # Show worker results persisted in session state + print(f"\n Session state after execution:") + print(f" result_a = {session.state.get('result_a', 'N/A')!r}") + print(f" result_b = {session.state.get('result_b', 'N/A')!r}") + + # Use checkpoint_service to manually create a checkpoint post-execution + cp_metadata = await checkpoint_service.create_checkpoint( + session=session, + description="Post parallel execution snapshot", + agent_name="parallel_checkpoint_workflow", + ) + print(f"\n Manual checkpoint created: {cp_metadata.checkpoint_id}") + print(f" Checkpoint state keys: {list(cp_metadata.state_snapshot.keys())}") + + print("\nExample complete!\n") + print(" Use CheckpointCallback for automatic per-node checkpointing") + print(" Use checkpoint_service.restore_checkpoint() for resume") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/13_parallel_interrupts/__init__.py b/contributing/samples/graph_examples/13_parallel_interrupts/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/13_parallel_interrupts/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/13_parallel_interrupts/agent.py b/contributing/samples/graph_examples/13_parallel_interrupts/agent.py new file mode 100644 index 0000000000..d34bfe25dd --- /dev/null +++ b/contributing/samples/graph_examples/13_parallel_interrupts/agent.py @@ -0,0 +1,210 @@ +"""Example 13: Interrupt Inside a Parallel Group Branch + +Demonstrates: +- Configuring an interrupt on one branch of a parallel group +- Using interrupt_service.is_active() to detect a pending interrupt +- Calling interrupt_service.resume() to continue after the interrupt fires +- branch_a runs normally; branch_b triggers the AFTER interrupt + +Run modes: +- Default: python -m contributing.samples.graph_examples.13_parallel_interrupts.agent +- LLM: python -m contributing.samples.graph_examples.13_parallel_interrupts.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.13_parallel_interrupts.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import InterruptConfig +from google.adk.agents.graph import InterruptMode +from google.adk.agents.graph import JoinStrategy +from google.adk.agents.graph import ParallelNodeGroup +from google.adk.agents.graph.interrupt_service import InterruptService +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +SESSION_ID = "session1" + + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class BranchAgent(BaseAgent): + """A simple agent representing one parallel branch.""" + + def __init__(self, name: str, label: str, **kwargs): + super().__init__(name=name, **kwargs) + self._label = label + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"Branch '{self._label}' executed")] + ), + ) + + +class JoinAgent(BaseAgent): + """Merges results after parallel branches complete.""" + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text="Both branches joined β€” workflow resumed")] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(): + """Create agents based on USE_LLM mode. + + Returns: + tuple: (branch_a, branch_b, join) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + branch_a = create_llm_agent( + name="branch_a", + instruction="Respond with \"Branch 'A (normal)' executed\" exactly.", + ) + branch_b = create_llm_agent( + name="branch_b", + instruction=( + "Respond with \"Branch 'B (interrupted)' executed\" exactly." + ), + ) + join = create_llm_agent( + name="join", + instruction=( + "Respond with 'Both branches joined β€” workflow resumed' exactly." + ), + ) + + return branch_a, branch_b, join + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + branch_a = BranchAgent(name="branch_a", label="A (normal)") + branch_b = BranchAgent(name="branch_b", label="B (interrupted)") + join = JoinAgent(name="join") + + return branch_a, branch_b, join + + +async def main(): + print("\n" + "=" * 60) + print("Example 13: Interrupt Inside a Parallel Group Branch") + print("=" * 60 + "\n") + + interrupt_service = InterruptService() + + # Create agents (deterministic or LLM based on USE_LLM flag) + branch_a, branch_b, join = create_agents() + + # Build graph with interrupt configured AFTER branch_b + graph = ( + GraphAgent( + name="parallel_interrupt_workflow", + interrupt_service=interrupt_service, + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, + nodes=["branch_b"], + ), + ) + .add_node("branch_a", agent=branch_a) + .add_node("branch_b", agent=branch_b) + .add_node("join", agent=join) + # Parallel group: both branches run concurrently + .add_parallel_group( + "parallel_branches", + ParallelNodeGroup( + nodes=["branch_a", "branch_b"], + join_strategy=JoinStrategy.WAIT_ALL, + ), + ) + # Edges: branch_a (parallel group entry) -> join + .add_edge("branch_a", "join") + .set_start("branch_a") + .set_end("join") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="parallel_interrupt_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("Graph: [branch_a || branch_b] -> join") + print("Interrupt configured AFTER branch_b\n") + + # Pre-register session so we can interact with interrupt service + interrupt_service.register_session(SESSION_ID) + + # Collect events while running; concurrently check for interrupt and resume + events_received = [] + + async def resume_after_interrupt() -> None: + """Poll until interrupt is active, then resume.""" + for _ in range(50): + await asyncio.sleep(0.05) + if interrupt_service.is_paused(SESSION_ID): + print( + " Interrupt detected on 'branch_b' " + "(interrupt_service.is_paused() == True)" + ) + print(" Calling interrupt_service.resume() to continue...") + await interrupt_service.resume(SESSION_ID) + return + print(" (interrupt not triggered within poll window)") + + new_message = types.Content(parts=[types.Part(text="Start")]) + + resume_task = asyncio.create_task(resume_after_interrupt()) + + async for event in runner.run_async( + user_id="user1", session_id=SESSION_ID, new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + events_received.append((event.author, part.text)) + + await resume_task + + print("\n Events received during execution:") + for author, text in events_received: + print(f" [{author}] {text}") + + is_still_active = interrupt_service.is_active(SESSION_ID) + print(f"\n interrupt_service.is_active(session_id): {is_still_active}") + + join_reached = any("join" in author for author, _ in events_received) + print(f" Join node reached: {join_reached}") + + print("\nExample complete!\n") + print(" branch_b triggered an AFTER interrupt during parallel execution") + print(" interrupt_service.resume() allowed the workflow to continue") + print(" Use interrupt_service.cancel() instead to abort the workflow") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/14_parallel_rewind/__init__.py b/contributing/samples/graph_examples/14_parallel_rewind/__init__.py new file mode 100644 index 0000000000..b3a053b0ec --- /dev/null +++ b/contributing/samples/graph_examples/14_parallel_rewind/__init__.py @@ -0,0 +1 @@ +"""GraphAgent example.""" diff --git a/contributing/samples/graph_examples/14_parallel_rewind/agent.py b/contributing/samples/graph_examples/14_parallel_rewind/agent.py new file mode 100644 index 0000000000..78eb352e82 --- /dev/null +++ b/contributing/samples/graph_examples/14_parallel_rewind/agent.py @@ -0,0 +1,194 @@ +"""Example 14: Parallel Execution + Rewind + +Demonstrates: +- Parallel node execution +- Invocation tracking in parallel workflows +- Rewinding to parallel node +- Re-execution of parallel group + +Run modes: +- Default: python -m contributing.samples.graph_examples.14_parallel_rewind.agent +- LLM: python -m contributing.samples.graph_examples.14_parallel_rewind.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.14_parallel_rewind.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import ParallelNodeGroup +from google.adk.agents.graph import rewind_to_node +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class TaskAgent(BaseAgent): + """Agent that executes a task.""" + + def __init__(self, name: str, task_name: str, **kwargs): + super().__init__(name=name, **kwargs) + self._task_name = task_name + self._count = 0 + + async def _run_async_impl(self, ctx): + self._count += 1 + yield Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=( + f"βœ… {self._task_name} completed (execution" + f" #{self._count})" + ) + ) + ] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents(): + """Create agents based on USE_LLM mode. + + Returns: + tuple: (task1, task2, merge) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + task1 = create_llm_agent( + name="task1", + instruction=( + "Respond with 'Data fetch completed (execution #X)' where X is the" + " execution count. Track this in your context." + ), + ) + task2 = create_llm_agent( + name="task2", + instruction=( + "Respond with 'Data transform completed (execution #X)' where X is" + " the execution count. Track this in your context." + ), + ) + merge = create_llm_agent( + name="merge", + instruction=( + "Respond with 'Merge results completed (execution #X)' where X is" + " the execution count. Track this in your context." + ), + ) + + return task1, task2, merge + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + task1 = TaskAgent(name="task1", task_name="Data fetch") + task2 = TaskAgent(name="task2", task_name="Data transform") + merge = TaskAgent(name="merge", task_name="Merge results") + + return task1, task2, merge + + +async def main(): + print("\n" + "=" * 60) + print("Example 14: Parallel Execution + Rewind") + print("=" * 60 + "\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + task1, task2, merge = create_agents() + + # Build graph + graph = ( + GraphAgent(name="parallel_rewind_workflow") + .add_node("task1", agent=task1) + .add_node("task2", agent=task2) + .add_node("merge", agent=merge) + # Add parallel group + .add_parallel_group( + "parallel_tasks", ParallelNodeGroup(nodes=["task1", "task2"]) + ) + .add_edge("task1", "merge") + .add_edge("task2", "merge") + .set_start("task1") + .set_end("merge") + ) + + # Execute + session_service = InMemorySessionService() + runner = Runner( + app_name="parallel_rewind_demo", + agent=graph, + session_service=session_service, + auto_create_session=True, + ) + + print("πŸš€ First execution (parallel tasks)...\n") + + new_message = types.Content(parts=[types.Part(text="Start")]) + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + # Check invocations + session = await session_service.get_session( + app_name="parallel_rewind_demo", user_id="user1", session_id="session1" + ) + node_invocations = session.state.get("node_invocations", {}) + + print(f"\nπŸ“Š Invocation Tracking:") + for node_name, invocations in node_invocations.items(): + print(f" {node_name}: {len(invocations)} invocation(s)") + + # Rewind to task1 (part of parallel group) + print(f"\nβͺ Rewinding to 'task1' (parallel group node)...") + await rewind_to_node( + graph, + session_service, + app_name="parallel_rewind_demo", + user_id="user1", + session_id="session1", + node_name="task1", + invocation_index=-1, + ) + + print(" βœ… Rewind successful!") + + # Re-execute from rewind point + print("\nπŸš€ Re-execution after rewind (parallel group re-runs)...\n") + + async for event in runner.run_async( + user_id="user1", session_id="session1", new_message=new_message + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f" {part.text}") + + print("\nβœ… Example complete!") + print("\n Key Points:") + print(" - Rewind works with parallel nodes") + print(" - Entire parallel group re-executes") + print(" - Invocations tracked per node") + print(" - Execution counts show: task1=#1, task2=#2, merge=#2\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/15_enhanced_routing/__init__.py b/contributing/samples/graph_examples/15_enhanced_routing/__init__.py new file mode 100644 index 0000000000..491d6cc8cf --- /dev/null +++ b/contributing/samples/graph_examples/15_enhanced_routing/__init__.py @@ -0,0 +1 @@ +"""Example 3: Enhanced Routing (Priority, Weight, Fallback).""" diff --git a/contributing/samples/graph_examples/15_enhanced_routing/agent.py b/contributing/samples/graph_examples/15_enhanced_routing/agent.py new file mode 100644 index 0000000000..6dc04718ca --- /dev/null +++ b/contributing/samples/graph_examples/15_enhanced_routing/agent.py @@ -0,0 +1,411 @@ +"""Example 15: Enhanced Routing (Priority, Weight, Fallback) + +Demonstrates: +- Priority-based routing (higher priority evaluated first) +- Weighted random selection (probabilistic routing) +- Fallback edges (priority=0 always matches) + +Run modes: +- Default: python -m contributing.samples.graph_examples.15_enhanced_routing.agent +- LLM: python -m contributing.samples.graph_examples.15_enhanced_routing.agent --use-llm + or: USE_LLM=1 python -m contributing.samples.graph_examples.15_enhanced_routing.agent +""" + +import asyncio + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.graph import EdgeCondition +from google.adk.agents.graph import GraphAgent +from google.adk.agents.graph import GraphState +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from contributing.samples.graph_examples.example_utils import create_llm_agent +from contributing.samples.graph_examples.example_utils import use_llm_mode + +# =========================== +# Deterministic Agents (BaseAgent) +# =========================== + + +class SimpleAgent(BaseAgent): + """Agent that outputs a message.""" + + def __init__(self, name: str, message: str, **kwargs): + super().__init__(name=name, **kwargs) + self._message = message + + async def _run_async_impl(self, ctx): + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=self._message)]), + ) + + +class ScoreAgent(BaseAgent): + """Agent that sets a risk score.""" + + def __init__(self, name: str, score: float, **kwargs): + super().__init__(name=name, **kwargs) + self._score = score + + async def _run_async_impl(self, ctx): + ctx.session.state["risk_score"] = self._score + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"Risk score: {self._score}")] + ), + ) + + +# =========================== +# Agent Factory +# =========================== + + +def create_agents_priority(score: float): + """Create agents for priority routing example. + + Returns: + tuple: (analyze, critical, warning, normal) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + analyze = create_llm_agent( + name="analyze", + instruction=f"Respond with 'Risk score: {score}' exactly.", + ) + critical = create_llm_agent( + name="critical", + instruction=( + "Respond with 'CRITICAL: Immediate action required' exactly." + ), + ) + warning = create_llm_agent( + name="warning", + instruction="Respond with 'WARNING: Review recommended' exactly.", + ) + normal = create_llm_agent( + name="normal", + instruction="Respond with 'NORMAL: No action needed' exactly.", + ) + + return analyze, critical, warning, normal + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + analyze = ScoreAgent(name="analyze", score=score) + critical = SimpleAgent( + name="critical", message="🚨 CRITICAL: Immediate action required" + ) + warning = SimpleAgent( + name="warning", message="⚠️ WARNING: Review recommended" + ) + normal = SimpleAgent(name="normal", message="βœ… NORMAL: No action needed") + + return analyze, critical, warning, normal + + +def create_agents_weighted(): + """Create agents for weighted routing example. + + Returns: + tuple: (start, server_a, server_b, server_c) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + start = create_llm_agent( + name="start", + instruction="Respond with 'Starting load balancer...' exactly.", + ) + server_a = create_llm_agent( + name="server_a", + instruction="Respond with ' β†’ Routed to Server A' exactly.", + ) + server_b = create_llm_agent( + name="server_b", + instruction="Respond with ' β†’ Routed to Server B' exactly.", + ) + server_c = create_llm_agent( + name="server_c", + instruction="Respond with ' β†’ Routed to Server C' exactly.", + ) + + return start, server_a, server_b, server_c + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + start = SimpleAgent(name="start", message="Starting load balancer...") + server_a = SimpleAgent(name="server_a", message=" β†’ Routed to Server A") + server_b = SimpleAgent(name="server_b", message=" β†’ Routed to Server B") + server_c = SimpleAgent(name="server_c", message=" β†’ Routed to Server C") + + return start, server_a, server_b, server_c + + +def create_agents_fallback(score: float): + """Create agents for fallback routing example. + + Returns: + tuple: (validate, premium, standard, fallback) agents + """ + if use_llm_mode(): + print("πŸ€– Using LLM-powered agents (gemini-2.5-flash)\n") + + validate = create_llm_agent( + name="validate", + instruction=f"Respond with 'Risk score: {score}' exactly.", + ) + premium = create_llm_agent( + name="premium", + instruction="Respond with 'Premium path (VIP users)' exactly.", + ) + standard = create_llm_agent( + name="standard", + instruction="Respond with 'Standard path (regular users)' exactly.", + ) + fallback = create_llm_agent( + name="fallback", + instruction="Respond with 'Fallback path (default handler)' exactly.", + ) + + return validate, premium, standard, fallback + else: + print("🎭 Using deterministic agents (BaseAgent)\n") + + validate = ScoreAgent(name="validate", score=score) + premium = SimpleAgent(name="premium", message="🌟 Premium path (VIP users)") + standard = SimpleAgent( + name="standard", message="πŸ“¦ Standard path (regular users)" + ) + fallback = SimpleAgent( + name="fallback", message="πŸ”’ Fallback path (default handler)" + ) + + return validate, premium, standard, fallback + + +async def main(): + print("\n" + "=" * 60) + print("Example 15: Enhanced Routing") + print("=" * 60 + "\n") + + # ===== Example 1: Priority-based Routing ===== + print("πŸ“Š Example 1: Priority-based Routing\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + analyze, critical, warning, normal = create_agents_priority(0.85) + + graph1 = ( + GraphAgent(name="priority_routing") + .add_node("analyze", agent=analyze) + .add_node("critical", agent=critical) + .add_node("warning", agent=warning) + .add_node("normal", agent=normal) + ) + + # Set output mapper to persist risk_score in state + def store_score(output, state): + new_state = GraphState(data=state.data.copy()) + new_state.data["risk_score"] = 0.85 # Score from analyze agent + return new_state + + graph1.nodes["analyze"].output_mapper = store_score + + # Priority-based routing: higher priority evaluated first + graph1.nodes["analyze"].edges = [ + EdgeCondition( + target_node="critical", + condition=lambda s: s.data.get("risk_score", 0) > 0.9, + priority=10, # Highest priority + ), + EdgeCondition( + target_node="warning", + condition=lambda s: s.data.get("risk_score", 0) > 0.7, + priority=5, # Medium priority - THIS WILL MATCH + ), + EdgeCondition( + target_node="normal", + priority=0, # Fallback (priority=0 always matches if no other matched) + ), + ] + + graph1.set_start("analyze") + graph1.set_end("critical") + graph1.set_end("warning") + graph1.set_end("normal") + + session_service = InMemorySessionService() + runner = Runner( + app_name="routing_demo", + agent=graph1, + session_service=session_service, + auto_create_session=True, + ) + + async for event in runner.run_async( + user_id="user1", + session_id="session1", + new_message=types.Content(parts=[types.Part(text="Analyze")]), + ): + if event.content and event.content.parts and event.content.parts[0].text: + print(f" {event.content.parts[0].text}") + + print("\n πŸ’‘ Score was 0.85 β†’ matched 'warning' (priority=5)") + print(" πŸ’‘ 'critical' didn't match (0.85 < 0.9)") + print(" πŸ’‘ 'normal' fallback not needed (higher priority matched)\n") + + # ===== Example 2: Weighted Random Routing ===== + print("🎲 Example 2: Weighted Random Routing\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + start, server_a, server_b, server_c = create_agents_weighted() + + graph2 = ( + GraphAgent(name="weighted_routing") + .add_node("start", agent=start) + .add_node("server_a", agent=server_a) + .add_node("server_b", agent=server_b) + .add_node("server_c", agent=server_c) + ) + + # Weighted routing: all at same priority, different weights + graph2.nodes["start"].edges = [ + EdgeCondition( + target_node="server_a", + condition=lambda s: True, # All match + priority=1, # Same priority + weight=0.5, # 50% probability + ), + EdgeCondition( + target_node="server_b", + condition=lambda s: True, + priority=1, # Same priority + weight=0.3, # 30% probability + ), + EdgeCondition( + target_node="server_c", + condition=lambda s: True, + priority=1, # Same priority + weight=0.2, # 20% probability + ), + ] + + graph2.set_start("start") + graph2.set_end("server_a") + graph2.set_end("server_b") + graph2.set_end("server_c") + + runner2 = Runner( + app_name="weighted_demo", + agent=graph2, + session_service=session_service, + auto_create_session=True, + ) + + # Run multiple times to show distribution + counts = {"server_a": 0, "server_b": 0, "server_c": 0} + trials = 20 + + print(f" Running {trials} trials with weights (A:50%, B:30%, C:20%):\n") + + for i in range(trials): + async for event in runner2.run_async( + user_id="user1", + session_id=f"session_weighted_{i}", + new_message=types.Content(parts=[types.Part(text="Route")]), + ): + if event.content and event.content.parts and event.author in counts: + text = event.content.parts[0].text + counts[event.author] += 1 + print(f" Trial {i+1:2d}: {text}") + + print(f"\n πŸ“Š Distribution after {trials} trials:") + print( + f" Server A: {counts['server_a']:2d}/{trials}" + f" ({counts['server_a']/trials*100:.0f}%)" + ) + print( + f" Server B: {counts['server_b']:2d}/{trials}" + f" ({counts['server_b']/trials*100:.0f}%)" + ) + print( + f" Server C: {counts['server_c']:2d}/{trials}" + f" ({counts['server_c']/trials*100:.0f}%)\n" + ) + + # ===== Example 3: Fallback Edge ===== + print("πŸ›‘οΈ Example 3: Fallback Edge (priority=0)\n") + + # Create agents (deterministic or LLM based on USE_LLM flag) + validate, premium, standard, fallback = create_agents_fallback(0.5) + + graph3 = ( + GraphAgent(name="fallback_routing") + .add_node("validate", agent=validate) + .add_node("premium", agent=premium) + .add_node("standard", agent=standard) + .add_node("fallback", agent=fallback) + ) + + def store_score_fallback(output, state): + new_state = GraphState(data=state.data.copy()) + new_state.data["risk_score"] = 0.5 + # Don't set is_vip or is_standard - will fall through to fallback + return new_state + + graph3.nodes["validate"].output_mapper = store_score_fallback + + graph3.nodes["validate"].edges = [ + EdgeCondition( + target_node="premium", + condition=lambda s: s.data.get("is_vip", False), + priority=10, # High priority - won't match + ), + EdgeCondition( + target_node="standard", + condition=lambda s: s.data.get("is_standard", False), + priority=5, # Medium priority - won't match + ), + EdgeCondition( + target_node="fallback", + priority=0, # FALLBACK - always matches if reached + ), + ] + + graph3.set_start("validate") + graph3.set_end("premium") + graph3.set_end("standard") + graph3.set_end("fallback") + + runner3 = Runner( + app_name="fallback_demo", + agent=graph3, + session_service=session_service, + auto_create_session=True, + ) + + async for event in runner3.run_async( + user_id="user1", + session_id="session_fallback", + new_message=types.Content(parts=[types.Part(text="Validate")]), + ): + if event.content and event.content.parts and event.content.parts[0].text: + print(f" {event.content.parts[0].text}") + + print("\n πŸ’‘ No is_vip or is_standard flag set") + print(" πŸ’‘ All higher priority edges failed to match") + print(" πŸ’‘ Fallback (priority=0) caught it!\n") + + print("=" * 60) + print("βœ… Enhanced Routing Complete!") + print("=" * 60 + "\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/graph_examples/README.md b/contributing/samples/graph_examples/README.md new file mode 100644 index 0000000000..7eb7e7ee89 --- /dev/null +++ b/contributing/samples/graph_examples/README.md @@ -0,0 +1,594 @@ +# GraphAgent Examples - All Features + +Comprehensive collection of small, focused examples demonstrating every GraphAgent feature. + +--- + +## API Overview + +GraphAgent supports both **explicit** and **convenience** APIs for building workflows: + +### Convenience API (Recommended) + +```python +# Fluent chaining pattern +graph = ( + GraphAgent(name="workflow") + .add_node("validate", agent=validator) + .add_node("process", agent=processor) + .add_edge("validate", "process") # Positional args + .add_edge("process", "output", priority=10) # With priority + .set_start("validate") + .set_end("output") +) + +# Or step-by-step +graph = GraphAgent(name="workflow") +graph.add_node("step1", agent=agent1) +graph.add_node("step2", function=custom_function) +graph.add_edge("step1", "step2") +``` + +### Explicit API (Also Supported) + +```python +# Using GraphNode and EdgeCondition explicitly +graph.add_node(GraphNode(name="step1", agent=agent1)) +graph.add_edge("source", EdgeCondition( + target_node="target", + priority=10, + condition=lambda s: s.data.get("valid") +)) +``` + +### Key API Features + +- **add_node()**: Convenience syntax `add_node("name", agent=...)` or explicit `add_node(GraphNode(...))` +- **add_edge()**: Positional `add_edge("source", "target")` or keyword `add_edge(source_node="a", target_node="b")` +- **EdgeCondition**: Support for `add_edge("src", EdgeCondition(target_node="tgt", condition=...))` +- **checkpoint_service**: Optional parameter, not required +- **Fluent chaining**: All builder methods return self for method chaining + +--- + +## Quick Start + +Run any example with: +```bash +cd /path/to/adk-python +source venv/bin/activate +python -m contributing.samples.graph_examples..agent +``` + +### Deterministic vs LLM Modes + +Examples support two execution modes: + +**Deterministic Mode (default)** - Uses BaseAgent subclasses with deterministic outputs. No API keys required. +```bash +python -m contributing.samples.graph_examples.01_basic.agent +``` + +**LLM Mode (optional)** - Uses real Gemini LLM endpoints. Requires API credentials. +```bash +# Via command-line flag +python -m contributing.samples.graph_examples.01_basic.agent --use-llm + +# Via environment variable +USE_LLM=1 python -m contributing.samples.graph_examples.01_basic.agent +``` + +**Note:** LLM mode is only available for simple examples (01_basic, 02_conditional_routing, etc.). Examples that require precise state management (05_interrupts, parallel execution) use deterministic agents to demonstrate graph mechanics reliably. + +--- + +## Examples Overview + +### 🟒 Core Features + +#### **01_basic** - Basic GraphAgent Workflow +Simple directed graph with nodes and edges. +```bash +python -m contributing.samples.graph_examples.01_basic.agent +``` +**Demonstrates:** +- Creating a graph with fluent API +- Adding nodes with convenience syntax: `add_node("name", agent=...)` +- Adding edges with positional syntax: `add_edge("source", "target")` +- Executing workflow + +--- + +#### **02_conditional_routing** - Conditional Routing +State-based routing decisions. +```bash +python -m contributing.samples.graph_examples.02_conditional_routing.agent +``` +**Demonstrates:** +- Conditional edges with `condition` parameter +- EdgeCondition support: `add_edge("src", EdgeCondition(target_node="tgt", condition=...))` +- State-based decisions +- Multiple routing paths +- Priority-based routing with `priority` parameter + +--- + +#### **03_cyclic_execution** - Cyclic Graph Execution +Loops and iteration control. +```bash +python -m contributing.samples.graph_examples.03_cyclic_execution.agent +``` +**Demonstrates:** +- Cyclic graphs with back-edges +- Writing routing signals via state_delta (ADK pattern) +- GraphState.data auto-sync from session.state +- max_iterations guard + +**Key Pattern:** +Agents write routing signals via `Event.actions.state_delta`. GraphAgent automatically syncs `session.state` into `GraphState.data` before edge evaluation β€” no `output_mapper` needed. + +--- + +#### **04_checkpointing** - Checkpointing & Resume +Automatic state persistence. +```bash +python -m contributing.samples.graph_examples.04_checkpointing.agent +``` +**Demonstrates:** +- Automatic checkpointing with optional `checkpoint_service` parameter +- State persistence +- Checkpoint metadata +- Execution path tracking +- Resume from checkpoint capability + +--- + +#### **05_interrupts_basic** - Basic Interrupts +Human-in-the-loop interrupts. +```bash +python -m contributing.samples.graph_examples.05_interrupts_basic.agent +``` +**Demonstrates:** +- All 8 interrupt actions: continue, rerun, pause, go_back, skip, defer, update_state, change_condition +- Concurrent injection via asyncio.create_task +- SlowNode (2 sub-steps Γ— 1s) timing pattern +- AFTER interrupt check behavior (queued during execution, consumed after node completes) + +--- + +#### **06_interrupts_reasoning** - Interrupt with Reasoning +Condition-based action selection. +```bash +python -m contributing.samples.graph_examples.06_interrupts_reasoning.agent +``` +**Demonstrates:** +- Interrupt with condition evaluation +- Automated action selection based on state +- Draft-review workflow with interrupt points + +--- + +#### **07_callbacks** - Node Callbacks +Lifecycle hooks for nodes. +```bash +python -m contributing.samples.graph_examples.07_callbacks.agent +``` +**Demonstrates:** +- `before_node_callback` - executed before node runs +- `after_node_callback` - executed after node completes +- Timing and performance tracking +- Telemetry integration patterns + +--- + +#### **08_rewind** - Rewind Integration +Time-travel debugging. +```bash +python -m contributing.samples.graph_examples.08_rewind.agent +``` +**Demonstrates:** +- Invocation tracking +- Rewinding to specific node +- State restoration +- Re-execution after rewind + +--- + +### ⚑ Parallel Execution + +#### **09_parallel_wait_all** - Parallel Execution (WAIT_ALL) +Concurrent node execution, wait for all. +```bash +python -m contributing.samples.graph_examples.09_parallel_wait_all.agent +``` +**Demonstrates:** +- Parallel node execution +- WAIT_ALL join strategy +- Speedup vs sequential (2.25x) +- Event streaming from parallel nodes + +**Output (example):** +``` +βœ… Fetched data from products_db +βœ… Fetched data from users_db +βœ… Fetched data from orders_db + +All three fetched in parallel. +``` + +--- + +#### **10_parallel_wait_any** - Parallel Execution (WAIT_ANY) +Race condition, first-to-complete wins. +```bash +python -m contributing.samples.graph_examples.10_parallel_wait_any.agent +``` +**Demonstrates:** +- Racing multiple data sources +- WAIT_ANY join strategy +- Automatic cancellation of slower nodes +- Cache-DB-API fallback pattern + +**Output (example):** +``` +βœ… Data from CACHE + +Winner: Cache +Cancelled: Database, API +``` + +--- + +#### **11_parallel_wait_n** - Parallel Execution (WAIT_N) +Continue after N of M complete. +```bash +python -m contributing.samples.graph_examples.11_parallel_wait_n.agent +``` +**Demonstrates:** +- WAIT_N join strategy (e.g., 2 out of 3) +- ML model ensemble pattern +- Partial completion workflows +- Automatic cancellation of remaining nodes + +--- + +#### **12_parallel_checkpointing** - Parallel + Checkpointing +State persistence across parallel execution. +```bash +python -m contributing.samples.graph_examples.12_parallel_checkpointing.agent +``` +**Demonstrates:** +- Parallel execution with automatic checkpointing +- State recovery after interruption +- Checkpoint metadata tracking +- Resume from mid-parallel execution + +--- + +#### **13_parallel_interrupts** - Parallel + Interrupts +Interrupt handling inside parallel branches. +```bash +python -m contributing.samples.graph_examples.13_parallel_interrupts.agent +``` +**Demonstrates:** +- Interrupts within parallel node execution +- Branch-specific interrupt handling +- Pause/resume in parallel context +- Interrupt isolation across branches + +--- + +### πŸ”— Combined Features + +#### **14_parallel_rewind** - Parallel Execution + Rewind +Rewind works with parallel workflows! +```bash +python -m contributing.samples.graph_examples.14_parallel_rewind.agent +``` +**Demonstrates:** +- Parallel + Rewind integration +- Invocation tracking in parallel groups +- Re-execution of entire parallel group +- State consistency across rewind + +**Key Insight:** +- Rewind to parallel node β†’ entire parallel group re-executes +- All branches get new invocations +- Deterministic re-execution + +--- + +#### **15_enhanced_routing** - Enhanced Routing Patterns +Priority, weighted, and fallback routing. +```bash +python -m contributing.samples.graph_examples.15_enhanced_routing.agent +``` +**Demonstrates:** +- Priority-based routing (higher priority evaluated first) +- Weighted random selection (probabilistic routing) +- Fallback edges (priority=0 always matches) +- Three routing patterns in one example + +--- + +## Feature Matrix + +| Example | Parallel | Rewind | Checkpoints | Interrupts | Callbacks | Cyclic | Routing | +|---------|----------|--------|-------------|------------|-----------|--------|---------| +| 01_basic | - | - | - | - | - | - | Simple | +| 02_conditional_routing | - | - | - | - | - | - | Conditional | +| 03_cyclic_execution | - | - | - | - | - | βœ… | Conditional | +| 04_checkpointing | - | - | βœ… | - | - | - | Simple | +| 05_interrupts_basic | - | - | - | βœ… | - | - | Simple | +| 06_interrupts_reasoning | - | - | - | βœ… | - | - | Conditional | +| 07_callbacks | - | - | - | - | βœ… | - | Simple | +| 08_rewind | - | βœ… | - | - | - | - | Simple | +| 09_parallel_wait_all | βœ… | - | - | - | - | - | Parallel | +| 10_parallel_wait_any | βœ… | - | - | - | - | - | Parallel | +| 11_parallel_wait_n | βœ… | - | - | - | - | - | Parallel | +| 12_parallel_checkpointing | βœ… | - | βœ… | - | - | - | Parallel | +| 13_parallel_interrupts | βœ… | - | - | βœ… | - | - | Parallel | +| 14_parallel_rewind | βœ… | βœ… | - | - | - | - | Parallel | +| 15_enhanced_routing | - | - | - | - | - | - | Advanced | + +--- + +## Architectural Insights + +### Parallel Execution Architecture + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ validate β”‚ +β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ β”‚ β”‚ + β–Ό β–Ό β–Ό +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ fetch_A β”‚ β”‚ fetch_B β”‚ β”‚ fetch_C β”‚ +β”‚ (isolated) β”‚ β”‚ (isolated) β”‚ β”‚ (isolated) β”‚ +β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ β”‚ β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β–Ό + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ aggregate β”‚ + β”‚(merged state)β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +**Key Points:** +- Each branch has **isolated state** during execution +- No race conditions possible +- State **merged** after all branches complete +- Events **streamed** as branches complete (FIRST_COMPLETED) + +--- + +### Rewind with Parallel Execution + +``` +1. Initial Execution: + validate β†’ (fetch_A || fetch_B || fetch_C) β†’ aggregate + + Invocations created: + - validate: ["inv_1"] + - fetch_A: ["inv_2"] + - fetch_B: ["inv_3"] + - fetch_C: ["inv_4"] + - aggregate: ["inv_5"] + +2. Rewind to fetch_A (inv_2): + Session state restored to BEFORE inv_2 + +3. Re-execution: + (fetch_A || fetch_B || fetch_C) β†’ aggregate + + New invocations: + - fetch_A: ["inv_2", "inv_6"] + - fetch_B: ["inv_3", "inv_7"] + - fetch_C: ["inv_4", "inv_8"] + - aggregate: ["inv_5", "inv_9"] +``` + +**Key Points:** +- Rewind works seamlessly with parallel groups +- Entire parallel group re-executes +- New invocations created on re-execution +- Deterministic behavior guaranteed + +--- + +### State Isolation + +**Problem:** Multiple nodes modifying same state β†’ race conditions + +**Solution:** Isolated state copies per branch + +```python +# During parallel execution +for node in parallel_group.nodes: + # Each branch gets ISOLATED copy + branch_state = state.copy() + + # Modify branch state + execute_node(node, branch_state) + +# After all complete +merged_state = merge(all_branch_states) +``` + +**Benefits:** +- No race conditions +- Deterministic results +- Safe concurrent execution + +--- + +## Performance Comparison + +### Sequential vs Parallel (WAIT_ALL) + +**Scenario:** Fetch from 3 sources (100ms, 150ms, 200ms each) + +**Sequential:** +``` +Total time = 100 + 150 + 200 = 450ms +``` + +**Parallel (WAIT_ALL):** +``` +Total time = max(100, 150, 200) = 200ms +Speedup: 450ms / 200ms = 2.25x +``` + +**Parallel (WAIT_ANY):** +``` +Total time = min(100, 150, 200) = 100ms +Speedup: 450ms / 100ms = 4.5x +``` + +--- + +## Common Patterns + +### 1. Data Pipeline (WAIT_ALL) +Fetch data from multiple sources concurrently. +```python +ParallelNodeGroup( + nodes=["fetch_users", "fetch_products", "fetch_orders"], + join_strategy=JoinStrategy.WAIT_ALL +) +``` + +### 2. Cache-DB-API Fallback (WAIT_ANY) +Race multiple data sources, use fastest. +```python +ParallelNodeGroup( + nodes=["from_cache", "from_db", "from_api"], + join_strategy=JoinStrategy.WAIT_ANY +) +``` + +### 3. ML Model Ensemble (WAIT_N) +Run multiple models, proceed when N complete. +```python +ParallelNodeGroup( + nodes=["model1", "model2", "model3"], + join_strategy=JoinStrategy.WAIT_N, + wait_n=2 # 2 out of 3 +) +``` + +### 4. Interrupt-Driven Review +Human review after key nodes. +```python +InterruptConfig( + mode=InterruptMode.AFTER, + nodes=["draft", "review"] +) +``` + +### 5. Checkpoint-Resume Workflow +Long-running workflows with state persistence. +```python +GraphAgent( + name="workflow", + checkpoint_service=checkpoint_service # Optional parameter +) +``` + +--- + +## Error Handling + +### Parallel Error Policies + +#### FAIL_FAST (default) +```python +ParallelNodeGroup( + nodes=["task1", "task2", "task3"], + error_policy=ErrorPolicy.FAIL_FAST +) +# One error β†’ cancel all β†’ raise exception +``` + +#### CONTINUE +```python +ParallelNodeGroup( + nodes=["task1", "task2", "task3"], + error_policy=ErrorPolicy.CONTINUE +) +# One error β†’ continue others β†’ log error +``` + +#### COLLECT +```python +ParallelNodeGroup( + nodes=["task1", "task2", "task3"], + error_policy=ErrorPolicy.COLLECT +) +# All errors β†’ collect all β†’ raise at end +``` + +--- + +## Testing + +Run tests: +```bash +pytest tests/unittests/agents/test_graph_*.py -v +``` + +--- + +## Next Steps + +1. **Try the examples** - Run each one to see features in action +2. **Modify examples** - Change parameters, add nodes, experiment +3. **Combine features** - Mix parallel + rewind + checkpoints +4. **Build your workflow** - Use patterns for your use case + +--- + +## Related Samples: graph_agent_* (Complex Real-World Examples) + +In addition to the numbered graph_examples, there are advanced samples at `contributing/samples/graph_agent_*`: + +| Sample | Description | Pattern | +|--------|-------------|---------| +| **graph_agent_basic** | Basic research workflow | LLM-powered, `agents.py` + `agent.py` | +| **graph_agent_advanced** | Complex research paper workflow | Multi-phase with review loop | +| **graph_agent_react_pattern** | ReAct pattern (Reason + Act) | Thought-action-observation cycle | +| **graph_agent_multi_agent** | Multiple specialized agents | Delegation and collaboration | +| **graph_agent_dynamic_queue** | Dynamic node queueing | Runtime graph modification | +| **graph_agent_parallel_features** | Parallel feature demonstrations | Showcases parallel capabilities | +| **graph_agent_pattern_dynamic_node** | Dynamic node creation | Runtime node injection | +| **graph_agent_pattern_nested_graph** | Nested GraphAgents | Graph as node pattern | +| **graph_agent_pattern_parallel_group** | Parallel group pattern | Advanced parallel workflows | + +**Key Differences from graph_examples**: +- **LLM-Required**: Use LlmAgent, require API credentials (no deterministic mode) +- **Structure**: Single `agent.py` with inline agent definitions +- **Complexity**: Multi-phase workflows, real-world use cases +- **Purpose**: Demonstrate production patterns, not individual features +- **Note**: Some samples have `__init__.py` for package structure + +**Run Example**: +```bash +cd /path/to/adk-python +source venv/bin/activate +python -m contributing.samples.graph_agent_advanced.agent +``` + +--- + +## Support + +Questions? Check: +- Examples: `contributing/samples/graph_examples/` +- Advanced: `contributing/samples/graph_agent_*/` +- Tests: `tests/unittests/agents/test_graph_*.py` +- Source: `src/google/adk/agents/graph/` diff --git a/contributing/samples/graph_examples/__init__.py b/contributing/samples/graph_examples/__init__.py new file mode 100644 index 0000000000..425a6cca45 --- /dev/null +++ b/contributing/samples/graph_examples/__init__.py @@ -0,0 +1,20 @@ +"""GraphAgent Examples - All Features + +Small, focused examples demonstrating every GraphAgent feature. + +Examples: +- 01_basic: Basic workflow +- 02_conditional_routing: State-based routing +- 15_enhanced_routing: Priority, weight, and fallback routing +- 04_checkpointing: Automatic state persistence +- 05_interrupts_basic: Human-in-the-loop +- 08_rewind: Time-travel debugging +- 09_parallel_wait_all: Parallel execution (WAIT_ALL) +- 10_parallel_wait_any: Parallel execution (WAIT_ANY / race) +- 14_parallel_rewind: Combined parallel + rewind + +Run any example: + python -m contributing.samples.graph_examples..agent + +See README.md for complete documentation. +""" diff --git a/contributing/samples/graph_examples/example_utils.py b/contributing/samples/graph_examples/example_utils.py new file mode 100644 index 0000000000..3b48cd7359 --- /dev/null +++ b/contributing/samples/graph_examples/example_utils.py @@ -0,0 +1,62 @@ +"""Shared utilities for GraphAgent examples. + +Provides: +- USE_LLM flag to toggle between deterministic agents and real LLM endpoints +- Helper to create LLM-powered agents with consistent configuration +""" + +import os +import sys + + +def use_llm_mode() -> bool: + """Check if examples should use real LLM endpoints instead of deterministic agents. + + Returns True if: + - Environment variable USE_LLM=1 or USE_LLM=true + - Command-line flag --use-llm is present + + Default: False (use deterministic BaseAgent implementations) + """ + # Check environment variable + env_use_llm = os.getenv("USE_LLM", "").lower() in ("1", "true", "yes") + + # Check command-line args + arg_use_llm = "--use-llm" in sys.argv + + return env_use_llm or arg_use_llm + + +def create_llm_agent( + name: str, + instruction=None, + model: str = "gemini-2.5-flash", + tools: list = None, + **kwargs, +): + """Create an LLM-powered agent. + + Args: + name: Agent name + instruction: System instruction (str or callable for dynamic instructions) + model: Model to use (default: gemini-2.5-flash) + tools: Optional list of tools + **kwargs: Additional Agent configuration (e.g., include_contents='none') + + Returns: + Agent instance configured with LLM + + Note: + Requires valid API credentials configured via: + - GOOGLE_GENAI_API_KEY environment variable, or + - gcloud auth application-default login + """ + from google.adk import Agent + + return Agent( + name=name, + model=model, + instruction=instruction, + tools=tools or [], + **kwargs, + ) diff --git a/contributing/samples/graph_examples/run_all_examples.sh b/contributing/samples/graph_examples/run_all_examples.sh new file mode 100755 index 0000000000..65272f5de9 --- /dev/null +++ b/contributing/samples/graph_examples/run_all_examples.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Run all GraphAgent examples + +set -e + +cd "$(dirname "$0")/../../.." +source venv/bin/activate + +echo "========================================" +echo "Running All GraphAgent Examples" +echo "========================================" +echo "" + +examples=( + "01_basic" + "02_conditional_routing" + "03_cyclic_execution" + "15_enhanced_routing" + "04_checkpointing" + "05_interrupts_basic" + "06_interrupts_reasoning" + "07_callbacks" + "08_rewind" + "09_parallel_wait_all" + "10_parallel_wait_any" + "11_parallel_wait_n" + "12_parallel_checkpointing" + "13_parallel_interrupts" + "14_parallel_rewind" +) + +for example in "${examples[@]}"; do + echo "----------------------------------------" + echo "Running: $example" + echo "----------------------------------------" + python -m "contributing.samples.graph_examples.${example}.agent" 2>&1 | grep -v "UserWarning" || true + echo "" +done + +echo "========================================" +echo "βœ… All Examples Complete!" +echo "========================================" diff --git a/contributing/samples/graph_examples/run_example.py b/contributing/samples/graph_examples/run_example.py new file mode 100755 index 0000000000..c07b2579d4 --- /dev/null +++ b/contributing/samples/graph_examples/run_example.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +"""Utility to run graph_examples with optional trace logging and LLM mode. + +Usage: + # Default mode (deterministic agents) + python run_example.py 01_basic + + # LLM mode + python run_example.py 01_basic --use-llm + + # With trace logs + python run_example.py 01_basic --trace + + # Both + python run_example.py 01_basic --use-llm --trace + + # List all examples + python run_example.py --list +""" + +import argparse +import importlib +import logging +from pathlib import Path +import sys + + +def setup_logging(trace: bool = False): + """Configure logging based on trace flag.""" + level = logging.DEBUG if trace else logging.INFO + format_str = ( + "%(asctime)s [%(levelname)s] %(name)s:%(lineno)d - %(message)s" + if trace + else "%(message)s" + ) + + logging.basicConfig( + level=level, + format=format_str, + datefmt="%H:%M:%S", + ) + + # Enable ADK trace logging + if trace: + logging.getLogger("google_adk").setLevel(logging.DEBUG) + logging.getLogger("google.adk").setLevel(logging.DEBUG) + + +def list_examples(): + """List all available examples.""" + examples_dir = Path(__file__).parent + examples = sorted([ + d.name + for d in examples_dir.iterdir() + if d.is_dir() and (d / "agent.py").exists() and not d.name.startswith("_") + ]) + + print("\nπŸ“š Available graph_examples:\n") + for ex in examples: + agent_file = examples_dir / ex / "agent.py" + # Read first docstring line + with open(agent_file) as f: + lines = f.readlines() + desc = "" + for line in lines: + if line.strip().startswith('"""'): + desc = line.strip('"""').strip() + break + print(f" {ex:30s} - {desc}") + + print("\n") + + +def run_example(example_name: str, use_llm: bool = False, trace: bool = False): + """Run a specific example.""" + setup_logging(trace=trace) + + # Set USE_LLM env var if needed + if use_llm: + import os + + os.environ["USE_LLM"] = "1" + + # Verify example exists + example_dir = Path(__file__).parent / example_name + if not example_dir.exists() or not (example_dir / "agent.py").exists(): + print(f"❌ Example '{example_name}' not found") + print("\nRun with --list to see available examples") + sys.exit(1) + + # Run via subprocess to handle module names starting with numbers + import os + import subprocess + + env = os.environ.copy() + if use_llm: + env["USE_LLM"] = "1" + + # Run from adk-python root + adk_root = Path(__file__).parent.parent.parent.parent + module_path = f"contributing.samples.graph_examples.{example_name}.agent" + + print(f"\n{'='*70}") + print(f"Running: {example_name}") + print(f"Mode: {'πŸ€– LLM' if use_llm else '🎭 Deterministic'}") + print(f"Trace: {'βœ“ Enabled' if trace else 'βœ— Disabled'}") + print(f"{'='*70}\n") + + try: + result = subprocess.run( + [sys.executable, "-m", module_path], + cwd=str(adk_root), + env=env, + capture_output=False, + text=True, + ) + sys.exit(result.returncode) + except Exception as e: + print(f"\n❌ Error running example: {e}") + if trace: + import traceback + + traceback.print_exc() + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser( + description="Run graph_examples with optional trace logging and LLM mode" + ) + parser.add_argument( + "example", nargs="?", help="Example name (e.g., 01_basic)" + ) + parser.add_argument( + "--use-llm", + action="store_true", + help="Use real LLM endpoints instead of deterministic agents", + ) + parser.add_argument( + "--trace", action="store_true", help="Enable detailed trace logging" + ) + parser.add_argument("--list", action="store_true", help="List all examples") + + args = parser.parse_args() + + if args.list: + list_examples() + return + + if not args.example: + parser.print_help() + print("\nRun with --list to see available examples") + sys.exit(1) + + run_example(args.example, use_llm=args.use_llm, trace=args.trace) + + +if __name__ == "__main__": + main() diff --git a/src/google/adk/agents/__init__.py b/src/google/adk/agents/__init__.py index fbd1808f3f..4735379bd0 100644 --- a/src/google/adk/agents/__init__.py +++ b/src/google/adk/agents/__init__.py @@ -14,6 +14,11 @@ from .base_agent import BaseAgent from .context import Context +from .graph import END +from .graph import GraphAgent +from .graph import GraphNode +from .graph import GraphState +from .graph import START from .invocation_context import InvocationContext from .live_request_queue import LiveRequest from .live_request_queue import LiveRequestQueue @@ -29,6 +34,11 @@ 'Agent', 'BaseAgent', 'Context', + 'GraphAgent', + 'GraphNode', + 'GraphState', + 'START', + 'END', 'LlmAgent', 'LoopAgent', 'McpInstructionProvider', diff --git a/src/google/adk/agents/graph/__init__.py b/src/google/adk/agents/graph/__init__.py new file mode 100644 index 0000000000..63328e2542 --- /dev/null +++ b/src/google/adk/agents/graph/__init__.py @@ -0,0 +1,124 @@ +"""Graph-based agent components. + +This module contains components for graph-based workflow orchestration: +- GraphAgent: Main graph workflow agent +- GraphState: Domain data container +- GraphAgentState: Execution tracking state (BaseAgentState) +- GraphNode: Node wrapper for agents and functions +- EdgeCondition: Conditional routing between nodes +- StateReducer: State merge strategies +- InterruptMode: Human-in-the-loop interrupt modes +- InterruptService: Dynamic runtime interrupts with queue bounds and metrics +- InterruptServiceConfig: Configuration for interrupt service +- InterruptMessage: Message from human to agent +- QueueStatus: Queue status information +- SessionMetrics: Per-session interrupt metrics +- GraphEvent: Typed events for streaming +- GraphEventType: Event type enumeration +- GraphStreamMode: Stream mode enumeration +- NodeCallbackContext: Context for node lifecycle callbacks +- EdgeCallbackContext: Context for edge condition callbacks +- NodeCallback: Type for node lifecycle callbacks +- EdgeCallback: Type for edge condition callbacks +- DynamicNode: Runtime agent selection based on state +- NestedGraphNode: Hierarchical workflow composition (graph within graph) +- DynamicParallelGroup: Dynamic concurrent execution with variable agent count +""" + +from __future__ import annotations + +from .checkpoint_callback import GraphCheckpointCallback +from .callbacks import create_nested_observability_callback +from .callbacks import EdgeCallback +from .callbacks import EdgeCallbackContext +from .callbacks import NodeCallback +from .callbacks import NodeCallbackContext +from .evaluation_metrics import graph_path_match +from .evaluation_metrics import node_execution_count +from .evaluation_metrics import state_contains_keys +from .graph_agent import GraphAgent +from .graph_agent_config import GraphAgentConfig +from .graph_agent_config import GraphEdgeConfig +from .graph_agent_config import GraphNodeConfig +from .graph_agent_state import GraphAgentState +from .graph_edge import EdgeCondition +from .graph_events import GraphEvent +from .graph_events import GraphEventType +from .graph_events import GraphStreamMode +from .graph_export import export_execution_timeline +from .graph_export import export_graph_structure +from .graph_export import export_graph_with_execution +from .graph_node import GraphNode +from .graph_rewind import rewind_to_node +from .graph_state import GraphState +from .graph_state import PydanticJSONEncoder +from .graph_state import StateReducer +from .interrupt import InterruptAction +from .interrupt import InterruptConfig +from .interrupt import InterruptMode +from .interrupt_reasoner import InterruptDecision +from .interrupt_reasoner import InterruptReasoner +from .interrupt_reasoner import InterruptReasonerConfig +from .interrupt_service import InterruptMessage +from .interrupt_service import InterruptService +from .interrupt_service import InterruptServiceConfig +from .interrupt_service import QueueStatus +from .interrupt_service import SessionMetrics +from .parallel import ErrorPolicy +from .parallel import JoinStrategy +from .parallel import ParallelNodeGroup +from .patterns import DynamicNode +from .patterns import DynamicParallelGroup +from .patterns import NestedGraphNode + +# Sentinel constants for graph boundaries +START = "__start__" +END = "__end__" + +__all__ = [ + "GraphAgent", + "GraphAgentConfig", + "GraphAgentState", + "GraphNodeConfig", + "GraphEdgeConfig", + "GraphState", + "GraphNode", + "EdgeCondition", + "StateReducer", + "PydanticJSONEncoder", + "InterruptMode", + "InterruptConfig", + "InterruptAction", + "InterruptDecision", + "InterruptReasoner", + "InterruptReasonerConfig", + "InterruptService", + "InterruptServiceConfig", + "InterruptMessage", + "QueueStatus", + "SessionMetrics", + "GraphEvent", + "GraphEventType", + "GraphStreamMode", + "NodeCallbackContext", + "EdgeCallbackContext", + "NodeCallback", + "EdgeCallback", + "create_nested_observability_callback", + "ParallelNodeGroup", + "JoinStrategy", + "ErrorPolicy", + "graph_path_match", + "state_contains_keys", + "node_execution_count", + "DynamicNode", + "NestedGraphNode", + "DynamicParallelGroup", + "export_graph_structure", + "export_graph_with_execution", + "export_execution_timeline", + "rewind_to_node", + "GraphCheckpointCallback", + "START", + "END", +] diff --git a/src/google/adk/agents/graph/callbacks.py b/src/google/adk/agents/graph/callbacks.py new file mode 100644 index 0000000000..6660932b36 --- /dev/null +++ b/src/google/adk/agents/graph/callbacks.py @@ -0,0 +1,218 @@ +"""Callback infrastructure for graph observability and extensibility. + +This module provides callback primitives for customizing graph behavior: +- NodeCallbackContext: Context passed to node lifecycle callbacks +- EdgeCallbackContext: Context passed to edge condition callbacks +- NodeCallback: Type for before/after node callbacks +- EdgeCallback: Type for edge condition callbacks + +Callbacks enable custom observability, logging, debugging, and control flow +without modifying the core GraphAgent implementation. + +Example: + ```python + import json + from google.adk.agents.graph import GraphAgent + from google.adk.agents.graph.callbacks import NodeCallbackContext + from google import genai + + async def my_observability(ctx: NodeCallbackContext): + '''Custom observability callback.''' + # Automatic Pydantic serialization + return genai.types.Event( + author="observability", + content=genai.types.Content(parts=[ + genai.types.Part(text=f"β†’ Executing: {ctx.node.name}"), + genai.types.Part(text=f"State:\n{ctx.state.data_to_json()}"), + ]), + actions=genai.types.EventActions( + escalate=False, + state_delta={ + "observability_node": ctx.node.name, + "observability_iteration": ctx.iteration, + } + ) + ) + + graph = GraphAgent( + name="my_graph", + before_node_callback=my_observability, + ) + ``` +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import Dict +from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...events.event import Event + +from google import genai + +from .graph_state import GraphState + + +@dataclass +class NodeCallbackContext: + """Context passed to node lifecycle callbacks. + + Contains all information about the current node execution, + including the node itself, current state, iteration number, + and full invocation context. + + Attributes: + node: The GraphNode being executed + state: Current graph state (before or after execution) + iteration: Current iteration number in graph execution + invocation_context: Full ADK invocation context + metadata: Extensible metadata dictionary for custom use + """ + + node: Any # GraphNode (avoiding circular import) + state: GraphState + iteration: int + invocation_context: Any # InvocationContext (from genai.types) + metadata: Dict[str, Any] + + +@dataclass +class EdgeCallbackContext: + """Context passed to edge condition callbacks. + + Contains all information about an edge condition evaluation, + including source/target nodes, the condition function, + evaluation result, and current state. + + Attributes: + source_node: Name of the source node + target_node: Name of the target node + condition: The condition function being evaluated (or None for unconditional) + condition_result: Boolean result of condition evaluation + state: Current graph state + invocation_context: Full ADK invocation context + metadata: Extensible metadata dictionary for custom use + """ + + source_node: str + target_node: str + condition: Optional[Callable[[GraphState], bool]] + condition_result: bool + state: GraphState + invocation_context: Any # InvocationContext (from genai.types) + metadata: Dict[str, Any] + + +# Type aliases for callbacks +NodeCallback = Callable[[NodeCallbackContext], Awaitable[Optional["Event"]]] +"""Callback function type for node lifecycle events. + +Receives NodeCallbackContext and optionally returns an Event to emit. +Returning None skips event emission. + +Example: + ```python + from google.adk.events.event import Event + + async def my_node_callback(ctx: NodeCallbackContext) -> Optional[Event]: + # Custom logic here + if should_emit: + return Event(...) + return None # Skip event + ``` +""" + +EdgeCallback = Callable[[EdgeCallbackContext], Awaitable[Optional["Event"]]] +"""Callback function type for edge condition events. + +Receives EdgeCallbackContext and optionally returns an Event to emit. +Returning None skips event emission. + +Example: + ```python + from google.adk.events.event import Event + from google.genai.types import Content, Part + + async def my_edge_callback(ctx: EdgeCallbackContext) -> Optional[Event]: + # Log conditional routing decisions + if ctx.condition_result: + return Event( + content=Content(parts=[ + Part(text=f"Routing: {ctx.source_node} β†’ {ctx.target_node}") + ]) + ) + return None + ``` +""" + + +def create_nested_observability_callback() -> NodeCallback: + """Create a callback that shows nested graph hierarchy. + + Returns a NodeCallback that includes agent path information + in observability events, making nested graph execution visible. + + Example: + ```python + graph = GraphAgent( + name="my_graph", + before_node_callback=create_nested_observability_callback(), + ) + ``` + + Returns: + NodeCallback that emits events with nesting hierarchy + """ + + async def nested_callback(ctx: NodeCallbackContext) -> Optional["Event"]: + """Emit observability event with nested graph hierarchy.""" + from ...events.event import Event + from ...events.event_actions import EventActions + + # Get agent path from callback metadata + agent_path = ctx.metadata.get("agent_path", []) + hierarchy = " β†’ ".join(agent_path) if agent_path else ctx.node.name + + # Derive node type from class name + node_type_map = { + "GraphNode": "agent", + "DynamicNode": "dynamic", + "DynamicParallelGroup": "parallel", + "NestedGraphNode": "nested", + } + node_class_name = type(ctx.node).__name__ + node_type = node_type_map.get(node_class_name, "function") + + # Collect _debug_ keys relevant to current node + debug_prefix = f"_debug_{ctx.node.name}_" + debug_info = { + k: v for k, v in ctx.state.data.items() if k.startswith(debug_prefix) + } + + return Event( + author="observability", + content=genai.types.Content( + parts=[ + genai.types.Part(text=f"[{hierarchy}] β†’ {ctx.node.name}"), + genai.types.Part(text=f"Iteration: {ctx.iteration}"), + ] + ), + actions=EventActions( + escalate=False, + state_delta={ + "observability_hierarchy": hierarchy, + "observability_level": len(agent_path), + "observability_node": ctx.node.name, + "observability_node_type": node_type, + "observability_debug": debug_info, + }, + ), + ) + + return nested_callback diff --git a/src/google/adk/agents/graph/checkpoint_callback.py b/src/google/adk/agents/graph/checkpoint_callback.py new file mode 100644 index 0000000000..b6b1d65399 --- /dev/null +++ b/src/google/adk/agents/graph/checkpoint_callback.py @@ -0,0 +1,195 @@ +"""Graph-specific checkpoint callback for node-level checkpointing. + +Extends the generic CheckpointCallback with before_node/after_node methods +that integrate with GraphAgent's node callback system. +""" + +from __future__ import annotations + +import json +from typing import Optional +from typing import Set +from typing import TYPE_CHECKING + +from ...checkpoints.callback import CheckpointCallback +from ...checkpoints.checkpoint_service import CheckpointService + +if TYPE_CHECKING: + from .callbacks import NodeCallbackContext + + +class GraphCheckpointCallback(CheckpointCallback): + """Checkpoint callback with GraphAgent node-level support. + + Extends CheckpointCallback with before_node/after_node methods for + per-node checkpointing in GraphAgent workflows. + + Inherits agent-level callbacks (before_agent/after_agent) from + CheckpointCallback β€” works with any BaseAgent subclass. + + Node-level example (selective checkpointing per node): + ```python + from google.adk.agents.graph import GraphCheckpointCallback + + checkpoint_service = CheckpointService(session_service=session_service) + # Checkpoint only after critical nodes, not every node + checkpoint_callback = GraphCheckpointCallback( + checkpoint_service, + checkpoint_nodes={"analyze", "execute"}, # only these nodes + ) + + graph = GraphAgent( + name="workflow", + after_node_callback=checkpoint_callback.after_node, + ) + ``` + + Agent-proposed checkpoint example (LLM decides when to checkpoint): + ```python + from pydantic import BaseModel + + class AnalysisOutput(BaseModel): + finding: str + risk_level: str + checkpoint_requested: bool = False + + analyzer = LlmAgent( + name="analyzer", + output_schema=AnalysisOutput, + instruction="... Set checkpoint_requested=true if risk_level is 'high'.", + ) + + checkpoint_callback = GraphCheckpointCallback( + checkpoint_service, + checkpoint_after=False, # no automatic checkpoints + checkpoint_request_key="analyzer.checkpoint_requested", + ) + graph = GraphAgent( + name="workflow", + after_node_callback=checkpoint_callback.after_node, + ) + ``` + """ + + def __init__( + self, + checkpoint_service: CheckpointService, + checkpoint_before: bool = True, + checkpoint_after: bool = True, + checkpoint_nodes: Optional[Set[str]] = None, + checkpoint_request_key: Optional[str] = None, + ): + """Initialize graph checkpoint callback. + + Args: + checkpoint_service: CheckpointService instance to use + checkpoint_before: Create checkpoint before agent/node execution + checkpoint_after: Create checkpoint after agent/node execution + checkpoint_nodes: For node-level callbacks, only checkpoint these nodes. + None means checkpoint all nodes. Has no effect on agent-level callbacks. + checkpoint_request_key: Dotted path "state_key.bool_field" that an LLM + agent can set to propose a checkpoint (e.g. "analyzer.checkpoint_requested"). + When the named node finishes and the field is truthy, an additional + checkpoint is created. Default None (disabled). Opt-in only. + """ + super().__init__( + checkpoint_service=checkpoint_service, + checkpoint_before=checkpoint_before, + checkpoint_after=checkpoint_after, + ) + self.checkpoint_nodes = checkpoint_nodes + # Pre-parse dotted key once to avoid repeated string splits at runtime + if checkpoint_request_key: + parts = checkpoint_request_key.split(".", 1) + self._req_state_key: Optional[str] = parts[0] + self._req_field: Optional[str] = ( + parts[1] if len(parts) > 1 else "checkpoint_requested" + ) + else: + self._req_state_key = None + self._req_field = None + + def _should_checkpoint_node(self, node_name: str) -> bool: + """Check if a specific node should be checkpointed.""" + if self.checkpoint_nodes is None: + return True + return node_name in self.checkpoint_nodes + + async def before_node(self, ctx: "NodeCallbackContext") -> None: + """Create checkpoint before a GraphAgent node executes. + + Used with GraphAgent's before_node_callback. Supports selective + checkpointing via checkpoint_nodes parameter. + + Args: + ctx: Node callback context from GraphAgent + + Returns: + None (checkpoint stored via session_service.append_event) + """ + if not self.checkpoint_before: + return None + + if not self._should_checkpoint_node(ctx.node.name): + return None + + session = ctx.invocation_context.session + checkpoint_id = f"{session.id}-{ctx.node.name}-{ctx.iteration}-before" + + await self.service.create_checkpoint( + session=session, + checkpoint_id=checkpoint_id, + description=f"Before node {ctx.node.name} (iteration {ctx.iteration})", + agent_name=ctx.node.name, + ) + + return None + + async def after_node(self, ctx: "NodeCallbackContext") -> None: + """Create checkpoint after a GraphAgent node completes. + + Used with GraphAgent's after_node_callback. Supports selective + checkpointing via checkpoint_nodes parameter, and agent-proposed + checkpointing via checkpoint_request_key. + + Args: + ctx: Node callback context from GraphAgent + + Returns: + None (checkpoint stored via session_service.append_event) + """ + session = ctx.invocation_context.session + + # Infrastructure-driven checkpoint (checkpoint_after + checkpoint_nodes filter) + if self.checkpoint_after and self._should_checkpoint_node(ctx.node.name): + checkpoint_id = f"{session.id}-{ctx.node.name}-{ctx.iteration}-after" + await self.service.create_checkpoint( + session=session, + checkpoint_id=checkpoint_id, + description=f"After node {ctx.node.name} (iteration {ctx.iteration})", + agent_name=ctx.node.name, + ) + + # Agent-proposed checkpoint: LLM sets a bool flag in its output schema + if self._req_state_key and ctx.node.name == self._req_state_key: + raw = ctx.state.data.get(self._req_state_key, {}) + if isinstance(raw, str): + try: + raw = json.loads(raw) + except (json.JSONDecodeError, TypeError): + raw = {} + if isinstance(raw, dict) and raw.get(self._req_field, False): + checkpoint_id = ( + f"{session.id}-{ctx.node.name}-{ctx.iteration}-requested" + ) + await self.service.create_checkpoint( + session=session, + checkpoint_id=checkpoint_id, + description=( + f"Agent-requested checkpoint at {ctx.node.name}" + f" (iteration {ctx.iteration})" + ), + agent_name=ctx.node.name, + ) + + return None diff --git a/src/google/adk/agents/graph/evaluation_metrics.py b/src/google/adk/agents/graph/evaluation_metrics.py new file mode 100644 index 0000000000..ffa9b8c6d5 --- /dev/null +++ b/src/google/adk/agents/graph/evaluation_metrics.py @@ -0,0 +1,382 @@ +"""Custom evaluation metrics for GraphAgent workflows. + +These metrics enable evaluating graph execution paths, state transitions, +and workflow behavior in ADK's evaluation framework. + +Example usage: + ```python + from google.adk.evaluation import EvalMetric + from google.adk.agents.graph.evaluation_metrics import ( + graph_path_match, + state_contains_keys, + node_execution_count, + ) + + # In eval config: + metrics = [ + EvalMetric( + name="graph_path", + custom_function_path="google.adk.agents.graph.evaluation_metrics.graph_path_match", + ), + ] + ``` +""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from ...evaluation.eval_case import ConversationScenario +from ...evaluation.eval_case import Invocation +from ...evaluation.eval_metrics import EvalMetric +from ...evaluation.eval_metrics import EvalStatus +from ...evaluation.evaluator import EvaluationResult +from ...evaluation.evaluator import PerInvocationResult + + +def graph_path_match( + eval_metric: EvalMetric, + actual_invocations: List[Invocation], + expected_invocations: Optional[List[Invocation]], + conversation_scenario: Optional[ConversationScenario] = None, +) -> EvaluationResult: + """Evaluate if graph execution path matches expected path. + + Checks if the sequence of nodes executed matches the expected path. + Looks for 'graph_path' in session state. + + Args: + eval_metric: Metric configuration + actual_invocations: Actual agent invocations + expected_invocations: Expected invocations (unused, path comes from scenario) + conversation_scenario: Test scenario with expected path in metadata + + Returns: + EvaluationResult with scores based on path matching + + Expected format in conversation_scenario.metadata: + { + "expected_graph_path": ["node1", "node2", "node3"] + } + """ + results = [] + overall_score = 0.0 + overall_status = EvalStatus.PASSED + + # Get expected path from eval_metric custom fields (for testing) + # In production, would come from scenario or expected_invocations + expected_path = getattr(eval_metric, "expected_graph_path", None) + + for actual_inv in actual_invocations: + # Extract actual path from intermediate_data (production) + # or from eval_metric custom fields (for testing) + actual_path = getattr(eval_metric, "actual_graph_path", None) + + if actual_path is None and actual_inv.intermediate_data: + # Extract from InvocationEvents + from ...evaluation.eval_case import InvocationEvents + + if isinstance(actual_inv.intermediate_data, InvocationEvents): + # Parse graph metadata from intermediate events + # Get the LAST/latest metadata event (final graph state) + for event in reversed(actual_inv.intermediate_data.invocation_events): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text and "[GraphMetadata]" in part.text: + # Extract metadata from text + import ast + + try: + # Parse the dict from text like "[GraphMetadata] {'graph_path': [...]}" + metadata_str = part.text.split("[GraphMetadata]", 1)[ + 1 + ].strip() + metadata = ast.literal_eval(metadata_str) + actual_path = metadata.get("graph_path") + if actual_path: + break + except Exception: + continue + if actual_path: + break + + # Compute score + score = 0.0 + status = EvalStatus.NOT_EVALUATED + + if expected_path is None: + # No expected path specified + status = EvalStatus.NOT_EVALUATED + elif actual_path is None: + # No actual path found + status = EvalStatus.FAILED + score = 0.0 + elif actual_path == expected_path: + # Exact match + status = EvalStatus.PASSED + score = 1.0 + else: + # Partial match - score based on how many nodes match + matched = sum(1 for a, e in zip(actual_path, expected_path) if a == e) + max_len = max(len(actual_path), len(expected_path)) + score = matched / max_len if max_len > 0 else 0.0 + status = EvalStatus.FAILED if score < 0.5 else EvalStatus.PASSED + + results.append( + PerInvocationResult( + actual_invocation=actual_inv, + score=score, + eval_status=status, + ) + ) + + if status == EvalStatus.FAILED: + overall_status = EvalStatus.FAILED + + overall_score += score + + # Average score across invocations + if results: + overall_score /= len(results) + + # If all results are NOT_EVALUATED, set overall status to NOT_EVALUATED + if results and all( + r.eval_status == EvalStatus.NOT_EVALUATED for r in results + ): + overall_status = EvalStatus.NOT_EVALUATED + + return EvaluationResult( + overall_score=overall_score, + overall_eval_status=overall_status, + per_invocation_results=results, + ) + + +def state_contains_keys( + eval_metric: EvalMetric, + actual_invocations: List[Invocation], + expected_invocations: Optional[List[Invocation]], + conversation_scenario: Optional[ConversationScenario] = None, +) -> EvaluationResult: + """Evaluate if final state contains expected keys. + + Checks if session state contains all expected keys with correct values. + + Args: + eval_metric: Metric configuration + actual_invocations: Actual agent invocations + expected_invocations: Expected invocations (unused) + conversation_scenario: Test scenario with expected state in metadata + + Returns: + EvaluationResult with scores based on state matching + + Expected format in conversation_scenario.metadata: + { + "expected_state": {"key1": "value1", "key2": 42}, + "actual_state": {"key1": "value1", "key2": 42} # For testing + } + """ + results = [] + overall_score = 0.0 + overall_status = EvalStatus.PASSED + + # Get expected state from eval_metric custom fields (for testing) + expected_state = getattr(eval_metric, "expected_state", None) + + for actual_inv in actual_invocations: + # Extract actual state from eval_metric custom fields (for testing) + # or from intermediate_data (production) + actual_state = getattr(eval_metric, "actual_state", None) + + if actual_state is None and actual_inv.intermediate_data: + # Extract from InvocationEvents + from ...evaluation.eval_case import InvocationEvents + + if isinstance(actual_inv.intermediate_data, InvocationEvents): + # Parse graph_state from metadata events + # Get the LAST/latest metadata event (final state) + for event in reversed(actual_inv.intermediate_data.invocation_events): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text and "[GraphMetadata]" in part.text: + import ast + + try: + metadata_str = part.text.split("[GraphMetadata]", 1)[ + 1 + ].strip() + metadata = ast.literal_eval(metadata_str) + actual_state = metadata.get("graph_state") + if actual_state: + break + except Exception: + continue + if actual_state: + break + + # Compute score + score = 0.0 + status = EvalStatus.NOT_EVALUATED + + if expected_state is None: + status = EvalStatus.NOT_EVALUATED + elif actual_state is None: + status = EvalStatus.FAILED + score = 0.0 + else: + # Check each expected key + total_keys = len(expected_state) + matched_keys = 0 + + for key, expected_value in expected_state.items(): + if key in actual_state and actual_state[key] == expected_value: + matched_keys += 1 + + score = matched_keys / total_keys if total_keys > 0 else 0.0 + status = EvalStatus.PASSED if score >= 1.0 else EvalStatus.FAILED + + results.append( + PerInvocationResult( + actual_invocation=actual_inv, + score=score, + eval_status=status, + ) + ) + + if status == EvalStatus.FAILED: + overall_status = EvalStatus.FAILED + + overall_score += score + + # Average score across invocations + if results: + overall_score /= len(results) + + # If all results are NOT_EVALUATED, set overall status to NOT_EVALUATED + if results and all( + r.eval_status == EvalStatus.NOT_EVALUATED for r in results + ): + overall_status = EvalStatus.NOT_EVALUATED + + return EvaluationResult( + overall_score=overall_score, + overall_eval_status=overall_status, + per_invocation_results=results, + ) + + +def node_execution_count( + eval_metric: EvalMetric, + actual_invocations: List[Invocation], + expected_invocations: Optional[List[Invocation]], + conversation_scenario: Optional[ConversationScenario] = None, +) -> EvaluationResult: + """Evaluate if nodes executed expected number of times. + + Checks node_invocations tracking in session state. + + Args: + eval_metric: Metric configuration + actual_invocations: Actual agent invocations + expected_invocations: Expected invocations (unused) + conversation_scenario: Test scenario with expected counts in metadata + + Returns: + EvaluationResult with scores based on execution counts + + Expected format in conversation_scenario.metadata: + { + "expected_node_counts": {"node1": 1, "node2": 3}, + "actual_node_counts": {"node1": 1, "node2": 3} # For testing + } + """ + results = [] + overall_score = 0.0 + overall_status = EvalStatus.PASSED + + # Get expected counts from eval_metric custom fields (for testing) + expected_counts = getattr(eval_metric, "expected_node_counts", None) + + for actual_inv in actual_invocations: + # Extract actual counts from eval_metric custom fields (for testing) + # In production, would come from intermediate_data + actual_counts = getattr(eval_metric, "actual_node_counts", {}) + + if not actual_counts and actual_inv.intermediate_data: + # Extract from InvocationEvents + from ...evaluation.eval_case import InvocationEvents + + if isinstance(actual_inv.intermediate_data, InvocationEvents): + # Parse node_invocations from graph metadata events + for event in actual_inv.intermediate_data.invocation_events: + if event.content and event.content.parts: + for part in event.content.parts: + if part.text and "[GraphMetadata]" in part.text: + import ast + + try: + metadata_str = part.text.split("[GraphMetadata]", 1)[ + 1 + ].strip() + metadata = ast.literal_eval(metadata_str) + node_invocs = metadata.get("node_invocations", {}) + if node_invocs: + actual_counts = node_invocs + # Continue to get the latest counts + except Exception: + continue + + # Compute score + score = 0.0 + status = EvalStatus.NOT_EVALUATED + + if expected_counts is None: + status = EvalStatus.NOT_EVALUATED + elif not actual_counts: + status = EvalStatus.FAILED + score = 0.0 + else: + # Check each expected node count + total_nodes = len(expected_counts) + matched_nodes = 0 + + for node_name, expected_count in expected_counts.items(): + actual_count = actual_counts.get(node_name, 0) + if actual_count == expected_count: + matched_nodes += 1 + + score = matched_nodes / total_nodes if total_nodes > 0 else 0.0 + status = EvalStatus.PASSED if score >= 1.0 else EvalStatus.FAILED + + results.append( + PerInvocationResult( + actual_invocation=actual_inv, + score=score, + eval_status=status, + ) + ) + + if status == EvalStatus.FAILED: + overall_status = EvalStatus.FAILED + + overall_score += score + + # Average score across invocations + if results: + overall_score /= len(results) + + # If all results are NOT_EVALUATED, set overall status to NOT_EVALUATED + if results and all( + r.eval_status == EvalStatus.NOT_EVALUATED for r in results + ): + overall_status = EvalStatus.NOT_EVALUATED + + return EvaluationResult( + overall_score=overall_score, + overall_eval_status=overall_status, + per_invocation_results=results, + ) diff --git a/src/google/adk/agents/graph/graph_agent.py b/src/google/adk/agents/graph/graph_agent.py new file mode 100644 index 0000000000..b6babfedf2 --- /dev/null +++ b/src/google/adk/agents/graph/graph_agent.py @@ -0,0 +1,2097 @@ +"""Graph-based workflow orchestration for ADK. + +GraphAgent is ADK's fourth workflow agent type (alongside Sequential, Loop, Parallel), +enabling directed graph-based orchestration with conditional routing and complex branching. + +GraphAgent enables workflow creation using directed graphs where: +- Nodes are agents or functions +- Edges define allowed transitions with optional conditions +- State flows through the graph with configurable reducers +- Full checkpointing support via CheckpointService integration + +Key features: +- Directed graph workflows with conditional routing +- State management with custom reducers (OVERWRITE, APPEND, SUM, CUSTOM) +- Always-on observability: node lifecycle events for every execution +- Human-in-the-loop interrupts via InterruptService (retrospective feedback) +- CheckpointService integration for checkpoint/resume +- DatabaseSessionService support for persistence +- Cyclic execution with max_iterations +- Event-based state persistence (ADK-native) + +Inspired by adk-graph (Rust) and LangGraph patterns. + +Checkpointing Integration: + For checkpoint/resume functionality, use CheckpointService with CheckpointCallback: + + ```python + from google.adk.agents.graph import GraphAgent + from google.adk.checkpoints import CheckpointService, CheckpointCallback + from google.adk.sessions import InMemorySessionService + + # Create services + session_service = InMemorySessionService() + checkpoint_service = CheckpointService(session_service) + + # Create graph with checkpoint callback + graph = GraphAgent(name="workflow", checkpointing=True) + graph.add_node(...) + graph.set_callbacks([ + CheckpointCallback(checkpoint_service, checkpoint_after=True) + ]) + + # Checkpoints are created automatically after each node + # Use checkpoint_service to list/delete/export/import checkpoints + ``` +""" + +from __future__ import annotations + +import ast +import asyncio +import json +import logging +import time +from typing import Any +from typing import AsyncGenerator +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING + +from google.genai import types +from pydantic import ConfigDict +from pydantic import Field +from typing_extensions import override + +from ...events.event import Event +from ...events.event_actions import EventActions +from ...telemetry import graph_tracing +from ...telemetry.tracing import tracer +from ...utils.feature_decorator import experimental +from ..base_agent import BaseAgent +from ..invocation_context import InvocationContext +from ..llm_agent import LlmAgent +from .callbacks import EdgeCallback +from .callbacks import NodeCallback +from .graph_agent_state import GraphAgentState +from .graph_edge import EdgeCondition +from .graph_interrupt_handler import GraphInterruptMixin +from .graph_node import GraphNode +from .graph_state import GraphState +from .graph_state import StateReducer +from .graph_telemetry import GraphTelemetryMixin +from .interrupt import InterruptAction +from .interrupt import InterruptConfig +from .interrupt import InterruptMode +from .interrupt_reasoner import InterruptReasoner +from .interrupt_service import InterruptMessage +from .interrupt_service import InterruptService +from .parallel import execute_parallel_group +from .parallel import ParallelNodeGroup + +if TYPE_CHECKING: + from .graph_agent_config import TelemetryConfig + +logger = logging.getLogger("google_adk." + __name__) + +# Keys stored in session.state by GraphAgent's own state_delta events. +# These must be excluded when syncing session.state β†’ state.data to avoid +# circular references (state.data["graph_data"] β†’ state.data) and to keep +# domain data clean of graph-internal bookkeeping. +_GRAPH_INTERNAL_KEYS = frozenset({ + "graph_data", + "graph_checkpoint", + "graph_cancelled", + "graph_cancelled_at_node", + "graph_task_cancelled", + "graph_can_resume", + "graph_iterations", + "graph_path", + "graph_partial_output", + "graph_state", +}) + + +_SAFE_NAMES = frozenset({ + "state", + "data", + "True", + "False", + "None", +}) +_SAFE_METHODS = frozenset({ + "get", + "get_parsed", + "get_str", + "get_dict", +}) +_SAFE_BUILTINS = frozenset({ + "len", + "min", + "max", + "abs", + "bool", + "int", + "float", + "str", + "isinstance", + "type", +}) + + +def _validate_condition_ast(node: ast.AST) -> None: + """Walk AST and reject any unsafe node types. + + Only allows: comparisons, boolean ops, unary not, attribute access, + safe method calls (.get, .get_parsed, .get_str, .get_dict), + constants, and whitelisted names. + + Raises: + ValueError: If an unsafe AST node is encountered. + """ + if isinstance(node, ast.Expression): + _validate_condition_ast(node.body) + elif isinstance(node, ast.BoolOp): + for value in node.values: + _validate_condition_ast(value) + elif isinstance(node, ast.UnaryOp): + if not isinstance(node.op, ast.Not): + raise ValueError(f"Unsafe unary operator: {type(node.op).__name__}") + _validate_condition_ast(node.operand) + elif isinstance(node, ast.Compare): + _validate_condition_ast(node.left) + for comparator in node.comparators: + _validate_condition_ast(comparator) + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Attribute): + if node.func.attr not in _SAFE_METHODS: + raise ValueError(f"Unsafe method call: .{node.func.attr}()") + _validate_condition_ast(node.func.value) + elif isinstance(node.func, ast.Name) and node.func.id in _SAFE_BUILTINS: + pass # Safe builtin call + else: + raise ValueError(f"Unsafe call: {ast.dump(node.func)}") + for arg in node.args: + _validate_condition_ast(arg) + for kw in node.keywords: + _validate_condition_ast(kw.value) + elif isinstance(node, ast.Attribute): + # Block dunder attribute access to prevent sandbox escape + # (e.g., state.__class__.__init__.__globals__) + if node.attr.startswith("_"): + raise ValueError(f"Unsafe attribute access: '{node.attr}'") + _validate_condition_ast(node.value) + elif isinstance(node, ast.Subscript): + _validate_condition_ast(node.value) + _validate_condition_ast(node.slice) + elif isinstance(node, ast.Name): + if node.id not in _SAFE_NAMES and node.id not in _SAFE_BUILTINS: + raise ValueError(f"Unsafe name: '{node.id}'") + elif isinstance(node, ast.Constant): + pass # string, int, float, bool, None literals are safe + elif isinstance(node, (ast.List, ast.Tuple)): + for elt in node.elts: + _validate_condition_ast(elt) + else: + raise ValueError(f"Unsafe expression node: {type(node).__name__}") + + +def _parse_condition_string(condition_str: str) -> Callable[[GraphState], bool]: + """Parse YAML condition string to a safe callable function. + + Conditions are parsed via AST and validated against a whitelist of + safe operations before compilation. This prevents arbitrary code + execution while supporting common condition expressions. + + Allowed in conditions: + - Names: state, data, metadata, True, False, None + - Methods: .get(), .get_parsed(), .get_str(), .get_dict() + - Builtins: len, min, max, abs, bool, int, float, str, isinstance, type + - Operators: ==, !=, <, >, <=, >=, is, is not, in, not in + - Boolean: and, or, not + - Literals: strings, numbers, booleans, None + + Examples: + "data.get('approved') is True" + "data.get('count', 0) < 10" + "'CONTINUE' in data.get('status', '')" + + Args: + condition_str: Python expression string to evaluate safely + + Returns: + Callable that takes GraphState and returns bool + + Raises: + ValueError: If condition contains unsafe expressions + """ + # Parse and validate at definition time (fail fast) + tree = ast.parse(condition_str, mode="eval") + _validate_condition_ast(tree.body) + code = compile(tree, "", "eval") + + def condition_func(state: GraphState) -> bool: + import builtins as _builtins_mod + + safe_builtins = { + name: getattr(_builtins_mod, name) for name in _SAFE_BUILTINS + } + namespace = { + "state": state, + "data": state.data, + } + try: + result = eval(code, {"__builtins__": safe_builtins}, namespace) # noqa: S307 + return bool(result) + except Exception as e: + logger.error( + f"Condition evaluation failed: '{condition_str}' - {e}", + exc_info=True, + ) + return False + + return condition_func + + +# Sentinel constants for graph boundaries +START = "__start__" +END = "__end__" + + +@experimental +class GraphAgent(GraphInterruptMixin, GraphTelemetryMixin, BaseAgent): # type: ignore[misc] + """Graph-based workflow agent for ADK. + + GraphAgent is the fourth workflow agent type in ADK (alongside SequentialAgent, + LoopAgent, and ParallelAgent), enabling directed graph-based orchestration with + conditional routing, state management, and full checkpointing support. + + Workflow agents control execution flow through deterministic logic rather than LLM + reasoning, providing predictable, reliable, and structured agent orchestration. + + Features: + - Directed graph workflow with nodes (agents/functions) and edges + - Conditional routing based on state predicates + - Cyclic execution support (loops, iterative refinement, ReAct pattern) + - Always-on observability: node lifecycle events emitted for every execution + - Human-in-the-loop interrupts via InterruptService (retrospective feedback) + - CheckpointService integration for state persistence + - DatabaseSessionService support for persistence + - Full ADK event system integration + + Example: + >>> from google.adk.agents.graph import GraphAgent, GraphNode + >>> from google.adk.agents import LlmAgent + >>> from google.adk.checkpoints import CheckpointService, CheckpointCallback + >>> from google.adk.runners import Runner + >>> + >>> # Selective node-level checkpointing (only critical nodes) + >>> checkpoint_service = CheckpointService(session_service) + >>> checkpoint_cb = CheckpointCallback( + ... checkpoint_service, + ... checkpoint_before=False, + ... checkpoint_after=True, + ... checkpoint_nodes={"analyze", "process"}, # only these nodes + ... ) + >>> + >>> graph = GraphAgent( + ... name="workflow", + ... after_node_callback=checkpoint_cb.after_node, + ... ) + >>> graph.add_node(GraphNode(name="analyze", agent=LlmAgent(...))) + >>> graph.add_node(GraphNode(name="process", agent=LlmAgent(...))) + >>> graph.add_edge("analyze", "process") + >>> graph.set_start("analyze") + >>> graph.set_end("process") + >>> + >>> # Run with automatic checkpointing at critical nodes + >>> runner = Runner(app_name="app", agent=graph) + >>> async for event in runner.run_async(...): + ... print(event) + >>> + >>> # Legacy: checkpoint after EVERY node (all-or-nothing) + >>> graph_legacy = GraphAgent(name="workflow", checkpointing=True) + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + nodes: Dict[str, GraphNode] = Field(default_factory=dict) + start_node: Optional[str] = None + end_nodes: List[str] = Field(default_factory=list) + max_iterations: int = 50 # Prevent infinite loops + checkpointing: bool = False + parallel_groups: Dict[str, Any] = Field( + default_factory=dict, + description="Parallel node groups for concurrent execution", + ) + interrupt_service: Optional[InterruptService] = Field( + default=None, + description="Optional InterruptService for dynamic runtime interrupts", + ) + interrupt_config: Optional[InterruptConfig] = Field( + default=None, + description="Configuration for interrupt timing and behavior", + ) + telemetry_config: Optional[Any] = Field( + default=None, + description="Configuration for OpenTelemetry instrumentation", + ) + before_node_callback: Optional[NodeCallback] = Field( + default=None, + description="Callback invoked before each node execution", + ) + after_node_callback: Optional[NodeCallback] = Field( + default=None, + description="Callback invoked after each node execution", + ) + on_edge_condition_callback: Optional[EdgeCallback] = Field( + default=None, + description="Callback invoked when evaluating edge conditions", + ) + + def __init__( + self, + name: str, + description: str = "", + max_iterations: int = 50, + checkpointing: bool = False, + interrupt_service: Optional[InterruptService] = None, + interrupt_config: Optional[InterruptConfig] = None, + telemetry_config: Optional[Any] = None, + before_node_callback: Optional[NodeCallback] = None, + after_node_callback: Optional[NodeCallback] = None, + on_edge_condition_callback: Optional[EdgeCallback] = None, + **kwargs: Any, + ) -> None: + """Initialize GraphAgent. + + Args: + name: Agent name + description: Agent description + max_iterations: Max iterations to prevent infinite loops + checkpointing: Enable state checkpointing after each node + Note: For full checkpoint/resume, use CheckpointCallback + interrupt_service: Optional InterruptService for dynamic runtime interrupts + interrupt_config: Configuration for interrupt timing and behavior + telemetry_config: Configuration for OpenTelemetry instrumentation + before_node_callback: Callback invoked before each node execution + after_node_callback: Callback invoked after each node execution + on_edge_condition_callback: Callback invoked when evaluating edge conditions + """ + super().__init__(name=name, description=description, **kwargs) + self.nodes = {} + self.start_node = None + self.end_nodes = [] + self.max_iterations = max_iterations + self.interrupt_service = interrupt_service + self.interrupt_config = interrupt_config + self.telemetry_config = telemetry_config + self.before_node_callback = before_node_callback + self.after_node_callback = after_node_callback + self.on_edge_condition_callback = on_edge_condition_callback + self.checkpointing = checkpointing + # parallel_groups initialized by Field default_factory + + def add_node( + self, + node: GraphNode | str, + agent: Optional[BaseAgent] = None, + function: Optional[Callable[..., Any]] = None, + **kwargs: Any, + ) -> "GraphAgent": + """Add a node to the graph. + + Supports two usage patterns: + 1. Pass a GraphNode directly + 2. Pass node name and agent/function (convenience method) + + Args: + node: GraphNode to add, or string name (convenience) + agent: Optional agent for convenience pattern + function: Optional function for convenience pattern + **kwargs: Additional GraphNode parameters (output_mapper, state_reducer, etc.) + + Returns: + Self for chaining + + Examples: + >>> # Pattern 1: GraphNode (explicit) + >>> graph.add_node(GraphNode(name="validate", agent=validator)) + + >>> # Pattern 2: Convenience (name + agent) + >>> graph.add_node("validate", agent=validator) + + >>> # Pattern 3: Convenience with kwargs + >>> graph.add_node("validate", agent=validator, state_reducer=StateReducer.OVERWRITE) + """ + if isinstance(node, GraphNode): + # Pattern 1: Direct GraphNode + if agent is not None or function is not None or kwargs: + raise ValueError( + "When passing a GraphNode, do not specify agent, function, or" + " kwargs" + ) + if node.name in self.nodes: + raise ValueError(f"Node '{node.name}' already exists in graph") + self._validate_node_configuration(node) + self.nodes[node.name] = node + elif isinstance(node, str): + # Pattern 2: Convenience (name + agent/function) + if agent is None and function is None: + raise ValueError( + "When passing node name as string, must specify agent or function" + ) + if agent is not None and function is not None: + raise ValueError("Cannot specify both agent and function") + + if node in self.nodes: + raise ValueError(f"Node '{node}' already exists in graph") + + graph_node = GraphNode( + name=node, agent=agent, function=function, **kwargs + ) + self._validate_node_configuration(graph_node) + self.nodes[node] = graph_node + else: + raise TypeError( + f"node must be GraphNode or str, got {type(node).__name__}" + ) + + # Register statically-known agents from the node in sub_agents + graph_node = self.nodes[node.name if isinstance(node, GraphNode) else node] + self._register_node_agents(graph_node) + + return self + + def _get_node_agent(self, node: "GraphNode") -> Optional[BaseAgent]: + """Extract the primary agent from a graph node, including pattern nodes. + + Args: + node: GraphNode (or pattern subclass) to inspect + + Returns: + The agent associated with this node, or None + """ + if node.agent is not None: + return node.agent + from .patterns import DynamicNode + from .patterns import NestedGraphNode + + if isinstance(node, NestedGraphNode): + return node.graph_agent + if isinstance(node, DynamicNode): + return node.fallback_agent + return None + + def _register_node_agents(self, node: "GraphNode") -> None: + """Register statically-known agents from a node into sub_agents. + + Handles regular agent nodes and pattern nodes (NestedGraphNode, + DynamicNode). Skips DynamicParallelGroup (runtime-only agents). + + Args: + node: GraphNode to extract agents from + + Raises: + ValueError: If agent name collides with graph name, duplicates + an existing sub_agent name, or agent already has a parent. + """ + agent = self._get_node_agent(node) + if agent is None: + return + + # Identity check: same instance already registered (e.g., shared across nodes) + if any(agent is sa for sa in self.sub_agents): + return + + # Reject agent name that shadows the graph itself β€” find_agent() would + # return the graph instead of the sub-agent, causing silent bugs. + if agent.name == self.name: + raise ValueError( + f"Node agent name '{agent.name}' collides with GraphAgent name." + " Rename the agent to avoid find_agent() ambiguity." + ) + + # Validate unique name among existing sub_agents + for sa in self.sub_agents: + if sa.name == agent.name: + raise ValueError( + f"Duplicate sub_agent name '{agent.name}'. Another node already" + " registered an agent with this name." + ) + + # Single-parent constraint (matches BaseAgent behavior) + existing_parent = getattr(agent, "parent_agent", None) + if existing_parent is not None: + raise ValueError( + f"Agent '{agent.name}' already has a parent agent" + f" '{existing_parent.name}', cannot add to '{self.name}'" + ) + + agent.parent_agent = self + self.sub_agents.append(agent) + + @override + def find_sub_agent(self, name: str) -> Optional[BaseAgent]: + """Find agent by name, searching sub_agents then graph nodes as fallback. + + Overrides BaseAgent.find_sub_agent to also search graph node agents + that may not be in sub_agents (e.g., agents added to nodes before + registration). NestedGraphNode agents are recursively searched. + + Note: DynamicNode runtime-selected agents (chosen by agent_selector + at execution time) are NOT discoverable via find_sub_agent because + they don't exist until the graph runs. Only DynamicNode.fallback_agent + (if set) is registered and searchable. + + Args: + name: The agent name to find + + Returns: + The matching agent, or None + """ + # Standard sub_agents search first + for sub_agent in self.sub_agents: + if result := sub_agent.find_agent(name): + return result + # Fallback: search graph nodes for agents not in sub_agents + for node in self.nodes.values(): + agent = self._get_node_agent(node) + if agent is not None: + if agent.name == name: + return agent + if result := agent.find_agent(name): + return result + return None + + def _validate_node_configuration(self, node: "GraphNode") -> None: + """Validate node configuration before adding to graph. + + Emits warnings for potential misconfiguration issues. + + Args: + node: GraphNode to validate + """ + # Warn if output_schema present but was auto-defaulted + if isinstance(node.agent, LlmAgent): + if node.agent.output_schema and node.agent.output_key: + # Check if it looks like it was auto-defaulted (matches agent name) + if node.agent.output_key == node.agent.name: + logger.warning( + f"Node '{node.name}': Using auto-defaulted" + f" output_key='{node.agent.output_key}'. To silence this warning," + " explicitly set output_key on the LlmAgent." + ) + + def set_start(self, node_name: str) -> "GraphAgent": + """Set the starting node. + + Args: + node_name: Name of the start node + + Returns: + Self for chaining + + Raises: + ValueError: If node not found in graph + """ + if node_name not in self.nodes: + raise ValueError(f"Node {node_name} not found in graph") + self.start_node = node_name + return self + + def set_end(self, node_name: str) -> "GraphAgent": + """Mark a node as an end node. + + Args: + node_name: Name of the end node + + Returns: + Self for chaining + + Raises: + ValueError: If node not found in graph + """ + if node_name not in self.nodes: + raise ValueError(f"Node {node_name} not found in graph") + if node_name not in self.end_nodes: + self.end_nodes.append(node_name) + return self + + def add_edge( + self, + source_node: str, + target_node: str | EdgeCondition, + condition: Optional[Callable[[GraphState], bool]] = None, + priority: Optional[int] = None, + weight: Optional[float] = None, + ) -> "GraphAgent": + """Add an edge from source node to target node. + + Supports two usage patterns: + 1. Pass EdgeCondition as target_node (explicit) + 2. Pass target node name with optional params (convenience) + + Advanced routing features: + - Priority-based routing (higher priority evaluated first) + - Weighted random selection (probabilistic routing) + - Fallback edges (priority=0 always matches) + + Args: + source_node: Source node name + target_node: Target node name OR EdgeCondition object + condition: Optional condition (ignored if target_node is EdgeCondition) + priority: Optional priority (ignored if target_node is EdgeCondition) + weight: Optional weight (ignored if target_node is EdgeCondition) + + Returns: + Self for chaining + + Raises: + ValueError: If nodes not found in graph + TypeError: If target_node is not str or EdgeCondition + + Examples: + >>> # Pattern 1: EdgeCondition (explicit) + >>> graph.add_edge("validate", EdgeCondition( + ... target_node="process", + ... condition=lambda s: s.data.get("valid"), + ... priority=10 + ... )) + + >>> # Pattern 2: Convenience - simple edge + >>> graph.add_edge("validate", "process") + + >>> # Pattern 2: Convenience - conditional edge + >>> graph.add_edge("validate", "process", condition=lambda s: s.data.get("valid")) + + >>> # Pattern 2: Convenience - priority-based routing + >>> graph.add_edge("check", "critical", condition=lambda s: s.data["score"] > 0.9, priority=10) + >>> graph.add_edge("check", "normal", priority=0) # Fallback + + >>> # Pattern 2: Convenience - weighted random routing + >>> graph.add_edge("start", "server_a", condition=lambda s: True, priority=1, weight=0.5) + >>> graph.add_edge("start", "server_b", condition=lambda s: True, priority=1, weight=0.3) + """ + if source_node not in self.nodes: + raise ValueError(f"Source node {source_node} not found") + + if isinstance(target_node, EdgeCondition): + # Pattern 1: EdgeCondition + if condition is not None or priority is not None or weight is not None: + raise ValueError( + "When passing EdgeCondition, do not specify condition, priority, or" + " weight" + ) + if target_node.target_node not in self.nodes: + raise ValueError(f"Target node {target_node.target_node} not found") + + # Check for duplicate edge + if ( + hasattr(self.nodes[source_node], "edges") + and self.nodes[source_node].edges is not None + ): + for existing_edge in self.nodes[source_node].edges: + if existing_edge.target_node == target_node.target_node: + raise ValueError( + f"Edge from '{source_node}' to '{target_node.target_node}'" + " already exists. Cannot add duplicate edge." + ) + + self.nodes[source_node].edges.append(target_node) + self.nodes[source_node]._sorted_edges_cache = None + + elif isinstance(target_node, str): + # Pattern 2: Convenience + if target_node not in self.nodes: + raise ValueError(f"Target node {target_node} not found") + + # Check for duplicate edge + if ( + hasattr(self.nodes[source_node], "edges") + and self.nodes[source_node].edges is not None + ): + for existing_edge in self.nodes[source_node].edges: + if existing_edge.target_node == target_node: + raise ValueError( + f"Edge from '{source_node}' to '{target_node}' already exists." + " Cannot add duplicate edge." + ) + + # If priority or weight specified, create EdgeCondition + if priority is not None or weight is not None: + edge_condition = EdgeCondition( + target_node=target_node, + condition=condition, + priority=priority if priority is not None else 1, + weight=weight if weight is not None else 1.0, + ) + self.nodes[source_node].edges.append(edge_condition) + self.nodes[source_node]._sorted_edges_cache = None + else: + # Simple edge (no priority/weight) + self.nodes[source_node].add_edge(target_node, condition) + else: + raise TypeError( + "target_node must be str or EdgeCondition, got" + f" {type(target_node).__name__}" + ) + + return self + + def add_parallel_group( + self, + group_id: str, + group: "ParallelNodeGroup", + ) -> "GraphAgent": + """Add a parallel node group for concurrent execution. + + Args: + group_id: Unique identifier for the group + group: ParallelNodeGroup configuration + + Returns: + Self for chaining + + Raises: + ValueError: If nodes in group not found + + Example: + >>> from google.adk.agents.graph import ParallelNodeGroup, JoinStrategy + >>> graph.add_parallel_group( + ... "fetch_group", + ... ParallelNodeGroup( + ... nodes=["fetch_user", "fetch_products"], + ... join_strategy=JoinStrategy.WAIT_ALL + ... ) + ... ) + """ + # Validate all nodes exist + for node_name in group.nodes: + if node_name not in self.nodes: + raise ValueError(f"Node {node_name} not found in graph") + + self.parallel_groups[group_id] = group + return self + + def _find_parallel_group(self, node_name: str) -> Optional[Tuple[str, Any]]: + """Find if a node is part of a parallel group. + + Args: + node_name: Node name to check + + Returns: + Tuple of (group_id, ParallelNodeGroup) if found, None otherwise + """ + for group_id, group in self.parallel_groups.items(): + if node_name in group.nodes: + return (group_id, group) + return None + + # Export methods moved to graph_export.py + # rewind_to_node moved to graph_rewind.py + + # Telemetry methods inherited from GraphTelemetryMixin + + async def _execute_node( + self, + node: GraphNode, + state: GraphState, + ctx: InvocationContext, + effective_config: Optional[TelemetryConfig] = None, + output_holder: Optional[Dict[str, Any]] = None, + iteration: int = 0, + ) -> AsyncGenerator[Event, None]: + """Execute a single node and yield events. + + Output is stored in output_holder["output"] for the caller. + + Args: + node: GraphNode to execute + state: Current graph state + ctx: Invocation context + effective_config: Effective telemetry config (merged parent + own) + output_holder: Mutable dict to store node output + iteration: Current iteration number + + Yields: + Events from node execution + """ + # Determine node type + node_type = "function" if node.function else "agent" + start_time = time.time() + + # Create telemetry span for node execution + with graph_tracing.tracer.start_as_current_span( + f"graph_node {node.name}" + ) as span: + # Add attributes with additional_attributes support + attrs = self._get_telemetry_attributes( + { + graph_tracing.GRAPH_NODE_NAME: node.name, + graph_tracing.GRAPH_NODE_TYPE: node_type, + graph_tracing.GRAPH_NODE_ITERATION: iteration, + graph_tracing.GRAPH_AGENT_NAME: self.name, + }, + effective_config=effective_config, + ) + for key, value in attrs.items(): + span.set_attribute(key, value) + + try: + # Map state to node input with telemetry + mapper_start = time.time() + node_input = node.input_mapper(state) + mapper_latency_ms = (time.time() - mapper_start) * 1000 + + # Record input mapper telemetry (check sampling) + if self._should_sample(effective_config=effective_config): + is_default_mapper = ( + node.input_mapper.__name__ == "_default_input_mapper" + ) + graph_tracing.record_mapper( + node_name=node.name, + mapper_type="input", + agent_name=self.name, + latency_ms=mapper_latency_ms, + is_default=is_default_mapper, + ) + + # Execute node (agent or function) + output = "" + if node.agent: + # Create new context with updated user_content for this node + node_content = types.Content( + role="user", parts=[types.Part(text=node_input)] + ) + node_ctx = ctx.model_copy(update={"user_content": node_content}) + + # Execute ADK agent with updated context + # (BaseAgent will create invoke_agent span automatically) + async for event in node.agent.run_async(node_ctx): + # Extract output from final response + if event.content and event.content.parts: + output = event.content.parts[0].text or "" + yield event + # ADK resumability: pause when long-running tool detected. + # Function nodes don't set "pause" (they run synchronously, + # no tool calls) so output_holder.get("pause") is always + # falsy for them β€” safe to check unconditionally in caller. + if ctx.should_pause_invocation(event): + if output_holder is not None: + output_holder["output"] = output + output_holder["pause"] = True + return + elif node.function: + # Execute custom function (CRITICAL: no automatic span) + if asyncio.iscoroutinefunction(node.function): + output = await node.function(state, ctx) + else: + output = node.function(state, ctx) + else: # pragma: no cover + # Defensive: This should never happen due to GraphNode validation + raise ValueError(f"Node {node.name} has no agent or function") + + # Store output for caller retrieval + if output_holder is not None: + output_holder["output"] = output + + # Record success metrics (check sampling) + latency_ms = (time.time() - start_time) * 1000 + span.set_attribute("graph.node.success", True) + if self._should_sample(effective_config=effective_config): + graph_tracing.record_node_execution( + node_name=node.name, + node_type=node_type, + agent_name=self.name, + latency_ms=latency_ms, + success=True, + ) + + except Exception as e: + # Record failure metrics (check sampling) + latency_ms = (time.time() - start_time) * 1000 + span.set_attribute("graph.node.success", False) + span.set_attribute("graph.node.error", str(e)) + if self._should_sample(effective_config=effective_config): + graph_tracing.record_node_execution( + node_name=node.name, + node_type=node_type, + agent_name=self.name, + latency_ms=latency_ms, + success=False, + ) + raise + + # _check_interrupt_with_telemetry inherited from GraphInterruptMixin + + def _get_next_node_with_telemetry( + self, + current_node: GraphNode, + state: GraphState, + effective_config: Optional[TelemetryConfig] = None, + ) -> Optional[str]: + """Get next node with edge evaluation telemetry. + + Args: + current_node: Current graph node + state: Current graph state + effective_config: Effective telemetry config (merged parent + own) + + Returns: + Name of next node, or None if no edge matches + """ + # Track all condition results for detailed telemetry + condition_results = [] + + # Evaluate each edge with telemetry + for edge in current_node.edges: + start_time = time.time() + + # Create span for edge evaluation + with graph_tracing.tracer.start_as_current_span( + f"edge_condition {edge.target_node}" + ) as span: + # Add attributes with additional_attributes support + attrs = self._get_telemetry_attributes( + { + graph_tracing.GRAPH_EDGE_SOURCE: current_node.name, + graph_tracing.GRAPH_EDGE_TARGET: edge.target_node, + graph_tracing.GRAPH_EDGE_PRIORITY: edge.priority, + graph_tracing.GRAPH_AGENT_NAME: self.name, + }, + effective_config=effective_config, + ) + for key, value in attrs.items(): + span.set_attribute(key, value) + + try: + # Evaluate condition + result = edge.should_route(state) + span.set_attribute( + graph_tracing.GRAPH_EDGE_CONDITION_RESULT, str(result) + ) + + # Track condition result details for debugging + condition_results.append({ + "target_node": edge.target_node, + "condition_matched": result, + "condition_name": getattr(edge.condition, "__name__", ""), + "priority": edge.priority, + }) + + # Record metrics (check sampling) + if self._should_sample(effective_config=effective_config): + latency_ms = (time.time() - start_time) * 1000 + graph_tracing.record_edge_evaluation( + source_node=current_node.name, + target_node=edge.target_node, + agent_name=self.name, + condition_result=result, + latency_ms=latency_ms, + priority=edge.priority, + ) + + except Exception as e: + span.set_attribute("graph.edge.error", str(e)) + span.set_attribute(graph_tracing.GRAPH_EDGE_CONDITION_RESULT, "false") + raise + + # Add detailed condition results to GraphState for debugging + # This helps identify routing issues by showing ALL edge evaluations + if condition_results: + state.data["_debug_edge_evaluations"] = { + "source_node": current_node.name, + "evaluations": condition_results, + "timestamp": time.time(), + } + + # Use original get_next_node for routing logic + selected_node = current_node.get_next_node(state) + + # Log final node selection decision with all context + if selected_node: + # Find which edge was selected (if any) + selected_edge_info = next( + ( + r + for r in condition_results + if r["target_node"] == selected_node and r["condition_matched"] + ), + None, + ) + + # Add node selection to debug info + state.data.setdefault("_debug_node_selections", []).append({ + "from_node": current_node.name, + "to_node": selected_node, + "selected_edge": selected_edge_info, + "num_edges_evaluated": len(condition_results), + "timestamp": time.time(), + }) + + # Log structured selection event + graph_tracing.logger.debug( + f"Node selected: {current_node.name} -> {selected_node}", + extra={ + "source_node": current_node.name, + "selected_node": selected_node, + "condition_name": ( + selected_edge_info["condition_name"] + if selected_edge_info + else None + ), + "priority": ( + selected_edge_info["priority"] if selected_edge_info else None + ), + "edges_evaluated": len(condition_results), + "agent_name": self.name, + }, + ) + + return selected_node + + def _get_resume_state( + self, agent_state: GraphAgentState + ) -> Tuple[Optional[str], int, bool]: + """Get resume point from loaded agent state. + + Mirrors SequentialAgent._get_start_index() pattern. + + Args: + agent_state: Loaded agent state (may have current_node from prior run) + + Returns: + Tuple of (start_node_name, start_iteration, is_resuming) + """ + if agent_state.current_node and agent_state.current_node in self.nodes: + return agent_state.current_node, agent_state.iteration, True + if agent_state.current_node and agent_state.current_node not in self.nodes: + logger.warning( + "Saved node '%s' no longer exists in graph. Restarting from '%s'.", + agent_state.current_node, + self.start_node, + ) + return self.start_node, 0, False + + async def _execute_callback( + self, + callback: Callable[..., Any], + callback_type: str, + current_node: GraphNode, + current_node_name: str, + state: GraphState, + iteration: int, + ctx: InvocationContext, + agent_state: GraphAgentState, + effective_config: Optional["TelemetryConfig"] = None, + output: str = "", + ) -> Optional[Event]: + """Execute a node callback (before_node or after_node) with telemetry. + + Args: + callback: The callback function to execute + callback_type: "before_node" or "after_node" + current_node: The current GraphNode + current_node_name: Name of the current node + state: Current graph state + iteration: Current iteration number + ctx: Invocation context + agent_state: Execution tracking state + effective_config: Effective telemetry config + output: Node output (only for after_node callbacks) + + Returns: + Event from callback, or None + """ + from .callbacks import NodeCallbackContext + + metadata: Dict[str, Any] = { + "agent_path": list(agent_state.agent_path), + "path": list(agent_state.path), + } + if callback_type == "after_node": + metadata["output"] = output + + callback_ctx = NodeCallbackContext( + node=current_node, + state=state, + iteration=iteration, + invocation_context=ctx, + metadata=metadata, + ) + + callback_start_time = time.time() + with graph_tracing.tracer.start_as_current_span( + f"graph_callback {callback_type}" + ) as cb_span: + attrs = self._get_telemetry_attributes( + { + graph_tracing.GRAPH_CALLBACK_TYPE: callback_type, + graph_tracing.GRAPH_AGENT_NAME: self.name, + graph_tracing.GRAPH_NODE_NAME: current_node_name, + }, + effective_config=effective_config, + ) + for key, value in attrs.items(): + cb_span.set_attribute(key, value) + + try: + event = await callback(callback_ctx) + cb_span.set_attribute("graph.callback.success", True) + if self._should_sample(effective_config=effective_config): + callback_latency_ms = (time.time() - callback_start_time) * 1000 + graph_tracing.record_callback_execution( + callback_type=callback_type, + agent_name=self.name, + latency_ms=callback_latency_ms, + success=True, + ) + return event + + except Exception as e: + cb_span.set_attribute("graph.callback.success", False) + cb_span.set_attribute("graph.callback.error", str(e)) + if self._should_sample(effective_config=effective_config): + callback_latency_ms = (time.time() - callback_start_time) * 1000 + graph_tracing.record_callback_execution( + callback_type=callback_type, + agent_name=self.name, + latency_ms=callback_latency_ms, + success=False, + ) + logger.error( + "%s_callback failed for node '%s': %s", + callback_type, + current_node_name, + e, + exc_info=True, + ) + return None + + def _sync_state_and_reduce( + self, + current_node: GraphNode, + current_node_name: str, + state: GraphState, + ctx: InvocationContext, + output: str, + effective_config: Optional["TelemetryConfig"] = None, + agent_state: Optional[GraphAgentState] = None, + ) -> GraphState: + """Sync session state into GraphState and apply output_mapper + reducer. + + Args: + current_node: The current GraphNode + current_node_name: Name of the current node + state: Current graph state (mutated in-place for session sync) + ctx: Invocation context + output: Node output string + effective_config: Effective telemetry config + agent_state: Execution tracking state for output key tracking + + Returns: + Updated GraphState after sync and reduction + """ + # Sync session state into GraphState.data + for _sk, _sv in ctx.session.state.items(): + if not _sk.startswith("_") and _sk not in _GRAPH_INTERNAL_KEYS: + state.data[_sk] = _sv + + # Apply output mapper with reducer + if output: + had_previous_value = current_node.name in state.data + reducer_start = time.time() + + # Snapshot keys before output_mapper to track what gets written + keys_before = set(state.data.keys()) + + prev_state = state + state = current_node.output_mapper(output, state) + if state is None: + state = prev_state + + # Track which keys were written by this node's output_mapper + if agent_state is not None: + keys_after = set(state.data.keys()) + written_keys = list(keys_after - keys_before) + # Also include the node name key if it was overwritten + if ( + current_node_name in state.data + and current_node_name not in written_keys + ): + written_keys.append(current_node_name) + agent_state.output_keys[current_node_name] = written_keys + + reducer_latency_ms = (time.time() - reducer_start) * 1000 + if self._should_sample(effective_config=effective_config): + graph_tracing.record_state_reducer( + node_name=current_node.name, + reducer_type=current_node.reducer.name, + state_key=current_node.name, + agent_name=self.name, + latency_ms=reducer_latency_ms, + had_previous_value=had_previous_value, + ) + is_default_mapper = ( + current_node.output_mapper.__name__ == "_default_output_mapper" + ) + graph_tracing.record_mapper( + node_name=current_node.name, + mapper_type="output", + agent_name=self.name, + latency_ms=reducer_latency_ms, + is_default=is_default_mapper, + ) + + return state + + def _build_cancellation_events( + self, + ctx: InvocationContext, + agent_state: GraphAgentState, + current_node_name: str, + state: GraphState, + *, + state_key: str = "graph_cancelled", + message: str, + iteration: Optional[int] = None, + partial_output: Optional[str] = None, + path: Optional[List[str]] = None, + ) -> List[Event]: + """Build agent-state + cancellation events for graph abort scenarios. + + Consolidates the repeated pattern of saving agent state then yielding + a cancellation event with appropriate state_delta keys. + + Args: + ctx: Invocation context + agent_state: Execution tracking state (saved before cancel) + current_node_name: Node where cancellation occurred + state: Current graph state + state_key: Key for the cancellation flag (e.g. "graph_cancelled", + "graph_task_cancelled") + message: Human-readable cancellation message + iteration: Current iteration (included in state_delta when set) + partial_output: Partial node output (included when set) + path: Execution path (included in state_delta when set) + + Returns: + List of two events: [agent_state_event, cancellation_event] + """ + ctx.set_agent_state(self.name, agent_state=agent_state) + state_event = self._create_agent_state_event(ctx) + + state_delta: Dict[str, Any] = { + state_key: True, + "graph_cancelled_at_node": current_node_name, + "graph_data": state.data, + "graph_can_resume": True, + } + if iteration is not None: + state_delta["graph_iteration"] = iteration + if partial_output is not None: + state_delta["graph_partial_output"] = partial_output + if path is not None: + state_delta["graph_path"] = path + + cancel_event = Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"\u26a0\ufe0f {message}")] + ), + actions=EventActions( + escalate=False, + state_delta=state_delta, + ), + ) + return [state_event, cancel_event] + + async def _execute_parallel_phase( + self, + group_id: str, + parallel_group: "ParallelNodeGroup", + current_node: GraphNode, + current_node_name: str, + state: GraphState, + ctx: InvocationContext, + effective_config: Optional["TelemetryConfig"], + agent_state: GraphAgentState, + executed_parallel_groups: set[str], + result: Dict[str, Any], + ) -> AsyncGenerator[Event, None]: + """Execute a parallel node group phase with telemetry. + + Handles already-executed check, telemetry instrumentation, + parallel execution, group marking, and next-node routing. + Sets result["next"] to the next node name (or None if at end node). + + Args: + group_id: Parallel group identifier + parallel_group: ParallelNodeGroup configuration + current_node: Current GraphNode (for edge routing) + current_node_name: Name of current node + state: Current graph state + ctx: Invocation context + effective_config: Telemetry config + agent_state: Execution tracking state + executed_parallel_groups: Set of already-executed group IDs (mutated) + result: Mutable dict; sets result["next"] to next node name or None + + Yields: + Events from parallel execution + + Raises: + ValueError: If parallel group has no outgoing edges and node is not + an end node + """ + # Check if this group has already been executed + if group_id in executed_parallel_groups: + logger.info( + f"Skipping node '{current_node_name}' - already executed as" + f" part of parallel group '{group_id}'" + ) + next_node_name = self._get_next_node_with_telemetry( + current_node, state, effective_config=effective_config + ) + if next_node_name is None: + if current_node_name in self.end_nodes: + result["next"] = None + return + else: + raise ValueError( + f"Node {current_node_name} has no outgoing edges and is" + " not an end node" + ) + result["next"] = next_node_name + return + + # Execute entire parallel group + logger.info( + f"Executing parallel group '{group_id}' with nodes:" + f" {parallel_group.nodes}" + ) + + parallel_start_time = time.time() + with graph_tracing.tracer.start_as_current_span( + f"parallel_group {group_id}" + ) as pg_span: + attrs = self._get_telemetry_attributes( + { + graph_tracing.GRAPH_PARALLEL_NODE_COUNT: len( + parallel_group.nodes + ), + graph_tracing.GRAPH_PARALLEL_STRATEGY: ( + parallel_group.join_strategy.value + ), + graph_tracing.GRAPH_PARALLEL_WAIT_N: parallel_group.wait_n, + graph_tracing.GRAPH_AGENT_NAME: self.name, + }, + effective_config=effective_config, + ) + for key, value in attrs.items(): + pg_span.set_attribute(key, value) + + completed_count = 0 + async for event in execute_parallel_group( + parallel_group, + self.nodes, + state, + ctx, + self._execute_node, + ): + yield event + if event.author != self.name: + completed_count = min(completed_count + 1, len(parallel_group.nodes)) + + pg_span.set_attribute("graph.parallel.completed_count", completed_count) + if self._should_sample(effective_config=effective_config): + parallel_latency_ms = (time.time() - parallel_start_time) * 1000 + graph_tracing.record_parallel_group_execution( + agent_name=self.name, + node_count=len(parallel_group.nodes), + strategy=parallel_group.join_strategy.value, + latency_ms=parallel_latency_ms, + completed_count=completed_count, + ) + + # Mark group as executed + executed_parallel_groups.add(group_id) + agent_state.executed_parallel_groups = list(executed_parallel_groups) + + # Route to next node + next_node_name = self._get_next_node_with_telemetry( + current_node, state, effective_config=effective_config + ) + if next_node_name is None: + if current_node_name in self.end_nodes: + result["next"] = None + return + else: + raise ValueError( + f"Parallel group '{group_id}' has no outgoing edges and" + f" node '{current_node_name}' is not an end node" + ) + result["next"] = next_node_name + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Core graph execution logic. + + Executes nodes in graph order, following conditional edges, + supporting loops and human-in-the-loop interrupts. + + Args: + ctx: Invocation context + + Yields: + Events from graph execution + + Raises: + ValueError: If start node not set or graph structure invalid + """ + if not self.start_node: + raise ValueError("Start node not set. Call set_start() first.") + + # Register session with InterruptService if enabled + if self.interrupt_service: + self.interrupt_service.register_session(ctx.session.id) + + # Get effective telemetry config for nested graph inheritance + effective_config = self._get_effective_telemetry_config(ctx) + + with tracer.start_as_current_span( + f"graph_agent_execution {self.name}" + ) as span: + span.set_attribute("graph_agent.name", self.name) + span.set_attribute("graph_agent.start_node", self.start_node) + span.set_attribute("graph_agent.max_iterations", self.max_iterations) + try: + # Load execution tracking state (BaseAgentState pattern) + agent_state = ( + self._load_agent_state(ctx, GraphAgentState) or GraphAgentState() + ) + + # Store telemetry config for nested graph inheritance + if effective_config: + agent_state.telemetry_config_dict = effective_config.model_dump() + + # Initialize domain data from session state or user input. + # Exclude graph-internal keys to prevent circular references + # (state.data["graph_data"] β†’ state.data) and keep domain data clean. + domain_data = { + k: v + for k, v in ctx.session.state.items() + if not k.startswith("_") and k not in _GRAPH_INTERNAL_KEYS + } + if domain_data: + state = GraphState(data=domain_data) + else: + # Extract text from Content object + user_text = "" + if ( + hasattr(ctx, "user_content") + and ctx.user_content + and ctx.user_content.parts + ): + user_text = ( + ctx.user_content.parts[0].text + if ctx.user_content.parts[0].text + else "" + ) + state = GraphState(data={"input": user_text}) + + # Track which parallel groups have been executed + executed_parallel_groups = set(agent_state.executed_parallel_groups) + + # ADK resumability: resume from saved node or start fresh. + # + # Design note: SequentialAgent ONLY emits state events when + # ctx.is_resumable is True, because its state events serve only + # resumability. GraphAgent's state events serve multiple consumers + # (rewind, interrupts, telemetry) that are orthogonal to + # resumability. Therefore: + # - Per-iteration state events: always emitted (multi-consumer) + # - Resume skip: first iteration skipped when resuming (already + # persisted before pause, avoids duplicate) + # - end_of_agent: guarded by is_resumable (purely a resumability + # lifecycle signal, has no other consumers) + # - Interrupt/cancellation state saves: always emitted (they + # serve interrupt functionality, not just resumability) + current_node_name, iteration, resuming = self._get_resume_state( + agent_state + ) + pause_invocation = False + + while current_node_name and iteration < self.max_iterations: + iteration += 1 + current_node = self.nodes[current_node_name] + + # Check for immediate cancellation (ESC-like interrupt) + # Allows user to abort execution at any time, not just at pause points + if self.interrupt_service and not self.interrupt_service.is_active( + ctx.session.id + ): + logger.info( + "GraphAgent execution cancelled (immediate interrupt) for" + f" session {ctx.session.id}" + ) + for _ce in self._build_cancellation_events( + ctx, + agent_state, + current_node_name, + state, + message="Execution cancelled by user", + iteration=iteration, + path=list(agent_state.path), + ): + yield _ce + break # Exit immediately but state is saved + + # Track execution path in agent_state + agent_state.path.append(current_node_name) + agent_state.iteration = iteration + agent_state.current_node = current_node_name + agent_state.node_invocations.setdefault(current_node_name, []).append( + ctx.invocation_id + ) + + # ADK resumability: reset sub-agent states on cycle revisit + # (mirrors LoopAgent pattern at loop_agent.py:114) + # O(1) lookup via node_invocations instead of O(N) path.count() + if len(agent_state.node_invocations.get(current_node_name, [])) > 1: + if current_node.agent: + ctx.reset_sub_agent_states(current_node.agent.name) + # Reset parallel group tracking so groups can re-execute + # in cyclic workflows + executed_parallel_groups.clear() + agent_state.executed_parallel_groups = [] + + # Track agent path for nested graph support + if self.name not in agent_state.agent_path: + agent_state.agent_path.append(self.name) + + # Persist execution tracking via agent_state event. + # These events are consumed by rewind, interrupts, and telemetry + # (not just resumability), so they're always emitted. + # Skip only on first iteration when resuming (already persisted). + if not resuming: + ctx.set_agent_state(self.name, agent_state=agent_state) + yield self._create_agent_state_event(ctx) + else: + resuming = False # Only skip first iteration after resume + + # Handle BEFORE-node interrupt (validation timing) + if ( + self._should_interrupt_before(current_node_name) + and self.interrupt_service + ): + _b_events, _b_ctrl = await self._handle_before_node_interrupt( + current_node_name, current_node, state, ctx, agent_state + ) + for _e in _b_events: + yield _e + # Persist agent_state after interrupt handler may have mutated it + ctx.set_agent_state(self.name, agent_state=agent_state) + yield self._create_agent_state_event(ctx) + if _b_ctrl == "break": + break + elif _b_ctrl is not None: + if isinstance(_b_ctrl, tuple): + current_node_name = _b_ctrl[1] + continue + + # Check if current node is part of a parallel group + parallel_group_info = self._find_parallel_group(current_node_name) + if parallel_group_info: + group_id, parallel_group = parallel_group_info + _pg_result: Dict[str, Any] = {} + async for event in self._execute_parallel_phase( + group_id, + parallel_group, + current_node, + current_node_name, + state, + ctx, + effective_config, + agent_state, + executed_parallel_groups, + _pg_result, + ): + yield event + # Fire after_node_callback for parallel trigger node + if self.after_node_callback: + event = await self._execute_callback( + self.after_node_callback, + "after_node", + current_node, + current_node_name, + state, + iteration, + ctx, + agent_state, + effective_config, + ) + if event: + yield event + current_node_name = _pg_result.get("next") + if current_node_name is None: + break + continue + + # Invoke before_node_callback (custom observability) + if self.before_node_callback: + event = await self._execute_callback( + self.before_node_callback, + "before_node", + current_node, + current_node_name, + state, + iteration, + ctx, + agent_state, + effective_config, + ) + if event: + yield event + + # Execute node with immediate cancellation support + # Check cancellation while streaming events from node execution + output_holder: Dict[str, Any] = {"output": ""} + try: + async for event in self._execute_node( + current_node, + state, + ctx, + effective_config, + output_holder=output_holder, + iteration=iteration, + ): + # Check for immediate cancellation DURING node execution + if ( + self.interrupt_service + and not self.interrupt_service.is_active(ctx.session.id) + ): + logger.info( + "GraphAgent execution cancelled (immediate interrupt" + f" during node '{current_node_name}') for session" + f" {ctx.session.id}" + ) + for _ce in self._build_cancellation_events( + ctx, + agent_state, + current_node_name, + state, + message=( + f"Execution cancelled during node '{current_node_name}'" + ), + partial_output=output_holder["output"], + ): + yield _ce + return + yield event + except asyncio.CancelledError: + # Task cancelled externally (e.g., timeout, user abort) + logger.info( + f"GraphAgent task cancelled during node '{current_node_name}'" + f" for session {ctx.session.id}" + ) + for _ce in self._build_cancellation_events( + ctx, + agent_state, + current_node_name, + state, + state_key="graph_task_cancelled", + message=f"Task cancelled during node '{current_node_name}'", + partial_output=output_holder["output"], + ): + yield _ce + raise + + # ADK resumability: check if node execution was paused + if output_holder.get("pause"): + pause_invocation = True + return + + # Sync session state + apply output_mapper/reducer + output = output_holder["output"] + state = self._sync_state_and_reduce( + current_node, + current_node_name, + state, + ctx, + output, + effective_config, + agent_state=agent_state, + ) + + # Emit output_mapper changes as state_delta so domain data + # flows through ADK's event pipeline to session.state. + # This enables downstream LlmAgent nodes to read + # output_mapper results via dynamic instructions. + if output: + delta = {} + for _k, _v in state.data.items(): + if ( + not _k.startswith("_") + and _k not in _GRAPH_INTERNAL_KEYS + and ctx.session.state.get(_k) != _v + ): + delta[_k] = _v + if delta: + yield Event( + author=self.name, + actions=EventActions(state_delta=delta), + ) + + # Invoke after_node_callback (custom observability) + if self.after_node_callback: + event = await self._execute_callback( + self.after_node_callback, + "after_node", + current_node, + current_node_name, + state, + iteration, + ctx, + agent_state, + effective_config, + output=output, + ) + if event: + yield event + + # Emit graph metadata event for evaluation framework + # This will be captured in Invocation.intermediate_data by EvaluationGenerator + # Set partial=True so is_final_response() returns False (making it an intermediate event) + graph_metadata = { + "graph_node": current_node_name, + "graph_iteration": iteration, + "graph_path": list(agent_state.path), + "node_invocations": { + name: len(invocs) + for name, invocs in agent_state.node_invocations.items() + }, + "graph_state": dict(state.data), + } + yield Event( + author=f"{self.name}#metadata", + content=types.Content( + parts=[types.Part(text=f"[GraphMetadata] {graph_metadata}")] + ), + partial=True, # Mark as intermediate event + ) + + # Handle AFTER-node interrupt (retrospective feedback timing) + # This enables retrospective feedback: observe past, steer future + if ( + self._should_interrupt_after(current_node_name) + and self.interrupt_service + ): + _a_events, _a_ctrl = await self._handle_after_node_interrupt( + current_node_name, state, ctx, agent_state + ) + for _e in _a_events: + yield _e + # Persist agent_state after interrupt handler may have mutated it + ctx.set_agent_state(self.name, agent_state=agent_state) + yield self._create_agent_state_event(ctx) + if _a_ctrl == "break": + break + elif _a_ctrl is not None: + if isinstance(_a_ctrl, tuple): + current_node_name = _a_ctrl[1] + continue + + # Checkpointing - yield event with state_delta to persist checkpoint + # Note: For full checkpoint/resume functionality, use CheckpointCallback + if self.checkpointing: + ctx.set_agent_state(self.name, agent_state=agent_state) + yield self._create_agent_state_event(ctx) + yield Event( + author=self.name, + content=types.Content( + parts=[types.Part(text=f"Checkpoint: {current_node_name}")] + ), + actions=EventActions( + state_delta={ + "graph_data": state.data, + "graph_checkpoint": { + "node": current_node_name, + "iteration": iteration, + }, + } + ), + ) + + # Inject transient execution data for edge conditions + state.data["_graph_iteration"] = agent_state.iteration + state.data["_graph_path"] = list(agent_state.path) + state.data["_conditions"] = dict(agent_state.conditions) + + # Get next node via conditional routing + next_node_name = self._get_next_node_with_telemetry( + current_node, state, effective_config=effective_config + ) + + # Clean up transient keys + for _tk in ("_graph_iteration", "_graph_path", "_conditions"): + state.data.pop(_tk, None) + if next_node_name is None: + # No more edges - check if we're at an end node + if current_node_name in self.end_nodes: + break + else: + # Not at an end node and no edges - error + raise ValueError( + f"Node {current_node_name} has no outgoing edges and is not" + " an end node" + ) + + current_node_name = next_node_name + + # Record iteration metrics (check sampling) + if self._should_sample(effective_config=effective_config): + graph_tracing.record_graph_iteration( + agent_name=self.name, + iteration=iteration, + path_length=len(agent_state.path), + ) + + # ADK resumability: skip final response + end_of_agent when paused + if not pause_invocation: + # Final response - yield event with graph metadata + # Include last node's output ONLY if: + # 1. explicit final_output is set, OR + # 2. last node was a function (doesn't yield events, so we need to show output) + # Don't include output for agent nodes (they already yielded their output) + final_output = state.data.get("final_output", "") + + # If no explicit final_output, check if last node was a function + if not final_output and current_node_name: + last_node = self.nodes.get(current_node_name) + if last_node and last_node.function: + # Function node - include its output + final_output = state.data.get(current_node_name, "") + + response_text = f"{final_output}" + + yield Event( + author=self.name, + content=types.Content(parts=[types.Part(text=response_text)]), + actions=EventActions( + state_delta={ + "graph_data": state.data, + "graph_iterations": iteration, + "graph_path": list(agent_state.path), + } + ), + ) + # end_of_agent is guarded by is_resumable because it is purely a + # resumability lifecycle signal (tells the runner "this agent is + # done, don't re-run it on resume"). Unlike per-iteration state + # events which serve rewind/interrupts/telemetry, end_of_agent + # has no other consumers. + if ctx.is_resumable: + ctx.set_agent_state(self.name, end_of_agent=True) + yield self._create_agent_state_event(ctx) + + finally: + # Unregister session from InterruptService and finalize tracing + if self.interrupt_service: + self.interrupt_service.unregister_session(ctx.session.id) + span.set_attribute("graph_agent.completed", True) + + # Interrupt methods inherited from GraphInterruptMixin + + @override + @classmethod + def _parse_config( + cls, + config: Any, # GraphAgentConfig + config_abs_path: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """Parse GraphAgentConfig and return kwargs for GraphAgent constructor. + + Args: + config: GraphAgentConfig instance + config_abs_path: Absolute path to config file + kwargs: Base kwargs from BaseAgent + + Returns: + Updated kwargs with graph-specific configuration + """ + from .graph_agent_config import GraphAgentConfig + + if not isinstance(config, GraphAgentConfig): + return kwargs + + # Basic graph config + if hasattr(config, "start_node") and config.start_node: + kwargs["start_node"] = config.start_node + + if hasattr(config, "max_iterations") and config.max_iterations: + kwargs["max_iterations"] = config.max_iterations + + if hasattr(config, "checkpointing"): + kwargs["checkpointing"] = config.checkpointing + + # Interrupt configuration + if hasattr(config, "interrupt_config") and config.interrupt_config: + from .interrupt import InterruptConfig + from .interrupt import InterruptMode + + interrupt_cfg = config.interrupt_config + if interrupt_cfg.mode: # None = disabled, only process if mode is set + mode = InterruptMode(interrupt_cfg.mode) + kwargs["interrupt_config"] = InterruptConfig(mode=mode) + + # Callbacks + from ..config_agent_utils import resolve_code_reference + + if ( + hasattr(config, "before_node_callback_ref") + and config.before_node_callback_ref + ): + kwargs["before_node_callback"] = resolve_code_reference( + config.before_node_callback_ref + ) + + if ( + hasattr(config, "after_node_callback_ref") + and config.after_node_callback_ref + ): + kwargs["after_node_callback"] = resolve_code_reference( + config.after_node_callback_ref + ) + + if ( + hasattr(config, "on_edge_condition_callback_ref") + and config.on_edge_condition_callback_ref + ): + kwargs["on_edge_condition_callback"] = resolve_code_reference( + config.on_edge_condition_callback_ref + ) + + return kwargs + + @override + @classmethod + def from_config( + cls, + config: Any, # GraphAgentConfig + config_abs_path: str, + ) -> "GraphAgent": + """Creates a GraphAgent from a YAML config. + + This method performs post-construction setup to add nodes, edges, + end nodes, and parallel groups from the config. + + Args: + config: GraphAgentConfig instance + config_abs_path: Absolute path to config file + + Returns: + Configured GraphAgent instance + """ + from ..config_agent_utils import resolve_agent_reference + from ..config_agent_utils import resolve_code_reference + from .graph_agent_config import GraphAgentConfig + + # Create base agent instance + graph_instance = super().from_config(config, config_abs_path) + + # Type assertion: we know this is a GraphAgent because cls is GraphAgent + assert isinstance( + graph_instance, cls + ), "Expected GraphAgent instance from super().from_config()" + graph: GraphAgent = graph_instance # type: ignore[assignment] + + if not isinstance(config, GraphAgentConfig): + return graph + + # Add nodes + if hasattr(config, "nodes") and config.nodes: + for node_config in config.nodes: + # Resolve sub-agents for this node + sub_agents = [] + if node_config.sub_agents: + for agent_ref in node_config.sub_agents: + agent = resolve_agent_reference(agent_ref, config_abs_path) + sub_agents.append(agent) + + # Resolve function ref + function = None + if node_config.function_ref: + function = resolve_code_reference(node_config.function_ref) + + # Create GraphNode + node = GraphNode( + name=node_config.name, + agent=sub_agents[0] if sub_agents else None, + function=function, + ) + graph.add_node(node) + + # Add edges + if hasattr(config, "edges") and config.edges: + from .graph_edge import EdgeCondition + + for edge_config in config.edges: + condition = None + if edge_config.condition: + # Parse string condition to callable + condition = _parse_condition_string(edge_config.condition) + + # Create EdgeCondition with priority and weight support + edge = EdgeCondition( + target_node=edge_config.target_node, + condition=condition, + priority=edge_config.priority, + weight=edge_config.weight, + ) + + # Add edge directly to the node's edges list + if edge_config.source_node in graph.nodes: + graph.nodes[edge_config.source_node].edges.append(edge) + graph.nodes[edge_config.source_node]._sorted_edges_cache = None + else: + raise ValueError( + f"Source node {edge_config.source_node} not found in graph" + ) + + # Set start node + if hasattr(config, "start_node") and config.start_node: + graph.set_start(config.start_node) + + # Set end nodes + if hasattr(config, "end_nodes") and config.end_nodes: + for end_node in config.end_nodes: + graph.set_end(end_node) + + # Add parallel groups + if hasattr(config, "parallel_groups") and config.parallel_groups: + from .parallel import ErrorPolicy + from .parallel import JoinStrategy + from .parallel import ParallelNodeGroup + + for pg_config in config.parallel_groups: + join_strategy = JoinStrategy(pg_config.join_strategy) + error_policy = ErrorPolicy(pg_config.error_policy) + + parallel_group = ParallelNodeGroup( + nodes=pg_config.nodes, + join_strategy=join_strategy, + error_policy=error_policy, + wait_n=pg_config.wait_n, + ) + # Store parallel group (keyed by first node name for simplicity) + graph.parallel_groups[pg_config.nodes[0]] = parallel_group + + return graph diff --git a/src/google/adk/agents/graph/graph_agent_config.py b/src/google/adk/agents/graph/graph_agent_config.py new file mode 100644 index 0000000000..3d44d51cf6 --- /dev/null +++ b/src/google/adk/agents/graph/graph_agent_config.py @@ -0,0 +1,307 @@ +"""Config definition for GraphAgent.""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional + +from pydantic import BaseModel # type: ignore[attr-defined] +from pydantic import ConfigDict +from pydantic import Field + +from ...utils.feature_decorator import experimental +from ..base_agent_config import BaseAgentConfig # type: ignore[attr-defined] +from ..common_configs import AgentRefConfig +from ..common_configs import CodeConfig + + +@experimental +class GraphNodeConfig(BaseModel): # type: ignore[misc] + """Configuration for a single node in the graph. + + A node can contain either an agent reference or a function reference, + plus optional mappers and reducers for state management. + """ + + model_config = ConfigDict(extra="forbid") + + name: str = Field(description="Node name") + + # Node can reference an agent (sub_agents) OR a function + sub_agents: Optional[List[AgentRefConfig]] = Field( + default=None, + description="Sub-agents for this node", + ) + + function_ref: Optional[str] = Field( + default=None, + description="Reference to a function (e.g., 'module.function_name')", + ) + + input_mapper_ref: Optional[str] = Field( + default=None, + description="Reference to custom input mapper function", + ) + + output_mapper_ref: Optional[str] = Field( + default=None, + description="Reference to custom output mapper function", + ) + + reducer: str = Field( + default="overwrite", + description="State reducer strategy: overwrite|append|sum|custom", + ) + + custom_reducer_ref: Optional[str] = Field( + default=None, + description="Reference to custom reducer function (if reducer=custom)", + ) + + +@experimental +class GraphEdgeConfig(BaseModel): # type: ignore[misc] + """Configuration for an edge between nodes. + + Edges can have optional conditions for conditional routing. + """ + + model_config = ConfigDict(extra="forbid") + + source_node: str = Field(description="Source node name") + + target_node: str = Field(description="Target node name") + + condition: Optional[str] = Field( + default=None, + description=( + "AST-safe condition expression evaluated against graph state." + " Allowed names: state, data, metadata, True, False, None." + " Allowed methods: .get(), .get_parsed(), .get_str()," + " .get_dict(). Supports comparisons, boolean ops, 'in'." + " Example: \"data.get('approved') is True\"" + ), + ) + + priority: int = Field( + default=1, + description="Edge priority for routing (higher = evaluated first)", + ) + + weight: float = Field( + default=1.0, + description="Edge weight for weighted random routing", + ) + + +@experimental +class InterruptConfigYaml(BaseModel): # type: ignore[misc] + """Configuration for interrupt handling.""" + + model_config = ConfigDict(extra="forbid") + + mode: Optional[Literal["before", "after", "both"]] = Field( + default=None, + description="Interrupt mode (None = disabled, before|after|both)", + ) + + interrupt_service: Optional[CodeConfig] = Field( + default=None, + description="Interrupt service configuration (CodeConfig)", + ) + + +@experimental +class ParallelGroupConfig(BaseModel): # type: ignore[misc] + """Configuration for parallel node execution.""" + + model_config = ConfigDict(extra="forbid") + + nodes: List[str] = Field( + description="List of node names to execute in parallel" + ) + + join_strategy: str = Field( + default="all", + description="Join strategy: all|any|n", + ) + + error_policy: str = Field( + default="fail_fast", + description="Error policy: fail_fast|continue|collect", + ) + + wait_n: int = Field( + default=1, + description="Number of nodes to wait for (when join_strategy=n)", + ) + + +@experimental +class TelemetryConfig(BaseModel): # type: ignore[misc] + """Configuration for GraphAgent telemetry. + + Controls OpenTelemetry instrumentation for graph workflow execution. + """ + + model_config = ConfigDict(extra="forbid") + + enabled: bool = Field( + default=True, + description="Enable/disable all telemetry collection", + ) + + trace_nodes: bool = Field( + default=True, + description="Enable tracing for node executions", + ) + + trace_edges: bool = Field( + default=True, + description="Enable tracing for edge condition evaluations", + ) + + trace_iterations: bool = Field( + default=True, + description="Enable metrics for graph iterations", + ) + + trace_parallel_groups: bool = Field( + default=True, + description="Enable tracing for parallel group executions", + ) + + trace_callbacks: bool = Field( + default=True, + description="Enable tracing for callback executions", + ) + + trace_interrupts: bool = Field( + default=True, + description="Enable tracing for interrupt checks", + ) + + sampling_rate: float = Field( + default=1.0, + ge=0.0, + le=1.0, + description="Sampling rate for telemetry (0.0-1.0, 1.0 = 100%)", + ) + + additional_attributes: Optional[Dict[str, Any]] = Field( + default=None, + description="Additional custom attributes to add to all telemetry", + ) + + +@experimental +class GraphAgentConfig(BaseAgentConfig): # type: ignore[misc] + """The config for the YAML schema of a GraphAgent. + + This config supports defining graph structure, nodes, edges, and + advanced features like interrupts and parallel execution. + + Example YAML: + ```yaml + agent_class: GraphAgent + name: my_graph + description: My graph workflow + start_node: start + end_nodes: + - end + max_iterations: 10 + checkpointing: true + nodes: + - name: start + sub_agents: + - agent1 + - name: middle + sub_agents: + - agent2 + - name: end + sub_agents: + - agent3 + edges: + - source_node: start + target_node: middle + - source_node: middle + target_node: end + ``` + """ + + model_config = ConfigDict(extra="forbid") + + agent_class: str = Field( + default="GraphAgent", + description=( + "The value is used to uniquely identify the GraphAgent class." + ), + ) + + start_node: str = Field(description="Name of the starting node") + + end_nodes: List[str] = Field( + default_factory=list, + description="List of end node names", + ) + + max_iterations: int = Field( + default=20, + description="Maximum iterations for cyclic graphs", + ) + + checkpointing: bool = Field( + default=False, + description="Enable automatic checkpointing", + ) + + checkpoint_service: Optional[CodeConfig] = Field( + default=None, + description="Checkpoint service configuration (CodeConfig)", + ) + + # Graph structure + nodes: List[GraphNodeConfig] = Field( + default_factory=list, + description="List of node configurations", + ) + + edges: List[GraphEdgeConfig] = Field( + default_factory=list, + description="List of edge configurations", + ) + + # Advanced features + interrupt_config: Optional[InterruptConfigYaml] = Field( + default=None, + description="Interrupt configuration", + ) + + telemetry_config: Optional[TelemetryConfig] = Field( + default=None, + description="Telemetry configuration for OpenTelemetry instrumentation", + ) + + parallel_groups: List[ParallelGroupConfig] = Field( + default_factory=list, + description="List of parallel execution group configurations", + ) + + # Callbacks (following ADK CodeConfig pattern) + before_node_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Callbacks executed before each node", + ) + + after_node_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Callbacks executed after each node", + ) + + on_edge_condition_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Callbacks executed when evaluating edge conditions", + ) diff --git a/src/google/adk/agents/graph/graph_agent_state.py b/src/google/adk/agents/graph/graph_agent_state.py new file mode 100644 index 0000000000..af2d6dcdda --- /dev/null +++ b/src/google/adk/agents/graph/graph_agent_state.py @@ -0,0 +1,48 @@ +"""Execution tracking state for GraphAgent. + +Follows ADK's BaseAgentState pattern: persisted via +ctx.agent_states / Event.actions.agent_state. + +Domain data (node outputs) remains in GraphState via state_delta. +""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from pydantic import Field + +from ...utils.feature_decorator import experimental +from ..base_agent import BaseAgentState + + +@experimental +class GraphAgentState(BaseAgentState): # type: ignore[misc] + """Execution tracking state for GraphAgent. + + Serialized via model_dump(mode='json'), restored via model_validate(). + """ + + current_node: str = "" + prev_node: str = "" + iteration: int = 0 + execution_start: float = 0.0 + + path: List[str] = Field(default_factory=list) + node_invocations: Dict[str, List[str]] = Field(default_factory=dict) + conditions: Dict[str, Any] = Field(default_factory=dict) + rerun_guidance: str = "" + + interrupt_history: List[Dict[str, Any]] = Field(default_factory=list) + interrupt_todos: List[Dict[str, Any]] = Field(default_factory=list) + last_interrupt_decision: Optional[Dict[str, Any]] = None + + telemetry_config_dict: Optional[Dict[str, Any]] = None + + output_keys: Dict[str, List[str]] = Field(default_factory=dict) + + agent_path: List[str] = Field(default_factory=list) + executed_parallel_groups: List[str] = Field(default_factory=list) diff --git a/src/google/adk/agents/graph/graph_edge.py b/src/google/adk/agents/graph/graph_edge.py new file mode 100644 index 0000000000..79517ba1e9 --- /dev/null +++ b/src/google/adk/agents/graph/graph_edge.py @@ -0,0 +1,97 @@ +"""Conditional edges for graph routing.""" + +from __future__ import annotations + +from typing import Callable +from typing import Optional + +from .graph_state import GraphState + + +class EdgeCondition: + """Conditional edge that routes based on state. + + Edges connect nodes in the graph and can have optional conditions, + priorities, and weights for advanced routing strategies. + + Example: + ```python + # Unconditional edge (always taken) + edge = EdgeCondition(target_node="next_node") + + # Conditional edge (taken if score > 0.8) + edge = EdgeCondition( + target_node="high_score_handler", + condition=lambda state: state.data.get("score", 0) > 0.8 + ) + + # Priority-based routing (higher priority evaluated first) + edge1 = EdgeCondition( + target_node="critical_path", + condition=lambda state: state.data.get("is_critical", False), + priority=10 # High priority + ) + edge2 = EdgeCondition( + target_node="normal_path", + priority=5 # Lower priority + ) + + # Weighted random selection (among matching edges) + edge1 = EdgeCondition(target_node="path_a", weight=0.7) # 70% chance + edge2 = EdgeCondition(target_node="path_b", weight=0.3) # 30% chance + + # Fallback edge (priority=0 always matches if no other edge matched) + edge_fallback = EdgeCondition( + target_node="default_handler", + priority=0 # Fallback priority + ) + ``` + """ + + def __init__( + self, + target_node: str, + condition: Optional[Callable[[GraphState], bool]] = None, + priority: int = 1, + weight: float = 1.0, + ): + """Initialize edge condition. + + Args: + target_node: Name of the target node + condition: Function that returns True if this edge should be taken. + If None, edge is always taken (unconditional). + priority: Priority for edge evaluation (higher = evaluated first). + Priority 0 is special: treated as fallback (always matches if reached). + Default is 1 (normal priority). + weight: Weight for weighted random selection among matching edges. + Only used when multiple edges match. Higher weight = higher probability. + Default is 1.0. + """ + self.target_node = target_node + self.has_condition = condition is not None + self.condition = condition or (lambda _: True) + self.priority = priority + self.weight = weight + + def should_route(self, state: GraphState) -> bool: + """Check if this edge should be taken given the current state. + + Args: + state: Current graph state + + Returns: + True if edge condition is satisfied, False otherwise + """ + # Priority 0 is fallback - always matches + if self.priority == 0: + return True + return self.condition(state) + + def __repr__(self) -> str: + """String representation for debugging.""" + return ( + f"EdgeCondition(target={self.target_node}, " + f"priority={self.priority}, weight={self.weight}, " + f"has_condition={self.has_condition})" + ) diff --git a/src/google/adk/agents/graph/graph_events.py b/src/google/adk/agents/graph/graph_events.py new file mode 100644 index 0000000000..1a632b0427 --- /dev/null +++ b/src/google/adk/agents/graph/graph_events.py @@ -0,0 +1,90 @@ +"""Typed event streams for GraphAgent execution.""" + +from __future__ import annotations + +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from pydantic import BaseModel +from pydantic import Field + +from .graph_state import GraphState + + +class GraphEventType(str, Enum): + """Types of graph execution events.""" + + NODE_START = "node_start" # Node execution starting + NODE_END = "node_end" # Node execution completed + EDGE_TRAVERSAL = "edge_traversal" # Edge being traversed + CHECKPOINT = "checkpoint" # Checkpoint created + INTERRUPT = "interrupt" # Human-in-the-loop interrupt + ERROR = "error" # Execution error + COMPLETE = "complete" # Graph execution complete + + +class GraphEvent(BaseModel): # type: ignore[misc] + """Typed event for graph execution streaming. + + GraphEvent provides structured events for monitoring and debugging + graph execution. These events can be streamed to track execution progress, + state changes, and control flow. + + Example: + ```python + async for event in graph.run_async_with_events(ctx): + if event.event_type == GraphEventType.NODE_START: + print(f"Starting node: {event.node_name}") + elif event.event_type == GraphEventType.CHECKPOINT: + print(f"Checkpoint at iteration {event.iteration}") + ``` + """ + + event_type: GraphEventType + timestamp: str = Field(description="ISO timestamp of event") + + # Node information + node_name: Optional[str] = None + iteration: Optional[int] = None + + # State information + graph_state: Optional[Dict[str, Any]] = None + state_delta: Optional[Dict[str, Any]] = None + + # Edge information + source_node: Optional[str] = None + target_node: Optional[str] = None + edge_condition_result: Optional[bool] = None + + # Interrupt information + interrupt_mode: Optional[str] = None + interrupt_message: Optional[str] = None + + # Error information + error_message: Optional[str] = None + error_type: Optional[str] = None + + # Checkpoint information + checkpoint_id: Optional[str] = None + + # Additional metadata + metadata: Dict[str, Any] = Field(default_factory=dict) + + +class GraphStreamMode(str, Enum): + """Stream modes for graph execution. + + Different stream modes provide different levels of detail: + - VALUES: Only final node outputs + - UPDATES: State updates after each node + - MESSAGES: All agent messages and events + - DEBUG: All events including edges, checkpoints, interrupts + """ + + VALUES = "values" # Stream final values only + UPDATES = "updates" # Stream state updates + MESSAGES = "messages" # Stream all messages + DEBUG = "debug" # Stream all debug events diff --git a/src/google/adk/agents/graph/graph_export.py b/src/google/adk/agents/graph/graph_export.py new file mode 100644 index 0000000000..5eef97ca22 --- /dev/null +++ b/src/google/adk/agents/graph/graph_export.py @@ -0,0 +1,210 @@ +"""Graph export functions for visualization. + +Standalone functions that export graph structure and execution data +in D3-compatible JSON format. Separated from GraphAgent for +single-responsibility. +""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .graph_agent import GraphAgent + + +def export_graph_structure(graph: GraphAgent) -> Dict[str, Any]: + """Export graph structure in D3-compatible JSON format. + + Args: + graph: GraphAgent instance to export + + Returns: + Dictionary with nodes, links, and metadata suitable for + D3.js or other graph visualization tools. + """ + nodes = [] + links = [] + + for node_name, node in graph.nodes.items(): + node_data = { + "id": node_name, + "type": "agent" if node.agent else "function", + "name": node.name, + } + nodes.append(node_data) + + for node_name, node in graph.nodes.items(): + for edge in node.edges: + link_data = { + "source": node_name, + "target": edge.target_node, + "conditional": edge.has_condition, + } + links.append(link_data) + + metadata = { + "start_node": graph.start_node, + "end_nodes": graph.end_nodes, + "checkpointing": graph.checkpointing, + "max_iterations": graph.max_iterations, + } + + return { + "nodes": nodes, + "links": links, + "metadata": metadata, + "directed": True, + } + + +def export_graph_with_execution( + graph: GraphAgent, + execution_history: Optional[List[Dict[str, Any]]] = None, + state_history: Optional[List[Dict[str, Any]]] = None, + interrupt_markers: Optional[List[Dict[str, Any]]] = None, +) -> Dict[str, Any]: + """Export graph with execution history, state evolution, and interrupts. + + Enhanced D3-compatible format including runtime information for + interactive visualization and replay. + + Args: + graph: GraphAgent instance to export + execution_history: List of executed nodes with timestamps + state_history: List of state snapshots after each node + interrupt_markers: List of interrupt events + + Returns: + Enhanced dictionary with execution data overlaid on graph structure. + """ + base_structure = export_graph_structure(graph) + nodes = base_structure["nodes"] + links = base_structure["links"] + + if execution_history: + node_executions: Dict[str, List[Dict[str, Any]]] = {} + for exec_record in execution_history: + node_name = exec_record["node"] + if node_name not in node_executions: + node_executions[node_name] = [] + node_executions[node_name].append(exec_record) + + for node in nodes: + node_id = node["id"] + if node_id in node_executions: + node["executions"] = node_executions[node_id] + node["execution_count"] = len(node_executions[node_id]) + statuses = [ + e.get("status", "unknown") for e in node_executions[node_id] + ] + node["status_summary"] = { + "success": statuses.count("success"), + "error": statuses.count("error"), + "interrupted": statuses.count("interrupted"), + } + else: + node["executions"] = [] + node["execution_count"] = 0 + + if execution_history: + link_traversals: Dict[Tuple[str, str], int] = {} + for i in range(len(execution_history) - 1): + source = execution_history[i]["node"] + target = execution_history[i + 1]["node"] + link_key = (source, target) + link_traversals[link_key] = link_traversals.get(link_key, 0) + 1 + + for link in links: + link_key = (link["source"], link["target"]) + link["traversals"] = link_traversals.get(link_key, 0) + + if interrupt_markers: + node_interrupts: Dict[str, List[Dict[str, Any]]] = {} + for interrupt in interrupt_markers: + node_name = interrupt.get("node") + if node_name: + if node_name not in node_interrupts: + node_interrupts[node_name] = [] + node_interrupts[node_name].append(interrupt) + + for node in nodes: + node_id = node["id"] + if node_id in node_interrupts: + node["interrupt_markers"] = node_interrupts[node_id] + node["interrupt_count"] = len(node_interrupts[node_id]) + + return { + "nodes": nodes, + "links": links, + "metadata": base_structure["metadata"], + "execution_history": execution_history or [], + "state_history": state_history or [], + "interrupt_markers": interrupt_markers or [], + "directed": True, + "enhanced": True, + } + + +def export_execution_timeline( + execution_history: List[Dict[str, Any]], + state_history: Optional[List[Dict[str, Any]]] = None, +) -> Dict[str, Any]: + """Export execution timeline for temporal visualization. + + Creates a timeline view of graph execution suitable for + Gantt charts, timeline visualizations, or replay UIs. + + Args: + execution_history: List of executed nodes with timestamps + state_history: Optional state snapshots at each step + + Returns: + Dictionary with timeline, total_duration, total_steps, iterations. + """ + if not execution_history: + return { + "timeline": [], + "total_duration": 0, + "total_steps": 0, + "iterations": 0, + } + + timeline = [] + for i, exec_record in enumerate(execution_history): + step_data = { + "step": i, + "node": exec_record["node"], + "timestamp": exec_record.get("timestamp", 0), + "iteration": exec_record.get("iteration", 1), + "status": exec_record.get("status", "unknown"), + } + + if i < len(execution_history) - 1: + next_timestamp = execution_history[i + 1].get("timestamp", 0) + step_data["duration"] = next_timestamp - step_data["timestamp"] + else: + step_data["duration"] = 0 + + if state_history and i < len(state_history): + step_data["state"] = state_history[i].get("state", {}) + + timeline.append(step_data) + + total_duration = 0 + if len(timeline) > 1: + total_duration = timeline[-1]["timestamp"] - timeline[0]["timestamp"] + + max_iteration = max((step["iteration"] for step in timeline), default=0) + + return { + "timeline": timeline, + "total_duration": total_duration, + "total_steps": len(timeline), + "iterations": max_iteration, + } diff --git a/src/google/adk/agents/graph/graph_interrupt_handler.py b/src/google/adk/agents/graph/graph_interrupt_handler.py new file mode 100644 index 0000000000..9a1fe11136 --- /dev/null +++ b/src/google/adk/agents/graph/graph_interrupt_handler.py @@ -0,0 +1,475 @@ +"""Interrupt handling mixin for GraphAgent. + +Extracts interrupt-related methods to keep GraphAgent focused on +core graph execution. The mixin pattern allows potential reuse +by other agent types that need step-level interrupt semantics. +""" + +from __future__ import annotations + +import logging +import time +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING + +from google.genai import types + +from ...events.event import Event +from ...events.event_actions import EventActions +from ...telemetry import graph_tracing +from .graph_state import GraphState +from .interrupt import InterruptAction +from .interrupt import InterruptMode + +if TYPE_CHECKING: + from ..invocation_context import InvocationContext + from .graph_agent_config import TelemetryConfig + from .graph_agent_state import GraphAgentState + from .graph_node import GraphNode + from .interrupt_service import InterruptMessage + +logger = logging.getLogger("google_adk." + __name__) + + +class GraphInterruptMixin: + """Mixin providing interrupt handling for graph-based agents. + + Expects the host class to have: + - self.interrupt_service: Optional[InterruptService] + - self.interrupt_config: Optional[InterruptConfig] + - self.name: str + - self._get_telemetry_attributes() (from AgentTelemetryMixin) + - self._should_sample() (from AgentTelemetryMixin) + - self._get_next_node_with_telemetry() (from GraphAgent) + """ + + # Type stubs for fields provided by the host class + interrupt_service: Any + interrupt_config: Any + name: str + + async def _check_interrupt_with_telemetry( + self, + session_id: str, + mode: str, + effective_config: Optional[TelemetryConfig] = None, + ) -> Optional[Any]: + """Check interrupt with telemetry. + + Args: + session_id: Session identifier + mode: Interrupt mode (before, after, both) + effective_config: Effective telemetry config (merged parent + own) + + Returns: + Interrupt message if any, None otherwise + """ + if not self.interrupt_service: + return None + + with graph_tracing.tracer.start_as_current_span("interrupt_check") as span: + attrs = self._get_telemetry_attributes( # type: ignore[attr-defined] + { + graph_tracing.GRAPH_INTERRUPT_MODE: mode, + graph_tracing.GRAPH_SESSION_ID: session_id, + graph_tracing.GRAPH_AGENT_NAME: self.name, + }, + effective_config=effective_config, + ) + for key, value in attrs.items(): + span.set_attribute(key, value) + + interrupt_message = await self.interrupt_service.check_interrupt( + session_id + ) + + if self._should_sample(effective_config=effective_config): # type: ignore[attr-defined] + graph_tracing.record_interrupt_check( + mode=mode, agent_name=self.name, session_id=session_id + ) + + return interrupt_message + + async def _dispatch_interrupt_action( + self, + action_result: str | Tuple[str, str], + ctx: InvocationContext, + timing: str, + current_node: Optional[GraphNode] = None, + state: Optional[GraphState] = None, + ) -> str | Tuple[str, str] | None: + """Route an interrupt action result to the appropriate control flow. + + Shared logic for both before-node and after-node interrupt handlers. + + Args: + action_result: Result from _process_interrupt_message + ctx: Invocation context + timing: "before" or "after" + current_node: The GraphNode (needed for "skip" in before-node) + state: Current graph state (needed for "skip" routing) + + Returns: + Control flow signal: None, "rerun", "break", ("go_back", target), + or ("skip", next_node). + """ + if isinstance(action_result, tuple): + action, target_node = action_result + if action == "go_back": + return ("go_back", target_node) + elif action_result == "rerun": + return "rerun" + elif action_result == "skip" and timing == "before": + next_node_name = self._get_next_node_with_telemetry(current_node, state) # type: ignore[attr-defined] + return ("skip", next_node_name) if next_node_name else "break" + elif action_result == "pause": + try: + resumed = await self.interrupt_service.wait_if_paused(ctx.session.id) + if not resumed: + logger.info( + "GraphAgent execution cancelled for session %s", + ctx.session.id, + ) + return "break" + except TimeoutError as e: + logger.warning("Interrupt wait timeout: %s", e) + return "break" + + return None + + async def _handle_before_node_interrupt( + self, + current_node_name: str, + current_node: GraphNode, + state: GraphState, + ctx: InvocationContext, + agent_state: GraphAgentState, + ) -> Tuple[List[Event], str | Tuple[str, str] | None]: + """Handle a BEFORE-node interrupt and return events + routing control. + + Args: + current_node_name: Name of the node about to execute. + current_node: The GraphNode about to execute (needed for "skip"). + state: Current graph state. + ctx: Invocation context. + agent_state: Execution tracking state. + + Returns: + Tuple of (events_to_yield, control) where control is: + - None: proceed to normal node execution. + - "rerun": re-run current node (continue the loop). + - "break": exit the main loop immediately. + - ("go_back", target_node): jump to target_node. + - ("skip", next_node | None): skip node, route to next_node. + """ + assert self.interrupt_service is not None + interrupt_message = await self._check_interrupt_with_telemetry( + ctx.session.id, "before" + ) + if not interrupt_message: + return [], None + + action_result = await self._process_interrupt_message( + interrupt_message, state, current_node_name, ctx, agent_state + ) + + should_escalate = ( + action_result == "pause" + if isinstance(action_result, str) + else (isinstance(action_result, tuple) and action_result[0] == "pause") + ) + + event = Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=( + "\U0001f6d1 INTERRUPT (BEFORE):" + f" {interrupt_message.text}" + ) + ) + ] + ), + actions=EventActions( + escalate=should_escalate, + state_delta={ + "interrupt_message": interrupt_message.text, + "interrupt_timing": "before", + "interrupt_node": current_node_name, + }, + ), + ) + + control = await self._dispatch_interrupt_action( + action_result, + ctx, + "before", + current_node=current_node, + state=state, + ) + return [event], control + + async def _handle_after_node_interrupt( + self, + current_node_name: str, + state: GraphState, + ctx: InvocationContext, + agent_state: GraphAgentState, + ) -> Tuple[List[Event], str | Tuple[str, str] | None]: + """Handle an AFTER-node interrupt and return events + routing control. + + Args: + current_node_name: Name of the node that just executed. + state: Current graph state (includes the node's output). + ctx: Invocation context. + agent_state: Execution tracking state. + + Returns: + Tuple of (events_to_yield, control) where control is: + - None: accept results and proceed to next node. + - "rerun": re-run current node. + - "break": exit the main loop. + - ("go_back", target_node): jump to target_node. + """ + assert self.interrupt_service is not None + interrupt_message = await self._check_interrupt_with_telemetry( + ctx.session.id, "after" + ) + if not interrupt_message: + return [], None + + action_result = await self._process_interrupt_message( + interrupt_message, state, current_node_name, ctx, agent_state + ) + + should_escalate = ( + action_result == "pause" + if isinstance(action_result, str) + else (isinstance(action_result, tuple) and action_result[0] == "pause") + ) + + state_delta_dict: Dict[str, Any] = { + "interrupt_message": interrupt_message.text, + "interrupt_timing": "after", + "interrupt_metadata": interrupt_message.metadata, + "interrupt_action": interrupt_message.action, + "interrupt_node": current_node_name, + } + + event = Event( + author=self.name, + content=types.Content( + parts=[ + types.Part( + text=( + "\U0001f6d1 INTERRUPT (AFTER):" + f" {interrupt_message.text}" + ) + ) + ] + ), + actions=EventActions( + escalate=should_escalate, state_delta=state_delta_dict + ), + ) + + control = await self._dispatch_interrupt_action( + action_result, + ctx, + "after", + ) + return [event], control + + async def _process_interrupt_message( + self, + message: InterruptMessage, + state: GraphState, + current_node_name: str, + ctx: InvocationContext, + agent_state: GraphAgentState, + ) -> str | Tuple[str, str]: + """Process interrupt message using LLM reasoner if configured. + + Args: + message: InterruptMessage from human + state: Current graph state + current_node_name: Name of the current node + ctx: Invocation context + agent_state: Execution tracking state + + Returns: + Action string, or tuple (action, target_node) for go_back + """ + agent_state.interrupt_history.append({ + "text": message.text, + "action": message.action, + "metadata": message.metadata or {}, + "timestamp": time.time(), + "node": agent_state.current_node, + "iteration": agent_state.iteration, + }) + + if self.interrupt_config and self.interrupt_config.reasoner: + logger.debug("Using InterruptReasoner to decide action") + action_obj = await self.interrupt_config.reasoner.reason_about_interrupt( + message, state, current_node_name, ctx, agent_state + ) + agent_state.last_interrupt_decision = { + "action": action_obj.action, + "reasoning": action_obj.reasoning, + "parameters": action_obj.parameters, + "node": current_node_name, + "timestamp": time.time(), + } + logger.info( + "InterruptReasoner decided: %s - %s", + action_obj.action, + action_obj.reasoning, + ) + else: + action_obj = InterruptAction( + action=message.action or "continue", + reasoning="Direct action from interrupt message", + parameters=message.metadata or {}, + ) + + return await self._execute_interrupt_action( + action_obj, state, ctx, agent_state + ) + + async def _execute_interrupt_action( + self, + action: InterruptAction, + state: GraphState, + ctx: InvocationContext, + agent_state: GraphAgentState, + ) -> str | Tuple[str, str]: + """Execute interrupt action based on LLM reasoner decision. + + Args: + action: InterruptAction from reasoner + state: Current graph state + ctx: Invocation context + agent_state: Execution tracking state + + Returns: + Action string, or tuple (action, target_node) for go_back + """ + if action.action == "defer": + agent_state.interrupt_todos.append({ + "message": action.parameters.get("message", ""), + "metadata": action.parameters, + "timestamp": time.time(), + "node": agent_state.current_node, + "iteration": agent_state.iteration, + }) + logger.info( + "Deferred interrupt to todos: %s", + action.parameters.get("message", ""), + ) + return "continue" + + elif action.action == "rerun": + if action.parameters.get("guidance"): + agent_state.rerun_guidance = action.parameters["guidance"] + logger.info( + "Rerunning with guidance: %s", + action.parameters["guidance"], + ) + return "rerun" + + elif action.action == "go_back": + steps = action.parameters.get("steps", 1) + current_path = list(agent_state.path) + + if len(current_path) >= steps + 1: + target_node = current_path[-(steps + 1)] + nodes_to_clear = current_path[-steps:] + agent_state.path = current_path[:-steps] + + for node_name in nodes_to_clear: + # Use tracked output keys if available, fall back to node name + tracked_keys = agent_state.output_keys.get(node_name) + if tracked_keys: + for key in tracked_keys: + state.data.pop(key, None) + else: + logger.warning( + "go_back: no tracked output_keys for node '%s', " + "falling back to clearing key '%s'", + node_name, + node_name, + ) + state.data.pop(node_name, None) + + logger.info( + "Going back %d steps to node '%s' (cleared: %s)", + steps, + target_node, + nodes_to_clear, + ) + return ("go_back", target_node) + else: + logger.warning( + "Cannot go back %d steps, only %d nodes in path. Continuing.", + steps, + len(current_path), + ) + return "continue" + + elif action.action == "pause": + return "pause" + + elif action.action == "skip": + logger.info("Skipping current node execution") + return "skip" + + elif action.action == "update_state": + if action.parameters: + for key in action.parameters: + if key.startswith("_") or key.startswith("graph_"): + raise ValueError( + f"Cannot update reserved key '{key}'. " + "Reserved prefixes: '_', 'graph_'" + ) + state.data.update(action.parameters) + logger.info( + "Interrupt updated state: %s", + list(action.parameters.keys()), + ) + return "continue" + + elif action.action == "change_condition": + if action.parameters: + agent_state.conditions.update(action.parameters) + logger.info("Interrupt changed conditions: %s", action.parameters) + return "continue" + + else: # "continue" or unknown + logger.info("Interrupt action: %s", action.action) + return "continue" + + def _should_interrupt_before(self, node_name: str) -> bool: + """Check if should interrupt before this node.""" + if not self.interrupt_config: + return False + mode = self.interrupt_config.mode + nodes = self.interrupt_config.nodes + return mode in (InterruptMode.BEFORE, InterruptMode.BOTH) and ( + nodes is None or node_name in nodes + ) + + def _should_interrupt_after(self, node_name: str) -> bool: + """Check if should interrupt after this node.""" + if not self.interrupt_config: + return False + mode = self.interrupt_config.mode + nodes = self.interrupt_config.nodes + return mode in (InterruptMode.AFTER, InterruptMode.BOTH) and ( + nodes is None or node_name in nodes + ) diff --git a/src/google/adk/agents/graph/graph_node.py b/src/google/adk/agents/graph/graph_node.py new file mode 100644 index 0000000000..996b5b1f5e --- /dev/null +++ b/src/google/adk/agents/graph/graph_node.py @@ -0,0 +1,231 @@ +"""Graph node wrapper for agents and functions.""" + +from __future__ import annotations + +from copy import deepcopy +import logging +import random +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import TYPE_CHECKING + +from ..base_agent import BaseAgent +from .graph_edge import EdgeCondition +from .graph_state import GraphState +from .graph_state import StateReducer + +if TYPE_CHECKING: + from ..llm_agent import LlmAgent + +logger = logging.getLogger(__name__) + + +class GraphNode: + """A node in the graph that can wrap ANY ADK agent or custom function. + + Supports all BaseAgent types: + - LLMAgent: Single agent with tools + - SequentialAgent: Chain of agents executed in sequence + - ParallelAgent: Multiple agents executed concurrently + - LoopAgent: Iterative agent execution + - GraphAgent: Recursive graph workflows (graphs within graphs!) + - Custom agents: Any subclass of BaseAgent + + This enables powerful composition patterns like: + - Nested workflows (GraphAgent containing other GraphAgents) + - Validation pipelines (SequentialAgent as a node) + - Parallel analysis (ParallelAgent as a node) + """ + + def __init__( + self, + name: str, + agent: Optional[BaseAgent] = None, + function: Optional[Callable[..., Any]] = None, + input_mapper: Optional[Callable[[GraphState], str]] = None, + output_mapper: Optional[Callable[[str, GraphState], GraphState]] = None, + reducer: StateReducer = StateReducer.OVERWRITE, + custom_reducer: Optional[Callable[..., Any]] = None, + ): + """Initialize graph node. + + Args: + name: Node name + agent: ANY ADK BaseAgent subclass (LLMAgent, SequentialAgent, ParallelAgent, + LoopAgent, GraphAgent, or custom agents) + function: Custom async function to execute (alternative to agent) + input_mapper: Maps GraphState to agent input + output_mapper: Maps agent output back to GraphState + reducer: Strategy for merging output into state + custom_reducer: Custom reduction function + """ + if agent is None and function is None: + raise ValueError("Either agent or function must be provided") + + self.name = name + self.agent = agent + self.function = function + self.input_mapper = input_mapper or self._default_input_mapper + self.output_mapper = output_mapper or self._default_output_mapper + self.reducer = reducer + self.custom_reducer = custom_reducer + self.edges: List[EdgeCondition] = [] + self._sorted_edges_cache: Optional[List[tuple[int, EdgeCondition]]] = None + + # Auto-default output_key for LlmAgent with output_schema + from ..llm_agent import LlmAgent + + if isinstance(agent, LlmAgent): + if agent.output_schema and not agent.output_key: + logger.info( + f"Node '{name}': Auto-defaulting output_key to agent name" + f" '{agent.name}' (LlmAgent has output_schema but no explicit" + " output_key)" + ) + # Use model_copy to avoid mutating the caller's agent instance + self.agent = agent.model_copy(update={"output_key": agent.name}) + + def add_edge( + self, target_node: str, condition: Optional[Callable[..., Any]] = None + ) -> None: + """Add an edge to another node. + + Args: + target_node: Name of the target node + condition: Optional condition function for conditional routing + """ + self.edges.append(EdgeCondition(target_node, condition)) + self._sorted_edges_cache = None # Invalidate cache + + def _default_input_mapper(self, state: GraphState) -> str: + """Default input mapper: extract 'input' or 'messages' from state. + + Args: + state: Current graph state + + Returns: + Input string for the node + """ + return str(state.data.get("input", state.data.get("messages", ""))) + + def _default_output_mapper( + self, output: str, state: GraphState + ) -> GraphState: + """Default output mapper: store output in state with node name as key. + + Applies the configured reducer strategy to merge the output into state. + + Args: + output: Node execution output + state: Current graph state + + Returns: + New graph state with output merged + """ + new_state = GraphState(data=deepcopy(state.data)) + + if self.reducer == StateReducer.OVERWRITE: + new_state.data[self.name] = output + elif self.reducer == StateReducer.APPEND: + if self.name not in new_state.data: + new_state.data[self.name] = [] + new_state.data[self.name].append(output) + elif self.reducer == StateReducer.SUM: + existing = new_state.data.get(self.name) + if existing is None: + # Infer zero-value from output type: "" for str, 0 for int/float, [] for list + existing = type(output)() + try: + new_state.data[self.name] = existing + output + except TypeError: + raise TypeError( + f"StateReducer.SUM: cannot add {type(existing).__name__} + " + f"{type(output).__name__} for node '{self.name}'. " + "Ensure consistent types or use a custom output_mapper." + ) + elif self.reducer == StateReducer.CUSTOM and self.custom_reducer: + new_state.data[self.name] = self.custom_reducer( + new_state.data.get(self.name), output + ) + + return new_state + + def get_next_node(self, state: GraphState) -> Optional[str]: + """Determine next node based on conditional edges with priority and weight. + + Enhanced routing logic: + 1. Sort edges by priority (highest first), preserving insertion order + 2. Evaluate conditions in priority order + 3. Within same priority, return first match (insertion order) + 4. If edges have different weights, use weighted random selection + 5. Priority 0 edges are fallbacks (always match if no higher priority matched) + + Args: + state: Current graph state + + Returns: + Name of next node, or None if no edge matches + """ + if not self.edges: + return None + + # Use cached sorted edges (sort once, invalidated by add_edge) + if self._sorted_edges_cache is None: + indexed_edges = [(i, e) for i, e in enumerate(self.edges)] + self._sorted_edges_cache = sorted( + indexed_edges, key=lambda x: (-x[1].priority, x[0]) + ) + sorted_edges = self._sorted_edges_cache + + # Group edges by priority + current_priority = sorted_edges[0][1].priority + matching_edges: list[tuple[int, EdgeCondition]] = [] + + for idx, edge in sorted_edges: + # If we've moved to a lower priority and already have matches, stop + if edge.priority < current_priority and matching_edges: + break + + # Update current priority + current_priority = edge.priority + + # Check if edge matches + if edge.should_route(state): + matching_edges.append((idx, edge)) + + # No matching edges + if not matching_edges: + return None + + # Single matching edge - return it + if len(matching_edges) == 1: + return matching_edges[0][1].target_node + + # Multiple matching edges at same priority + # Check if weights are all the same (default behavior: first match) + weights = [e.weight for _, e in matching_edges] + all_same_weight = len(set(weights)) == 1 + + if all_same_weight: + # All weights equal - return first match in insertion order + return matching_edges[0][1].target_node + + # Different weights - use weighted random selection + total_weight = sum(weights) + if total_weight == 0: + # All weights are 0 β€” guard against ZeroDivisionError in weighted + # random selection below. Fall back to first match in insertion order. + return matching_edges[0][1].target_node + + # Weighted random choice + rand_value = random.random() * total_weight + cumulative = 0.0 + for idx, edge in matching_edges: + cumulative += edge.weight + if rand_value <= cumulative: + return edge.target_node + + # Fallback (shouldn't reach here, but safety) + return matching_edges[-1][1].target_node diff --git a/src/google/adk/agents/graph/graph_rewind.py b/src/google/adk/agents/graph/graph_rewind.py new file mode 100644 index 0000000000..ba74baa786 --- /dev/null +++ b/src/google/adk/agents/graph/graph_rewind.py @@ -0,0 +1,92 @@ +"""Graph rewind functionality. + +Standalone function for rewinding graph execution to a specific node. +Integrates with ADK's Runner.rewind_async for temporal navigation. +""" + +from __future__ import annotations + +from typing import Any +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .graph_agent import GraphAgent + + +async def rewind_to_node( + graph: GraphAgent, + session_service: Any, + app_name: str, + user_id: str, + session_id: str, + node_name: str, + invocation_index: int = -1, +) -> None: + """Rewind graph execution to before a specific node execution. + + Enables temporal navigation within graph workflows: + - Retry failed nodes with different inputs + - Explore alternative execution paths + - Debug workflow issues + - Select specific iteration in loops + + Args: + graph: GraphAgent instance + session_service: Session service instance + app_name: Application name + user_id: User ID + session_id: Session ID + node_name: Node to rewind to + invocation_index: Which invocation (-1 for most recent) + + Raises: + ValueError: If node has not been executed yet + ValueError: If invocation_index is out of range + """ + session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + if not session: + raise ValueError(f"Session not found: {session_id}") + + # Extract node_invocations from latest agent_state event + all_node_invocations: dict[str, list[str]] = {} + for event in reversed(session.events): + if ( + event.actions + and event.actions.agent_state + and "node_invocations" in (event.actions.agent_state or {}) + ): + all_node_invocations = event.actions.agent_state["node_invocations"] + break + + node_invocations = all_node_invocations.get(node_name, []) + if not node_invocations: + raise ValueError( + f"Node '{node_name}' has not been executed yet." + f" Available nodes: {list(all_node_invocations.keys())}" + ) + + if invocation_index < -len(node_invocations) or ( + invocation_index >= len(node_invocations) + ): + raise ValueError( + f"Invocation index {invocation_index} out of range. " + f"Node '{node_name}' has" + f" {len(node_invocations)} invocations." + ) + + invocation_id = node_invocations[invocation_index] + + from ...runners import Runner + + runner = Runner( + app_name=app_name, + agent=graph, + session_service=session_service, + ) + await runner.rewind_async( + user_id=user_id, + session_id=session_id, + rewind_before_invocation_id=invocation_id, + ) diff --git a/src/google/adk/agents/graph/graph_state.py b/src/google/adk/agents/graph/graph_state.py new file mode 100644 index 0000000000..62bb873644 --- /dev/null +++ b/src/google/adk/agents/graph/graph_state.py @@ -0,0 +1,118 @@ +"""Graph state management with typed state and reducers.""" + +from __future__ import annotations + +from enum import Enum +from typing import Any +from typing import Dict +from typing import Optional +from typing import Type +from typing import TypeVar + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field + +from .state_utils import parse_state_value +from .state_utils import PydanticJSONEncoder +from .state_utils import state_value_as_dict +from .state_utils import state_value_as_str + +# Re-export for backward compat +__all__ = ["GraphState", "StateReducer", "PydanticJSONEncoder"] + +T = TypeVar("T", bound=BaseModel) + + +class StateReducer(str, Enum): + """State reduction strategies for merging node outputs. + + Defines how node outputs are merged into the graph state: + - OVERWRITE: Replace existing value with new value + - APPEND: Append new value to list (creates list if needed) + - SUM: Accumulate values using + operator (works for strings, numbers, lists) + - CUSTOM: Use custom reducer function + """ + + OVERWRITE = "overwrite" + APPEND = "append" + SUM = "sum" + CUSTOM = "custom" + + +class GraphState(BaseModel): # type: ignore[misc] + """Domain data container for graph execution. + + GraphState holds node outputs and intermediate results as the graph + executes. Execution tracking (iteration, path, etc.) is handled + separately by GraphAgentState. + + Example: + ```python + state = GraphState( + data={"input": "user query", "result": "agent response"}, + ) + ``` + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + data: Dict[str, Any] = Field( + default_factory=dict, description="Node outputs and intermediate results" + ) + + def data_to_json(self, indent: int = 2) -> str: + """Convert state.data to JSON string (automatically handles Pydantic models). + + Args: + indent: JSON indentation level (default: 2) + + Returns: + JSON string of state.data + """ + import json + + return json.dumps(self.data, cls=PydanticJSONEncoder, indent=indent) + + def get_parsed( + self, key: str, schema: Type[T], default: Optional[T] = None + ) -> Optional[T]: + """Get state value with automatic JSON-string parsing. + + Handles both dict and JSON-string representations transparently. + + Args: + key: State data key (usually agent output_key) + schema: Pydantic model to parse into + default: Value to return if key missing or parse fails + + Returns: + Parsed Pydantic model instance or default + """ + return parse_state_value(self.data.get(key), schema, default) + + def get_str(self, key: str, default: str = "") -> str: + """Get state value as string (for non-schema agent outputs). + + Args: + key: State data key + default: Value to return if key missing + + Returns: + String value or default + """ + return state_value_as_str(self.data.get(key), default) + + def get_dict( + self, key: str, default: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Get state value as dict with JSON-string fallback. + + Args: + key: State data key + default: Value to return if key missing or parse fails + + Returns: + Dict value or default + """ + return state_value_as_dict(self.data.get(key), default) diff --git a/src/google/adk/agents/graph/graph_telemetry.py b/src/google/adk/agents/graph/graph_telemetry.py new file mode 100644 index 0000000000..29810b96b0 --- /dev/null +++ b/src/google/adk/agents/graph/graph_telemetry.py @@ -0,0 +1,183 @@ +"""Telemetry mixins for agent observability. + +Two-layer design for reusability: +- AgentTelemetryMixin: Generic telemetry (any agent with telemetry_config) +- GraphTelemetryMixin: Graph-specific trace toggles (nodes, edges, etc.) +""" + +from __future__ import annotations + +import random +from typing import Any +from typing import Dict +from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..invocation_context import InvocationContext + from .graph_agent_config import TelemetryConfig + + +class AgentTelemetryMixin: + """Generic telemetry mixin for any agent with a telemetry_config field. + + Expects the host class to have: + - self.telemetry_config: Optional[TelemetryConfig] + """ + + telemetry_config: Any # Declared for type checking + name: str # Provided by host agent class + + def _is_telemetry_enabled(self) -> bool: + """Check if telemetry is enabled globally.""" + if not self.telemetry_config: + return True # Default to enabled if no config + return bool(self.telemetry_config.enabled) + + def _should_sample( + self, effective_config: Optional[TelemetryConfig] = None + ) -> bool: + """Check if current operation should be sampled. + + Uses random sampling to control telemetry volume. + + Args: + effective_config: Effective telemetry config (merged parent + own). + If None, uses self.telemetry_config. + """ + config = effective_config or self.telemetry_config + if not config: + return True # Default to 100% sampling if no config + return bool(random.random() < config.sampling_rate) + + def _get_telemetry_attributes( + self, + base_attributes: Dict[str, Any], + effective_config: Optional[TelemetryConfig] = None, + ) -> Dict[str, Any]: + """Get telemetry attributes including custom attributes. + + Args: + base_attributes: Base attributes for the telemetry event + effective_config: Effective config. If None, uses self.telemetry_config. + + Returns: + Combined attributes with additional custom attributes + """ + config = effective_config or self.telemetry_config + if not config or not config.additional_attributes: + return base_attributes + + combined = dict(config.additional_attributes) + combined.update(base_attributes) + return combined + + def _get_parent_telemetry_config( + self, ctx: InvocationContext + ) -> Optional[Dict[str, Any]]: + """Get parent telemetry config from agent_states. + + Used for nested agents to inherit telemetry settings from parent. + + Args: + ctx: Invocation context with agent_states + """ + if not ctx.agent_states: + return None + for agent_name, state_dict in ctx.agent_states.items(): + if agent_name != self.name and isinstance(state_dict, dict): + config = state_dict.get("telemetry_config_dict") + if config is not None: + return dict(config) if isinstance(config, dict) else None + return None + + def _get_effective_telemetry_config( + self, ctx: InvocationContext + ) -> Optional[TelemetryConfig]: + """Get effective telemetry config by merging parent and own config. + + Own config takes precedence over parent config. + + Args: + ctx: Invocation context with session state + """ + parent_config_dict = self._get_parent_telemetry_config(ctx) + + if not parent_config_dict: + return self.telemetry_config # type: ignore[no-any-return] + + if not self.telemetry_config: + from .graph_agent_config import TelemetryConfig + + return TelemetryConfig(**parent_config_dict) + + # Merge: own config takes precedence + merged_dict = parent_config_dict.copy() + own_dict = self.telemetry_config.model_dump() + for key, value in own_dict.items(): + if key == "additional_attributes" and value is not None: + parent_attrs = merged_dict.get("additional_attributes") or {} + own_attrs = value or {} + merged_dict["additional_attributes"] = {**parent_attrs, **own_attrs} + elif value is not None: + merged_dict[key] = value + + from .graph_agent_config import TelemetryConfig + + return TelemetryConfig(**merged_dict) + + +class GraphTelemetryMixin(AgentTelemetryMixin): + """Graph-specific telemetry toggles. + + Extends AgentTelemetryMixin with granular trace controls for + graph execution components (nodes, edges, iterations, etc.). + """ + + def _should_trace_nodes(self) -> bool: + """Check if node execution tracing is enabled.""" + if not self._is_telemetry_enabled(): + return False + if not self.telemetry_config: + return True + return bool(self.telemetry_config.trace_nodes) + + def _should_trace_edges(self) -> bool: + """Check if edge evaluation tracing is enabled.""" + if not self._is_telemetry_enabled(): + return False + if not self.telemetry_config: + return True + return bool(self.telemetry_config.trace_edges) + + def _should_trace_iterations(self) -> bool: + """Check if graph iteration metrics are enabled.""" + if not self._is_telemetry_enabled(): + return False + if not self.telemetry_config: + return True + return bool(self.telemetry_config.trace_iterations) + + def _should_trace_parallel_groups(self) -> bool: + """Check if parallel group execution tracing is enabled.""" + if not self._is_telemetry_enabled(): + return False + if not self.telemetry_config: + return True + return bool(self.telemetry_config.trace_parallel_groups) + + def _should_trace_callbacks(self) -> bool: + """Check if callback execution tracing is enabled.""" + if not self._is_telemetry_enabled(): + return False + if not self.telemetry_config: + return True + return bool(self.telemetry_config.trace_callbacks) + + def _should_trace_interrupts(self) -> bool: + """Check if interrupt check tracing is enabled.""" + if not self._is_telemetry_enabled(): + return False + if not self.telemetry_config: + return True + return bool(self.telemetry_config.trace_interrupts) diff --git a/src/google/adk/agents/graph/interrupt.py b/src/google/adk/agents/graph/interrupt.py new file mode 100644 index 0000000000..8bc32e0f7b --- /dev/null +++ b/src/google/adk/agents/graph/interrupt.py @@ -0,0 +1,72 @@ +"""Human-in-the-loop interrupt modes and configuration for graph execution.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + + +class InterruptMode(str, Enum): + """When to interrupt execution for human-in-the-loop. + + Interrupts allow pausing graph execution at specific nodes + to enable human review, approval, or intervention before continuing. + + Example: + ```python + from google.adk.agents.graph import GraphAgent, InterruptMode, InterruptConfig + + graph = GraphAgent( + name="workflow", + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, # Check after node execution + nodes=["critical_node"], # Only these nodes + ) + ) + ``` + """ + + BEFORE = "before" # Interrupt before node execution + AFTER = "after" # Interrupt after node execution + BOTH = "both" # Interrupt both before and after node execution + + +@dataclass +class InterruptConfig: + """Configuration for interrupt behavior in GraphAgent. + + Attributes: + mode: When to check for interrupts (BEFORE, AFTER, or BOTH) + nodes: Specific node names to interrupt, None = all nodes + reasoner: Optional LLM-based interrupt reasoner for intelligent decisions + """ + + mode: InterruptMode = InterruptMode.AFTER + nodes: Optional[List[str]] = None + reasoner: Optional[Any] = None # InterruptReasoner (avoiding circular import) + + +@dataclass +class InterruptAction: + """Action to take after processing an interrupt message. + + Returned by interrupt reasoner to indicate what action to take. + + Attributes: + action: Action type ("continue", "rerun", "go_back", "pause", "defer") + reasoning: LLM's reasoning for this decision (optional) + parameters: Additional parameters for the action (e.g., target_node, guidance) + """ + + action: str + reasoning: str = "" + parameters: Dict[str, Any] = None # type: ignore[assignment] + + def __post_init__(self) -> None: + """Initialize parameters dict if None.""" + if self.parameters is None: + self.parameters = {} diff --git a/src/google/adk/agents/graph/interrupt_reasoner.py b/src/google/adk/agents/graph/interrupt_reasoner.py new file mode 100644 index 0000000000..ff24ae0966 --- /dev/null +++ b/src/google/adk/agents/graph/interrupt_reasoner.py @@ -0,0 +1,268 @@ +"""LLM-based interrupt reasoning for GraphAgent. + +This module provides an LLM agent that intelligently reasons about +interrupt messages and decides what action to take based on context. + +The InterruptReasoner is a2a compatible and can be used as a standard +LlmAgent in the ADK framework. + +Example: + ```python + from google.adk.agents.graph import ( + GraphAgent, + InterruptConfig, + InterruptMode, + ) + from google.adk.agents.graph.interrupt_reasoner import ( + InterruptReasoner, + InterruptReasonerConfig, + ) + + # Create reasoner with custom config + reasoner = InterruptReasoner(InterruptReasonerConfig( + model="gemini-2.0-flash-exp", + available_actions=["continue", "rerun", "go_back", "pause", "defer"], + )) + + # Use in GraphAgent + graph = GraphAgent( + name="my_graph", + interrupt_config=InterruptConfig( + mode=InterruptMode.AFTER, + reasoner=reasoner, + ), + ) + ``` +""" + +from __future__ import annotations + +from dataclasses import dataclass +import json +import logging +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..llm_agent import LlmAgent as LlmAgentType +else: + LlmAgentType = Any + +from pydantic import BaseModel + +from google import genai + +from ..llm_agent import LlmAgent +from .graph_state import GraphState +from .interrupt import InterruptAction +from .interrupt_service import InterruptMessage + +logger = logging.getLogger("google_adk." + __name__) + + +class InterruptDecision(BaseModel): # type: ignore[misc] + """Structured output schema for InterruptReasoner LLM responses. + + Used as output_schema on the LlmAgent to get API-level JSON enforcement. + The LLM returns valid JSON matching this schema without markdown wrapping. + """ + + action: str + reasoning: str = "" + parameters: Optional[Dict[str, Any]] = None + + +@dataclass +class InterruptReasonerConfig: + """Configuration for InterruptReasoner. + + Attributes: + model: LLM model to use for reasoning (default: gemini-2.0-flash-exp) + instruction: System instruction for the reasoner + available_actions: List of available actions the reasoner can choose + custom_actions: Dict of custom action handlers (extensible) + include_state_in_prompt: Whether to include full state in prompt (default: True) + max_state_size: Maximum state size to include in prompt (default: 10000) + """ + + model: str = "gemini-2.0-flash-exp" + instruction: str = ( + "You are an interrupt reasoning agent for a graph-based workflow system." + " Analyze interrupt messages from humans and decide what action to take" + " based on the current execution context, node output, and state." + ) + available_actions: List[str] = None # type: ignore[assignment] + custom_actions: Dict[str, Callable[..., Any]] = None # type: ignore[assignment] + include_state_in_prompt: bool = True + max_state_size: int = 10000 + fallback_action: str = "continue" + + def __post_init__(self) -> None: + """Initialize default values.""" + if self.available_actions is None: + self.available_actions = [ + "continue", + "rerun", + "go_back", + "pause", + "defer", + "skip", + ] + if self.custom_actions is None: + self.custom_actions = {} + + +class InterruptReasoner(LlmAgent): # type: ignore[misc] + """LLM agent that reasons about interrupt messages and decides actions. + + This agent receives interrupt messages, analyzes the execution context, + and uses an LLM to intelligently decide what action to take. + + The reasoner is a2a compatible and can be used as a standard ADK agent. + + Attributes: + config: InterruptReasonerConfig for this reasoner + available_actions: List of actions the reasoner can choose from + custom_actions: Dictionary of custom action handlers + """ + + def __init__( + self, + config: InterruptReasonerConfig, + name: str = "interrupt_reasoner", + **kwargs: Any, + ): + """Initialize InterruptReasoner. + + Args: + config: Configuration for the reasoner + name: Agent name (default: "interrupt_reasoner") + **kwargs: Additional arguments passed to LlmAgent + """ + super().__init__( + name=name, + model=config.model, + instruction=config.instruction, + output_schema=InterruptDecision, + output_key=name, + **kwargs, + ) + # Store in private attributes (Pydantic allows these) + self._config = config + self._available_actions = config.available_actions + self._custom_actions = config.custom_actions + + async def reason_about_interrupt( + self, + message: InterruptMessage, + state: GraphState, + current_node: str, + ctx: Any, # InvocationContext + agent_state: Any = None, # GraphAgentState + ) -> InterruptAction: + """Use LLM to reason about interrupt message and decide action. + + Args: + message: Interrupt message from human + state: Current graph state + current_node: Node that just executed (or is about to execute) + ctx: Invocation context + agent_state: Execution tracking state (GraphAgentState) + + Returns: + InterruptAction with decision (action, reasoning, parameters) + """ + # Build reasoning prompt + prompt = self._build_reasoning_prompt( + message, state, current_node, agent_state + ) + + logger.debug( + f"InterruptReasoner: reasoning about interrupt at node '{current_node}'" + ) + + # Call LLM via self.run_async() + try: + content = genai.types.Content( + role="user", parts=[genai.types.Part(text=prompt)] + ) + node_ctx = ctx.model_copy(update={"user_content": content}) + + response_text = "" + async for event in self.run_async(node_ctx): + if event.content and event.content.parts: + response_text = event.content.parts[0].text or "" + + # output_schema=InterruptDecision guarantees valid JSON from the API + decision = InterruptDecision.model_validate_json(response_text.strip()) + validated_action = ( + decision.action + if decision.action in self._available_actions + else self._config.fallback_action + ) + return InterruptAction( + action=validated_action, + reasoning=decision.reasoning, + parameters=decision.parameters or {}, + ) + + except Exception as e: + logger.error("InterruptReasoner: Error during reasoning: %s", e) + return InterruptAction( + action=self._config.fallback_action, + reasoning=f"Reasoning error: {e}", + parameters={}, + ) + + def _build_reasoning_prompt( + self, + message: InterruptMessage, + state: GraphState, + current_node: str, + agent_state: Any = None, + ) -> str: + """Build reasoning prompt for LLM. + + Args: + message: Interrupt message + state: Current graph state + current_node: Current node name + agent_state: Execution tracking state (GraphAgentState) + + Returns: + Formatted prompt string + """ + # Use type-safe serialization (handles Pydantic models) + state_str = state.data_to_json() + if len(state_str) > self._config.max_state_size: + state_str = state_str[: self._config.max_state_size] + "\n... (truncated)" + + path = list(agent_state.path) if agent_state else [] + iteration = agent_state.iteration if agent_state else "unknown" + + # Build prompt + prompt = f""" +Current Situation: +- Node: {current_node} +- State: {state_str if self._config.include_state_in_prompt else "