Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions strands-py/src/strands/multiagent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
"""

from .base import MultiAgentBase, MultiAgentResult, Status
from .graph import EdgeCondition, EdgeConditionWithContext, GraphBuilder, GraphResult
from .graph import GraphBuilder, GraphResult
from .swarm import Swarm, SwarmResult

__all__ = [
"EdgeCondition",
"EdgeConditionWithContext",
"GraphBuilder",
"GraphResult",
"MultiAgentBase",
Expand Down
136 changes: 11 additions & 125 deletions strands-py/src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@

import asyncio
import copy
import inspect
import json
import logging
import time
from collections.abc import AsyncIterator, Callable, Mapping
from dataclasses import dataclass, field
from typing import Any, Protocol, TypeGuard, cast
from typing import Any, cast

from opentelemetry import trace as trace_api

Expand Down Expand Up @@ -64,42 +62,6 @@
_DEFAULT_GRAPH_ID = "default_graph"


class EdgeConditionWithContext(Protocol):
"""Protocol for edge conditions that receive invocation_state.

This allows conditions to make routing decisions based on runtime context
passed during graph invocation, such as feature flags, user roles, or
environment-specific configuration.

Designed with **kwargs for future extensibility without breaking changes.

Not @runtime_checkable because the expected use case is a function or lambda,
and isinstance() checks cannot structurally distinguish callable signatures.
Dispatch uses _is_context_condition() with inspect.signature() instead.
"""

def __call__(self, state: "GraphState", *, invocation_state: dict[str, Any], **kwargs: Any) -> bool:
"""Evaluate whether the edge should be traversed."""
...


LegacyEdgeCondition = Callable[["GraphState"], bool]
EdgeCondition = LegacyEdgeCondition | EdgeConditionWithContext


def _is_context_condition(condition: EdgeCondition) -> TypeGuard[EdgeConditionWithContext]:
"""Check if a condition function accepts invocation_state parameter.

Uses inspect.signature() for reliable detection, returning a TypeGuard
so mypy can narrow the type at call sites.
"""
try:
sig = inspect.signature(condition)
return "invocation_state" in sig.parameters
except (ValueError, TypeError):
return False


@dataclass
class GraphState:
"""Graph execution state.
Expand Down Expand Up @@ -185,35 +147,17 @@ class GraphEdge:

from_node: "GraphNode"
to_node: "GraphNode"
condition: EdgeCondition | None = None
_is_context_condition_cached: bool | None = field(default=None, init=False, repr=False, compare=False)
condition: Callable[[GraphState], bool] | None = None

def __hash__(self) -> int:
"""Return hash for GraphEdge based on from_node and to_node."""
return hash((self.from_node.node_id, self.to_node.node_id))

def should_traverse(self, state: GraphState, *, invocation_state: dict[str, Any] | None = None) -> bool:
"""Check if this edge should be traversed based on condition.

Args:
state: The current graph execution state.
invocation_state: Runtime context passed during graph invocation.
New-style conditions (EdgeConditionWithContext) receive this parameter.
Legacy conditions (Callable[[GraphState], bool]) are called with state only.
"""
condition = self.condition
if condition is None:
def should_traverse(self, state: GraphState) -> bool:
"""Check if this edge should be traversed based on condition."""
if self.condition is None:
return True
if self._check_is_context_condition(condition):
return condition(state, invocation_state=invocation_state or {})
legacy_condition = cast(LegacyEdgeCondition, condition)
return legacy_condition(state)

def _check_is_context_condition(self, condition: EdgeCondition) -> TypeGuard[EdgeConditionWithContext]:
"""Check and cache whether this edge's condition accepts invocation_state."""
if self._is_context_condition_cached is None:
self._is_context_condition_cached = _is_context_condition(condition)
return self._is_context_condition_cached
return self.condition(state)


@dataclass
Expand Down Expand Up @@ -332,14 +276,9 @@ def add_edge(
self,
from_node: str | GraphNode,
to_node: str | GraphNode,
condition: EdgeCondition | None = None,
condition: Callable[[GraphState], bool] | None = None,
) -> GraphEdge:
"""Add an edge between two nodes with optional condition function.

The condition can be either:
- A legacy callable: Callable[[GraphState], bool] - receives only graph state
- A new-style callable: EdgeConditionWithContext - receives graph state and invocation_state
"""
"""Add an edge between two nodes with optional condition function that receives full GraphState."""

def resolve_node(node: str | GraphNode, node_type: str) -> GraphNode:
if isinstance(node, str):
Expand Down Expand Up @@ -552,7 +491,6 @@ def __init__(

self._resume_next_nodes: list[GraphNode] = []
self._resume_from_session = False
self._current_invocation_state: dict[str, Any] = {}
self.id = id

run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
Expand Down Expand Up @@ -631,10 +569,6 @@ async def stream_async(
if invocation_state is None:
invocation_state = {}

if self.session_manager is not None:
self._validate_invocation_state(invocation_state)
self._current_invocation_state = invocation_state

await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state))

logger.debug("task=<%s> | starting graph execution", task)
Expand Down Expand Up @@ -955,7 +889,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
# Check if at least one incoming edge condition is satisfied
for edge in incoming_edges:
if edge.from_node in completed_batch:
if edge.should_traverse(self.state, invocation_state=self._current_invocation_state):
if edge.should_traverse(self.state):
logger.debug(
"from=<%s>, to=<%s> | edge ready via satisfied condition", edge.from_node.node_id, node.node_id
)
Expand Down Expand Up @@ -1191,7 +1125,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
and edge.from_node in self.state.completed_nodes
and edge.from_node.node_id in self.state.results
):
if edge.should_traverse(self.state, invocation_state=self._current_invocation_state):
if edge.should_traverse(self.state):
dependency_results[edge.from_node.node_id] = self.state.results[edge.from_node.node_id]

if not dependency_results:
Expand Down Expand Up @@ -1252,20 +1186,6 @@ def _build_result(self, interrupts: list[Interrupt]) -> GraphResult:
interrupts=interrupts,
)

@staticmethod
def _validate_invocation_state(invocation_state: dict[str, Any]) -> None:
"""Validate that invocation_state is JSON-serializable.

Raises:
TypeError: If invocation_state contains non-JSON-serializable values.
"""
try:
json.dumps(invocation_state)
except (TypeError, ValueError) as e:
raise TypeError(
f"invocation_state must be JSON-serializable for session persistence: {e}"
) from e

def serialize_state(self) -> dict[str, Any]:
"""Serialize the current graph state to a dictionary."""
compute_nodes = self._compute_ready_nodes_for_resume()
Expand All @@ -1281,7 +1201,6 @@ def serialize_state(self) -> dict[str, Any]:
"next_nodes_to_execute": next_nodes,
"current_task": encode_bytes_values(self.state.task),
"execution_order": [n.node_id for n in self.state.execution_order],
"invocation_state": self._current_invocation_state,
"_internal_state": {
"interrupt_state": self._interrupt_state.to_dict(),
},
Expand All @@ -1304,10 +1223,6 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
internal_state = payload["_internal_state"]
self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"])

invocation_state = payload.get("invocation_state", {})
self._validate_invocation_state(invocation_state)
self._current_invocation_state = invocation_state

if not payload.get("next_nodes_to_execute"):
# Reset all nodes
for node in self.nodes.values():
Expand All @@ -1331,40 +1246,11 @@ def _compute_ready_nodes_for_resume(self) -> list[GraphNode]:
incoming = [e for e in self.edges if e.to_node is node]
if not incoming:
ready_nodes.append(node)
elif self._is_node_ready_for_resume(node, incoming, completed_nodes):
elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming):
ready_nodes.append(node)

return ready_nodes

def _is_node_ready_for_resume(
self,
node: GraphNode,
incoming: list[GraphEdge],
completed_nodes: set[GraphNode],
) -> bool:
"""Check if a node is ready for resume, accounting for conditional edges.

A node is ready if all TRAVERSABLE incoming edges have their source completed.
Edges whose condition evaluates to False are excluded from the check — they
represent paths that were intentionally skipped.

Re-evaluates conditions (rather than caching traversal results) intentionally:
invocation_state may change between invocations, so conditions must reflect
current runtime context. This means condition logic changes between serialize
and resume will also take effect — consistent with the graph being defined in code.
"""
traversable_edges = [
e
for e in incoming
# Short-circuit: skip signature inspection + cache lookup for unconditional edges.
if e.condition is None or e.should_traverse(self.state, invocation_state=self._current_invocation_state)
]

if not traversable_edges:
return False

return all(e.from_node in completed_nodes for e in traversable_edges)

def _from_dict(self, payload: dict[str, Any]) -> None:
self.state.status = Status(payload["status"])
# Hydrate completed nodes & results
Expand Down
Loading
Loading