diff --git a/haystack_experimental/components/agents/human_in_the_loop/strategies.py b/haystack_experimental/components/agents/human_in_the_loop/strategies.py index f8317e31..4f2251e1 100644 --- a/haystack_experimental/components/agents/human_in_the_loop/strategies.py +++ b/haystack_experimental/components/agents/human_in_the_loop/strategies.py @@ -68,9 +68,8 @@ def run( Optional unique identifier for the tool call. This can be used to track and correlate the decision with a specific tool invocation. :param confirmation_strategy_context: - Optional dictionary for passing request-scoped resources. Useful in web/server environments - to provide per-request objects (e.g., WebSocket connections, async queues, Redis pub/sub clients) - that strategies can use for non-blocking user interaction. + Optional dictionary for passing request-scoped resources. Not used by this strategy but included for + interface compatibility. :returns: A ToolExecutionDecision indicating whether to execute the tool with the given parameters, or a @@ -140,7 +139,8 @@ async def run_async( :param tool_call_id: Optional unique identifier for the tool call. :param confirmation_strategy_context: - Optional dictionary for passing request-scoped resources. + Optional dictionary for passing request-scoped resources. Not used by this strategy but included for + interface compatibility. :returns: A ToolExecutionDecision indicating whether to execute the tool with the given parameters. @@ -263,7 +263,8 @@ async def run_async( :param tool_call_id: Optional unique identifier for the tool call. :param confirmation_strategy_context: - Optional dictionary for passing request-scoped resources. + Optional dictionary for passing request-scoped resources. Not used by this strategy but included for + interface compatibility. :raises HITLBreakpointException: Always raises an `HITLBreakpointException` exception to signal that user confirmation is required. diff --git a/haystack_experimental/components/agents/human_in_the_loop/types.py b/haystack_experimental/components/agents/human_in_the_loop/types.py index a2468b03..2d193eb0 100644 --- a/haystack_experimental/components/agents/human_in_the_loop/types.py +++ b/haystack_experimental/components/agents/human_in_the_loop/types.py @@ -49,7 +49,7 @@ def update_after_confirmation( confirmation_result: ConfirmationUIResult, ) -> None: """Update the policy based on the confirmation UI result.""" - pass + return def to_dict(self) -> dict[str, Any]: """Serialize the policy to a dictionary.""" @@ -64,11 +64,12 @@ def from_dict(cls, data: dict[str, Any]) -> "ConfirmationPolicy": class ConfirmationStrategy(Protocol): def run( self, + *, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: str | None = None, - **kwargs: dict[str, Any] | None, + confirmation_strategy_context: dict[str, Any] | None = None, ) -> ToolExecutionDecision: """ Run the confirmation strategy for a given tool and its parameters. @@ -78,9 +79,8 @@ def run( :param tool_params: The parameters to be passed to the tool. :param tool_call_id: Optional unique identifier for the tool call. This can be used to track and correlate the decision with a specific tool invocation. - :param kwargs: Additional keyword arguments. Implementations may accept `confirmation_strategy_context` - for passing request-scoped resources (e.g., WebSocket connections, async queues) in web/server - environments. + :param confirmation_strategy_context: Optional context dictionary for passing request-scoped resources + (e.g., WebSocket connections, async queues) in web/server environments. :returns: The result of the confirmation strategy (e.g., tool output, rejection message, etc.). @@ -89,11 +89,12 @@ def run( async def run_async( self, + *, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: str | None = None, - **kwargs: dict[str, Any] | None, + confirmation_strategy_context: dict[str, Any] | None = None, ) -> ToolExecutionDecision: """ Async version of run. Run the confirmation strategy for a given tool and its parameters. @@ -105,9 +106,8 @@ async def run_async( :param tool_params: The parameters to be passed to the tool. :param tool_call_id: Optional unique identifier for the tool call. This can be used to track and correlate the decision with a specific tool invocation. - :param kwargs: Additional keyword arguments. Implementations may accept `confirmation_strategy_context` - for passing request-scoped resources (e.g., WebSocket connections, async queues) in web/server - environments. + :param confirmation_strategy_context: Optional context dictionary for passing request-scoped resources + (e.g., WebSocket connections, async queues) in web/server environments. :returns: The result of the confirmation strategy (e.g., tool output, rejection message, etc.). diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index a7596441..55df0bd6 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -129,6 +129,7 @@ def run_agent( snapshot = None if snapshot_file_path: snapshot = get_latest_snapshot(snapshot_file_path=snapshot_file_path) + assert snapshot.agent_snapshot is not None # Add any new tool execution decisions to the snapshot if tool_execution_decisions: @@ -152,6 +153,7 @@ def run_pipeline_with_agent( snapshot = None if snapshot_file_path: snapshot = get_latest_snapshot(snapshot_file_path=snapshot_file_path) + assert snapshot.agent_snapshot is not None # Add any new tool execution decisions to the snapshot if tool_execution_decisions: @@ -175,6 +177,7 @@ async def run_agent_async( snapshot = None if snapshot_file_path: snapshot = get_latest_snapshot(snapshot_file_path=snapshot_file_path) + assert snapshot.agent_snapshot is not None # Add any new tool execution decisions to the snapshot if tool_execution_decisions: @@ -284,6 +287,7 @@ def test_from_dict(self, tools, confirmation_strategies, monkeypatch): assert deserialized_agent.to_dict() == agent.to_dict() assert isinstance(deserialized_agent.chat_generator, OpenAIChatGenerator) assert len(deserialized_agent.tools) == 1 + assert isinstance(deserialized_agent.tools[0], Tool) assert deserialized_agent.tools[0].name == "addition_tool" assert isinstance(deserialized_agent._tool_invoker, type(agent._tool_invoker)) assert isinstance(deserialized_agent._confirmation_strategies["addition_tool"], BlockingConfirmationStrategy) @@ -316,6 +320,7 @@ def test_get_tool_calls_and_descriptions_from_snapshot_no_mutation_of_snapshot(s original_snapshot = copy.deepcopy(loaded_snapshot) # Extract tool calls and descriptions + assert loaded_snapshot.agent_snapshot is not None _ = get_tool_calls_and_descriptions_from_snapshot( agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True ) @@ -341,6 +346,7 @@ def test_run_blocking_confirmation_strategy_modify(self, tools): result = agent.run([ChatMessage.from_user("What is 2+2?")]) assert isinstance(result["last_message"], ChatMessage) + assert result["last_message"].text is not None assert "5" in result["last_message"].text @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @@ -362,6 +368,7 @@ def test_run_breakpoint_confirmation_strategy_modify(self, tools, tmp_path): while result is None: # Load the latest snapshot from disk and prep data for front-end loaded_snapshot = get_latest_snapshot(snapshot_file_path=str(tmp_path)) + assert loaded_snapshot.agent_snapshot is not None serialized_tool_calls, tool_descripts = get_tool_calls_and_descriptions_from_snapshot( agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True ) @@ -379,6 +386,7 @@ def test_run_breakpoint_confirmation_strategy_modify(self, tools, tmp_path): # Step 3: Final result last_message = result["last_message"] assert isinstance(last_message, ChatMessage) + assert last_message.text is not None assert "5" in last_message.text @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @@ -402,6 +410,7 @@ def test_run_in_pipeline_breakpoint_confirmation_strategy_modify(self, tools, tm while result is None: # Load the latest snapshot from disk and prep data for front-end loaded_snapshot = get_latest_snapshot(snapshot_file_path=str(tmp_path)) + assert loaded_snapshot.agent_snapshot is not None serialized_tool_calls, tool_descripts = get_tool_calls_and_descriptions_from_snapshot( agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True ) @@ -419,6 +428,7 @@ def test_run_in_pipeline_breakpoint_confirmation_strategy_modify(self, tools, tm # Step 3: Final result last_message = result["agent"]["last_message"] assert isinstance(last_message, ChatMessage) + assert last_message.text is not None assert "5" in last_message.text @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @@ -440,6 +450,7 @@ async def test_run_async_blocking_confirmation_strategy_modify(self, tools): result = await agent.run_async([ChatMessage.from_user("What is 2+2?")]) assert isinstance(result["last_message"], ChatMessage) + assert result["last_message"].text is not None assert "5" in result["last_message"].text @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @@ -462,6 +473,7 @@ async def test_run_async_breakpoint_confirmation_strategy_modify(self, tools, tm while result is None: # Load the latest snapshot from disk and prep data for front-end loaded_snapshot = get_latest_snapshot(snapshot_file_path=str(tmp_path)) + assert loaded_snapshot.agent_snapshot is not None serialized_tool_calls, tool_descripts = get_tool_calls_and_descriptions_from_snapshot( agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True ) @@ -479,6 +491,7 @@ async def test_run_async_breakpoint_confirmation_strategy_modify(self, tools, tm # Step 3: Final result last_message = result["last_message"] assert isinstance(last_message, ChatMessage) + assert last_message.text is not None assert "5" in last_message.text