diff --git a/strands-py/src/strands/multiagent/__init__.py b/strands-py/src/strands/multiagent/__init__.py index 8dd78c38cb..ad99944a8d 100644 --- a/strands-py/src/strands/multiagent/__init__.py +++ b/strands-py/src/strands/multiagent/__init__.py @@ -9,12 +9,10 @@ """ from .base import MultiAgentBase, MultiAgentResult, Status -from .graph import EdgeCondition, EdgeConditionWithContext, GraphBuilder, GraphResult +from .graph import GraphBuilder, GraphResult from .swarm import Swarm, SwarmResult __all__ = [ - "EdgeCondition", - "EdgeConditionWithContext", "GraphBuilder", "GraphResult", "MultiAgentBase", diff --git a/strands-py/src/strands/multiagent/graph.py b/strands-py/src/strands/multiagent/graph.py index 8489d8bc43..146a31563e 100644 --- a/strands-py/src/strands/multiagent/graph.py +++ b/strands-py/src/strands/multiagent/graph.py @@ -16,13 +16,11 @@ import asyncio import copy -import inspect -import json import logging import time from collections.abc import AsyncIterator, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, Protocol, TypeGuard, cast +from typing import Any, cast from opentelemetry import trace as trace_api @@ -64,42 +62,6 @@ _DEFAULT_GRAPH_ID = "default_graph" -class EdgeConditionWithContext(Protocol): - """Protocol for edge conditions that receive invocation_state. - - This allows conditions to make routing decisions based on runtime context - passed during graph invocation, such as feature flags, user roles, or - environment-specific configuration. - - Designed with **kwargs for future extensibility without breaking changes. - - Not @runtime_checkable because the expected use case is a function or lambda, - and isinstance() checks cannot structurally distinguish callable signatures. - Dispatch uses _is_context_condition() with inspect.signature() instead. - """ - - def __call__(self, state: "GraphState", *, invocation_state: dict[str, Any], **kwargs: Any) -> bool: - """Evaluate whether the edge should be traversed.""" - ... - - -LegacyEdgeCondition = Callable[["GraphState"], bool] -EdgeCondition = LegacyEdgeCondition | EdgeConditionWithContext - - -def _is_context_condition(condition: EdgeCondition) -> TypeGuard[EdgeConditionWithContext]: - """Check if a condition function accepts invocation_state parameter. - - Uses inspect.signature() for reliable detection, returning a TypeGuard - so mypy can narrow the type at call sites. - """ - try: - sig = inspect.signature(condition) - return "invocation_state" in sig.parameters - except (ValueError, TypeError): - return False - - @dataclass class GraphState: """Graph execution state. @@ -185,35 +147,17 @@ class GraphEdge: from_node: "GraphNode" to_node: "GraphNode" - condition: EdgeCondition | None = None - _is_context_condition_cached: bool | None = field(default=None, init=False, repr=False, compare=False) + condition: Callable[[GraphState], bool] | None = None def __hash__(self) -> int: """Return hash for GraphEdge based on from_node and to_node.""" return hash((self.from_node.node_id, self.to_node.node_id)) - def should_traverse(self, state: GraphState, *, invocation_state: dict[str, Any] | None = None) -> bool: - """Check if this edge should be traversed based on condition. - - Args: - state: The current graph execution state. - invocation_state: Runtime context passed during graph invocation. - New-style conditions (EdgeConditionWithContext) receive this parameter. - Legacy conditions (Callable[[GraphState], bool]) are called with state only. - """ - condition = self.condition - if condition is None: + def should_traverse(self, state: GraphState) -> bool: + """Check if this edge should be traversed based on condition.""" + if self.condition is None: return True - if self._check_is_context_condition(condition): - return condition(state, invocation_state=invocation_state or {}) - legacy_condition = cast(LegacyEdgeCondition, condition) - return legacy_condition(state) - - def _check_is_context_condition(self, condition: EdgeCondition) -> TypeGuard[EdgeConditionWithContext]: - """Check and cache whether this edge's condition accepts invocation_state.""" - if self._is_context_condition_cached is None: - self._is_context_condition_cached = _is_context_condition(condition) - return self._is_context_condition_cached + return self.condition(state) @dataclass @@ -332,14 +276,9 @@ def add_edge( self, from_node: str | GraphNode, to_node: str | GraphNode, - condition: EdgeCondition | None = None, + condition: Callable[[GraphState], bool] | None = None, ) -> GraphEdge: - """Add an edge between two nodes with optional condition function. - - The condition can be either: - - A legacy callable: Callable[[GraphState], bool] - receives only graph state - - A new-style callable: EdgeConditionWithContext - receives graph state and invocation_state - """ + """Add an edge between two nodes with optional condition function that receives full GraphState.""" def resolve_node(node: str | GraphNode, node_type: str) -> GraphNode: if isinstance(node, str): @@ -552,7 +491,6 @@ def __init__( self._resume_next_nodes: list[GraphNode] = [] self._resume_from_session = False - self._current_invocation_state: dict[str, Any] = {} self.id = id run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) @@ -631,10 +569,6 @@ async def stream_async( if invocation_state is None: invocation_state = {} - if self.session_manager is not None: - self._validate_invocation_state(invocation_state) - self._current_invocation_state = invocation_state - await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state)) logger.debug("task=<%s> | starting graph execution", task) @@ -955,7 +889,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ # Check if at least one incoming edge condition is satisfied for edge in incoming_edges: if edge.from_node in completed_batch: - if edge.should_traverse(self.state, invocation_state=self._current_invocation_state): + if edge.should_traverse(self.state): logger.debug( "from=<%s>, to=<%s> | edge ready via satisfied condition", edge.from_node.node_id, node.node_id ) @@ -1191,7 +1125,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: and edge.from_node in self.state.completed_nodes and edge.from_node.node_id in self.state.results ): - if edge.should_traverse(self.state, invocation_state=self._current_invocation_state): + if edge.should_traverse(self.state): dependency_results[edge.from_node.node_id] = self.state.results[edge.from_node.node_id] if not dependency_results: @@ -1252,20 +1186,6 @@ def _build_result(self, interrupts: list[Interrupt]) -> GraphResult: interrupts=interrupts, ) - @staticmethod - def _validate_invocation_state(invocation_state: dict[str, Any]) -> None: - """Validate that invocation_state is JSON-serializable. - - Raises: - TypeError: If invocation_state contains non-JSON-serializable values. - """ - try: - json.dumps(invocation_state) - except (TypeError, ValueError) as e: - raise TypeError( - f"invocation_state must be JSON-serializable for session persistence: {e}" - ) from e - def serialize_state(self) -> dict[str, Any]: """Serialize the current graph state to a dictionary.""" compute_nodes = self._compute_ready_nodes_for_resume() @@ -1281,7 +1201,6 @@ def serialize_state(self) -> dict[str, Any]: "next_nodes_to_execute": next_nodes, "current_task": encode_bytes_values(self.state.task), "execution_order": [n.node_id for n in self.state.execution_order], - "invocation_state": self._current_invocation_state, "_internal_state": { "interrupt_state": self._interrupt_state.to_dict(), }, @@ -1304,10 +1223,6 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: internal_state = payload["_internal_state"] self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"]) - invocation_state = payload.get("invocation_state", {}) - self._validate_invocation_state(invocation_state) - self._current_invocation_state = invocation_state - if not payload.get("next_nodes_to_execute"): # Reset all nodes for node in self.nodes.values(): @@ -1331,40 +1246,11 @@ def _compute_ready_nodes_for_resume(self) -> list[GraphNode]: incoming = [e for e in self.edges if e.to_node is node] if not incoming: ready_nodes.append(node) - elif self._is_node_ready_for_resume(node, incoming, completed_nodes): + elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming): ready_nodes.append(node) return ready_nodes - def _is_node_ready_for_resume( - self, - node: GraphNode, - incoming: list[GraphEdge], - completed_nodes: set[GraphNode], - ) -> bool: - """Check if a node is ready for resume, accounting for conditional edges. - - A node is ready if all TRAVERSABLE incoming edges have their source completed. - Edges whose condition evaluates to False are excluded from the check — they - represent paths that were intentionally skipped. - - Re-evaluates conditions (rather than caching traversal results) intentionally: - invocation_state may change between invocations, so conditions must reflect - current runtime context. This means condition logic changes between serialize - and resume will also take effect — consistent with the graph being defined in code. - """ - traversable_edges = [ - e - for e in incoming - # Short-circuit: skip signature inspection + cache lookup for unconditional edges. - if e.condition is None or e.should_traverse(self.state, invocation_state=self._current_invocation_state) - ] - - if not traversable_edges: - return False - - return all(e.from_node in completed_nodes for e in traversable_edges) - def _from_dict(self, payload: dict[str, Any]) -> None: self.state.status = Status(payload["status"]) # Hydrate completed nodes & results diff --git a/strands-py/tests/strands/multiagent/test_graph.py b/strands-py/tests/strands/multiagent/test_graph.py index 7ff31dd19b..a6085627c3 100644 --- a/strands-py/tests/strands/multiagent/test_graph.py +++ b/strands-py/tests/strands/multiagent/test_graph.py @@ -16,26 +16,6 @@ from strands.types._events import MultiAgentNodeCancelEvent -def _make_graph( - nodes: dict, - edges=None, - state=None, - invocation_state=None, - interrupt_state=None, -) -> Graph: - """Create a minimally-valid Graph instance for unit tests without invoking __init__.""" - graph = Graph.__new__(Graph) - graph.nodes = nodes - graph.edges = edges if edges is not None else set() - graph.state = state or GraphState() - graph._current_invocation_state = invocation_state or {} - graph._interrupt_state = interrupt_state or _InterruptState() - graph._resume_from_session = False - graph._resume_next_nodes = [] - graph.id = "test_graph" - return graph - - def create_mock_agent(name, response_text="Default response", metrics=None, agent_id=None): """Create a mock Agent with specified properties.""" agent = Mock(spec=Agent) @@ -2480,14 +2460,14 @@ def test_find_newly_ready_nodes_only_evaluates_outbound_edges(): node_d = GraphNode(node_id="D", executor=create_mock_agent("D")) node_e = GraphNode(node_id="E", executor=create_mock_agent("E")) - graph = _make_graph( - nodes={"A": node_a, "B": node_b, "C": node_c, "D": node_d, "E": node_e}, - edges={ - GraphEdge(from_node=node_a, to_node=node_b), - GraphEdge(from_node=node_b, to_node=node_c), - GraphEdge(from_node=node_d, to_node=node_e), - }, - ) + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a, "B": node_b, "C": node_c, "D": node_d, "E": node_e} + graph.edges = [ + GraphEdge(from_node=node_a, to_node=node_b), + GraphEdge(from_node=node_b, to_node=node_c), + GraphEdge(from_node=node_d, to_node=node_e), + ] + graph.state = GraphState() # When A completes, only B should be ready (not E) ready = graph._find_newly_ready_nodes([node_a]) @@ -2498,425 +2478,3 @@ def test_find_newly_ready_nodes_only_evaluates_outbound_edges(): ready = graph._find_newly_ready_nodes([node_d]) ready_ids = {n.node_id for n in ready} assert ready_ids == {"E"}, f"Expected only E, got {ready_ids}" - - -# ============================================================================= -# Tests for EdgeConditionWithContext (invocation_state in edge conditions) -# ============================================================================= - - -class TestEdgeConditionProtocol: - """Tests for the EdgeConditionWithContext protocol and dispatch logic.""" - - def test_legacy_condition_still_works(self): - """Verify Callable[[GraphState], bool] conditions work unchanged.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - - def legacy_condition(state: GraphState) -> bool: - return len(state.completed_nodes) > 0 - - edge = GraphEdge(from_node=node_a, to_node=node_b, condition=legacy_condition) - - assert not edge.should_traverse(GraphState()) - assert edge.should_traverse(GraphState(completed_nodes={node_a})) - - def test_legacy_condition_not_affected_by_invocation_state(self): - """Legacy conditions should work even when invocation_state is passed.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - - def legacy_condition(state: GraphState) -> bool: - return True - - edge = GraphEdge(from_node=node_a, to_node=node_b, condition=legacy_condition) - assert edge.should_traverse(GraphState(), invocation_state={"key": "value"}) - - def test_new_style_condition_receives_invocation_state(self): - """Verify EdgeConditionWithContext receives invocation_state kwarg.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - - received_invocation_state = {} - - def context_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - received_invocation_state.update(invocation_state) - return invocation_state.get("enable_path", False) - - edge = GraphEdge(from_node=node_a, to_node=node_b, condition=context_condition) - - # Without the flag, should not traverse - assert not edge.should_traverse(GraphState(), invocation_state={"enable_path": False}) - assert received_invocation_state == {"enable_path": False} - - # With the flag, should traverse - received_invocation_state.clear() - assert edge.should_traverse(GraphState(), invocation_state={"enable_path": True}) - assert received_invocation_state == {"enable_path": True} - - def test_condition_none_always_traverses(self): - """Verify edges without conditions always traverse.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - - edge = GraphEdge(from_node=node_a, to_node=node_b, condition=None) - assert edge.should_traverse(GraphState()) - assert edge.should_traverse(GraphState(), invocation_state={"anything": True}) - - def test_new_style_condition_with_kwargs_extensibility(self): - """Verify conditions with **kwargs work for future extensibility.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - - def extensible_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return True - - edge = GraphEdge(from_node=node_a, to_node=node_b, condition=extensible_condition) - assert edge.should_traverse(GraphState(), invocation_state={}) - - def test_invocation_state_defaults_to_empty_dict_when_none(self): - """Verify graceful behavior when invocation_state is None.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - - received = [] - - def context_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - received.append(invocation_state) - return True - - edge = GraphEdge(from_node=node_a, to_node=node_b, condition=context_condition) - assert edge.should_traverse(GraphState(), invocation_state=None) - assert received == [{}] - - -class TestInvocationStatePropagation: - """Tests that invocation_state flows correctly through graph execution paths.""" - - def test_is_node_ready_with_conditions_passes_invocation_state(self): - """Verify _is_node_ready_with_conditions passes invocation_state to edge conditions.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - node_b.dependencies.add(node_a) - - received_state = {} - - def context_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - received_state.update(invocation_state) - return invocation_state.get("activate", False) - - edge = GraphEdge(from_node=node_a, to_node=node_b, condition=context_condition) - - graph = _make_graph( - nodes={"A": node_a, "B": node_b}, - edges={edge}, - state=GraphState(completed_nodes={node_a}), - invocation_state={"activate": True}, - ) - - assert graph._is_node_ready_with_conditions(node_b, [node_a]) - assert received_state == {"activate": True} - - def test_is_node_ready_with_conditions_invocation_state_false(self): - """Verify condition returning False blocks node readiness.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - node_b.dependencies.add(node_a) - - def context_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return invocation_state.get("activate", False) - - edge = GraphEdge(from_node=node_a, to_node=node_b, condition=context_condition) - - graph = _make_graph( - nodes={"A": node_a, "B": node_b}, - edges={edge}, - state=GraphState(completed_nodes={node_a}), - invocation_state={"activate": False}, - ) - - assert not graph._is_node_ready_with_conditions(node_b, [node_a]) - - def test_build_node_input_passes_invocation_state(self): - """Verify _build_node_input uses invocation_state for edge condition evaluation.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - node_b.dependencies.add(node_a) - - def context_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return invocation_state.get("include_dep", False) - - edge = GraphEdge(from_node=node_a, to_node=node_b, condition=context_condition) - - mock_result = AgentResult( - message={"role": "assistant", "content": [{"text": "result from A"}]}, - stop_reason="end_turn", - state={}, - metrics=Mock( - accumulated_usage={"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, - accumulated_metrics={"latencyMs": 100.0}, - ), - ) - - graph = _make_graph( - nodes={"A": node_a, "B": node_b}, - edges={edge}, - state=GraphState( - task="test task", - completed_nodes={node_a}, - results={"A": NodeResult(result=mock_result)}, - ), - invocation_state={"include_dep": False}, - ) - - # With condition=False, dependency is excluded -> gets raw task - node_input = graph._build_node_input(node_b) - assert any("test task" in str(block) for block in node_input) - - # With condition=True, dependency result is included - graph._current_invocation_state = {"include_dep": True} - node_input = graph._build_node_input(node_b) - input_text = " ".join(str(block) for block in node_input) - assert "result from A" in input_text - - -class TestResumeDeadlockFix: - """Tests for the _compute_ready_nodes_for_resume deadlock fix with conditional edges.""" - - def test_resume_skips_false_condition_edges(self): - """Graph: A->(cond=False)->B, A->(unconditional)->C. C should be ready on resume.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - node_c = GraphNode(node_id="C", executor=create_mock_agent("C")) - node_b.dependencies.add(node_a) - node_c.dependencies.add(node_a) - - def always_false(state: GraphState) -> bool: - return False - - graph = _make_graph( - nodes={"A": node_a, "B": node_b, "C": node_c}, - edges={ - GraphEdge(from_node=node_a, to_node=node_b, condition=always_false), - GraphEdge(from_node=node_a, to_node=node_c), # unconditional - }, - state=GraphState(status=Status.INTERRUPTED, completed_nodes={node_a}), - ) - - ready = graph._compute_ready_nodes_for_resume() - ready_ids = {n.node_id for n in ready} - # C should be ready (unconditional edge from A), B should not (condition=False) - assert "C" in ready_ids - assert "B" not in ready_ids - - def test_resume_diamond_with_conditional_skip(self): - """Exact scenario from issue comment: A->(cond=True)->B->C, A->(cond=False)->C. - - When condition is False, B is skipped. C has two incoming edges: - - B->C (unconditional, but B never ran) - - A->C (condition=False, should be excluded from readiness check) - - Without the fix, C is stuck because all() requires both edges satisfied. - With the fix, the A->C edge is excluded (condition=False), and since there - are no other traversable edges with incomplete sources, we need B->C. - But B never ran, so C can't be ready via B->C either. - - The correct fix scenario: when condition selects the FAST path (True), - B runs and C should be ready via B->C (excluding A->C which is False). - """ - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - node_c = GraphNode(node_id="C", executor=create_mock_agent("C")) - node_b.dependencies.add(node_a) - node_c.dependencies.add(node_a) - node_c.dependencies.add(node_b) - - def use_fast_path(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return invocation_state.get("fast", False) - - def skip_direct(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return not invocation_state.get("fast", False) - - graph = _make_graph( - nodes={"A": node_a, "B": node_b, "C": node_c}, - edges={ - GraphEdge(from_node=node_a, to_node=node_b, condition=use_fast_path), - GraphEdge(from_node=node_a, to_node=node_c, condition=skip_direct), - GraphEdge(from_node=node_b, to_node=node_c), - }, - state=GraphState(status=Status.INTERRUPTED, completed_nodes={node_a, node_b}), - invocation_state={"fast": True}, - ) - - ready = graph._compute_ready_nodes_for_resume() - ready_ids = {n.node_id for n in ready} - # C should be ready: A->C edge excluded (condition=False), B->C is unconditional and B completed - assert "C" in ready_ids - - def test_resume_all_conditions_false_blocks_node(self): - """If ALL incoming edges have conditions that are False, node should not be ready.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - node_b.dependencies.add(node_a) - - def always_false(state: GraphState) -> bool: - return False - - graph = _make_graph( - nodes={"A": node_a, "B": node_b}, - edges={ - GraphEdge(from_node=node_a, to_node=node_b, condition=always_false), - }, - state=GraphState(status=Status.INTERRUPTED, completed_nodes={node_a}), - ) - - ready = graph._compute_ready_nodes_for_resume() - ready_ids = {n.node_id for n in ready} - assert "B" not in ready_ids - - def test_resume_with_invocation_state_condition(self): - """Condition uses invocation_state; on resume with same state, correct routing.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - node_c = GraphNode(node_id="C", executor=create_mock_agent("C")) - node_b.dependencies.add(node_a) - node_c.dependencies.add(node_a) - - def check_role(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return invocation_state.get("role") == "admin" - - def check_not_admin(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return invocation_state.get("role") != "admin" - - graph = _make_graph( - nodes={"A": node_a, "B": node_b, "C": node_c}, - edges={ - GraphEdge(from_node=node_a, to_node=node_b, condition=check_role), - GraphEdge(from_node=node_a, to_node=node_c, condition=check_not_admin), - }, - state=GraphState(status=Status.INTERRUPTED, completed_nodes={node_a}), - ) - - # As admin: only B should be ready - graph._current_invocation_state = {"role": "admin"} - ready = graph._compute_ready_nodes_for_resume() - ready_ids = {n.node_id for n in ready} - assert ready_ids == {"B"} - - # As non-admin: only C should be ready - graph._current_invocation_state = {"role": "user"} - ready = graph._compute_ready_nodes_for_resume() - ready_ids = {n.node_id for n in ready} - assert ready_ids == {"C"} - - def test_resume_mixed_conditional_unconditional_edges(self): - """Node with both conditional (False) and unconditional edges: ready if unconditional source completed.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) - node_c = GraphNode(node_id="C", executor=create_mock_agent("C")) - node_b.dependencies.add(node_a) - node_c.dependencies.add(node_a) - node_c.dependencies.add(node_b) - - def always_false(state: GraphState) -> bool: - return False - - graph = _make_graph( - nodes={"A": node_a, "B": node_b, "C": node_c}, - edges={ - GraphEdge(from_node=node_a, to_node=node_b), # unconditional - GraphEdge(from_node=node_a, to_node=node_c, condition=always_false), # conditional (False) - GraphEdge(from_node=node_b, to_node=node_c), # unconditional - }, - state=GraphState(status=Status.INTERRUPTED, completed_nodes={node_a, node_b}), - ) - - ready = graph._compute_ready_nodes_for_resume() - ready_ids = {n.node_id for n in ready} - # C should be ready: A->C is excluded (condition=False), B->C is unconditional and B completed - assert "C" in ready_ids - - -class TestSerializationWithInvocationState: - """Tests for serialization/deserialization of invocation_state.""" - - def test_serialize_includes_invocation_state(self): - """Verify invocation_state appears in serialized payload.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - - graph = _make_graph( - nodes={"A": node_a}, - state=GraphState(status=Status.COMPLETED, completed_nodes={node_a}, task="test"), - invocation_state={"feature_flag": True, "user_id": "123"}, - ) - - serialized = graph.serialize_state() - assert "invocation_state" in serialized - assert serialized["invocation_state"] == {"feature_flag": True, "user_id": "123"} - - def test_deserialize_restores_invocation_state(self): - """Verify invocation_state is restored on deserialization.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - - graph = _make_graph(nodes={"A": node_a}) - - payload = { - "status": "completed", - "completed_nodes": [], - "next_nodes_to_execute": [], - "invocation_state": {"role": "admin"}, - } - graph.deserialize_state(payload) - assert graph._current_invocation_state == {"role": "admin"} - - def test_deserialize_missing_invocation_state_defaults_empty(self): - """Backwards compat: old serialized payloads without invocation_state still work.""" - node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - - graph = _make_graph(nodes={"A": node_a}) - - payload = { - "status": "completed", - "completed_nodes": [], - "next_nodes_to_execute": [], - } - graph.deserialize_state(payload) - assert graph._current_invocation_state == {} - - -class TestConditionSignatureDetection: - """Tests for the _is_context_condition helper.""" - - def test_detects_legacy_condition(self): - """Legacy condition without invocation_state param.""" - from strands.multiagent.graph import _is_context_condition - - def legacy(state: GraphState) -> bool: - return True - - assert not _is_context_condition(legacy) - - def test_detects_new_style_condition(self): - """New-style condition with invocation_state param.""" - from strands.multiagent.graph import _is_context_condition - - def new_style(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return True - - assert _is_context_condition(new_style) - - def test_detects_positional_invocation_state(self): - """Condition with invocation_state as positional param (also supported).""" - from strands.multiagent.graph import _is_context_condition - - def positional(state: GraphState, invocation_state: dict) -> bool: - return True - - assert _is_context_condition(positional) - - def test_lambda_without_invocation_state(self): - """Lambda conditions (legacy pattern).""" - from strands.multiagent.graph import _is_context_condition - - cond = lambda state: len(state.completed_nodes) > 0 # noqa: E731 - assert not _is_context_condition(cond) diff --git a/strands-py/tests_integ/test_multiagent_graph.py b/strands-py/tests_integ/test_multiagent_graph.py index 2b7f17547c..b80a0f82dd 100644 --- a/strands-py/tests_integ/test_multiagent_graph.py +++ b/strands-py/tests_integ/test_multiagent_graph.py @@ -15,7 +15,7 @@ MessageAddedEvent, ) from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult, Status -from strands.multiagent.graph import GraphBuilder, GraphState +from strands.multiagent.graph import GraphBuilder from strands.session.file_session_manager import FileSessionManager from strands.types.content import ContentBlock from tests.fixtures.mock_hook_provider import MockHookProvider @@ -586,319 +586,3 @@ async def failing_after_two(*args, **kwargs): assert result.status == Status.COMPLETED assert len(result.execution_order) == 5 assert all(node.node_id == "loop_node" for node in result.execution_order) - - -@pytest.mark.asyncio -async def test_conditional_routing_with_invocation_state(): - """Test that edge conditions can use invocation_state for routing decisions. - - Graph structure: - entry -> (condition: use_detailed) -> detailed_agent - -> (condition: not use_detailed) -> brief_agent - """ - detailed_agent = Agent( - name="detailed", - model="us.amazon.nova-pro-v1:0", - system_prompt="Provide a very detailed, multi-paragraph explanation.", - ) - brief_agent = Agent( - name="brief", - model="us.amazon.nova-lite-v1:0", - system_prompt="Provide a one-sentence answer.", - ) - entry_agent = Agent( - name="entry", - model="us.amazon.nova-lite-v1:0", - system_prompt="You are a router. Just say 'routing complete'.", - ) - - def use_detailed(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return invocation_state.get("detail_level") == "high" - - def use_brief(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return invocation_state.get("detail_level") != "high" - - builder = GraphBuilder() - builder.add_node(entry_agent, "entry") - builder.add_node(detailed_agent, "detailed") - builder.add_node(brief_agent, "brief") - builder.add_edge("entry", "detailed", condition=use_detailed) - builder.add_edge("entry", "brief", condition=use_brief) - builder.set_entry_point("entry") - graph = builder.build() - - # With detail_level=high, only detailed_agent should execute - result = await graph.invoke_async("What is Python?", invocation_state={"detail_level": "high"}) - assert result.status == Status.COMPLETED - executed_nodes = {n.node_id for n in result.execution_order} - assert "detailed" in executed_nodes - assert "brief" not in executed_nodes - - # With detail_level=low, only brief_agent should execute - result = await graph.invoke_async("What is Python?", invocation_state={"detail_level": "low"}) - assert result.status == Status.COMPLETED - executed_nodes = {n.node_id for n in result.execution_order} - assert "brief" in executed_nodes - assert "detailed" not in executed_nodes - - -@pytest.mark.asyncio -async def test_legacy_conditions_unaffected_by_invocation_state(): - """Test that existing graphs with legacy conditions still work when invocation_state is passed.""" - agent1 = Agent( - name="agent1", - model="us.amazon.nova-lite-v1:0", - system_prompt="You are agent 1. Just say hello.", - ) - agent2 = Agent( - name="agent2", - model="us.amazon.nova-lite-v1:0", - system_prompt="You are agent 2. Just say goodbye.", - ) - - def legacy_condition(state: GraphState) -> bool: - return any(n.node_id == "agent1" for n in state.completed_nodes) - - builder = GraphBuilder() - builder.add_node(agent1, "agent1") - builder.add_node(agent2, "agent2") - builder.add_edge("agent1", "agent2", condition=legacy_condition) - builder.set_entry_point("agent1") - graph = builder.build() - - # Legacy condition should still work fine even with invocation_state passed - result = await graph.invoke_async("Hello", invocation_state={"some_key": "some_value"}) - assert result.status == Status.COMPLETED - executed_nodes = {n.node_id for n in result.execution_order} - assert "agent1" in executed_nodes - assert "agent2" in executed_nodes - - -@pytest.mark.asyncio -async def test_condition_combining_graph_state_and_invocation_state(): - """Test condition that uses both GraphState and invocation_state for decisions. - - Graph: entry -> (condition: completed entry AND feature_flag) -> feature_agent -> final - -> (unconditional) -> final - """ - entry_agent = Agent( - name="entry", - model="us.amazon.nova-lite-v1:0", - system_prompt="Just say 'entry done'.", - ) - feature_agent = Agent( - name="feature", - model="us.amazon.nova-lite-v1:0", - system_prompt="Execute the new feature path. Say 'feature executed'.", - ) - final_agent = Agent( - name="final", - model="us.amazon.nova-lite-v1:0", - system_prompt="Summarize. Say 'done'.", - ) - - def feature_gate(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - entry_completed = any(n.node_id == "entry" for n in state.completed_nodes) - flag_enabled = invocation_state.get("enable_feature_x", False) - return entry_completed and flag_enabled - - builder = GraphBuilder() - builder.add_node(entry_agent, "entry") - builder.add_node(feature_agent, "feature") - builder.add_node(final_agent, "final") - builder.add_edge("entry", "feature", condition=feature_gate) - builder.add_edge("entry", "final") - builder.add_edge("feature", "final") - builder.set_entry_point("entry") - graph = builder.build() - - # With flag disabled: entry -> final (feature skipped) - result = await graph.invoke_async("Run task", invocation_state={"enable_feature_x": False}) - assert result.status == Status.COMPLETED - executed_nodes = {n.node_id for n in result.execution_order} - assert "entry" in executed_nodes - assert "final" in executed_nodes - assert "feature" not in executed_nodes - - # With flag enabled: entry -> feature -> final - result = await graph.invoke_async("Run task", invocation_state={"enable_feature_x": True}) - assert result.status == Status.COMPLETED - executed_nodes = {n.node_id for n in result.execution_order} - assert "entry" in executed_nodes - assert "feature" in executed_nodes - assert "final" in executed_nodes - - -@pytest.mark.asyncio -async def test_diamond_graph_conditional_convergence(): - """Test diamond graph where one path is conditionally skipped and downstream converges. - - Graph: - entry -> (cond=True) -> fast_path -> merger - -> (cond=False) -> slow_path -> merger - - This tests the deadlock fix: merger should still execute even though slow_path's - incoming edge evaluates to False. - """ - entry_agent = Agent( - name="entry", - model="us.amazon.nova-lite-v1:0", - system_prompt="Say 'routed'.", - ) - fast_agent = Agent( - name="fast", - model="us.amazon.nova-lite-v1:0", - system_prompt="Say 'fast result'.", - ) - slow_agent = Agent( - name="slow", - model="us.amazon.nova-lite-v1:0", - system_prompt="Say 'slow result'.", - ) - merger_agent = Agent( - name="merger", - model="us.amazon.nova-lite-v1:0", - system_prompt="Merge results. Say 'merged'.", - ) - - def is_fast_mode(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return invocation_state.get("mode") == "fast" - - def is_slow_mode(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return invocation_state.get("mode") == "slow" - - builder = GraphBuilder() - builder.add_node(entry_agent, "entry") - builder.add_node(fast_agent, "fast") - builder.add_node(slow_agent, "slow") - builder.add_node(merger_agent, "merger") - builder.add_edge("entry", "fast", condition=is_fast_mode) - builder.add_edge("entry", "slow", condition=is_slow_mode) - builder.add_edge("fast", "merger") - builder.add_edge("slow", "merger") - builder.set_entry_point("entry") - graph = builder.build() - - # Fast mode: entry -> fast -> merger (slow skipped, merger not deadlocked) - result = await graph.invoke_async("Process", invocation_state={"mode": "fast"}) - assert result.status == Status.COMPLETED - executed_nodes = {n.node_id for n in result.execution_order} - assert executed_nodes == {"entry", "fast", "merger"} - - # Slow mode: entry -> slow -> merger (fast skipped) - result = await graph.invoke_async("Process", invocation_state={"mode": "slow"}) - assert result.status == Status.COMPLETED - executed_nodes = {n.node_id for n in result.execution_order} - assert executed_nodes == {"entry", "slow", "merger"} - - -@pytest.mark.asyncio -async def test_invocation_state_persisted_on_resume(tmp_path): - """Test that invocation_state is serialized and correctly used on resume after failure. - - Verifies that when a graph fails mid-execution and resumes from persisted state, - the invocation_state is restored and edge conditions evaluate correctly. - """ - session_id = f"invocation_state_resume_{uuid4()}" - session_manager = FileSessionManager(session_id=session_id, storage_dir=str(tmp_path)) - - agent1 = Agent(model="us.amazon.nova-lite-v1:0", system_prompt="Say 'step 1 done'.", name="agent1") - agent2 = Agent(model="us.amazon.nova-lite-v1:0", system_prompt="Say 'step 2 done'.", name="agent2") - agent3 = Agent(model="us.amazon.nova-lite-v1:0", system_prompt="Say 'step 3 done'.", name="agent3") - - def requires_premium(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return invocation_state.get("tier") == "premium" - - builder = GraphBuilder() - builder.add_node(agent1, "step1") - builder.add_node(agent2, "step2") - builder.add_node(agent3, "step3") - builder.add_edge("step1", "step2", condition=requires_premium) - builder.add_edge("step1", "step3") - builder.add_edge("step2", "step3") - builder.set_entry_point("step1") - builder.set_session_manager(session_manager) - graph = builder.build() - - # First invocation: step2 fails, step1 completed - async def failing_stream(*args, **kwargs): - raise Exception("Simulated failure in step2") - yield - - with patch.object(agent2, "stream_async", side_effect=failing_stream): - try: - await graph.invoke_async("Premium task", invocation_state={"tier": "premium"}) - raise AssertionError("Expected exception") - except Exception as e: - assert "Simulated failure in step2" in str(e) - - # Verify invocation_state was persisted - persisted = session_manager.read_multi_agent(session_id, graph.id) - assert persisted is not None - assert persisted.get("invocation_state") == {"tier": "premium"} - assert "step1" in persisted["completed_nodes"] - assert "step2" in persisted["failed_nodes"] - - # Resume: step2 should retry (condition still True), then step3 - result = await graph.invoke_async("Premium task", invocation_state={"tier": "premium"}) - assert result.status == Status.COMPLETED - executed_nodes = {n.node_id for n in result.execution_order} - assert "step2" in executed_nodes - assert "step3" in executed_nodes - - session_manager.delete_session(session_id) - - -@pytest.mark.asyncio -async def test_invocation_state_streaming_with_conditional_edges(): - """Test that streaming events are correctly emitted for conditional edge graphs. - - Verifies that only the activated path's nodes produce stream events. - """ - agent_a = Agent( - name="agent_a", - model="us.amazon.nova-lite-v1:0", - system_prompt="Say 'A done'.", - ) - agent_b = Agent( - name="agent_b", - model="us.amazon.nova-lite-v1:0", - system_prompt="Say 'B done'.", - ) - agent_c = Agent( - name="agent_c", - model="us.amazon.nova-lite-v1:0", - system_prompt="Say 'C done'.", - ) - - def go_to_b(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return invocation_state.get("path") == "B" - - def go_to_c(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: - return invocation_state.get("path") == "C" - - builder = GraphBuilder() - builder.add_node(agent_a, "A") - builder.add_node(agent_b, "B") - builder.add_node(agent_c, "C") - builder.add_edge("A", "B", condition=go_to_b) - builder.add_edge("A", "C", condition=go_to_c) - builder.set_entry_point("A") - graph = builder.build() - - # Stream with path=B — only A and B should have events - events = [] - async for event in graph.stream_async("Do something", invocation_state={"path": "B"}): - events.append(event) - - node_start_events = [e for e in events if e.get("type") == "multiagent_node_start"] - started_nodes = {e["node_id"] for e in node_start_events} - assert "A" in started_nodes - assert "B" in started_nodes - assert "C" not in started_nodes - - # Verify final result - result_events = [e for e in events if "result" in e and e.get("type") != "multiagent_node_stream"] - assert len(result_events) >= 1 - final_result = result_events[-1]["result"] - assert final_result.status == Status.COMPLETED