Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from openai._exceptions import (
APIStatusError as OpenAIAPIStatusError,
)
from typing_extensions import deprecated

from authentication import get_auth_dependency
from authentication.interface import AuthTuple
Expand Down Expand Up @@ -42,6 +43,7 @@
from models.common.responses.types import ResponseInput
from models.common.turn_summary import TurnSummary
from models.config import Action
from utils.agents.query import retrieve_agent_response
from utils.conversation_compaction import (
apply_compaction_blocking,
configured_conversation_cache,
Expand All @@ -68,7 +70,7 @@
build_turn_summary,
deduplicate_referenced_documents,
extract_vector_store_ids_from_tools,
get_topic_summary,
maybe_get_topic_summary,
prepare_responses_params,
)
from utils.shields import run_shield_moderation, validate_shield_ids_override
Expand Down Expand Up @@ -226,12 +228,12 @@ async def query_endpoint_handler(
client = await AsyncLlamaStackClientHolder().update_azure_token()

# Retrieve response using Responses API
turn_summary = await retrieve_response(
turn_summary = await retrieve_agent_response(
client,
responses_params,
moderation_result,
endpoint_path,
original_input=compaction.original_input if compaction.compacted else None,
compaction.original_input if compaction.compacted else None,
)

if moderation_result.decision == "passed":
Expand All @@ -249,13 +251,15 @@ async def query_endpoint_handler(
)

# Get topic summary for new conversation
if not user_conversation and query_request.generate_topic_summary:
logger.debug("Generating topic summary for new conversation")
topic_summary = await get_topic_summary(
query_request.query, client, responses_params.model
)
else:
topic_summary = None
should_generate = not user_conversation and bool(
query_request.generate_topic_summary
)
topic_summary = await maybe_get_topic_summary(
generate_topic_summary=should_generate,
input_text=query_request.query,
client=client,
model_id=responses_params.model,
)

logger.info("Consuming tokens")
consume_query_tokens(
Expand Down Expand Up @@ -301,6 +305,10 @@ async def query_endpoint_handler(
)


@deprecated(
"Deprecated in favor of utils.agents.query.retrieve_agent_response.",
stacklevel=2,
)
async def retrieve_response(
client: AsyncLlamaStackClient,
responses_params: ResponsesApiParams,
Expand Down
5 changes: 4 additions & 1 deletion src/utils/agents/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from enum import Enum
from typing import TypeAlias, cast
from typing import Optional, TypeAlias, cast

from fastapi import HTTPException
from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient
Expand Down Expand Up @@ -33,6 +33,7 @@
from models.common.agents import AgentTurnAccumulator
from models.common.moderation import ShieldModerationResult
from models.common.responses.responses_api_params import ResponsesApiParams
from models.common.responses.types import ResponseInput
from models.common.turn_summary import TurnSummary
from utils.agents.tool_processor import (
process_function_tool_call,
Expand Down Expand Up @@ -281,6 +282,7 @@ async def retrieve_agent_response(
responses_params: ResponsesApiParams,
moderation_result: ShieldModerationResult,
endpoint_path: str,
_original_input: Optional[ResponseInput] = None,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compacted turn persistence is disconnected and will be part of a separate PR as it will require additional logic.

) -> TurnSummary:
"""Retrieve a turn summary from a blocking agent run.

Expand All @@ -291,6 +293,7 @@ async def retrieve_agent_response(
responses_params: Prepared Responses API parameters.
moderation_result: Shield moderation outcome for the turn.
endpoint_path: Endpoint path used for metric labeling.
_original_input: Original user input before the explicit-input rewrite.

Returns:
Turn summary for the completed agent run.
Expand Down
2 changes: 1 addition & 1 deletion src/utils/agents/tool_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def summarize_mcp_tool_result(
Tool result summary in LCS turn-summary format.
"""
content = cast(dict[str, Any], part.content)
if "tools" in content or "error" in content:
if "tools" in content:
return summarize_mcp_list_tools_result(part, tool_round)
return summarize_mcp_call_result(part, tool_round)

Expand Down
243 changes: 235 additions & 8 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,20 @@
import pytest
from fastapi import Request, Response
from fastapi.testclient import TestClient
from pytest_mock import MockerFixture
from llama_stack_api.openai_responses import OpenAIResponseObject
from llama_stack_client.types import VersionInfo
from pydantic_ai.messages import (
ModelRequest,
ModelResponse,
NativeToolCallPart,
NativeToolReturnPart,
TextPart,
ToolCallPart,
ToolReturnPart,
)
from pydantic_ai.native_tools import FileSearchTool, MCPServerTool
from pydantic_ai.usage import RunUsage
from pytest_mock import AsyncMockType, MockerFixture
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, sessionmaker
Expand Down Expand Up @@ -70,9 +83,6 @@ def create_mock_llm_response( # pylint: disable=too-many-arguments,too-many-pos
Returns:
Mock LLM response object with the specified configuration.
"""
# pylint: disable=import-outside-toplevel
from llama_stack_api.openai_responses import OpenAIResponseObject

mock_response = mocker.MagicMock(spec=OpenAIResponseObject)
mock_response.id = "response-123"

Expand Down Expand Up @@ -154,6 +164,223 @@ def create_mock_tool_call(
return mock_tool_call


def create_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments
mocker: MockerFixture,
*,
content: str = "This is a test response about Ansible.",
response_id: str = "response-123",
input_tokens: int = 10,
output_tokens: int = 5,
model_response: Any = None,
new_messages: Optional[list[Any]] = None,
) -> Any:
"""Create a mock AgentRunResult wired for retrieve_agent_response.

Uses real pydantic-ai message types so build_turn_summary_from_agent_run
exercises the same path as production agent runs.

Args:
mocker: pytest-mock fixture.
content: Assistant text content for the run.
response_id: Provider response identifier.
input_tokens: Input token count for the run.
output_tokens: Output token count for the run.
model_response: Optional pre-built ModelResponse.
new_messages: Optional message sequence returned by new_messages().

Returns:
Mock AgentRunResult compatible with build_turn_summary_from_agent_run.
"""
if model_response is None:
parts = [TextPart(content)] if content else []
model_response = ModelResponse(
parts=parts,
finish_reason="stop",
provider_response_id=response_id,
)

messages = new_messages if new_messages is not None else [model_response]
run_result = mocker.MagicMock()
run_result.response = model_response
run_result.usage = RunUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
requests=1,
)
run_result.new_messages.return_value = messages
return run_result


def create_file_search_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments
mocker: MockerFixture,
*,
content: str,
response_id: str = "response-tool-rag",
queries: Optional[list[str]] = None,
results: Optional[list[dict[str, Any]]] = None,
input_tokens: int = 10,
output_tokens: int = 5,
) -> Any:
"""Create an AgentRunResult containing a native file_search tool call."""
call = NativeToolCallPart(
tool_name=FileSearchTool.kind,
args={"queries": queries or ["test query"]},
tool_call_id="call-fs-1",
)
return_part = NativeToolReturnPart(
tool_name=FileSearchTool.kind,
tool_call_id="call-fs-1",
content={
"status": "success",
"results": results or [],
},
)
model_response = ModelResponse(
parts=[call, return_part, TextPart(content)],
finish_reason="stop",
provider_response_id=response_id,
)
return create_agent_run_result(
mocker,
content=content,
response_id=response_id,
input_tokens=input_tokens,
output_tokens=output_tokens,
model_response=model_response,
)


def create_mcp_list_tools_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments
mocker: MockerFixture,
*,
content: str,
response_id: str = "response-mcplist",
server_label: str = "kubernetes-server",
tools: Optional[list[dict[str, Any]]] = None,
input_tokens: int = 15,
output_tokens: int = 20,
) -> Any:
"""Create an AgentRunResult containing an MCP list-tools native tool call."""
call = NativeToolCallPart(
tool_name=f"{MCPServerTool.kind}:{server_label}",
args={"action": "list_tools"},
tool_call_id="mcplist-101",
)
return_part = NativeToolReturnPart(
tool_name=f"{MCPServerTool.kind}:{server_label}",
tool_call_id="mcplist-101",
content={"tools": tools or []},
)
model_response = ModelResponse(
parts=[call, return_part, TextPart(content)],
finish_reason="stop",
provider_response_id=response_id,
)
return create_agent_run_result(
mocker,
content=content,
response_id=response_id,
input_tokens=input_tokens,
output_tokens=output_tokens,
model_response=model_response,
)


def create_multi_tool_agent_run_result(
mocker: MockerFixture,
*,
content: str = "Based on documentation and calculations...",
response_id: str = "response-multi",
input_tokens: int = 40,
output_tokens: int = 60,
) -> Any:
"""Create an AgentRunResult with file_search and function tool calls."""
file_search_call = NativeToolCallPart(
tool_name=FileSearchTool.kind,
args={"queries": ["Kubernetes deployment"]},
tool_call_id="search-1",
)
file_search_return = NativeToolReturnPart(
tool_name=FileSearchTool.kind,
tool_call_id="search-1",
content={"status": "success", "results": []},
)
function_call = ToolCallPart(
tool_name="calculate",
args={"operation": "sum"},
tool_call_id="func-2",
)
function_return = ToolReturnPart(
tool_name="calculate",
content={"result": 2},
tool_call_id="func-2",
)
model_response = ModelResponse(
parts=[
file_search_call,
file_search_return,
function_call,
TextPart(content),
],
finish_reason="stop",
provider_response_id=response_id,
)
return create_agent_run_result(
mocker,
content=content,
response_id=response_id,
input_tokens=input_tokens,
output_tokens=output_tokens,
model_response=model_response,
new_messages=[model_response, ModelRequest(parts=[function_return])],
)


def set_query_agent_run(
mock_llama_stack_client: AsyncMockType,
mocker: MockerFixture,
**kwargs: Any,
) -> None:
"""Configure mock agent.run return value for /query integration tests."""
mock_llama_stack_client.query_agent.run.return_value = create_agent_run_result(
mocker,
**kwargs,
)


def configure_query_agent_mock(
mocker: MockerFixture,
*,
run_result: Any = None,
run_side_effect: Any = None,
) -> Any:
"""Patch build_agent for /query integration tests and return the mock agent.

Args:
mocker: pytest-mock fixture.
run_result: AgentRunResult returned by agent.run().
run_side_effect: Optional exception side effect for agent.run().

Returns:
Mock agent exposing AsyncMock run().
"""
if run_result is None:
run_result = create_agent_run_result(mocker)

mock_agent = mocker.AsyncMock()
if run_side_effect is not None:
mock_agent.run = mocker.AsyncMock(side_effect=run_side_effect)
else:
mock_agent.run = mocker.AsyncMock(return_value=run_result)

build_agent_mock = mocker.patch(
"utils.agents.query.build_agent",
return_value=mock_agent,
)
mock_agent.build_agent_mock = build_agent_mock
return mock_agent


# ==========================================
# Fixtures
# ==========================================
Expand Down Expand Up @@ -448,10 +675,6 @@ def mock_llama_stack_client_fixture(
Yields:
mock_client: The mocked Llama Stack client instance.
"""
# pylint: disable=import-outside-toplevel
from llama_stack_api.openai_responses import OpenAIResponseObject
from llama_stack_client.types import VersionInfo

# Patch AsyncLlamaStackClientHolder at multiple import locations
# This ensures the mock is active both during app startup (app.main)
# and during endpoint execution (query, conversations_v1, responses, etc.)
Expand Down Expand Up @@ -484,6 +707,10 @@ def mock_llama_stack_client_fixture(

mock_client.responses.create.return_value = mock_response

mock_agent = configure_query_agent_mock(mocker)
mock_client.query_agent = mock_agent
mock_client.build_agent_mock = mock_agent.build_agent_mock

# Mock models.list
mock_model = mocker.MagicMock()
mock_model.id = "test-provider/test-model"
Expand Down
Loading
Loading