Skip to content

Commit 4ec5a1e

Browse files
committed
Fix typecheck errors and add coverage tests for file search
- Fix typecheck errors in test_google.py by properly checking event types - Remove unnecessary comments - Add tests to cover missing branches in openai.py: - Test openai_include_file_search_results setting (line 1334) - Test file_search with results (line 2503) - Test round-trip status update (line 1697)
1 parent 0d658dc commit 4ec5a1e

File tree

2 files changed

+110
-14
lines changed

2 files changed

+110
-14
lines changed

tests/models/test_google.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4115,7 +4115,6 @@ async def test_google_model_file_search_tool_stream(allow_model_requests: None,
41154115

41164116
async def test_google_model_file_search_empty_grounding_chunks():
41174117
"""Test that file search handles empty grounding_chunks gracefully."""
4118-
# Create a response with grounding_metadata but empty grounding_chunks
41194118
response = GenerateContentResponse.model_validate(
41204119
{
41214120
'response_id': 'test-123',
@@ -4132,7 +4131,7 @@ async def test_google_model_file_search_empty_grounding_chunks():
41324131
'parts': [{'text': 'Some text'}],
41334132
},
41344133
'grounding_metadata': {
4135-
'grounding_chunks': [], # Empty grounding_chunks
4134+
'grounding_chunks': [],
41364135
},
41374136
}
41384137
],
@@ -4150,18 +4149,17 @@ async def response_iterator() -> AsyncIterator[GenerateContentResponse]:
41504149
_provider_name='test-provider',
41514150
_provider_url='',
41524151
)
4153-
# Set _file_search_tool_call_id to trigger the code path
41544152
streamed_response._file_search_tool_call_id = 'test-tool-call-id' # pyright: ignore[reportPrivateUsage]
41554153

41564154
events = [event async for event in streamed_response._get_event_iterator()] # pyright: ignore[reportPrivateUsage]
4157-
# Should not crash and should emit text event
41584155
assert len(events) > 0
4159-
assert any(isinstance(event.part, TextPart) if hasattr(event, 'part') else False for event in events)
4156+
assert any(
4157+
isinstance(event, PartStartEvent | PartEndEvent) and isinstance(event.part, TextPart) for event in events
4158+
)
41604159

41614160

41624161
async def test_google_model_file_search_empty_retrieved_contexts():
41634162
"""Test that file search handles empty retrieved_contexts gracefully."""
4164-
# Create a response with grounding_chunks but no retrieved_context
41654163
response = GenerateContentResponse.model_validate(
41664164
{
41674165
'response_id': 'test-123',
@@ -4178,11 +4176,7 @@ async def test_google_model_file_search_empty_retrieved_contexts():
41784176
'parts': [{'text': 'Some text'}],
41794177
},
41804178
'grounding_metadata': {
4181-
'grounding_chunks': [
4182-
{
4183-
# Chunk without retrieved_context
4184-
}
4185-
],
4179+
'grounding_chunks': [{}],
41864180
},
41874181
}
41884182
],
@@ -4200,13 +4194,13 @@ async def response_iterator() -> AsyncIterator[GenerateContentResponse]:
42004194
_provider_name='test-provider',
42014195
_provider_url='',
42024196
)
4203-
# Set _file_search_tool_call_id to trigger the code path
42044197
streamed_response._file_search_tool_call_id = 'test-tool-call-id' # pyright: ignore[reportPrivateUsage]
42054198

42064199
events = [event async for event in streamed_response._get_event_iterator()] # pyright: ignore[reportPrivateUsage]
4207-
# Should not crash and should emit text event
42084200
assert len(events) > 0
4209-
assert any(isinstance(event.part, TextPart) if hasattr(event, 'part') else False for event in events)
4201+
assert any(
4202+
isinstance(event, PartStartEvent | PartEndEvent) and isinstance(event.part, TextPart) for event in events
4203+
)
42104204

42114205

42124206
async def test_cache_point_filtering():

tests/models/test_openai_responses.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7759,3 +7759,105 @@ async def test_openai_responses_model_file_search_tool_stream(allow_model_reques
77597759
if vector_store is not None:
77607760
await async_client.vector_stores.delete(vector_store.id)
77617761
await async_client.close()
7762+
7763+
7764+
async def test_openai_responses_file_search_include_results_setting():
7765+
from pydantic_ai.providers.openai import OpenAIProvider
7766+
7767+
from .mock_openai import MockOpenAIResponses
7768+
7769+
mock_client = MockOpenAIResponses(response=response_message([]))
7770+
model = OpenAIResponsesModel(
7771+
'gpt-4o',
7772+
provider=OpenAIProvider(openai_client=cast(Any, mock_client)),
7773+
model_settings=OpenAIResponsesModelSettings(openai_include_file_search_results=True),
7774+
)
7775+
7776+
async with model.request_stream(
7777+
messages=[],
7778+
model_settings=cast(OpenAIResponsesModelSettings, model.settings or {}),
7779+
model_request_parameters=ModelRequestParameters(),
7780+
):
7781+
pass
7782+
7783+
kwargs_list = get_mock_responses_kwargs(cast(Any, mock_client))
7784+
assert len(kwargs_list) > 0
7785+
kwargs = kwargs_list[0]
7786+
assert 'include' in kwargs
7787+
assert 'file_search_call.results' in kwargs['include']
7788+
7789+
7790+
async def test_openai_responses_file_search_with_results():
7791+
with try_import():
7792+
from openai.types.responses.response_file_search_tool_call import (
7793+
ResponseFileSearchToolCall,
7794+
ResponseFileSearchToolCallResult,
7795+
)
7796+
7797+
from pydantic_ai.models.openai import _map_file_search_tool_call
7798+
7799+
result_obj = ResponseFileSearchToolCallResult(
7800+
document_name='test.txt',
7801+
text='Test content',
7802+
score=0.9,
7803+
)
7804+
file_search_item = ResponseFileSearchToolCall(
7805+
id='fs_test',
7806+
type='file_search_call',
7807+
status='completed',
7808+
queries=['test query'],
7809+
results=[result_obj],
7810+
)
7811+
7812+
call_part, return_part = _map_file_search_tool_call(file_search_item, 'openai')
7813+
assert call_part.tool_name == 'file_search'
7814+
assert return_part.content['status'] == 'completed'
7815+
assert 'results' in return_part.content
7816+
assert len(return_part.content['results']) == 1
7817+
7818+
7819+
async def test_openai_responses_file_search_round_trip_status_update():
7820+
from pydantic_ai.builtin_tools import FileSearchTool
7821+
from pydantic_ai.providers.openai import OpenAIProvider
7822+
7823+
model = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key='test'))
7824+
messages = [
7825+
ModelRequest(
7826+
parts=[UserPromptPart(content='test', timestamp=IsDatetime())],
7827+
run_id=IsStr(),
7828+
),
7829+
ModelResponse(
7830+
parts=[
7831+
BuiltinToolCallPart(
7832+
tool_name=FileSearchTool.kind,
7833+
args={'queries': ['test']},
7834+
tool_call_id='fs_123',
7835+
provider_name='openai',
7836+
),
7837+
BuiltinToolReturnPart(
7838+
tool_name=FileSearchTool.kind,
7839+
content={'status': 'in_progress'},
7840+
tool_call_id='fs_123',
7841+
timestamp=IsDatetime(),
7842+
provider_name='openai',
7843+
),
7844+
],
7845+
usage=RequestUsage(input_tokens=10, output_tokens=5),
7846+
model_name='gpt-4o',
7847+
timestamp=IsDatetime(),
7848+
provider_name='openai',
7849+
finish_reason='stop',
7850+
run_id=IsStr(),
7851+
),
7852+
]
7853+
7854+
_, openai_messages = await model._map_messages( # pyright: ignore[reportPrivateUsage]
7855+
messages,
7856+
model_settings=cast(OpenAIResponsesModelSettings, model.settings or {}),
7857+
model_request_parameters=ModelRequestParameters(),
7858+
)
7859+
7860+
file_search_msgs = [msg for msg in openai_messages if isinstance(msg, dict) and msg.get('type') == 'file_search_call']
7861+
assert len(file_search_msgs) > 0
7862+
file_search_msg = file_search_msgs[0]
7863+
assert file_search_msg.get('status') == 'in_progress'

0 commit comments

Comments
 (0)