diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index beca48e32..800b15374 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -15,6 +15,7 @@ from openai._exceptions import ( APIStatusError as OpenAIAPIStatusError, ) +from typing_extensions import deprecated from authentication import get_auth_dependency from authentication.interface import AuthTuple @@ -42,6 +43,7 @@ from models.common.responses.types import ResponseInput from models.common.turn_summary import TurnSummary from models.config import Action +from utils.agents.query import retrieve_agent_response from utils.conversation_compaction import ( apply_compaction_blocking, configured_conversation_cache, @@ -68,7 +70,7 @@ build_turn_summary, deduplicate_referenced_documents, extract_vector_store_ids_from_tools, - get_topic_summary, + maybe_get_topic_summary, prepare_responses_params, ) from utils.shields import run_shield_moderation, validate_shield_ids_override @@ -226,12 +228,12 @@ async def query_endpoint_handler( client = await AsyncLlamaStackClientHolder().update_azure_token() # Retrieve response using Responses API - turn_summary = await retrieve_response( + turn_summary = await retrieve_agent_response( client, responses_params, moderation_result, endpoint_path, - original_input=compaction.original_input if compaction.compacted else None, + compaction.original_input if compaction.compacted else None, ) if moderation_result.decision == "passed": @@ -249,13 +251,15 @@ async def query_endpoint_handler( ) # Get topic summary for new conversation - if not user_conversation and query_request.generate_topic_summary: - logger.debug("Generating topic summary for new conversation") - topic_summary = await get_topic_summary( - query_request.query, client, responses_params.model - ) - else: - topic_summary = None + should_generate = not user_conversation and bool( + query_request.generate_topic_summary + ) + topic_summary = await maybe_get_topic_summary( + generate_topic_summary=should_generate, + input_text=query_request.query, + client=client, + model_id=responses_params.model, + ) logger.info("Consuming tokens") consume_query_tokens( @@ -301,6 +305,10 @@ async def query_endpoint_handler( ) +@deprecated( + "Deprecated in favor of utils.agents.query.retrieve_agent_response.", + stacklevel=2, +) async def retrieve_response( client: AsyncLlamaStackClient, responses_params: ResponsesApiParams, diff --git a/src/utils/agents/query.py b/src/utils/agents/query.py index 0322c43ad..e82006425 100644 --- a/src/utils/agents/query.py +++ b/src/utils/agents/query.py @@ -3,7 +3,7 @@ from __future__ import annotations from enum import Enum -from typing import TypeAlias, cast +from typing import Optional, TypeAlias, cast from fastapi import HTTPException from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient @@ -33,6 +33,7 @@ from models.common.agents import AgentTurnAccumulator from models.common.moderation import ShieldModerationResult from models.common.responses.responses_api_params import ResponsesApiParams +from models.common.responses.types import ResponseInput from models.common.turn_summary import TurnSummary from utils.agents.tool_processor import ( process_function_tool_call, @@ -281,6 +282,7 @@ async def retrieve_agent_response( responses_params: ResponsesApiParams, moderation_result: ShieldModerationResult, endpoint_path: str, + _original_input: Optional[ResponseInput] = None, ) -> TurnSummary: """Retrieve a turn summary from a blocking agent run. @@ -291,6 +293,7 @@ async def retrieve_agent_response( responses_params: Prepared Responses API parameters. moderation_result: Shield moderation outcome for the turn. endpoint_path: Endpoint path used for metric labeling. + _original_input: Original user input before the explicit-input rewrite. Returns: Turn summary for the completed agent run. diff --git a/src/utils/agents/tool_processor.py b/src/utils/agents/tool_processor.py index 6b0b910e2..77c7d516d 100644 --- a/src/utils/agents/tool_processor.py +++ b/src/utils/agents/tool_processor.py @@ -479,7 +479,7 @@ def summarize_mcp_tool_result( Tool result summary in LCS turn-summary format. """ content = cast(dict[str, Any], part.content) - if "tools" in content or "error" in content: + if "tools" in content: return summarize_mcp_list_tools_result(part, tool_round) return summarize_mcp_call_result(part, tool_round) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 7fc2edfa0..efc741e42 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -8,7 +8,20 @@ import pytest from fastapi import Request, Response from fastapi.testclient import TestClient -from pytest_mock import MockerFixture +from llama_stack_api.openai_responses import OpenAIResponseObject +from llama_stack_client.types import VersionInfo +from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + NativeToolCallPart, + NativeToolReturnPart, + TextPart, + ToolCallPart, + ToolReturnPart, +) +from pydantic_ai.native_tools import FileSearchTool, MCPServerTool +from pydantic_ai.usage import RunUsage +from pytest_mock import AsyncMockType, MockerFixture from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.orm import Session, sessionmaker @@ -70,9 +83,6 @@ def create_mock_llm_response( # pylint: disable=too-many-arguments,too-many-pos Returns: Mock LLM response object with the specified configuration. """ - # pylint: disable=import-outside-toplevel - from llama_stack_api.openai_responses import OpenAIResponseObject - mock_response = mocker.MagicMock(spec=OpenAIResponseObject) mock_response.id = "response-123" @@ -154,6 +164,223 @@ def create_mock_tool_call( return mock_tool_call +def create_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str = "This is a test response about Ansible.", + response_id: str = "response-123", + input_tokens: int = 10, + output_tokens: int = 5, + model_response: Any = None, + new_messages: Optional[list[Any]] = None, +) -> Any: + """Create a mock AgentRunResult wired for retrieve_agent_response. + + Uses real pydantic-ai message types so build_turn_summary_from_agent_run + exercises the same path as production agent runs. + + Args: + mocker: pytest-mock fixture. + content: Assistant text content for the run. + response_id: Provider response identifier. + input_tokens: Input token count for the run. + output_tokens: Output token count for the run. + model_response: Optional pre-built ModelResponse. + new_messages: Optional message sequence returned by new_messages(). + + Returns: + Mock AgentRunResult compatible with build_turn_summary_from_agent_run. + """ + if model_response is None: + parts = [TextPart(content)] if content else [] + model_response = ModelResponse( + parts=parts, + finish_reason="stop", + provider_response_id=response_id, + ) + + messages = new_messages if new_messages is not None else [model_response] + run_result = mocker.MagicMock() + run_result.response = model_response + run_result.usage = RunUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + requests=1, + ) + run_result.new_messages.return_value = messages + return run_result + + +def create_file_search_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str, + response_id: str = "response-tool-rag", + queries: Optional[list[str]] = None, + results: Optional[list[dict[str, Any]]] = None, + input_tokens: int = 10, + output_tokens: int = 5, +) -> Any: + """Create an AgentRunResult containing a native file_search tool call.""" + call = NativeToolCallPart( + tool_name=FileSearchTool.kind, + args={"queries": queries or ["test query"]}, + tool_call_id="call-fs-1", + ) + return_part = NativeToolReturnPart( + tool_name=FileSearchTool.kind, + tool_call_id="call-fs-1", + content={ + "status": "success", + "results": results or [], + }, + ) + model_response = ModelResponse( + parts=[call, return_part, TextPart(content)], + finish_reason="stop", + provider_response_id=response_id, + ) + return create_agent_run_result( + mocker, + content=content, + response_id=response_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + model_response=model_response, + ) + + +def create_mcp_list_tools_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str, + response_id: str = "response-mcplist", + server_label: str = "kubernetes-server", + tools: Optional[list[dict[str, Any]]] = None, + input_tokens: int = 15, + output_tokens: int = 20, +) -> Any: + """Create an AgentRunResult containing an MCP list-tools native tool call.""" + call = NativeToolCallPart( + tool_name=f"{MCPServerTool.kind}:{server_label}", + args={"action": "list_tools"}, + tool_call_id="mcplist-101", + ) + return_part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:{server_label}", + tool_call_id="mcplist-101", + content={"tools": tools or []}, + ) + model_response = ModelResponse( + parts=[call, return_part, TextPart(content)], + finish_reason="stop", + provider_response_id=response_id, + ) + return create_agent_run_result( + mocker, + content=content, + response_id=response_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + model_response=model_response, + ) + + +def create_multi_tool_agent_run_result( + mocker: MockerFixture, + *, + content: str = "Based on documentation and calculations...", + response_id: str = "response-multi", + input_tokens: int = 40, + output_tokens: int = 60, +) -> Any: + """Create an AgentRunResult with file_search and function tool calls.""" + file_search_call = NativeToolCallPart( + tool_name=FileSearchTool.kind, + args={"queries": ["Kubernetes deployment"]}, + tool_call_id="search-1", + ) + file_search_return = NativeToolReturnPart( + tool_name=FileSearchTool.kind, + tool_call_id="search-1", + content={"status": "success", "results": []}, + ) + function_call = ToolCallPart( + tool_name="calculate", + args={"operation": "sum"}, + tool_call_id="func-2", + ) + function_return = ToolReturnPart( + tool_name="calculate", + content={"result": 2}, + tool_call_id="func-2", + ) + model_response = ModelResponse( + parts=[ + file_search_call, + file_search_return, + function_call, + TextPart(content), + ], + finish_reason="stop", + provider_response_id=response_id, + ) + return create_agent_run_result( + mocker, + content=content, + response_id=response_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + model_response=model_response, + new_messages=[model_response, ModelRequest(parts=[function_return])], + ) + + +def set_query_agent_run( + mock_llama_stack_client: AsyncMockType, + mocker: MockerFixture, + **kwargs: Any, +) -> None: + """Configure mock agent.run return value for /query integration tests.""" + mock_llama_stack_client.query_agent.run.return_value = create_agent_run_result( + mocker, + **kwargs, + ) + + +def configure_query_agent_mock( + mocker: MockerFixture, + *, + run_result: Any = None, + run_side_effect: Any = None, +) -> Any: + """Patch build_agent for /query integration tests and return the mock agent. + + Args: + mocker: pytest-mock fixture. + run_result: AgentRunResult returned by agent.run(). + run_side_effect: Optional exception side effect for agent.run(). + + Returns: + Mock agent exposing AsyncMock run(). + """ + if run_result is None: + run_result = create_agent_run_result(mocker) + + mock_agent = mocker.AsyncMock() + if run_side_effect is not None: + mock_agent.run = mocker.AsyncMock(side_effect=run_side_effect) + else: + mock_agent.run = mocker.AsyncMock(return_value=run_result) + + build_agent_mock = mocker.patch( + "utils.agents.query.build_agent", + return_value=mock_agent, + ) + mock_agent.build_agent_mock = build_agent_mock + return mock_agent + + # ========================================== # Fixtures # ========================================== @@ -448,10 +675,6 @@ def mock_llama_stack_client_fixture( Yields: mock_client: The mocked Llama Stack client instance. """ - # pylint: disable=import-outside-toplevel - from llama_stack_api.openai_responses import OpenAIResponseObject - from llama_stack_client.types import VersionInfo - # Patch AsyncLlamaStackClientHolder at multiple import locations # This ensures the mock is active both during app startup (app.main) # and during endpoint execution (query, conversations_v1, responses, etc.) @@ -484,6 +707,10 @@ def mock_llama_stack_client_fixture( mock_client.responses.create.return_value = mock_response + mock_agent = configure_query_agent_mock(mocker) + mock_client.query_agent = mock_agent + mock_client.build_agent_mock = mock_agent.build_agent_mock + # Mock models.list mock_model = mocker.MagicMock() mock_model.id = "test-provider/test-model" diff --git a/tests/integration/endpoints/test_query_byok_integration.py b/tests/integration/endpoints/test_query_byok_integration.py index b2a659f19..0af2d27d0 100644 --- a/tests/integration/endpoints/test_query_byok_integration.py +++ b/tests/integration/endpoints/test_query_byok_integration.py @@ -7,7 +7,6 @@ import pytest from fastapi import Request -from llama_stack_api.openai_responses import OpenAIResponseObject from llama_stack_client.types import VersionInfo from pytest_mock import AsyncMockType, MockerFixture @@ -17,6 +16,12 @@ from configuration import AppConfig from models.api.requests import QueryRequest from models.api.responses.successful import QueryResponse +from tests.integration.conftest import ( + configure_query_agent_mock, + create_agent_run_result, + create_file_search_agent_run_result, + create_mock_llm_response, +) # --------------------------------------------------------------------------- # Helpers @@ -87,8 +92,8 @@ def _make_vector_io_response( def _build_base_mock_client(mocker: MockerFixture) -> Any: """Build a base mock Llama Stack client with common stubs. - Configures models, shields, conversations, version, and a default - responses.create return value. + Configures models, shields, conversations, version, and a default agent.run + return value. responses.create remains available for topic summary generation. """ mock_client = mocker.AsyncMock() @@ -112,24 +117,27 @@ def _build_base_mock_client(mocker: MockerFixture) -> Any: # Version mock_client.inspect.version.return_value = VersionInfo(version="0.4.3") - # Default response - mock_response = mocker.MagicMock(spec=OpenAIResponseObject) - mock_response.id = "response-byok" - mock_output_item = mocker.MagicMock() - mock_output_item.type = "message" - mock_output_item.role = "assistant" - mock_output_item.content = ( - "Based on the documentation, OpenShift is a Kubernetes distribution." + mock_client.responses.create.return_value = create_mock_llm_response( + mocker, + content="OpenShift overview", + input_tokens=10, + output_tokens=5, + ) + + mock_agent = configure_query_agent_mock( + mocker, + run_result=create_agent_run_result( + mocker, + content=( + "Based on the documentation, OpenShift is a Kubernetes distribution." + ), + response_id="response-byok", + input_tokens=50, + output_tokens=20, + ), ) - mock_output_item.refusal = None - mock_response.output = [mock_output_item] - mock_response.stop_reason = "end_turn" - mock_response.tool_calls = [] - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 50 - mock_usage.output_tokens = 20 - mock_response.usage = mock_usage - mock_client.responses.create.return_value = mock_response + mock_client.query_agent = mock_agent + mock_client.build_agent_mock = mock_agent.build_agent_mock return mock_client @@ -171,8 +179,8 @@ def mock_byok_tool_rag_client_fixture( ) -> Generator[Any, None, None]: """Mock Llama Stack client with BYOK tool RAG (file_search) configured. - Configures vector_stores.list with a BYOK store and responses.create - to return a file_search_call output item alongside the assistant message. + Configures vector_stores.list with a BYOK store and agent.run to return + a file_search tool result alongside the assistant message. """ mock_holder_class = mocker.patch("app.endpoints.query.AsyncLlamaStackClientHolder") mock_client = _build_base_mock_client(mocker) @@ -190,54 +198,26 @@ def mock_byok_tool_rag_client_fixture( mock_list_result.data = [mock_vector_store] mock_client.vector_stores.list.return_value = mock_list_result - # Response with file_search tool call - mock_response = mocker.MagicMock(spec=OpenAIResponseObject) - mock_response.id = "response-tool-rag" - - mock_tool_output = mocker.MagicMock() - mock_tool_output.type = "file_search_call" - mock_tool_output.id = "call-fs-1" - mock_tool_output.queries = ["What is OpenShift?"] - mock_tool_output.status = "completed" - - mock_result = mocker.MagicMock() - mock_result.file_id = "doc-ocp-1" - mock_result.filename = "openshift-docs.txt" - mock_result.score = 0.92 - mock_result.text = "OpenShift is a Kubernetes distribution by Red Hat." - mock_result.attributes = { - "doc_url": "https://docs.redhat.com/ocp/overview", - "link": "https://docs.redhat.com/ocp/overview", - } - mock_result.model_dump = mocker.Mock( - return_value={ - "file_id": "doc-ocp-1", - "filename": "openshift-docs.txt", - "score": 0.92, - "text": "OpenShift is a Kubernetes distribution by Red Hat.", - "attributes": { - "doc_url": "https://docs.redhat.com/ocp/overview", - }, - } - ) - mock_tool_output.results = [mock_result] - - mock_message = mocker.MagicMock() - mock_message.type = "message" - mock_message.role = "assistant" - mock_message.content = ( - "Based on the documentation, OpenShift is a Kubernetes distribution." + tool_run_result = create_file_search_agent_run_result( + mocker, + content=("Based on the documentation, OpenShift is a Kubernetes distribution."), + response_id="response-tool-rag", + queries=["What is OpenShift?"], + results=[ + { + "text": "OpenShift is a Kubernetes distribution by Red Hat.", + "score": 0.92, + "attributes": { + "doc_url": "https://docs.redhat.com/ocp/overview", + "title": "openshift-docs.txt", + "document_id": "doc-ocp-1", + }, + } + ], + input_tokens=60, + output_tokens=25, ) - mock_message.refusal = None - - mock_response.output = [mock_tool_output, mock_message] - mock_response.stop_reason = "end_turn" - mock_response.tool_calls = [] - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 60 - mock_usage.output_tokens = 25 - mock_response.usage = mock_usage - mock_client.responses.create.return_value = mock_response + mock_client.query_agent.run.return_value = tool_run_result mock_holder_class.return_value.get_client.return_value = mock_client yield mock_client @@ -317,7 +297,7 @@ async def test_query_byok_inline_rag_injects_context( Verifies: - vector_io.query is called for BYOK inline RAG - - RAG context is injected into the responses.create input + - RAG context is injected into the agent prompt - Response includes RAG chunks from inline sources """ _ = byok_config @@ -353,13 +333,10 @@ async def test_query_byok_inline_rag_injects_context( call_kwargs = mock_byok_client.vector_io.query.call_args.kwargs assert call_kwargs["query"] == "What is OpenShift?" - # Verify RAG context was injected into responses.create input - # Use call_args_list[0] — the first call is the main query; - # a second call may follow for topic summary generation. - create_kwargs = mock_byok_client.responses.create.call_args_list[0].kwargs - input_text = create_kwargs["input"] - assert "file_search found" in input_text - assert "OpenShift is a Kubernetes distribution" in input_text + # Verify RAG context was injected into the agent prompt + prompt = mock_byok_client.query_agent.run.call_args.args[0] + assert "file_search found" in prompt + assert "OpenShift is a Kubernetes distribution" in prompt # Verify RAG chunks are included in the response assert response.rag_chunks is not None @@ -802,47 +779,27 @@ async def test_query_byok_combined_inline_and_tool_rag( # pylint: disable=too-m mock_list_result.data = [mock_vector_store] mock_client.vector_stores.list.return_value = mock_list_result - # Response includes file_search_call (tool RAG result) - mock_response = mocker.MagicMock(spec=OpenAIResponseObject) - mock_response.id = "response-combined" - - mock_tool_output = mocker.MagicMock() - mock_tool_output.type = "file_search_call" - mock_tool_output.id = "call-fs-combined" - mock_tool_output.queries = ["What is OpenShift?"] - mock_tool_output.status = "completed" - - mock_result = mocker.MagicMock() - mock_result.file_id = "doc-tool-1" - mock_result.filename = "tool-doc.txt" - mock_result.score = 0.90 - mock_result.text = "Tool-based RAG result about OpenShift." - mock_result.attributes = {"doc_url": "https://example.com/tool-doc"} - mock_result.model_dump = mocker.Mock( - return_value={ - "file_id": "doc-tool-1", - "filename": "tool-doc.txt", - "score": 0.90, - "text": "Tool-based RAG result about OpenShift.", - "attributes": {"doc_url": "https://example.com/tool-doc"}, - } + # Agent run includes file_search tool RAG result + combined_run_result = create_file_search_agent_run_result( + mocker, + content="Combined answer from inline and tool RAG.", + response_id="response-combined", + queries=["What is OpenShift?"], + results=[ + { + "text": "Tool-based RAG result about OpenShift.", + "score": 0.90, + "attributes": { + "doc_url": "https://example.com/tool-doc", + "title": "tool-doc.txt", + "document_id": "doc-tool-1", + }, + } + ], + input_tokens=80, + output_tokens=30, ) - mock_tool_output.results = [mock_result] - - mock_message = mocker.MagicMock() - mock_message.type = "message" - mock_message.role = "assistant" - mock_message.content = "Combined answer from inline and tool RAG." - mock_message.refusal = None - - mock_response.output = [mock_tool_output, mock_message] - mock_response.stop_reason = "end_turn" - mock_response.tool_calls = [] - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 80 - mock_usage.output_tokens = 30 - mock_response.usage = mock_usage - mock_client.responses.create.return_value = mock_response + mock_client.query_agent.run.return_value = combined_run_result mock_holder_class.return_value.get_client.return_value = mock_client diff --git a/tests/integration/endpoints/test_query_integration.py b/tests/integration/endpoints/test_query_integration.py index 725f01d25..bf4f0df55 100644 --- a/tests/integration/endpoints/test_query_integration.py +++ b/tests/integration/endpoints/test_query_integration.py @@ -7,7 +7,6 @@ import pytest from fastapi import HTTPException, Request, status -from llama_stack_api.openai_responses import OpenAIResponseObject from llama_stack_client import APIConnectionError from pytest_mock import AsyncMockType, MockerFixture from sqlalchemy.orm import Session @@ -25,7 +24,10 @@ from tests.integration.conftest import ( TEST_CONVERSATION_ID, TEST_NON_EXISTENT_ID, - create_mock_llm_response, + create_file_search_agent_run_result, + create_mcp_list_tools_agent_run_result, + create_multi_tool_agent_run_result, + set_query_agent_run, ) # File-specific test constants @@ -115,7 +117,7 @@ async def test_query_v2_endpoint_handles_connection_error( """ _ = test_config - mock_llama_stack_client.responses.create.side_effect = APIConnectionError( + mock_llama_stack_client.query_agent.run.side_effect = APIConnectionError( request=mocker.Mock() ) @@ -453,50 +455,24 @@ async def test_query_v2_endpoint_with_tool_calls( """ _ = test_config - mock_response = mocker.MagicMock(spec=OpenAIResponseObject) - mock_response.id = "response-789" - - mock_tool_output = mocker.MagicMock() - mock_tool_output.type = "file_search_call" - mock_tool_output.id = "call-1" - mock_tool_output.queries = ["What is Ansible"] - mock_tool_output.status = "completed" - mock_result = mocker.MagicMock() - mock_result.file_id = "doc-1" - mock_result.filename = "ansible-docs.txt" - mock_result.score = 0.95 - mock_result.text = "Ansible is an open-source automation tool..." - mock_result.attributes = { - "doc_url": "https://example.com/ansible-docs.txt", - "link": "https://example.com/ansible-docs.txt", - } - mock_result.model_dump = mocker.Mock( - return_value={ - "file_id": "doc-1", - "filename": "ansible-docs.txt", - "score": 0.95, - "text": "Ansible is an open-source automation tool...", - "attributes": { - "doc_url": "https://example.com/ansible-docs.txt", - "link": "https://example.com/ansible-docs.txt", - }, - } + tool_run_result = create_file_search_agent_run_result( + mocker, + content="Based on the documentation, Ansible is...", + response_id="response-789", + queries=["What is Ansible"], + results=[ + { + "text": "Ansible is an open-source automation tool...", + "score": 0.95, + "attributes": { + "doc_url": "https://example.com/ansible-docs.txt", + "title": "ansible-docs.txt", + "document_id": "doc-1", + }, + } + ], ) - mock_tool_output.results = [mock_result] - - mock_message_output = mocker.MagicMock() - mock_message_output.type = "message" - mock_message_output.role = "assistant" - mock_message_output.content = "Based on the documentation, Ansible is..." - - mock_response.output = [mock_tool_output, mock_message_output] - mock_response.stop_reason = "end_turn" - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 10 - mock_usage.output_tokens = 5 - mock_response.usage = mock_usage - - mock_llama_stack_client.responses.create.return_value = mock_response + mock_llama_stack_client.query_agent.run.return_value = tool_run_result query_request = QueryRequest(query="What is Ansible?") @@ -537,38 +513,23 @@ async def test_query_v2_endpoint_with_mcp_list_tools( """ _ = test_config - mock_response = mocker.MagicMock() - mock_response.id = "response-mcplist" - - mock_tool1 = mocker.MagicMock() - mock_tool1.name = "list_pods" - mock_tool1.description = "List Kubernetes pods" - mock_tool1.input_schema = {"type": "object", "properties": {}} - - mock_tool2 = mocker.MagicMock() - mock_tool2.name = "get_deployment" - mock_tool2.description = "Get Kubernetes deployment" - mock_tool2.input_schema = {"type": "object", "properties": {}} - - mock_mcp_list = mocker.MagicMock() - mock_mcp_list.type = "mcp_list_tools" - mock_mcp_list.id = "mcplist-101" - mock_mcp_list.server_label = "kubernetes-server" - mock_mcp_list.tools = [mock_tool1, mock_tool2] - - mock_message = mocker.MagicMock() - mock_message.type = "message" - mock_message.role = "assistant" - mock_message.content = "Available tools: list_pods, get_deployment" - - mock_response.output = [mock_mcp_list, mock_message] - mock_response.tool_calls = [] - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 15 - mock_usage.output_tokens = 20 - mock_response.usage = mock_usage - - mock_llama_stack_client.responses.create.return_value = mock_response + mcp_run_result = create_mcp_list_tools_agent_run_result( + mocker, + content="Available tools: list_pods, get_deployment", + tools=[ + { + "name": "list_pods", + "description": "List Kubernetes pods", + "input_schema": {"type": "object", "properties": {}}, + }, + { + "name": "get_deployment", + "description": "Get Kubernetes deployment", + "input_schema": {"type": "object", "properties": {}}, + }, + ], + ) + mock_llama_stack_client.query_agent.run.return_value = mcp_run_result query_request = QueryRequest(query="What tools are available?") @@ -609,37 +570,9 @@ async def test_query_v2_endpoint_with_multiple_tool_types( """ _ = test_config - mock_response = mocker.MagicMock() - mock_response.id = "response-multi" - - mock_file_search = mocker.MagicMock() - mock_file_search.type = "file_search_call" - mock_file_search.id = "search-1" - mock_file_search.queries = ["Kubernetes deployment"] - mock_file_search.status = "completed" - mock_file_search.results = [] - - mock_function = mocker.MagicMock() - mock_function.type = "function_call" - mock_function.id = "func-2" - mock_function.call_id = "func-2" - mock_function.name = "calculate" - mock_function.arguments = '{"operation": "sum"}' - mock_function.status = "completed" - - mock_message = mocker.MagicMock() - mock_message.type = "message" - mock_message.role = "assistant" - mock_message.content = "Based on documentation and calculations..." - - mock_response.output = [mock_file_search, mock_function, mock_message] - mock_response.tool_calls = [] - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 40 - mock_usage.output_tokens = 60 - mock_response.usage = mock_usage - - mock_llama_stack_client.responses.create.return_value = mock_response + mock_llama_stack_client.query_agent.run.return_value = ( + create_multi_tool_agent_run_result(mocker) + ) query_request = QueryRequest(query="Search docs and calculate deployment replicas") @@ -711,9 +644,9 @@ async def test_query_v2_endpoint_bypasses_tools_when_no_tools_true( assert response.conversation_id is not None assert response.response is not None - # Verify NO tools were passed to Llama Stack (despite vector stores being available) - call_kwargs = mock_llama_stack_client.responses.create.call_args.kwargs - assert call_kwargs.get("tools") is None + # Verify NO tools were passed to the agent (despite vector stores being available) + responses_params = mock_llama_stack_client.build_agent_mock.call_args[0][1] + assert responses_params.tools is None @pytest.mark.asyncio @@ -770,11 +703,11 @@ async def test_query_v2_endpoint_uses_tools_when_available( assert response.conversation_id is not None assert response.response is not None - # Verify tools were passed to Llama Stack (real tool preparation logic ran) - call_kwargs = mock_llama_stack_client.responses.create.call_args_list[0].kwargs - assert call_kwargs.get("tools") is not None - assert len(call_kwargs["tools"]) > 0 - assert any(tool.get("type") == "file_search" for tool in call_kwargs["tools"]) + # Verify tools were passed to the agent (real tool preparation logic ran) + responses_params = mock_llama_stack_client.build_agent_mock.call_args[0][1] + assert responses_params.tools is not None + assert len(responses_params.tools) > 0 + assert any(tool.type == "file_search" for tool in responses_params.tools) # ========================================== @@ -876,16 +809,14 @@ async def test_query_v2_endpoint_updates_existing_conversation( original_topic = existing_conversation.topic_summary original_count = existing_conversation.message_count - # Create a proper mock response with all required attributes - mock_response = create_mock_llm_response( + set_query_agent_run( + mock_llama_stack_client, mocker, content="", + response_id=EXISTING_CONV_ID, input_tokens=10, output_tokens=5, ) - mock_response.id = EXISTING_CONV_ID - mock_response.output = [] # Override to empty for this test - mock_llama_stack_client.responses.create.return_value = mock_response query_request = QueryRequest(query="Tell me more", conversation_id=EXISTING_CONV_ID) @@ -1110,17 +1041,14 @@ async def test_query_v2_endpoint_with_shield_violation( """ _ = test_config - # Configure Llama Stack mock to return response with violation - mock_response = create_mock_llm_response( + set_query_agent_run( + mock_llama_stack_client, mocker, content="I cannot respond to this request", - refusal="Content violates safety policy", + response_id="response-violation", input_tokens=10, output_tokens=5, ) - mock_response.id = "response-violation" - - mock_llama_stack_client.responses.create.return_value = mock_response query_request = QueryRequest(query="Inappropriate query") @@ -1182,13 +1110,10 @@ async def test_query_v2_endpoint_without_shields( assert response.conversation_id is not None assert response.response is not None - # Verify extra_body was not included (or guardrails is empty) - call_kwargs = mock_llama_stack_client.responses.create.call_args.kwargs - if "extra_body" in call_kwargs: - assert ( - "guardrails" not in call_kwargs["extra_body"] - or not call_kwargs["extra_body"]["guardrails"] - ) + # Verify responses params passed to the agent do not include guardrails + responses_params = mock_llama_stack_client.build_agent_mock.call_args[0][1] + dumped_params = responses_params.model_dump(exclude_none=True) + assert "guardrails" not in dumped_params @pytest.mark.asyncio @@ -1217,17 +1142,14 @@ async def test_query_v2_endpoint_handles_empty_llm_response( """ _ = test_config - # Create a response with truly empty output array (no assistant messages) - mock_response = create_mock_llm_response( + set_query_agent_run( + mock_llama_stack_client, mocker, content="", + response_id="response-empty", input_tokens=10, output_tokens=0, ) - mock_response.id = "response-empty" - mock_response.output = [] # Override to test truly empty response - - mock_llama_stack_client.responses.create.return_value = mock_response query_request = QueryRequest(query="What is Ansible?") @@ -1276,16 +1198,14 @@ async def test_query_v2_endpoint_quota_integration( _ = test_config _ = patch_db_session - mock_response = create_mock_llm_response( + set_query_agent_run( + mock_llama_stack_client, mocker, content="", + response_id="response-quota", input_tokens=100, output_tokens=50, ) - mock_response.id = "response-quota" - mock_response.output = [] # Override to empty for this test - - mock_llama_stack_client.responses.create.return_value = mock_response mock_consume = mocker.spy(app.endpoints.query, "consume_query_tokens") _ = mocker.spy(app.endpoints.query, "get_available_quotas") @@ -1513,15 +1433,14 @@ async def test_query_v2_endpoint_uses_conversation_history_model( patch_db_session.add(existing_conv) patch_db_session.commit() - mock_response = create_mock_llm_response( + set_query_agent_run( + mock_llama_stack_client, mocker, content="", + response_id=EXISTING_CONV_ID, input_tokens=10, output_tokens=5, ) - mock_response.id = EXISTING_CONV_ID - mock_response.output = [] # Override to empty for this test - mock_llama_stack_client.responses.create.return_value = mock_response query_request = QueryRequest(query="Tell me more", conversation_id=EXISTING_CONV_ID) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 5a6b43684..4c67fe64a 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -127,7 +127,7 @@ async def test_successful_query_no_conversation( return_value=mock_client_holder, ) mocker.patch( - "app.endpoints.query.get_topic_summary", + "app.endpoints.query.maybe_get_topic_summary", new=mocker.AsyncMock(return_value=None), ) mocker.patch( @@ -153,11 +153,14 @@ async def test_successful_query_no_conversation( "Kubernetes is a container orchestration platform" ) - async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: + async def mock_retrieve_agent_response( + *_args: Any, **_kwargs: Any + ) -> TurnSummary: return mock_turn_summary mocker.patch( - "app.endpoints.query.retrieve_response", side_effect=mock_retrieve_response + "app.endpoints.query.retrieve_agent_response", + side_effect=mock_retrieve_agent_response, ) mocker.patch( @@ -245,7 +248,7 @@ async def test_query_merges_inline_and_tool_rag_chunks_and_documents( mock_turn_summary.referenced_documents = [tool_doc] mocker.patch( - "app.endpoints.query.retrieve_response", + "app.endpoints.query.retrieve_agent_response", new=mocker.AsyncMock(return_value=mock_turn_summary), ) mocker.patch("app.endpoints.query.store_query_results") @@ -317,7 +320,7 @@ async def test_successful_query_with_conversation( new=mocker.AsyncMock(return_value=ShieldModerationPassed()), ) mocker.patch( - "app.endpoints.query.retrieve_response", + "app.endpoints.query.retrieve_agent_response", new=mocker.AsyncMock(return_value=TurnSummary()), ) mocker.patch("app.endpoints.query.store_query_results") @@ -373,7 +376,7 @@ async def test_query_with_attachments( return_value=mock_client_holder, ) mocker.patch( - "app.endpoints.query.get_topic_summary", + "app.endpoints.query.maybe_get_topic_summary", new=mocker.AsyncMock(return_value=None), ) mocker.patch( @@ -394,11 +397,14 @@ async def test_query_with_attachments( new=mocker.AsyncMock(return_value=mock_responses_params), ) - async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: + async def mock_retrieve_agent_response( + *_args: Any, **_kwargs: Any + ) -> TurnSummary: return TurnSummary() mocker.patch( - "app.endpoints.query.retrieve_response", side_effect=mock_retrieve_response + "app.endpoints.query.retrieve_agent_response", + side_effect=mock_retrieve_agent_response, ) mocker.patch( "app.endpoints.query.normalize_conversation_id", return_value="123" @@ -459,11 +465,11 @@ async def test_query_with_topic_summary( ) mocker.patch( - "app.endpoints.query.retrieve_response", + "app.endpoints.query.retrieve_agent_response", new=mocker.AsyncMock(return_value=TurnSummary()), ) - mock_get_topic_summary = mocker.patch( - "app.endpoints.query.get_topic_summary", + mock_maybe_get_topic_summary = mocker.patch( + "app.endpoints.query.maybe_get_topic_summary", new=mocker.AsyncMock(return_value="Topic: Kubernetes"), ) mocker.patch( @@ -480,7 +486,7 @@ async def test_query_with_topic_summary( mcp_headers={}, ) - mock_get_topic_summary.assert_called_once() + mock_maybe_get_topic_summary.assert_called_once() @pytest.mark.asyncio async def test_query_azure_token_refresh( @@ -511,7 +517,7 @@ async def test_query_azure_token_refresh( return_value=mock_client_holder, ) mocker.patch( - "app.endpoints.query.get_topic_summary", + "app.endpoints.query.maybe_get_topic_summary", new=mocker.AsyncMock(return_value=None), ) mocker.patch( @@ -545,11 +551,14 @@ async def test_query_azure_token_refresh( return_value=mock_updated_client ) - async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: + async def mock_retrieve_agent_response( + *_args: Any, **_kwargs: Any + ) -> TurnSummary: return TurnSummary() mocker.patch( - "app.endpoints.query.retrieve_response", side_effect=mock_retrieve_response + "app.endpoints.query.retrieve_agent_response", + side_effect=mock_retrieve_agent_response, ) mocker.patch( "app.endpoints.query.normalize_conversation_id", return_value="123" diff --git a/tests/unit/utils/agents/test_tool_processor.py b/tests/unit/utils/agents/test_tool_processor.py index 4bc33484a..2fb83a9cf 100644 --- a/tests/unit/utils/agents/test_tool_processor.py +++ b/tests/unit/utils/agents/test_tool_processor.py @@ -430,7 +430,8 @@ def test_list_tools_success(self) -> None: content={ "tools": [ {"name": "tool_a", "description": "does things"}, - ] + ], + "error": None, }, ) @@ -447,7 +448,7 @@ def test_list_tools_error(self) -> None: part = NativeToolReturnPart( tool_name=f"{MCPServerTool.kind}:srv", tool_call_id="mcp-list-err", - content={"error": "unavailable"}, + content={"tools": [], "error": "unavailable"}, ) result = summarize_mcp_list_tools_result(part, tool_round=1) @@ -460,12 +461,12 @@ def test_mcp_call_success_and_error(self) -> None: success_part = NativeToolReturnPart( tool_name=f"{MCPServerTool.kind}:srv", tool_call_id="mcp-call-ok", - content={"output": "done"}, + content={"output": "done", "error": None}, ) error_part = NativeToolReturnPart( tool_name=f"{MCPServerTool.kind}:srv", tool_call_id="mcp-call-err", - content={"error": "failed"}, + content={"output": None, "error": "failed"}, ) success = summarize_mcp_call_result(success_part, tool_round=2) @@ -477,16 +478,16 @@ def test_mcp_call_success_and_error(self) -> None: assert error.content == "failed" def test_mcp_tool_result_dispatches_by_shape(self) -> None: - """Test summarize_mcp_tool_result routes list-tools vs call payloads.""" + """Test summarize_mcp_tool_result routes pydantic-ai MCP return shapes.""" list_part = NativeToolReturnPart( tool_name=f"{MCPServerTool.kind}:srv", tool_call_id="dispatch-list", - content={"tools": []}, + content={"tools": [], "error": None}, ) call_part = NativeToolReturnPart( tool_name=f"{MCPServerTool.kind}:srv", tool_call_id="dispatch-call", - content={"output": "ok"}, + content={"output": "ok", "error": None}, ) list_result = summarize_mcp_tool_result(list_part, tool_round=1) @@ -495,6 +496,19 @@ def test_mcp_tool_result_dispatches_by_shape(self) -> None: assert list_result.type == "mcp_list_tools" assert call_result.type == "mcp_call" + def test_mcp_call_with_error_field_not_routed_to_list_tools(self) -> None: + """Test MCP call returns are not misrouted when error is always present.""" + call_part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="dispatch-call-only-error", + content={"output": "ok", "error": None}, + ) + + result = summarize_mcp_tool_result(call_part, tool_round=1) + + assert result.type == "mcp_call" + assert result.content == "ok" + class TestSummarizeFileSearchResult: """Tests for summarize_file_search_result.""" @@ -571,7 +585,7 @@ def test_records_labeled_mcp_result(self, turn_state: AgentTurnAccumulator) -> N part = NativeToolReturnPart( tool_name=f"{MCPServerTool.kind}:srv", tool_call_id="mcp-labeled", - content={"output": "labeled-output"}, + content={"output": "labeled-output", "error": None}, ) result = process_native_tool_result(turn_state, part) @@ -591,7 +605,7 @@ def test_records_web_search_and_mcp_results( mcp_part = NativeToolReturnPart( tool_name=f"{MCPServerTool.kind}:srv", tool_call_id="mcp-process", - content={"output": "mcp-output"}, + content={"output": "mcp-output", "error": None}, ) web_result = process_native_tool_result(turn_state, web_part)