diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 968426b8b4..6327760e44 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -787,7 +787,8 @@ async def step(self): ) return - if not llm_resp.tools_call_name: + has_tool_calls = bool(llm_resp.tools_call_name) + if not has_tool_calls: await self._complete_with_assistant_response(llm_resp) # 返回 LLM 结果 @@ -800,18 +801,19 @@ async def step(self): ), ), ) - if llm_resp.result_chain: - yield AgentResponse( - type="llm_result", - data=AgentResponseData(chain=llm_resp.result_chain), - ) - elif llm_resp.completion_text: - yield AgentResponse( - type="llm_result", - data=AgentResponseData( - chain=MessageChain().message(llm_resp.completion_text), - ), - ) + if not has_tool_calls: + if llm_resp.result_chain: + yield AgentResponse( + type="llm_result", + data=AgentResponseData(chain=llm_resp.result_chain), + ) + elif llm_resp.completion_text: + yield AgentResponse( + type="llm_result", + data=AgentResponseData( + chain=MessageChain().message(llm_resp.completion_text), + ), + ) # 如果有工具调用,还需处理工具调用 if llm_resp.tools_call_name: diff --git a/astrbot/core/astr_main_agent_resources.py b/astrbot/core/astr_main_agent_resources.py index 821ece702c..cfc40a3551 100644 --- a/astrbot/core/astr_main_agent_resources.py +++ b/astrbot/core/astr_main_agent_resources.py @@ -24,16 +24,18 @@ TOOL_CALL_PROMPT = ( "When using tools: " - "never return an empty response; " - "briefly explain the purpose before calling a tool; " + "you may return only tool calls when no user-facing message is needed; " + 'do not emit placeholder text such as "No response"; ' + "briefly explain the purpose before calling a tool only when it helps the user; " "follow the tool schema exactly and do not invent parameters; " "after execution, briefly summarize the result for the user; " "keep the conversation style consistent." ) TOOL_CALL_PROMPT_SKILLS_LIKE_MODE = ( - "You MUST NOT return an empty response, especially after invoking a tool." - " Before calling any tool, provide a brief explanatory message to the user stating the purpose of the tool call." + "You may return only tool calls when no user-facing message is needed." + ' Do not emit placeholder text such as "No response".' + " Before calling any tool, provide a brief explanatory message to the user only when it helps." " Tool schemas are provided in two stages: first only name and description; " "if you decide to use a tool, the full parameter schema will be provided in " "a follow-up step. Do not guess arguments before you see the schema." diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 74d0691085..3f7bbf64ab 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -260,6 +260,35 @@ async def text_chat(self, **kwargs) -> LLMResponse: ) +class PreToolTextThenFinalProvider(MockProvider): + def __init__( + self, pre_tool_text: str, reasoning_content: str | None = None + ): + super().__init__() + self.pre_tool_text = pre_tool_text + self.reasoning_content = reasoning_content + + async def text_chat(self, **kwargs) -> LLMResponse: + self.call_count += 1 + func_tool = kwargs.get("func_tool") + if func_tool is None or self.call_count > 1: + return LLMResponse( + role="assistant", + completion_text="final answer", + usage=TokenUsage(input_other=10, output=5), + ) + + return LLMResponse( + role="assistant", + completion_text=self.pre_tool_text, + reasoning_content=self.reasoning_content, + tools_call_name=["test_tool"], + tools_call_args=[{"query": "test"}], + tools_call_ids=["call_pre_tool"], + usage=TokenUsage(input_other=10, output=5), + ) + + class SequentialToolProvider(MockProvider): def __init__(self, tool_sequence: list[str]): super().__init__() @@ -498,6 +527,97 @@ async def test_normal_completion_without_max_step( assert runner.req.func_tool is not None, "正常完成时工具不应该被禁用" +@pytest.mark.asyncio +@pytest.mark.parametrize( + "pre_tool_text", + ["*No response*", "I will check this first."], +) +async def test_tool_call_turn_does_not_emit_pre_tool_llm_result(pre_tool_text: str): + tool = FunctionTool( + name="test_tool", + description="test tool", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + handler=AsyncMock(), + ) + provider = PreToolTextThenFinalProvider(pre_tool_text) + request = ProviderRequest( + prompt="run tool", + func_tool=ToolSet(tools=[tool]), + contexts=[], + ) + runner = ToolLoopAgentRunner() + + await runner.reset( + provider=provider, + request=request, + run_context=ContextWrapper(context=None), + tool_executor=cast(Any, MockToolExecutor()), + agent_hooks=MockHooks(), + streaming=False, + ) + + responses = [] + async for response in runner.step_until_done(3): + responses.append(response) + + llm_result_texts = [ + resp.data["chain"].get_plain_text(with_other_comps_mark=True) + for resp in responses + if resp.type == "llm_result" + ] + + assert pre_tool_text not in llm_result_texts + assert any(resp.type == "tool_call" for resp in responses) + assert "final answer" in llm_result_texts + + +@pytest.mark.asyncio +async def test_tool_call_turn_still_emits_reasoning_content(): + tool = FunctionTool( + name="test_tool", + description="test tool", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + handler=AsyncMock(), + ) + provider = PreToolTextThenFinalProvider( + pre_tool_text="*No response*", + reasoning_content="thinking...", + ) + request = ProviderRequest( + prompt="run tool", + func_tool=ToolSet(tools=[tool]), + contexts=[], + ) + runner = ToolLoopAgentRunner() + + await runner.reset( + provider=provider, + request=request, + run_context=ContextWrapper(context=None), + tool_executor=cast(Any, MockToolExecutor()), + agent_hooks=MockHooks(), + streaming=False, + ) + + responses = [] + async for response in runner.step_until_done(3): + responses.append(response) + + reasoning_texts = [ + resp.data["chain"].get_plain_text(with_other_comps_mark=True) + for resp in responses + if resp.type == "llm_result" and resp.data["chain"].type == "reasoning" + ] + llm_result_texts = [ + resp.data["chain"].get_plain_text(with_other_comps_mark=True) + for resp in responses + if resp.type == "llm_result" + ] + + assert "thinking..." in reasoning_texts + assert "*No response*" not in llm_result_texts + + @pytest.mark.asyncio async def test_max_step_with_streaming( runner, mock_provider, provider_request, mock_tool_executor, mock_hooks