diff --git a/google/genai/tests/interactions/__init__.py b/google/genai/tests/interactions/__init__.py index e69de29bb..77e065fee 100644 --- a/google/genai/tests/interactions/__init__.py +++ b/google/genai/tests/interactions/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. diff --git a/google/genai/tests/interactions/api_resources/__init__.py b/google/genai/tests/interactions/api_resources/__init__.py new file mode 100644 index 000000000..77e065fee --- /dev/null +++ b/google/genai/tests/interactions/api_resources/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. diff --git a/google/genai/tests/interactions/api_resources/test_interactions.py b/google/genai/tests/interactions/api_resources/test_interactions.py new file mode 100644 index 000000000..5342a3a88 --- /dev/null +++ b/google/genai/tests/interactions/api_resources/test_interactions.py @@ -0,0 +1,1485 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from tests.utils import assert_matches_type +from google.genai._interactions import GeminiNextGenAPIClient, AsyncGeminiNextGenAPIClient +from google.genai._interactions.types import ( + Interaction, +) + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestInteractions: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_create_overload_1(self, client: GeminiNextGenAPIClient) -> None: + interaction = client.interactions.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_create_with_all_params_overload_1(self, client: GeminiNextGenAPIClient) -> None: + interaction = client.interactions.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + "annotations": [ + { + "type": "url_citation", + "end_index": 0, + "start_index": 0, + "title": "title", + "url": "url", + } + ], + }, + model="gemini-2.5-computer-use-preview-10-2025", + background=True, + generation_config={ + "image_config": { + "aspect_ratio": "1:1", + "image_size": "1K", + }, + "max_output_tokens": 0, + "seed": 0, + "speech_config": [ + { + "language": "language", + "speaker": "speaker", + "voice": "voice", + } + ], + "stop_sequences": ["string"], + "temperature": 0, + "thinking_level": "minimal", + "thinking_summaries": "auto", + "tool_choice": "auto", + "top_p": 0, + }, + previous_interaction_id="previous_interaction_id", + response_format=[ + { + "type": "audio", + "bit_rate": 0, + "delivery": "inline", + "mime_type": "audio/mp3", + "sample_rate": 0, + } + ], + response_mime_type="response_mime_type", + response_modalities=["text"], + service_tier="flex", + store=True, + stream=False, + system_instruction="system_instruction", + tools=[ + { + "type": "function", + "description": "description", + "name": "name", + "parameters": {}, + } + ], + webhook_config={ + "uris": ["string"], + "user_metadata": {"foo": "bar"}, + }, + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_create_overload_1(self, client: GeminiNextGenAPIClient) -> None: + response = client.interactions.with_raw_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ) + + assert response.is_closed is True + interaction = response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_create_overload_1(self, client: GeminiNextGenAPIClient) -> None: + with client.interactions.with_streaming_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ) as response: + assert not response.is_closed + + interaction = response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_create_overload_1(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.interactions.with_raw_response.create( + api_version="", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_create_overload_2(self, client: GeminiNextGenAPIClient) -> None: + interaction_stream = client.interactions.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + stream=True, + ) + interaction_stream.response.close() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_create_with_all_params_overload_2(self, client: GeminiNextGenAPIClient) -> None: + interaction_stream = client.interactions.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + "annotations": [ + { + "type": "url_citation", + "end_index": 0, + "start_index": 0, + "title": "title", + "url": "url", + } + ], + }, + model="gemini-2.5-computer-use-preview-10-2025", + stream=True, + background=True, + generation_config={ + "image_config": { + "aspect_ratio": "1:1", + "image_size": "1K", + }, + "max_output_tokens": 0, + "seed": 0, + "speech_config": [ + { + "language": "language", + "speaker": "speaker", + "voice": "voice", + } + ], + "stop_sequences": ["string"], + "temperature": 0, + "thinking_level": "minimal", + "thinking_summaries": "auto", + "tool_choice": "auto", + "top_p": 0, + }, + previous_interaction_id="previous_interaction_id", + response_format=[ + { + "type": "audio", + "bit_rate": 0, + "delivery": "inline", + "mime_type": "audio/mp3", + "sample_rate": 0, + } + ], + response_mime_type="response_mime_type", + response_modalities=["text"], + service_tier="flex", + store=True, + system_instruction="system_instruction", + tools=[ + { + "type": "function", + "description": "description", + "name": "name", + "parameters": {}, + } + ], + webhook_config={ + "uris": ["string"], + "user_metadata": {"foo": "bar"}, + }, + ) + interaction_stream.response.close() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_create_overload_2(self, client: GeminiNextGenAPIClient) -> None: + response = client.interactions.with_raw_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + stream=True, + ) + + stream = response.parse() + stream.close() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_create_overload_2(self, client: GeminiNextGenAPIClient) -> None: + with client.interactions.with_streaming_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + stream=True, + ) as response: + assert not response.is_closed + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_create_overload_2(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.interactions.with_raw_response.create( + api_version="", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + stream=True, + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_create_overload_3(self, client: GeminiNextGenAPIClient) -> None: + interaction = client.interactions.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_create_with_all_params_overload_3(self, client: GeminiNextGenAPIClient) -> None: + interaction = client.interactions.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + "annotations": [ + { + "type": "url_citation", + "end_index": 0, + "start_index": 0, + "title": "title", + "url": "url", + } + ], + }, + agent_config={"type": "dynamic"}, + background=True, + previous_interaction_id="previous_interaction_id", + response_format=[ + { + "type": "audio", + "bit_rate": 0, + "delivery": "inline", + "mime_type": "audio/mp3", + "sample_rate": 0, + } + ], + response_mime_type="response_mime_type", + response_modalities=["text"], + service_tier="flex", + store=True, + stream=False, + system_instruction="system_instruction", + tools=[ + { + "type": "function", + "description": "description", + "name": "name", + "parameters": {}, + } + ], + webhook_config={ + "uris": ["string"], + "user_metadata": {"foo": "bar"}, + }, + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_create_overload_3(self, client: GeminiNextGenAPIClient) -> None: + response = client.interactions.with_raw_response.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + ) + + assert response.is_closed is True + interaction = response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_create_overload_3(self, client: GeminiNextGenAPIClient) -> None: + with client.interactions.with_streaming_response.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + ) as response: + assert not response.is_closed + + interaction = response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_create_overload_3(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.interactions.with_raw_response.create( + api_version="", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_create_overload_4(self, client: GeminiNextGenAPIClient) -> None: + interaction_stream = client.interactions.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + stream=True, + ) + interaction_stream.response.close() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_create_with_all_params_overload_4(self, client: GeminiNextGenAPIClient) -> None: + interaction_stream = client.interactions.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + "annotations": [ + { + "type": "url_citation", + "end_index": 0, + "start_index": 0, + "title": "title", + "url": "url", + } + ], + }, + stream=True, + agent_config={"type": "dynamic"}, + background=True, + previous_interaction_id="previous_interaction_id", + response_format=[ + { + "type": "audio", + "bit_rate": 0, + "delivery": "inline", + "mime_type": "audio/mp3", + "sample_rate": 0, + } + ], + response_mime_type="response_mime_type", + response_modalities=["text"], + service_tier="flex", + store=True, + system_instruction="system_instruction", + tools=[ + { + "type": "function", + "description": "description", + "name": "name", + "parameters": {}, + } + ], + webhook_config={ + "uris": ["string"], + "user_metadata": {"foo": "bar"}, + }, + ) + interaction_stream.response.close() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_create_overload_4(self, client: GeminiNextGenAPIClient) -> None: + response = client.interactions.with_raw_response.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + stream=True, + ) + + stream = response.parse() + stream.close() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_create_overload_4(self, client: GeminiNextGenAPIClient) -> None: + with client.interactions.with_streaming_response.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + stream=True, + ) as response: + assert not response.is_closed + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_create_overload_4(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.interactions.with_raw_response.create( + api_version="", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + stream=True, + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_delete(self, client: GeminiNextGenAPIClient) -> None: + interaction = client.interactions.delete( + id="id", + api_version="api_version", + ) + assert_matches_type(object, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_delete(self, client: GeminiNextGenAPIClient) -> None: + response = client.interactions.with_raw_response.delete( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + interaction = response.parse() + assert_matches_type(object, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_delete(self, client: GeminiNextGenAPIClient) -> None: + with client.interactions.with_streaming_response.delete( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + interaction = response.parse() + assert_matches_type(object, interaction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_delete(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.interactions.with_raw_response.delete( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.interactions.with_raw_response.delete( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_cancel(self, client: GeminiNextGenAPIClient) -> None: + interaction = client.interactions.cancel( + id="id", + api_version="api_version", + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_cancel(self, client: GeminiNextGenAPIClient) -> None: + response = client.interactions.with_raw_response.cancel( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + interaction = response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_cancel(self, client: GeminiNextGenAPIClient) -> None: + with client.interactions.with_streaming_response.cancel( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + interaction = response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_cancel(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.interactions.with_raw_response.cancel( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.interactions.with_raw_response.cancel( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_get_overload_1(self, client: GeminiNextGenAPIClient) -> None: + interaction = client.interactions.get( + id="id", + api_version="api_version", + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_get_with_all_params_overload_1(self, client: GeminiNextGenAPIClient) -> None: + interaction = client.interactions.get( + id="id", + api_version="api_version", + include_input=True, + last_event_id="last_event_id", + stream=False, + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_get_overload_1(self, client: GeminiNextGenAPIClient) -> None: + response = client.interactions.with_raw_response.get( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + interaction = response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_get_overload_1(self, client: GeminiNextGenAPIClient) -> None: + with client.interactions.with_streaming_response.get( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + interaction = response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_get_overload_1(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.interactions.with_raw_response.get( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.interactions.with_raw_response.get( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_get_overload_2(self, client: GeminiNextGenAPIClient) -> None: + interaction_stream = client.interactions.get( + id="id", + api_version="api_version", + stream=True, + ) + interaction_stream.response.close() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_get_with_all_params_overload_2(self, client: GeminiNextGenAPIClient) -> None: + interaction_stream = client.interactions.get( + id="id", + api_version="api_version", + stream=True, + include_input=True, + last_event_id="last_event_id", + ) + interaction_stream.response.close() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_get_overload_2(self, client: GeminiNextGenAPIClient) -> None: + response = client.interactions.with_raw_response.get( + id="id", + api_version="api_version", + stream=True, + ) + + stream = response.parse() + stream.close() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_get_overload_2(self, client: GeminiNextGenAPIClient) -> None: + with client.interactions.with_streaming_response.get( + id="id", + api_version="api_version", + stream=True, + ) as response: + assert not response.is_closed + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_get_overload_2(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.interactions.with_raw_response.get( + id="id", + api_version="", + stream=True, + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.interactions.with_raw_response.get( + id="", + api_version="api_version", + stream=True, + ) + + +class TestAsyncInteractions: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_create_overload_1(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction = await async_client.interactions.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_create_with_all_params_overload_1(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction = await async_client.interactions.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + "annotations": [ + { + "type": "url_citation", + "end_index": 0, + "start_index": 0, + "title": "title", + "url": "url", + } + ], + }, + model="gemini-2.5-computer-use-preview-10-2025", + background=True, + generation_config={ + "image_config": { + "aspect_ratio": "1:1", + "image_size": "1K", + }, + "max_output_tokens": 0, + "seed": 0, + "speech_config": [ + { + "language": "language", + "speaker": "speaker", + "voice": "voice", + } + ], + "stop_sequences": ["string"], + "temperature": 0, + "thinking_level": "minimal", + "thinking_summaries": "auto", + "tool_choice": "auto", + "top_p": 0, + }, + previous_interaction_id="previous_interaction_id", + response_format=[ + { + "type": "audio", + "bit_rate": 0, + "delivery": "inline", + "mime_type": "audio/mp3", + "sample_rate": 0, + } + ], + response_mime_type="response_mime_type", + response_modalities=["text"], + service_tier="flex", + store=True, + stream=False, + system_instruction="system_instruction", + tools=[ + { + "type": "function", + "description": "description", + "name": "name", + "parameters": {}, + } + ], + webhook_config={ + "uris": ["string"], + "user_metadata": {"foo": "bar"}, + }, + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_create_overload_1(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.interactions.with_raw_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ) + + assert response.is_closed is True + interaction = await response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_create_overload_1(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.interactions.with_streaming_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ) as response: + assert not response.is_closed + + interaction = await response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_create_overload_1(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.interactions.with_raw_response.create( + api_version="", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_create_overload_2(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction_stream = await async_client.interactions.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + stream=True, + ) + await interaction_stream.response.aclose() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_create_with_all_params_overload_2(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction_stream = await async_client.interactions.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + "annotations": [ + { + "type": "url_citation", + "end_index": 0, + "start_index": 0, + "title": "title", + "url": "url", + } + ], + }, + model="gemini-2.5-computer-use-preview-10-2025", + stream=True, + background=True, + generation_config={ + "image_config": { + "aspect_ratio": "1:1", + "image_size": "1K", + }, + "max_output_tokens": 0, + "seed": 0, + "speech_config": [ + { + "language": "language", + "speaker": "speaker", + "voice": "voice", + } + ], + "stop_sequences": ["string"], + "temperature": 0, + "thinking_level": "minimal", + "thinking_summaries": "auto", + "tool_choice": "auto", + "top_p": 0, + }, + previous_interaction_id="previous_interaction_id", + response_format=[ + { + "type": "audio", + "bit_rate": 0, + "delivery": "inline", + "mime_type": "audio/mp3", + "sample_rate": 0, + } + ], + response_mime_type="response_mime_type", + response_modalities=["text"], + service_tier="flex", + store=True, + system_instruction="system_instruction", + tools=[ + { + "type": "function", + "description": "description", + "name": "name", + "parameters": {}, + } + ], + webhook_config={ + "uris": ["string"], + "user_metadata": {"foo": "bar"}, + }, + ) + await interaction_stream.response.aclose() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_create_overload_2(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.interactions.with_raw_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + stream=True, + ) + + stream = await response.parse() + await stream.close() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_create_overload_2(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.interactions.with_streaming_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + stream=True, + ) as response: + assert not response.is_closed + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_create_overload_2(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.interactions.with_raw_response.create( + api_version="", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + stream=True, + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_create_overload_3(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction = await async_client.interactions.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_create_with_all_params_overload_3(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction = await async_client.interactions.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + "annotations": [ + { + "type": "url_citation", + "end_index": 0, + "start_index": 0, + "title": "title", + "url": "url", + } + ], + }, + agent_config={"type": "dynamic"}, + background=True, + previous_interaction_id="previous_interaction_id", + response_format=[ + { + "type": "audio", + "bit_rate": 0, + "delivery": "inline", + "mime_type": "audio/mp3", + "sample_rate": 0, + } + ], + response_mime_type="response_mime_type", + response_modalities=["text"], + service_tier="flex", + store=True, + stream=False, + system_instruction="system_instruction", + tools=[ + { + "type": "function", + "description": "description", + "name": "name", + "parameters": {}, + } + ], + webhook_config={ + "uris": ["string"], + "user_metadata": {"foo": "bar"}, + }, + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_create_overload_3(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.interactions.with_raw_response.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + ) + + assert response.is_closed is True + interaction = await response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_create_overload_3(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.interactions.with_streaming_response.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + ) as response: + assert not response.is_closed + + interaction = await response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_create_overload_3(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.interactions.with_raw_response.create( + api_version="", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_create_overload_4(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction_stream = await async_client.interactions.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + stream=True, + ) + await interaction_stream.response.aclose() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_create_with_all_params_overload_4(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction_stream = await async_client.interactions.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + "annotations": [ + { + "type": "url_citation", + "end_index": 0, + "start_index": 0, + "title": "title", + "url": "url", + } + ], + }, + stream=True, + agent_config={"type": "dynamic"}, + background=True, + previous_interaction_id="previous_interaction_id", + response_format=[ + { + "type": "audio", + "bit_rate": 0, + "delivery": "inline", + "mime_type": "audio/mp3", + "sample_rate": 0, + } + ], + response_mime_type="response_mime_type", + response_modalities=["text"], + service_tier="flex", + store=True, + system_instruction="system_instruction", + tools=[ + { + "type": "function", + "description": "description", + "name": "name", + "parameters": {}, + } + ], + webhook_config={ + "uris": ["string"], + "user_metadata": {"foo": "bar"}, + }, + ) + await interaction_stream.response.aclose() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_create_overload_4(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.interactions.with_raw_response.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + stream=True, + ) + + stream = await response.parse() + await stream.close() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_create_overload_4(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.interactions.with_streaming_response.create( + api_version="api_version", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + stream=True, + ) as response: + assert not response.is_closed + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_create_overload_4(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.interactions.with_raw_response.create( + api_version="", + agent="deep-research-pro-preview-12-2025", + input={ + "text": "text", + "type": "text", + }, + stream=True, + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_delete(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction = await async_client.interactions.delete( + id="id", + api_version="api_version", + ) + assert_matches_type(object, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_delete(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.interactions.with_raw_response.delete( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + interaction = await response.parse() + assert_matches_type(object, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_delete(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.interactions.with_streaming_response.delete( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + interaction = await response.parse() + assert_matches_type(object, interaction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_delete(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.interactions.with_raw_response.delete( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await async_client.interactions.with_raw_response.delete( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_cancel(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction = await async_client.interactions.cancel( + id="id", + api_version="api_version", + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_cancel(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.interactions.with_raw_response.cancel( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + interaction = await response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_cancel(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.interactions.with_streaming_response.cancel( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + interaction = await response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_cancel(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.interactions.with_raw_response.cancel( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await async_client.interactions.with_raw_response.cancel( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_get_overload_1(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction = await async_client.interactions.get( + id="id", + api_version="api_version", + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_get_with_all_params_overload_1(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction = await async_client.interactions.get( + id="id", + api_version="api_version", + include_input=True, + last_event_id="last_event_id", + stream=False, + ) + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_get_overload_1(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.interactions.with_raw_response.get( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + interaction = await response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_get_overload_1(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.interactions.with_streaming_response.get( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + interaction = await response.parse() + assert_matches_type(Interaction, interaction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_get_overload_1(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.interactions.with_raw_response.get( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await async_client.interactions.with_raw_response.get( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_get_overload_2(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction_stream = await async_client.interactions.get( + id="id", + api_version="api_version", + stream=True, + ) + await interaction_stream.response.aclose() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_get_with_all_params_overload_2(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + interaction_stream = await async_client.interactions.get( + id="id", + api_version="api_version", + stream=True, + include_input=True, + last_event_id="last_event_id", + ) + await interaction_stream.response.aclose() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_get_overload_2(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.interactions.with_raw_response.get( + id="id", + api_version="api_version", + stream=True, + ) + + stream = await response.parse() + await stream.close() + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_get_overload_2(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.interactions.with_streaming_response.get( + id="id", + api_version="api_version", + stream=True, + ) as response: + assert not response.is_closed + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_get_overload_2(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.interactions.with_raw_response.get( + id="id", + api_version="", + stream=True, + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await async_client.interactions.with_raw_response.get( + id="", + api_version="api_version", + stream=True, + ) diff --git a/google/genai/tests/interactions/api_resources/test_webhooks.py b/google/genai/tests/interactions/api_resources/test_webhooks.py new file mode 100644 index 000000000..3b6552994 --- /dev/null +++ b/google/genai/tests/interactions/api_resources/test_webhooks.py @@ -0,0 +1,833 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from tests.utils import assert_matches_type +from google.genai._interactions import GeminiNextGenAPIClient, AsyncGeminiNextGenAPIClient +from google.genai._interactions.types import ( + Webhook, + WebhookListResponse, + WebhookPingResponse, + WebhookDeleteResponse, + WebhookRotateSigningSecretResponse, +) + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestWebhooks: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_create(self, client: GeminiNextGenAPIClient) -> None: + webhook = client.webhooks.create( + api_version="api_version", + subscribed_events=["batch.succeeded"], + uri="uri", + ) + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_create_with_all_params(self, client: GeminiNextGenAPIClient) -> None: + webhook = client.webhooks.create( + api_version="api_version", + subscribed_events=["batch.succeeded"], + uri="uri", + name="name", + ) + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_create(self, client: GeminiNextGenAPIClient) -> None: + response = client.webhooks.with_raw_response.create( + api_version="api_version", + subscribed_events=["batch.succeeded"], + uri="uri", + ) + + assert response.is_closed is True + webhook = response.parse() + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_create(self, client: GeminiNextGenAPIClient) -> None: + with client.webhooks.with_streaming_response.create( + api_version="api_version", + subscribed_events=["batch.succeeded"], + uri="uri", + ) as response: + assert not response.is_closed + + webhook = response.parse() + assert_matches_type(Webhook, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_create(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.webhooks.with_raw_response.create( + api_version="", + subscribed_events=["batch.succeeded"], + uri="uri", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_update(self, client: GeminiNextGenAPIClient) -> None: + webhook = client.webhooks.update( + id="id", + api_version="api_version", + ) + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_update_with_all_params(self, client: GeminiNextGenAPIClient) -> None: + webhook = client.webhooks.update( + id="id", + api_version="api_version", + update_mask="update_mask", + name="name", + state="enabled", + subscribed_events=["batch.succeeded"], + uri="uri", + ) + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_update(self, client: GeminiNextGenAPIClient) -> None: + response = client.webhooks.with_raw_response.update( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + webhook = response.parse() + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_update(self, client: GeminiNextGenAPIClient) -> None: + with client.webhooks.with_streaming_response.update( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + webhook = response.parse() + assert_matches_type(Webhook, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_update(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.webhooks.with_raw_response.update( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.webhooks.with_raw_response.update( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_list(self, client: GeminiNextGenAPIClient) -> None: + webhook = client.webhooks.list( + api_version="api_version", + ) + assert_matches_type(WebhookListResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_list_with_all_params(self, client: GeminiNextGenAPIClient) -> None: + webhook = client.webhooks.list( + api_version="api_version", + page_size=0, + page_token="page_token", + ) + assert_matches_type(WebhookListResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_list(self, client: GeminiNextGenAPIClient) -> None: + response = client.webhooks.with_raw_response.list( + api_version="api_version", + ) + + assert response.is_closed is True + webhook = response.parse() + assert_matches_type(WebhookListResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_list(self, client: GeminiNextGenAPIClient) -> None: + with client.webhooks.with_streaming_response.list( + api_version="api_version", + ) as response: + assert not response.is_closed + + webhook = response.parse() + assert_matches_type(WebhookListResponse, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_list(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.webhooks.with_raw_response.list( + api_version="", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_delete(self, client: GeminiNextGenAPIClient) -> None: + webhook = client.webhooks.delete( + id="id", + api_version="api_version", + ) + assert_matches_type(WebhookDeleteResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_delete(self, client: GeminiNextGenAPIClient) -> None: + response = client.webhooks.with_raw_response.delete( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + webhook = response.parse() + assert_matches_type(WebhookDeleteResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_delete(self, client: GeminiNextGenAPIClient) -> None: + with client.webhooks.with_streaming_response.delete( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + webhook = response.parse() + assert_matches_type(WebhookDeleteResponse, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_delete(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.webhooks.with_raw_response.delete( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.webhooks.with_raw_response.delete( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_get(self, client: GeminiNextGenAPIClient) -> None: + webhook = client.webhooks.get( + id="id", + api_version="api_version", + ) + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_get(self, client: GeminiNextGenAPIClient) -> None: + response = client.webhooks.with_raw_response.get( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + webhook = response.parse() + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_get(self, client: GeminiNextGenAPIClient) -> None: + with client.webhooks.with_streaming_response.get( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + webhook = response.parse() + assert_matches_type(Webhook, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_get(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.webhooks.with_raw_response.get( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.webhooks.with_raw_response.get( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_ping(self, client: GeminiNextGenAPIClient) -> None: + webhook = client.webhooks.ping( + id="id", + api_version="api_version", + ) + assert_matches_type(WebhookPingResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_ping_with_all_params(self, client: GeminiNextGenAPIClient) -> None: + webhook = client.webhooks.ping( + id="id", + api_version="api_version", + body={}, + ) + assert_matches_type(WebhookPingResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_ping(self, client: GeminiNextGenAPIClient) -> None: + response = client.webhooks.with_raw_response.ping( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + webhook = response.parse() + assert_matches_type(WebhookPingResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_ping(self, client: GeminiNextGenAPIClient) -> None: + with client.webhooks.with_streaming_response.ping( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + webhook = response.parse() + assert_matches_type(WebhookPingResponse, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_ping(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.webhooks.with_raw_response.ping( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.webhooks.with_raw_response.ping( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_rotate_signing_secret(self, client: GeminiNextGenAPIClient) -> None: + webhook = client.webhooks.rotate_signing_secret( + id="id", + api_version="api_version", + ) + assert_matches_type(WebhookRotateSigningSecretResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_method_rotate_signing_secret_with_all_params(self, client: GeminiNextGenAPIClient) -> None: + webhook = client.webhooks.rotate_signing_secret( + id="id", + api_version="api_version", + revocation_behavior="revoke_previous_secrets_after_h24", + ) + assert_matches_type(WebhookRotateSigningSecretResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_raw_response_rotate_signing_secret(self, client: GeminiNextGenAPIClient) -> None: + response = client.webhooks.with_raw_response.rotate_signing_secret( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + webhook = response.parse() + assert_matches_type(WebhookRotateSigningSecretResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_streaming_response_rotate_signing_secret(self, client: GeminiNextGenAPIClient) -> None: + with client.webhooks.with_streaming_response.rotate_signing_secret( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + webhook = response.parse() + assert_matches_type(WebhookRotateSigningSecretResponse, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + def test_path_params_rotate_signing_secret(self, client: GeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + client.webhooks.with_raw_response.rotate_signing_secret( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.webhooks.with_raw_response.rotate_signing_secret( + id="", + api_version="api_version", + ) + + +class TestAsyncWebhooks: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_create(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + webhook = await async_client.webhooks.create( + api_version="api_version", + subscribed_events=["batch.succeeded"], + uri="uri", + ) + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_create_with_all_params(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + webhook = await async_client.webhooks.create( + api_version="api_version", + subscribed_events=["batch.succeeded"], + uri="uri", + name="name", + ) + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_create(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.webhooks.with_raw_response.create( + api_version="api_version", + subscribed_events=["batch.succeeded"], + uri="uri", + ) + + assert response.is_closed is True + webhook = await response.parse() + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_create(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.webhooks.with_streaming_response.create( + api_version="api_version", + subscribed_events=["batch.succeeded"], + uri="uri", + ) as response: + assert not response.is_closed + + webhook = await response.parse() + assert_matches_type(Webhook, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_create(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.webhooks.with_raw_response.create( + api_version="", + subscribed_events=["batch.succeeded"], + uri="uri", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_update(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + webhook = await async_client.webhooks.update( + id="id", + api_version="api_version", + ) + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_update_with_all_params(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + webhook = await async_client.webhooks.update( + id="id", + api_version="api_version", + update_mask="update_mask", + name="name", + state="enabled", + subscribed_events=["batch.succeeded"], + uri="uri", + ) + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_update(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.webhooks.with_raw_response.update( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + webhook = await response.parse() + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_update(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.webhooks.with_streaming_response.update( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + webhook = await response.parse() + assert_matches_type(Webhook, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_update(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.webhooks.with_raw_response.update( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await async_client.webhooks.with_raw_response.update( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_list(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + webhook = await async_client.webhooks.list( + api_version="api_version", + ) + assert_matches_type(WebhookListResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + webhook = await async_client.webhooks.list( + api_version="api_version", + page_size=0, + page_token="page_token", + ) + assert_matches_type(WebhookListResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_list(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.webhooks.with_raw_response.list( + api_version="api_version", + ) + + assert response.is_closed is True + webhook = await response.parse() + assert_matches_type(WebhookListResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_list(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.webhooks.with_streaming_response.list( + api_version="api_version", + ) as response: + assert not response.is_closed + + webhook = await response.parse() + assert_matches_type(WebhookListResponse, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_list(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.webhooks.with_raw_response.list( + api_version="", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_delete(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + webhook = await async_client.webhooks.delete( + id="id", + api_version="api_version", + ) + assert_matches_type(WebhookDeleteResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_delete(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.webhooks.with_raw_response.delete( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + webhook = await response.parse() + assert_matches_type(WebhookDeleteResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_delete(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.webhooks.with_streaming_response.delete( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + webhook = await response.parse() + assert_matches_type(WebhookDeleteResponse, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_delete(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.webhooks.with_raw_response.delete( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await async_client.webhooks.with_raw_response.delete( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_get(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + webhook = await async_client.webhooks.get( + id="id", + api_version="api_version", + ) + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_get(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.webhooks.with_raw_response.get( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + webhook = await response.parse() + assert_matches_type(Webhook, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_get(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.webhooks.with_streaming_response.get( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + webhook = await response.parse() + assert_matches_type(Webhook, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_get(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.webhooks.with_raw_response.get( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await async_client.webhooks.with_raw_response.get( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_ping(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + webhook = await async_client.webhooks.ping( + id="id", + api_version="api_version", + ) + assert_matches_type(WebhookPingResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_ping_with_all_params(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + webhook = await async_client.webhooks.ping( + id="id", + api_version="api_version", + body={}, + ) + assert_matches_type(WebhookPingResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_ping(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.webhooks.with_raw_response.ping( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + webhook = await response.parse() + assert_matches_type(WebhookPingResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_ping(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.webhooks.with_streaming_response.ping( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + webhook = await response.parse() + assert_matches_type(WebhookPingResponse, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_ping(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.webhooks.with_raw_response.ping( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await async_client.webhooks.with_raw_response.ping( + id="", + api_version="api_version", + ) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_rotate_signing_secret(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + webhook = await async_client.webhooks.rotate_signing_secret( + id="id", + api_version="api_version", + ) + assert_matches_type(WebhookRotateSigningSecretResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_method_rotate_signing_secret_with_all_params( + self, async_client: AsyncGeminiNextGenAPIClient + ) -> None: + webhook = await async_client.webhooks.rotate_signing_secret( + id="id", + api_version="api_version", + revocation_behavior="revoke_previous_secrets_after_h24", + ) + assert_matches_type(WebhookRotateSigningSecretResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_raw_response_rotate_signing_secret(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + response = await async_client.webhooks.with_raw_response.rotate_signing_secret( + id="id", + api_version="api_version", + ) + + assert response.is_closed is True + webhook = await response.parse() + assert_matches_type(WebhookRotateSigningSecretResponse, webhook, path=["response"]) + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_streaming_response_rotate_signing_secret(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + async with async_client.webhooks.with_streaming_response.rotate_signing_secret( + id="id", + api_version="api_version", + ) as response: + assert not response.is_closed + + webhook = await response.parse() + assert_matches_type(WebhookRotateSigningSecretResponse, webhook, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="Mock server tests are disabled") + @parametrize + async def test_path_params_rotate_signing_secret(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `api_version` but received ''"): + await async_client.webhooks.with_raw_response.rotate_signing_secret( + id="id", + api_version="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await async_client.webhooks.with_raw_response.rotate_signing_secret( + id="", + api_version="api_version", + ) diff --git a/google/genai/tests/interactions/conftest.py b/google/genai/tests/interactions/conftest.py new file mode 100644 index 000000000..ab4cda288 --- /dev/null +++ b/google/genai/tests/interactions/conftest.py @@ -0,0 +1,99 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +import logging +from typing import TYPE_CHECKING, Iterator, AsyncIterator + +import httpx +import pytest +from pytest_asyncio import is_async_test + +from google.genai._interactions import DefaultAioHttpClient, GeminiNextGenAPIClient, AsyncGeminiNextGenAPIClient +from google.genai._interactions._utils import is_dict + +if TYPE_CHECKING: + from _pytest.fixtures import FixtureRequest # pyright: ignore[reportPrivateImportUsage] + +pytest.register_assert_rewrite("tests.utils") + +logging.getLogger("google.genai._interactions").setLevel(logging.DEBUG) + + +# automatically add `pytest.mark.asyncio()` to all of our async tests +# so we don't have to add that boilerplate everywhere +def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: + pytest_asyncio_tests = (item for item in items if is_async_test(item)) + session_scope_marker = pytest.mark.asyncio(loop_scope="session") + for async_test in pytest_asyncio_tests: + async_test.add_marker(session_scope_marker, append=False) + + # We skip tests that use both the aiohttp client and respx_mock as respx_mock + # doesn't support custom transports. + for item in items: + if "async_client" not in item.fixturenames or "respx_mock" not in item.fixturenames: + continue + + if not hasattr(item, "callspec"): + continue + + async_client_param = item.callspec.params.get("async_client") + if is_dict(async_client_param) and async_client_param.get("http_client") == "aiohttp": + item.add_marker(pytest.mark.skip(reason="aiohttp client is not compatible with respx_mock")) + + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + +api_key = "My API Key" + + +@pytest.fixture(scope="session") +def client(request: FixtureRequest) -> Iterator[GeminiNextGenAPIClient]: + strict = getattr(request, "param", True) + if not isinstance(strict, bool): + raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}") + + with GeminiNextGenAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=strict) as client: + yield client + + +@pytest.fixture(scope="session") +async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncGeminiNextGenAPIClient]: + param = getattr(request, "param", True) + + # defaults + strict = True + http_client: None | httpx.AsyncClient = None + + if isinstance(param, bool): + strict = param + elif is_dict(param): + strict = param.get("strict", True) + assert isinstance(strict, bool) + + http_client_type = param.get("http_client", "httpx") + if http_client_type == "aiohttp": + http_client = DefaultAioHttpClient() + else: + raise TypeError(f"Unexpected fixture parameter type {type(param)}, expected bool or dict") + + async with AsyncGeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=strict, http_client=http_client + ) as client: + yield client diff --git a/google/genai/tests/interactions/sample_file.txt b/google/genai/tests/interactions/sample_file.txt new file mode 100644 index 000000000..af5626b4a --- /dev/null +++ b/google/genai/tests/interactions/sample_file.txt @@ -0,0 +1 @@ +Hello, world! diff --git a/google/genai/tests/interactions/test_client.py b/google/genai/tests/interactions/test_client.py new file mode 100644 index 000000000..c5825c6bf --- /dev/null +++ b/google/genai/tests/interactions/test_client.py @@ -0,0 +1,1998 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import gc +import os +import sys +import json +import asyncio +import inspect +import dataclasses +import tracemalloc +from typing import Any, Union, TypeVar, Callable, Iterable, Iterator, Optional, Coroutine, cast +from unittest import mock +from typing_extensions import Literal, AsyncIterator, override + +import httpx +import pytest +from respx import MockRouter +from pydantic import ValidationError + +from google.genai._interactions import GeminiNextGenAPIClient, APIResponseValidationError, AsyncGeminiNextGenAPIClient +from google.genai._interactions._types import Omit +from google.genai._interactions._utils import asyncify +from google.genai._interactions._models import BaseModel, FinalRequestOptions +from google.genai._interactions._streaming import Stream, AsyncStream +from google.genai._interactions._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError +from google.genai._interactions._base_client import ( + DEFAULT_TIMEOUT, + HTTPX_DEFAULT_TIMEOUT, + BaseClient, + OtherPlatform, + DefaultHttpxClient, + DefaultAsyncHttpxClient, + get_platform, + make_request_options, +) + +from .utils import update_env + +T = TypeVar("T") +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +api_key = "My API Key" + + +def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]: + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + url = httpx.URL(request.url) + return dict(url.params) + + +def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float: + return 0.1 + + +def mirror_request_content(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, content=request.content) + + +# note: we can't use the httpx.MockTransport class as it consumes the request +# body itself, which means we can't test that the body is read lazily +class MockTransport(httpx.BaseTransport, httpx.AsyncBaseTransport): + def __init__( + self, + handler: Callable[[httpx.Request], httpx.Response] + | Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]], + ) -> None: + self.handler = handler + + @override + def handle_request( + self, + request: httpx.Request, + ) -> httpx.Response: + assert not inspect.iscoroutinefunction(self.handler), "handler must not be a coroutine function" + assert inspect.isfunction(self.handler), "handler must be a function" + return self.handler(request) + + @override + async def handle_async_request( + self, + request: httpx.Request, + ) -> httpx.Response: + assert inspect.iscoroutinefunction(self.handler), "handler must be a coroutine function" + return await self.handler(request) + + +@dataclasses.dataclass +class Counter: + value: int = 0 + + +def _make_sync_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> Iterator[T]: + for item in iterable: + if counter: + counter.value += 1 + yield item + + +async def _make_async_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> AsyncIterator[T]: + for item in iterable: + if counter: + counter.value += 1 + yield item + + +def _get_open_connections(client: GeminiNextGenAPIClient | AsyncGeminiNextGenAPIClient) -> int: + transport = client._client._transport + assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport) + + pool = transport._pool + return len(pool._requests) + + +class TestGeminiNextGenAPIClient: + @pytest.mark.respx(base_url=base_url) + def test_raw_response(self, respx_mock: MockRouter, client: GeminiNextGenAPIClient) -> None: + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + response = client.post("/foo", cast_to=httpx.Response) + assert response.status_code == 200 + assert isinstance(response, httpx.Response) + assert response.json() == {"foo": "bar"} + + @pytest.mark.respx(base_url=base_url) + def test_raw_response_for_binary(self, respx_mock: MockRouter, client: GeminiNextGenAPIClient) -> None: + respx_mock.post("/foo").mock( + return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') + ) + + response = client.post("/foo", cast_to=httpx.Response) + assert response.status_code == 200 + assert isinstance(response, httpx.Response) + assert response.json() == {"foo": "bar"} + + def test_copy(self, client: GeminiNextGenAPIClient) -> None: + copied = client.copy() + assert id(copied) != id(client) + + copied = client.copy(api_key="another My API Key") + assert copied.api_key == "another My API Key" + assert client.api_key == "My API Key" + + def test_copy_default_options(self, client: GeminiNextGenAPIClient) -> None: + # options that have a default are overridden correctly + copied = client.copy(max_retries=7) + assert copied.max_retries == 7 + assert client.max_retries == 2 + + copied2 = copied.copy(max_retries=6) + assert copied2.max_retries == 6 + assert copied.max_retries == 7 + + # timeout + assert isinstance(client.timeout, httpx.Timeout) + copied = client.copy(timeout=None) + assert copied.timeout is None + assert isinstance(client.timeout, httpx.Timeout) + + def test_copy_default_headers(self) -> None: + client = GeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} + ) + assert client.default_headers["X-Foo"] == "bar" + + # does not override the already given value when not specified + copied = client.copy() + assert copied.default_headers["X-Foo"] == "bar" + + # merges already given headers + copied = client.copy(default_headers={"X-Bar": "stainless"}) + assert copied.default_headers["X-Foo"] == "bar" + assert copied.default_headers["X-Bar"] == "stainless" + + # uses new values for any already given headers + copied = client.copy(default_headers={"X-Foo": "stainless"}) + assert copied.default_headers["X-Foo"] == "stainless" + + # set_default_headers + + # completely overrides already set values + copied = client.copy(set_default_headers={}) + assert copied.default_headers.get("X-Foo") is None + + copied = client.copy(set_default_headers={"X-Bar": "Robert"}) + assert copied.default_headers["X-Bar"] == "Robert" + + with pytest.raises( + ValueError, + match="`default_headers` and `set_default_headers` arguments are mutually exclusive", + ): + client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + client.close() + + def test_copy_default_query(self) -> None: + client = GeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"} + ) + assert _get_params(client)["foo"] == "bar" + + # does not override the already given value when not specified + copied = client.copy() + assert _get_params(copied)["foo"] == "bar" + + # merges already given params + copied = client.copy(default_query={"bar": "stainless"}) + params = _get_params(copied) + assert params["foo"] == "bar" + assert params["bar"] == "stainless" + + # uses new values for any already given headers + copied = client.copy(default_query={"foo": "stainless"}) + assert _get_params(copied)["foo"] == "stainless" + + # set_default_query + + # completely overrides already set values + copied = client.copy(set_default_query={}) + assert _get_params(copied) == {} + + copied = client.copy(set_default_query={"bar": "Robert"}) + assert _get_params(copied)["bar"] == "Robert" + + with pytest.raises( + ValueError, + # TODO: update + match="`default_query` and `set_default_query` arguments are mutually exclusive", + ): + client.copy(set_default_query={}, default_query={"foo": "Bar"}) + + client.close() + + def test_copy_signature(self, client: GeminiNextGenAPIClient) -> None: + # ensure the same parameters that can be passed to the client are defined in the `.copy()` method + init_signature = inspect.signature( + # mypy doesn't like that we access the `__init__` property. + client.__init__, # type: ignore[misc] + ) + copy_signature = inspect.signature(client.copy) + exclude_params = {"transport", "proxies", "_strict_response_validation"} + + for name in init_signature.parameters.keys(): + if name in exclude_params: + continue + + copy_param = copy_signature.parameters.get(name) + assert copy_param is not None, f"copy() signature is missing the {name} param" + + @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") + def test_copy_build_request(self, client: GeminiNextGenAPIClient) -> None: + options = FinalRequestOptions(method="get", url="/foo") + + def build_request(options: FinalRequestOptions) -> None: + client_copy = client.copy() + client_copy._build_request(options) + + # ensure that the machinery is warmed up before tracing starts. + build_request(options) + gc.collect() + + tracemalloc.start(1000) + + snapshot_before = tracemalloc.take_snapshot() + + ITERATIONS = 10 + for _ in range(ITERATIONS): + build_request(options) + + gc.collect() + snapshot_after = tracemalloc.take_snapshot() + + tracemalloc.stop() + + def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None: + if diff.count == 0: + # Avoid false positives by considering only leaks (i.e. allocations that persist). + return + + if diff.count % ITERATIONS != 0: + # Avoid false positives by considering only leaks that appear per iteration. + return + + for frame in diff.traceback: + if any( + frame.filename.endswith(fragment) + for fragment in [ + # to_raw_response_wrapper leaks through the @functools.wraps() decorator. + # + # removing the decorator fixes the leak for reasons we don't understand. + "google.genai._interactions/_legacy_response.py", + "google.genai._interactions/_response.py", + # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason. + "google.genai._interactions/_compat.py", + # Standard library leaks we don't care about. + "/logging/__init__.py", + ] + ): + return + + leaks.append(diff) + + leaks: list[tracemalloc.StatisticDiff] = [] + for diff in snapshot_after.compare_to(snapshot_before, "traceback"): + add_leak(leaks, diff) + if leaks: + for leak in leaks: + print("MEMORY LEAK:", leak) + for frame in leak.traceback: + print(frame) + raise AssertionError() + + def test_request_timeout(self, client: GeminiNextGenAPIClient) -> None: + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == DEFAULT_TIMEOUT + + request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == httpx.Timeout(100.0) + + def test_client_timeout_option(self) -> None: + client = GeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0) + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == httpx.Timeout(0) + + client.close() + + def test_http_client_timeout_option(self) -> None: + # custom timeout given to the httpx client should be used + with httpx.Client(timeout=None) as http_client: + client = GeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == httpx.Timeout(None) + + client.close() + + # no timeout given to the httpx client should not use the httpx default + with httpx.Client() as http_client: + client = GeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == DEFAULT_TIMEOUT + + client.close() + + # explicitly passing the default timeout currently results in it being ignored + with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: + client = GeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == DEFAULT_TIMEOUT # our default + + client.close() + + async def test_invalid_http_client(self) -> None: + with pytest.raises(TypeError, match="Invalid `http_client` arg"): + async with httpx.AsyncClient() as http_client: + GeminiNextGenAPIClient( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=cast(Any, http_client), + ) + + def test_default_headers_option(self) -> None: + test_client = GeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} + ) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) + assert request.headers.get("x-foo") == "bar" + + test_client.close() + + def test_validate_headers(self) -> None: + client = GeminiNextGenAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + assert request.headers.get("x-goog-api-key") == api_key + + with update_env(**{"GEMINI_API_KEY": Omit()}): + client2 = GeminiNextGenAPIClient(base_url=base_url, api_key=None, _strict_response_validation=True) + + with pytest.raises( + TypeError, + match="Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted", + ): + client2._build_request(FinalRequestOptions(method="get", url="/foo")) + + request2 = client2._build_request( + FinalRequestOptions(method="get", url="/foo", headers={"x-goog-api-key": Omit()}) + ) + assert request2.headers.get("x-goog-api-key") is None + + def test_default_query_option(self) -> None: + client = GeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"} + ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + url = httpx.URL(request.url) + assert dict(url.params) == {"query_param": "bar"} + + request = client._build_request( + FinalRequestOptions( + method="get", + url="/foo", + params={"foo": "baz", "query_param": "overridden"}, + ) + ) + url = httpx.URL(request.url) + assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} + + client.close() + + def test_hardcoded_query_params_in_url(self, client: GeminiNextGenAPIClient) -> None: + request = client._build_request(FinalRequestOptions(method="get", url="/foo?beta=true")) + url = httpx.URL(request.url) + assert dict(url.params) == {"beta": "true"} + + request = client._build_request( + FinalRequestOptions( + method="get", + url="/foo?beta=true", + params={"limit": "10", "page": "abc"}, + ) + ) + url = httpx.URL(request.url) + assert dict(url.params) == {"beta": "true", "limit": "10", "page": "abc"} + + request = client._build_request( + FinalRequestOptions( + method="get", + url="/files/a%2Fb?beta=true", + params={"limit": "10"}, + ) + ) + assert request.url.raw_path == b"/files/a%2Fb?beta=true&limit=10" + + @pytest.mark.skip(reason="Mock server tests are disabled") + def test_api_version_client_params(self) -> None: + client = GeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, api_version="My API Version" + ) + with client as c2: + c2.interactions.create( + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ) + + def test_request_extra_json(self, client: GeminiNextGenAPIClient) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar"}, + extra_json={"baz": False}, + ), + ) + data = json.loads(request.content.decode("utf-8")) + assert data == {"foo": "bar", "baz": False} + + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + extra_json={"baz": False}, + ), + ) + data = json.loads(request.content.decode("utf-8")) + assert data == {"baz": False} + + # `extra_json` takes priority over `json_data` when keys clash + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar", "baz": True}, + extra_json={"baz": None}, + ), + ) + data = json.loads(request.content.decode("utf-8")) + assert data == {"foo": "bar", "baz": None} + + def test_request_extra_headers(self, client: GeminiNextGenAPIClient) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options(extra_headers={"X-Foo": "Foo"}), + ), + ) + assert request.headers.get("X-Foo") == "Foo" + + # `extra_headers` takes priority over `default_headers` when keys clash + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + extra_headers={"X-Bar": "false"}, + ), + ), + ) + assert request.headers.get("X-Bar") == "false" + + def test_request_extra_query(self, client: GeminiNextGenAPIClient) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + extra_query={"my_query_param": "Foo"}, + ), + ), + ) + params = dict(request.url.params) + assert params == {"my_query_param": "Foo"} + + # if both `query` and `extra_query` are given, they are merged + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + query={"bar": "1"}, + extra_query={"foo": "2"}, + ), + ), + ) + params = dict(request.url.params) + assert params == {"bar": "1", "foo": "2"} + + # `extra_query` takes priority over `query` when keys clash + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + query={"foo": "1"}, + extra_query={"foo": "2"}, + ), + ), + ) + params = dict(request.url.params) + assert params == {"foo": "2"} + + def test_multipart_repeating_array(self, client: GeminiNextGenAPIClient) -> None: + request = client._build_request( + FinalRequestOptions.construct( + method="post", + url="/foo", + headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"}, + json_data={"array": ["foo", "bar"]}, + files=[("foo.txt", b"hello world")], + ) + ) + + assert request.read().split(b"\r\n") == [ + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="array[]"', + b"", + b"foo", + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="array[]"', + b"", + b"bar", + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="foo.txt"; filename="upload"', + b"Content-Type: application/octet-stream", + b"", + b"hello world", + b"--6b7ba517decee4a450543ea6ae821c82--", + b"", + ] + + @pytest.mark.respx(base_url=base_url) + def test_binary_content_upload(self, respx_mock: MockRouter, client: GeminiNextGenAPIClient) -> None: + respx_mock.post("/upload").mock(side_effect=mirror_request_content) + + file_content = b"Hello, this is a test file." + + response = client.post( + "/upload", + content=file_content, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + + def test_binary_content_upload_with_iterator(self) -> None: + file_content = b"Hello, this is a test file." + counter = Counter() + iterator = _make_sync_iterator([file_content], counter=counter) + + def mock_handler(request: httpx.Request) -> httpx.Response: + assert counter.value == 0, "the request body should not have been read" + return httpx.Response(200, content=request.read()) + + with GeminiNextGenAPIClient( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.Client(transport=MockTransport(handler=mock_handler)), + ) as client: + response = client.post( + "/upload", + content=iterator, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + assert counter.value == 1 + + @pytest.mark.respx(base_url=base_url) + def test_binary_content_upload_with_body_is_deprecated( + self, respx_mock: MockRouter, client: GeminiNextGenAPIClient + ) -> None: + respx_mock.post("/upload").mock(side_effect=mirror_request_content) + + file_content = b"Hello, this is a test file." + + with pytest.deprecated_call( + match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead." + ): + response = client.post( + "/upload", + body=file_content, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + + @pytest.mark.respx(base_url=base_url) + def test_basic_union_response(self, respx_mock: MockRouter, client: GeminiNextGenAPIClient) -> None: + class Model1(BaseModel): + name: str + + class Model2(BaseModel): + foo: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + assert isinstance(response, Model2) + assert response.foo == "bar" + + @pytest.mark.respx(base_url=base_url) + def test_union_response_different_types(self, respx_mock: MockRouter, client: GeminiNextGenAPIClient) -> None: + """Union of objects with the same field name using a different type""" + + class Model1(BaseModel): + foo: int + + class Model2(BaseModel): + foo: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + assert isinstance(response, Model2) + assert response.foo == "bar" + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) + + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + assert isinstance(response, Model1) + assert response.foo == 1 + + @pytest.mark.respx(base_url=base_url) + def test_non_application_json_content_type_for_json_data( + self, respx_mock: MockRouter, client: GeminiNextGenAPIClient + ) -> None: + """ + Response that sets Content-Type to something other than application/json but returns json data + """ + + class Model(BaseModel): + foo: int + + respx_mock.get("/foo").mock( + return_value=httpx.Response( + 200, + content=json.dumps({"foo": 2}), + headers={"Content-Type": "application/text"}, + ) + ) + + response = client.get("/foo", cast_to=Model) + assert isinstance(response, Model) + assert response.foo == 2 + + def test_base_url_setter(self) -> None: + client = GeminiNextGenAPIClient( + base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True + ) + assert client.base_url == "https://example.com/from_init/" + + client.base_url = "https://example.com/from_setter" # type: ignore[assignment] + + assert client.base_url == "https://example.com/from_setter/" + + client.close() + + def test_base_url_env(self) -> None: + with update_env(GEMINI_NEXT_GEN_API_BASE_URL="http://localhost:5000/from/env"): + client = GeminiNextGenAPIClient(api_key=api_key, _strict_response_validation=True) + assert client.base_url == "http://localhost:5000/from/env/" + + @pytest.mark.parametrize( + "client", + [ + GeminiNextGenAPIClient( + base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True + ), + GeminiNextGenAPIClient( + base_url="http://localhost:5000/custom/path/", + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.Client(), + ), + ], + ids=["standard", "custom http client"], + ) + def test_base_url_trailing_slash(self, client: GeminiNextGenAPIClient) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar"}, + ), + ) + assert request.url == "http://localhost:5000/custom/path/foo" + client.close() + + @pytest.mark.parametrize( + "client", + [ + GeminiNextGenAPIClient( + base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True + ), + GeminiNextGenAPIClient( + base_url="http://localhost:5000/custom/path/", + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.Client(), + ), + ], + ids=["standard", "custom http client"], + ) + def test_base_url_no_trailing_slash(self, client: GeminiNextGenAPIClient) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar"}, + ), + ) + assert request.url == "http://localhost:5000/custom/path/foo" + client.close() + + @pytest.mark.parametrize( + "client", + [ + GeminiNextGenAPIClient( + base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True + ), + GeminiNextGenAPIClient( + base_url="http://localhost:5000/custom/path/", + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.Client(), + ), + ], + ids=["standard", "custom http client"], + ) + def test_absolute_request_url(self, client: GeminiNextGenAPIClient) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="https://myapi.com/foo", + json_data={"foo": "bar"}, + ), + ) + assert request.url == "https://myapi.com/foo" + client.close() + + def test_copied_client_does_not_close_http(self) -> None: + test_client = GeminiNextGenAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() + + copied = test_client.copy() + assert copied is not test_client + + del copied + + assert not test_client.is_closed() + + def test_client_context_manager(self) -> None: + test_client = GeminiNextGenAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) + with test_client as c2: + assert c2 is test_client + assert not c2.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() + + @pytest.mark.respx(base_url=base_url) + def test_client_response_validation_error(self, respx_mock: MockRouter, client: GeminiNextGenAPIClient) -> None: + class Model(BaseModel): + foo: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) + + with pytest.raises(APIResponseValidationError) as exc: + client.get("/foo", cast_to=Model) + + assert isinstance(exc.value.__cause__, ValidationError) + + def test_client_max_retries_validation(self) -> None: + with pytest.raises(TypeError, match=r"max_retries cannot be None"): + GeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None) + ) + + @pytest.mark.respx(base_url=base_url) + def test_default_stream_cls(self, respx_mock: MockRouter, client: GeminiNextGenAPIClient) -> None: + class Model(BaseModel): + name: str + + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + stream = client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model]) + assert isinstance(stream, Stream) + stream.response.close() + + @pytest.mark.respx(base_url=base_url) + def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: + class Model(BaseModel): + name: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format")) + + strict_client = GeminiNextGenAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) + + with pytest.raises(APIResponseValidationError): + strict_client.get("/foo", cast_to=Model) + + non_strict_client = GeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=False + ) + + response = non_strict_client.get("/foo", cast_to=Model) + assert isinstance(response, str) # type: ignore[unreachable] + + strict_client.close() + non_strict_client.close() + + @pytest.mark.parametrize( + "remaining_retries,retry_after,timeout", + [ + [3, "20", 20], + [3, "0", 0.5], + [3, "-10", 0.5], + [3, "60", 60], + [3, "61", 0.5], + [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20], + [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5], + [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5], + [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60], + [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5], + [3, "99999999999999999999999999999999999", 0.5], + [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5], + [3, "", 0.5], + [2, "", 0.5 * 2.0], + [1, "", 0.5 * 4.0], + [-1100, "", 8], # test large number potentially overflowing + ], + ) + @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) + def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, client: GeminiNextGenAPIClient + ) -> None: + headers = httpx.Headers({"retry-after": retry_after}) + options = FinalRequestOptions(method="get", url="/foo", max_retries=3) + calculated = client._calculate_retry_timeout(remaining_retries, options, headers) + assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] + + @mock.patch("google.genai._interactions._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: GeminiNextGenAPIClient) -> None: + respx_mock.post("/api_version/interactions").mock(side_effect=httpx.TimeoutException("Test timeout error")) + + with pytest.raises(APITimeoutError): + client.interactions.with_streaming_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ).__enter__() + + assert _get_open_connections(client) == 0 + + @mock.patch("google.genai._interactions._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: GeminiNextGenAPIClient) -> None: + respx_mock.post("/api_version/interactions").mock(return_value=httpx.Response(500)) + + with pytest.raises(APIStatusError): + client.interactions.with_streaming_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ).__enter__() + assert _get_open_connections(client) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("google.genai._interactions._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.parametrize("failure_mode", ["status", "exception"]) + def test_retries_taken( + self, + client: GeminiNextGenAPIClient, + failures_before_success: int, + failure_mode: Literal["status", "exception"], + respx_mock: MockRouter, + ) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + if failure_mode == "exception": + raise RuntimeError("oops") + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/api_version/interactions").mock(side_effect=retry_handler) + + response = client.interactions.with_raw_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ) + + assert response.retries_taken == failures_before_success + + def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Test that the proxy environment variables are set correctly + monkeypatch.setenv("HTTPS_PROXY", "https://example.org") + # Delete in case our environment has any proxy env vars set + monkeypatch.delenv("HTTP_PROXY", raising=False) + monkeypatch.delenv("ALL_PROXY", raising=False) + monkeypatch.delenv("NO_PROXY", raising=False) + monkeypatch.delenv("http_proxy", raising=False) + monkeypatch.delenv("https_proxy", raising=False) + monkeypatch.delenv("all_proxy", raising=False) + monkeypatch.delenv("no_proxy", raising=False) + + client = DefaultHttpxClient() + + mounts = tuple(client._mounts.items()) + assert len(mounts) == 1 + assert mounts[0][0].pattern == "https://" + + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") + def test_default_client_creation(self) -> None: + # Ensure that the client can be initialized without any exceptions + DefaultHttpxClient( + verify=True, + cert=None, + trust_env=True, + http1=True, + http2=False, + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + ) + + @pytest.mark.respx(base_url=base_url) + def test_follow_redirects(self, respx_mock: MockRouter, client: GeminiNextGenAPIClient) -> None: + # Test that the default follow_redirects=True allows following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) + + response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + @pytest.mark.respx(base_url=base_url) + def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: GeminiNextGenAPIClient) -> None: + # Test that follow_redirects=False prevents following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + + with pytest.raises(APIStatusError) as exc_info: + client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response) + + assert exc_info.value.response.status_code == 302 + assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + + +class TestAsyncGeminiNextGenAPIClient: + @pytest.mark.respx(base_url=base_url) + async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncGeminiNextGenAPIClient) -> None: + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + response = await async_client.post("/foo", cast_to=httpx.Response) + assert response.status_code == 200 + assert isinstance(response, httpx.Response) + assert response.json() == {"foo": "bar"} + + @pytest.mark.respx(base_url=base_url) + async def test_raw_response_for_binary( + self, respx_mock: MockRouter, async_client: AsyncGeminiNextGenAPIClient + ) -> None: + respx_mock.post("/foo").mock( + return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') + ) + + response = await async_client.post("/foo", cast_to=httpx.Response) + assert response.status_code == 200 + assert isinstance(response, httpx.Response) + assert response.json() == {"foo": "bar"} + + def test_copy(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + copied = async_client.copy() + assert id(copied) != id(async_client) + + copied = async_client.copy(api_key="another My API Key") + assert copied.api_key == "another My API Key" + assert async_client.api_key == "My API Key" + + def test_copy_default_options(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + # options that have a default are overridden correctly + copied = async_client.copy(max_retries=7) + assert copied.max_retries == 7 + assert async_client.max_retries == 2 + + copied2 = copied.copy(max_retries=6) + assert copied2.max_retries == 6 + assert copied.max_retries == 7 + + # timeout + assert isinstance(async_client.timeout, httpx.Timeout) + copied = async_client.copy(timeout=None) + assert copied.timeout is None + assert isinstance(async_client.timeout, httpx.Timeout) + + async def test_copy_default_headers(self) -> None: + client = AsyncGeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} + ) + assert client.default_headers["X-Foo"] == "bar" + + # does not override the already given value when not specified + copied = client.copy() + assert copied.default_headers["X-Foo"] == "bar" + + # merges already given headers + copied = client.copy(default_headers={"X-Bar": "stainless"}) + assert copied.default_headers["X-Foo"] == "bar" + assert copied.default_headers["X-Bar"] == "stainless" + + # uses new values for any already given headers + copied = client.copy(default_headers={"X-Foo": "stainless"}) + assert copied.default_headers["X-Foo"] == "stainless" + + # set_default_headers + + # completely overrides already set values + copied = client.copy(set_default_headers={}) + assert copied.default_headers.get("X-Foo") is None + + copied = client.copy(set_default_headers={"X-Bar": "Robert"}) + assert copied.default_headers["X-Bar"] == "Robert" + + with pytest.raises( + ValueError, + match="`default_headers` and `set_default_headers` arguments are mutually exclusive", + ): + client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + await client.close() + + async def test_copy_default_query(self) -> None: + client = AsyncGeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"} + ) + assert _get_params(client)["foo"] == "bar" + + # does not override the already given value when not specified + copied = client.copy() + assert _get_params(copied)["foo"] == "bar" + + # merges already given params + copied = client.copy(default_query={"bar": "stainless"}) + params = _get_params(copied) + assert params["foo"] == "bar" + assert params["bar"] == "stainless" + + # uses new values for any already given headers + copied = client.copy(default_query={"foo": "stainless"}) + assert _get_params(copied)["foo"] == "stainless" + + # set_default_query + + # completely overrides already set values + copied = client.copy(set_default_query={}) + assert _get_params(copied) == {} + + copied = client.copy(set_default_query={"bar": "Robert"}) + assert _get_params(copied)["bar"] == "Robert" + + with pytest.raises( + ValueError, + # TODO: update + match="`default_query` and `set_default_query` arguments are mutually exclusive", + ): + client.copy(set_default_query={}, default_query={"foo": "Bar"}) + + await client.close() + + def test_copy_signature(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + # ensure the same parameters that can be passed to the client are defined in the `.copy()` method + init_signature = inspect.signature( + # mypy doesn't like that we access the `__init__` property. + async_client.__init__, # type: ignore[misc] + ) + copy_signature = inspect.signature(async_client.copy) + exclude_params = {"transport", "proxies", "_strict_response_validation"} + + for name in init_signature.parameters.keys(): + if name in exclude_params: + continue + + copy_param = copy_signature.parameters.get(name) + assert copy_param is not None, f"copy() signature is missing the {name} param" + + @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") + def test_copy_build_request(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + options = FinalRequestOptions(method="get", url="/foo") + + def build_request(options: FinalRequestOptions) -> None: + client_copy = async_client.copy() + client_copy._build_request(options) + + # ensure that the machinery is warmed up before tracing starts. + build_request(options) + gc.collect() + + tracemalloc.start(1000) + + snapshot_before = tracemalloc.take_snapshot() + + ITERATIONS = 10 + for _ in range(ITERATIONS): + build_request(options) + + gc.collect() + snapshot_after = tracemalloc.take_snapshot() + + tracemalloc.stop() + + def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None: + if diff.count == 0: + # Avoid false positives by considering only leaks (i.e. allocations that persist). + return + + if diff.count % ITERATIONS != 0: + # Avoid false positives by considering only leaks that appear per iteration. + return + + for frame in diff.traceback: + if any( + frame.filename.endswith(fragment) + for fragment in [ + # to_raw_response_wrapper leaks through the @functools.wraps() decorator. + # + # removing the decorator fixes the leak for reasons we don't understand. + "google.genai._interactions/_legacy_response.py", + "google.genai._interactions/_response.py", + # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason. + "google.genai._interactions/_compat.py", + # Standard library leaks we don't care about. + "/logging/__init__.py", + ] + ): + return + + leaks.append(diff) + + leaks: list[tracemalloc.StatisticDiff] = [] + for diff in snapshot_after.compare_to(snapshot_before, "traceback"): + add_leak(leaks, diff) + if leaks: + for leak in leaks: + print("MEMORY LEAK:", leak) + for frame in leak.traceback: + print(frame) + raise AssertionError() + + async def test_request_timeout(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + request = async_client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == DEFAULT_TIMEOUT + + request = async_client._build_request( + FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) + ) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == httpx.Timeout(100.0) + + async def test_client_timeout_option(self) -> None: + client = AsyncGeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0) + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == httpx.Timeout(0) + + await client.close() + + async def test_http_client_timeout_option(self) -> None: + # custom timeout given to the httpx client should be used + async with httpx.AsyncClient(timeout=None) as http_client: + client = AsyncGeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == httpx.Timeout(None) + + await client.close() + + # no timeout given to the httpx client should not use the httpx default + async with httpx.AsyncClient() as http_client: + client = AsyncGeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == DEFAULT_TIMEOUT + + await client.close() + + # explicitly passing the default timeout currently results in it being ignored + async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: + client = AsyncGeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == DEFAULT_TIMEOUT # our default + + await client.close() + + def test_invalid_http_client(self) -> None: + with pytest.raises(TypeError, match="Invalid `http_client` arg"): + with httpx.Client() as http_client: + AsyncGeminiNextGenAPIClient( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=cast(Any, http_client), + ) + + async def test_default_headers_option(self) -> None: + test_client = AsyncGeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} + ) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) + assert request.headers.get("x-foo") == "bar" + + await test_client.close() + + def test_validate_headers(self) -> None: + client = AsyncGeminiNextGenAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + assert request.headers.get("x-goog-api-key") == api_key + + with update_env(**{"GEMINI_API_KEY": Omit()}): + client2 = AsyncGeminiNextGenAPIClient(base_url=base_url, api_key=None, _strict_response_validation=True) + + with pytest.raises( + TypeError, + match="Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted", + ): + client2._build_request(FinalRequestOptions(method="get", url="/foo")) + + request2 = client2._build_request( + FinalRequestOptions(method="get", url="/foo", headers={"x-goog-api-key": Omit()}) + ) + assert request2.headers.get("x-goog-api-key") is None + + async def test_default_query_option(self) -> None: + client = AsyncGeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"} + ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + url = httpx.URL(request.url) + assert dict(url.params) == {"query_param": "bar"} + + request = client._build_request( + FinalRequestOptions( + method="get", + url="/foo", + params={"foo": "baz", "query_param": "overridden"}, + ) + ) + url = httpx.URL(request.url) + assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} + + await client.close() + + async def test_hardcoded_query_params_in_url(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + request = async_client._build_request(FinalRequestOptions(method="get", url="/foo?beta=true")) + url = httpx.URL(request.url) + assert dict(url.params) == {"beta": "true"} + + request = async_client._build_request( + FinalRequestOptions( + method="get", + url="/foo?beta=true", + params={"limit": "10", "page": "abc"}, + ) + ) + url = httpx.URL(request.url) + assert dict(url.params) == {"beta": "true", "limit": "10", "page": "abc"} + + request = async_client._build_request( + FinalRequestOptions( + method="get", + url="/files/a%2Fb?beta=true", + params={"limit": "10"}, + ) + ) + assert request.url.raw_path == b"/files/a%2Fb?beta=true&limit=10" + + @pytest.mark.skip(reason="Mock server tests are disabled") + async def test_api_version_client_params(self) -> None: + client = AsyncGeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, api_version="My API Version" + ) + async with client as c2: + await c2.interactions.create( + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ) + + def test_request_extra_json(self, client: GeminiNextGenAPIClient) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar"}, + extra_json={"baz": False}, + ), + ) + data = json.loads(request.content.decode("utf-8")) + assert data == {"foo": "bar", "baz": False} + + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + extra_json={"baz": False}, + ), + ) + data = json.loads(request.content.decode("utf-8")) + assert data == {"baz": False} + + # `extra_json` takes priority over `json_data` when keys clash + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar", "baz": True}, + extra_json={"baz": None}, + ), + ) + data = json.loads(request.content.decode("utf-8")) + assert data == {"foo": "bar", "baz": None} + + def test_request_extra_headers(self, client: GeminiNextGenAPIClient) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options(extra_headers={"X-Foo": "Foo"}), + ), + ) + assert request.headers.get("X-Foo") == "Foo" + + # `extra_headers` takes priority over `default_headers` when keys clash + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + extra_headers={"X-Bar": "false"}, + ), + ), + ) + assert request.headers.get("X-Bar") == "false" + + def test_request_extra_query(self, client: GeminiNextGenAPIClient) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + extra_query={"my_query_param": "Foo"}, + ), + ), + ) + params = dict(request.url.params) + assert params == {"my_query_param": "Foo"} + + # if both `query` and `extra_query` are given, they are merged + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + query={"bar": "1"}, + extra_query={"foo": "2"}, + ), + ), + ) + params = dict(request.url.params) + assert params == {"bar": "1", "foo": "2"} + + # `extra_query` takes priority over `query` when keys clash + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + query={"foo": "1"}, + extra_query={"foo": "2"}, + ), + ), + ) + params = dict(request.url.params) + assert params == {"foo": "2"} + + def test_multipart_repeating_array(self, async_client: AsyncGeminiNextGenAPIClient) -> None: + request = async_client._build_request( + FinalRequestOptions.construct( + method="post", + url="/foo", + headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"}, + json_data={"array": ["foo", "bar"]}, + files=[("foo.txt", b"hello world")], + ) + ) + + assert request.read().split(b"\r\n") == [ + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="array[]"', + b"", + b"foo", + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="array[]"', + b"", + b"bar", + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="foo.txt"; filename="upload"', + b"Content-Type: application/octet-stream", + b"", + b"hello world", + b"--6b7ba517decee4a450543ea6ae821c82--", + b"", + ] + + @pytest.mark.respx(base_url=base_url) + async def test_binary_content_upload( + self, respx_mock: MockRouter, async_client: AsyncGeminiNextGenAPIClient + ) -> None: + respx_mock.post("/upload").mock(side_effect=mirror_request_content) + + file_content = b"Hello, this is a test file." + + response = await async_client.post( + "/upload", + content=file_content, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + + async def test_binary_content_upload_with_asynciterator(self) -> None: + file_content = b"Hello, this is a test file." + counter = Counter() + iterator = _make_async_iterator([file_content], counter=counter) + + async def mock_handler(request: httpx.Request) -> httpx.Response: + assert counter.value == 0, "the request body should not have been read" + return httpx.Response(200, content=await request.aread()) + + async with AsyncGeminiNextGenAPIClient( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.AsyncClient(transport=MockTransport(handler=mock_handler)), + ) as client: + response = await client.post( + "/upload", + content=iterator, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + assert counter.value == 1 + + @pytest.mark.respx(base_url=base_url) + async def test_binary_content_upload_with_body_is_deprecated( + self, respx_mock: MockRouter, async_client: AsyncGeminiNextGenAPIClient + ) -> None: + respx_mock.post("/upload").mock(side_effect=mirror_request_content) + + file_content = b"Hello, this is a test file." + + with pytest.deprecated_call( + match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead." + ): + response = await async_client.post( + "/upload", + body=file_content, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + + @pytest.mark.respx(base_url=base_url) + async def test_basic_union_response( + self, respx_mock: MockRouter, async_client: AsyncGeminiNextGenAPIClient + ) -> None: + class Model1(BaseModel): + name: str + + class Model2(BaseModel): + foo: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + assert isinstance(response, Model2) + assert response.foo == "bar" + + @pytest.mark.respx(base_url=base_url) + async def test_union_response_different_types( + self, respx_mock: MockRouter, async_client: AsyncGeminiNextGenAPIClient + ) -> None: + """Union of objects with the same field name using a different type""" + + class Model1(BaseModel): + foo: int + + class Model2(BaseModel): + foo: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + assert isinstance(response, Model2) + assert response.foo == "bar" + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) + + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + assert isinstance(response, Model1) + assert response.foo == 1 + + @pytest.mark.respx(base_url=base_url) + async def test_non_application_json_content_type_for_json_data( + self, respx_mock: MockRouter, async_client: AsyncGeminiNextGenAPIClient + ) -> None: + """ + Response that sets Content-Type to something other than application/json but returns json data + """ + + class Model(BaseModel): + foo: int + + respx_mock.get("/foo").mock( + return_value=httpx.Response( + 200, + content=json.dumps({"foo": 2}), + headers={"Content-Type": "application/text"}, + ) + ) + + response = await async_client.get("/foo", cast_to=Model) + assert isinstance(response, Model) + assert response.foo == 2 + + async def test_base_url_setter(self) -> None: + client = AsyncGeminiNextGenAPIClient( + base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True + ) + assert client.base_url == "https://example.com/from_init/" + + client.base_url = "https://example.com/from_setter" # type: ignore[assignment] + + assert client.base_url == "https://example.com/from_setter/" + + await client.close() + + async def test_base_url_env(self) -> None: + with update_env(GEMINI_NEXT_GEN_API_BASE_URL="http://localhost:5000/from/env"): + client = AsyncGeminiNextGenAPIClient(api_key=api_key, _strict_response_validation=True) + assert client.base_url == "http://localhost:5000/from/env/" + + @pytest.mark.parametrize( + "client", + [ + AsyncGeminiNextGenAPIClient( + base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True + ), + AsyncGeminiNextGenAPIClient( + base_url="http://localhost:5000/custom/path/", + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.AsyncClient(), + ), + ], + ids=["standard", "custom http client"], + ) + async def test_base_url_trailing_slash(self, client: AsyncGeminiNextGenAPIClient) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar"}, + ), + ) + assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() + + @pytest.mark.parametrize( + "client", + [ + AsyncGeminiNextGenAPIClient( + base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True + ), + AsyncGeminiNextGenAPIClient( + base_url="http://localhost:5000/custom/path/", + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.AsyncClient(), + ), + ], + ids=["standard", "custom http client"], + ) + async def test_base_url_no_trailing_slash(self, client: AsyncGeminiNextGenAPIClient) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar"}, + ), + ) + assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() + + @pytest.mark.parametrize( + "client", + [ + AsyncGeminiNextGenAPIClient( + base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True + ), + AsyncGeminiNextGenAPIClient( + base_url="http://localhost:5000/custom/path/", + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.AsyncClient(), + ), + ], + ids=["standard", "custom http client"], + ) + async def test_absolute_request_url(self, client: AsyncGeminiNextGenAPIClient) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="https://myapi.com/foo", + json_data={"foo": "bar"}, + ), + ) + assert request.url == "https://myapi.com/foo" + await client.close() + + async def test_copied_client_does_not_close_http(self) -> None: + test_client = AsyncGeminiNextGenAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() + + copied = test_client.copy() + assert copied is not test_client + + del copied + + await asyncio.sleep(0.2) + assert not test_client.is_closed() + + async def test_client_context_manager(self) -> None: + test_client = AsyncGeminiNextGenAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) + async with test_client as c2: + assert c2 is test_client + assert not c2.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() + + @pytest.mark.respx(base_url=base_url) + async def test_client_response_validation_error( + self, respx_mock: MockRouter, async_client: AsyncGeminiNextGenAPIClient + ) -> None: + class Model(BaseModel): + foo: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) + + with pytest.raises(APIResponseValidationError) as exc: + await async_client.get("/foo", cast_to=Model) + + assert isinstance(exc.value.__cause__, ValidationError) + + async def test_client_max_retries_validation(self) -> None: + with pytest.raises(TypeError, match=r"max_retries cannot be None"): + AsyncGeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None) + ) + + @pytest.mark.respx(base_url=base_url) + async def test_default_stream_cls(self, respx_mock: MockRouter, async_client: AsyncGeminiNextGenAPIClient) -> None: + class Model(BaseModel): + name: str + + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + stream = await async_client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model]) + assert isinstance(stream, AsyncStream) + await stream.response.aclose() + + @pytest.mark.respx(base_url=base_url) + async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: + class Model(BaseModel): + name: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format")) + + strict_client = AsyncGeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=True + ) + + with pytest.raises(APIResponseValidationError): + await strict_client.get("/foo", cast_to=Model) + + non_strict_client = AsyncGeminiNextGenAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=False + ) + + response = await non_strict_client.get("/foo", cast_to=Model) + assert isinstance(response, str) # type: ignore[unreachable] + + await strict_client.close() + await non_strict_client.close() + + @pytest.mark.parametrize( + "remaining_retries,retry_after,timeout", + [ + [3, "20", 20], + [3, "0", 0.5], + [3, "-10", 0.5], + [3, "60", 60], + [3, "61", 0.5], + [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20], + [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5], + [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5], + [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60], + [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5], + [3, "99999999999999999999999999999999999", 0.5], + [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5], + [3, "", 0.5], + [2, "", 0.5 * 2.0], + [1, "", 0.5 * 4.0], + [-1100, "", 8], # test large number potentially overflowing + ], + ) + @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) + async def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncGeminiNextGenAPIClient + ) -> None: + headers = httpx.Headers({"retry-after": retry_after}) + options = FinalRequestOptions(method="get", url="/foo", max_retries=3) + calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers) + assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] + + @mock.patch("google.genai._interactions._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + async def test_retrying_timeout_errors_doesnt_leak( + self, respx_mock: MockRouter, async_client: AsyncGeminiNextGenAPIClient + ) -> None: + respx_mock.post("/api_version/interactions").mock(side_effect=httpx.TimeoutException("Test timeout error")) + + with pytest.raises(APITimeoutError): + await async_client.interactions.with_streaming_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ).__aenter__() + + assert _get_open_connections(async_client) == 0 + + @mock.patch("google.genai._interactions._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + async def test_retrying_status_errors_doesnt_leak( + self, respx_mock: MockRouter, async_client: AsyncGeminiNextGenAPIClient + ) -> None: + respx_mock.post("/api_version/interactions").mock(return_value=httpx.Response(500)) + + with pytest.raises(APIStatusError): + await async_client.interactions.with_streaming_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ).__aenter__() + assert _get_open_connections(async_client) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("google.genai._interactions._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.parametrize("failure_mode", ["status", "exception"]) + async def test_retries_taken( + self, + async_client: AsyncGeminiNextGenAPIClient, + failures_before_success: int, + failure_mode: Literal["status", "exception"], + respx_mock: MockRouter, + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + if failure_mode == "exception": + raise RuntimeError("oops") + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/api_version/interactions").mock(side_effect=retry_handler) + + response = await client.interactions.with_raw_response.create( + api_version="api_version", + input={ + "text": "text", + "type": "text", + }, + model="gemini-2.5-computer-use-preview-10-2025", + ) + + assert response.retries_taken == failures_before_success + + async def test_get_platform(self) -> None: + platform = await asyncify(get_platform)() + assert isinstance(platform, (str, OtherPlatform)) + + async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Test that the proxy environment variables are set correctly + monkeypatch.setenv("HTTPS_PROXY", "https://example.org") + # Delete in case our environment has any proxy env vars set + monkeypatch.delenv("HTTP_PROXY", raising=False) + monkeypatch.delenv("ALL_PROXY", raising=False) + monkeypatch.delenv("NO_PROXY", raising=False) + monkeypatch.delenv("http_proxy", raising=False) + monkeypatch.delenv("https_proxy", raising=False) + monkeypatch.delenv("all_proxy", raising=False) + monkeypatch.delenv("no_proxy", raising=False) + + client = DefaultAsyncHttpxClient() + + mounts = tuple(client._mounts.items()) + assert len(mounts) == 1 + assert mounts[0][0].pattern == "https://" + + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") + async def test_default_client_creation(self) -> None: + # Ensure that the client can be initialized without any exceptions + DefaultAsyncHttpxClient( + verify=True, + cert=None, + trust_env=True, + http1=True, + http2=False, + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + ) + + @pytest.mark.respx(base_url=base_url) + async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncGeminiNextGenAPIClient) -> None: + # Test that the default follow_redirects=True allows following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) + + response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + @pytest.mark.respx(base_url=base_url) + async def test_follow_redirects_disabled( + self, respx_mock: MockRouter, async_client: AsyncGeminiNextGenAPIClient + ) -> None: + # Test that follow_redirects=False prevents following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + + with pytest.raises(APIStatusError) as exc_info: + await async_client.post( + "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response + ) + + assert exc_info.value.response.status_code == 302 + assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" diff --git a/google/genai/tests/interactions/test_extract_files.py b/google/genai/tests/interactions/test_extract_files.py new file mode 100644 index 000000000..e75b3dffc --- /dev/null +++ b/google/genai/tests/interactions/test_extract_files.py @@ -0,0 +1,106 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +from typing import Sequence + +import pytest + +from google.genai._interactions._types import FileTypes, ArrayFormat +from google.genai._interactions._utils import extract_files + + +def test_removes_files_from_input() -> None: + query = {"foo": "bar"} + assert extract_files(query, paths=[]) == [] + assert query == {"foo": "bar"} + + query2 = {"foo": b"Bar", "hello": "world"} + assert extract_files(query2, paths=[["foo"]]) == [("foo", b"Bar")] + assert query2 == {"hello": "world"} + + query3 = {"foo": {"foo": {"bar": b"Bar"}}, "hello": "world"} + assert extract_files(query3, paths=[["foo", "foo", "bar"]]) == [("foo[foo][bar]", b"Bar")] + assert query3 == {"foo": {"foo": {}}, "hello": "world"} + + query4 = {"foo": {"bar": b"Bar", "baz": "foo"}, "hello": "world"} + assert extract_files(query4, paths=[["foo", "bar"]]) == [("foo[bar]", b"Bar")] + assert query4 == {"hello": "world", "foo": {"baz": "foo"}} + + +def test_multiple_files() -> None: + query = {"documents": [{"file": b"My first file"}, {"file": b"My second file"}]} + assert extract_files(query, paths=[["documents", "", "file"]]) == [ + ("documents[][file]", b"My first file"), + ("documents[][file]", b"My second file"), + ] + assert query == {"documents": [{}, {}]} + + +def test_top_level_file_array() -> None: + query = {"files": [b"file one", b"file two"], "title": "hello"} + assert extract_files(query, paths=[["files", ""]]) == [("files[]", b"file one"), ("files[]", b"file two")] + assert query == {"title": "hello"} + + +@pytest.mark.parametrize( + "query,paths,expected", + [ + [ + {"foo": {"bar": "baz"}}, + [["foo", "", "bar"]], + [], + ], + [ + {"foo": ["bar", "baz"]}, + [["foo", "bar"]], + [], + ], + [ + {"foo": {"bar": "baz"}}, + [["foo", "foo"]], + [], + ], + ], + ids=["dict expecting array", "array expecting dict", "unknown keys"], +) +def test_ignores_incorrect_paths( + query: dict[str, object], + paths: Sequence[Sequence[str]], + expected: list[tuple[str, FileTypes]], +) -> None: + assert extract_files(query, paths=paths) == expected + + +@pytest.mark.parametrize( + "array_format,expected_top_level,expected_nested", + [ + ("brackets", [("files[]", b"a"), ("files[]", b"b")], [("items[][file]", b"a"), ("items[][file]", b"b")]), + ("repeat", [("files", b"a"), ("files", b"b")], [("items[file]", b"a"), ("items[file]", b"b")]), + ("comma", [("files", b"a"), ("files", b"b")], [("items[file]", b"a"), ("items[file]", b"b")]), + ("indices", [("files[0]", b"a"), ("files[1]", b"b")], [("items[0][file]", b"a"), ("items[1][file]", b"b")]), + ], +) +def test_array_format_controls_file_field_names( + array_format: ArrayFormat, + expected_top_level: list[tuple[str, FileTypes]], + expected_nested: list[tuple[str, FileTypes]], +) -> None: + top_level = {"files": [b"a", b"b"]} + assert extract_files(top_level, paths=[["files", ""]], array_format=array_format) == expected_top_level + + nested = {"items": [{"file": b"a"}, {"file": b"b"}]} + assert extract_files(nested, paths=[["items", "", "file"]], array_format=array_format) == expected_nested diff --git a/google/genai/tests/interactions/test_files.py b/google/genai/tests/interactions/test_files.py new file mode 100644 index 000000000..ce665c405 --- /dev/null +++ b/google/genai/tests/interactions/test_files.py @@ -0,0 +1,163 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pathlib import Path + +import anyio +import pytest +from dirty_equals import IsDict, IsList, IsBytes, IsTuple + +from google.genai._interactions._files import to_httpx_files, deepcopy_with_paths, async_to_httpx_files +from google.genai._interactions._utils import extract_files + +readme_path = Path(__file__).parent.parent.joinpath("README.md") + + +def test_pathlib_includes_file_name() -> None: + result = to_httpx_files({"file": readme_path}) + print(result) + assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) + + +def test_tuple_input() -> None: + result = to_httpx_files([("file", readme_path)]) + print(result) + assert result == IsList(IsTuple("file", IsTuple("README.md", IsBytes()))) + + +@pytest.mark.asyncio +async def test_async_pathlib_includes_file_name() -> None: + result = await async_to_httpx_files({"file": readme_path}) + print(result) + assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) + + +@pytest.mark.asyncio +async def test_async_supports_anyio_path() -> None: + result = await async_to_httpx_files({"file": anyio.Path(readme_path)}) + print(result) + assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) + + +@pytest.mark.asyncio +async def test_async_tuple_input() -> None: + result = await async_to_httpx_files([("file", readme_path)]) + print(result) + assert result == IsList(IsTuple("file", IsTuple("README.md", IsBytes()))) + + +def test_string_not_allowed() -> None: + with pytest.raises(TypeError, match="Expected file types input to be a FileContent type or to be a tuple"): + to_httpx_files( + { + "file": "foo", # type: ignore + } + ) + + +def assert_different_identities(obj1: object, obj2: object) -> None: + assert obj1 == obj2 + assert obj1 is not obj2 + + +class TestDeepcopyWithPaths: + def test_copies_top_level_dict(self) -> None: + original = {"file": b"data", "other": "value"} + result = deepcopy_with_paths(original, [["file"]]) + assert_different_identities(result, original) + + def test_file_value_is_same_reference(self) -> None: + file_bytes = b"contents" + original = {"file": file_bytes} + result = deepcopy_with_paths(original, [["file"]]) + assert_different_identities(result, original) + assert result["file"] is file_bytes + + def test_list_popped_wholesale(self) -> None: + files = [b"f1", b"f2"] + original = {"files": files, "title": "t"} + result = deepcopy_with_paths(original, [["files", ""]]) + assert_different_identities(result, original) + result_files = result["files"] + assert isinstance(result_files, list) + assert_different_identities(result_files, files) + + def test_nested_array_path_copies_list_and_elements(self) -> None: + elem1 = {"file": b"f1", "extra": 1} + elem2 = {"file": b"f2", "extra": 2} + original = {"items": [elem1, elem2]} + result = deepcopy_with_paths(original, [["items", "", "file"]]) + assert_different_identities(result, original) + result_items = result["items"] + assert isinstance(result_items, list) + assert_different_identities(result_items, original["items"]) + assert_different_identities(result_items[0], elem1) + assert_different_identities(result_items[1], elem2) + + def test_empty_paths_returns_same_object(self) -> None: + original = {"foo": "bar"} + result = deepcopy_with_paths(original, []) + assert result is original + + def test_multiple_paths(self) -> None: + f1 = b"file1" + f2 = b"file2" + original = {"a": f1, "b": f2, "c": "unchanged"} + result = deepcopy_with_paths(original, [["a"], ["b"]]) + assert_different_identities(result, original) + assert result["a"] is f1 + assert result["b"] is f2 + assert result["c"] is original["c"] + + def test_extract_files_does_not_mutate_original_top_level(self) -> None: + file_bytes = b"contents" + original = {"file": file_bytes, "other": "value"} + + copied = deepcopy_with_paths(original, [["file"]]) + extracted = extract_files(copied, paths=[["file"]]) + + assert extracted == [("file", file_bytes)] + assert original == {"file": file_bytes, "other": "value"} + assert copied == {"other": "value"} + + def test_extract_files_does_not_mutate_original_nested_array_path(self) -> None: + file1 = b"f1" + file2 = b"f2" + original = { + "items": [ + {"file": file1, "extra": 1}, + {"file": file2, "extra": 2}, + ], + "title": "example", + } + + copied = deepcopy_with_paths(original, [["items", "", "file"]]) + extracted = extract_files(copied, paths=[["items", "", "file"]]) + + assert [entry for _, entry in extracted] == [file1, file2] + assert original == { + "items": [ + {"file": file1, "extra": 1}, + {"file": file2, "extra": 2}, + ], + "title": "example", + } + assert copied == { + "items": [ + {"extra": 1}, + {"extra": 2}, + ], + "title": "example", + } diff --git a/google/genai/tests/interactions/test_models.py b/google/genai/tests/interactions/test_models.py new file mode 100644 index 000000000..b86bf7bef --- /dev/null +++ b/google/genai/tests/interactions/test_models.py @@ -0,0 +1,978 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, cast +from datetime import datetime, timezone +from typing_extensions import Literal, Annotated, TypeAliasType + +import pytest +import pydantic +from pydantic import Field + +from google.genai._interactions._utils import PropertyInfo +from google.genai._interactions._compat import PYDANTIC_V1, parse_obj, model_dump, model_json +from google.genai._interactions._models import DISCRIMINATOR_CACHE, BaseModel, construct_type + + +class BasicModel(BaseModel): + foo: str + + +@pytest.mark.parametrize("value", ["hello", 1], ids=["correct type", "mismatched"]) +def test_basic(value: object) -> None: + m = BasicModel.construct(foo=value) + assert m.foo == value + + +def test_directly_nested_model() -> None: + class NestedModel(BaseModel): + nested: BasicModel + + m = NestedModel.construct(nested={"foo": "Foo!"}) + assert m.nested.foo == "Foo!" + + # mismatched types + m = NestedModel.construct(nested="hello!") + assert cast(Any, m.nested) == "hello!" + + +def test_optional_nested_model() -> None: + class NestedModel(BaseModel): + nested: Optional[BasicModel] + + m1 = NestedModel.construct(nested=None) + assert m1.nested is None + + m2 = NestedModel.construct(nested={"foo": "bar"}) + assert m2.nested is not None + assert m2.nested.foo == "bar" + + # mismatched types + m3 = NestedModel.construct(nested={"foo"}) + assert isinstance(cast(Any, m3.nested), set) + assert cast(Any, m3.nested) == {"foo"} + + +def test_list_nested_model() -> None: + class NestedModel(BaseModel): + nested: List[BasicModel] + + m = NestedModel.construct(nested=[{"foo": "bar"}, {"foo": "2"}]) + assert m.nested is not None + assert isinstance(m.nested, list) + assert len(m.nested) == 2 + assert m.nested[0].foo == "bar" + assert m.nested[1].foo == "2" + + # mismatched types + m = NestedModel.construct(nested=True) + assert cast(Any, m.nested) is True + + m = NestedModel.construct(nested=[False]) + assert cast(Any, m.nested) == [False] + + +def test_optional_list_nested_model() -> None: + class NestedModel(BaseModel): + nested: Optional[List[BasicModel]] + + m1 = NestedModel.construct(nested=[{"foo": "bar"}, {"foo": "2"}]) + assert m1.nested is not None + assert isinstance(m1.nested, list) + assert len(m1.nested) == 2 + assert m1.nested[0].foo == "bar" + assert m1.nested[1].foo == "2" + + m2 = NestedModel.construct(nested=None) + assert m2.nested is None + + # mismatched types + m3 = NestedModel.construct(nested={1}) + assert cast(Any, m3.nested) == {1} + + m4 = NestedModel.construct(nested=[False]) + assert cast(Any, m4.nested) == [False] + + +def test_list_optional_items_nested_model() -> None: + class NestedModel(BaseModel): + nested: List[Optional[BasicModel]] + + m = NestedModel.construct(nested=[None, {"foo": "bar"}]) + assert m.nested is not None + assert isinstance(m.nested, list) + assert len(m.nested) == 2 + assert m.nested[0] is None + assert m.nested[1] is not None + assert m.nested[1].foo == "bar" + + # mismatched types + m3 = NestedModel.construct(nested="foo") + assert cast(Any, m3.nested) == "foo" + + m4 = NestedModel.construct(nested=[False]) + assert cast(Any, m4.nested) == [False] + + +def test_list_mismatched_type() -> None: + class NestedModel(BaseModel): + nested: List[str] + + m = NestedModel.construct(nested=False) + assert cast(Any, m.nested) is False + + +def test_raw_dictionary() -> None: + class NestedModel(BaseModel): + nested: Dict[str, str] + + m = NestedModel.construct(nested={"hello": "world"}) + assert m.nested == {"hello": "world"} + + # mismatched types + m = NestedModel.construct(nested=False) + assert cast(Any, m.nested) is False + + +def test_nested_dictionary_model() -> None: + class NestedModel(BaseModel): + nested: Dict[str, BasicModel] + + m = NestedModel.construct(nested={"hello": {"foo": "bar"}}) + assert isinstance(m.nested, dict) + assert m.nested["hello"].foo == "bar" + + # mismatched types + m = NestedModel.construct(nested={"hello": False}) + assert cast(Any, m.nested["hello"]) is False + + +def test_unknown_fields() -> None: + m1 = BasicModel.construct(foo="foo", unknown=1) + assert m1.foo == "foo" + assert cast(Any, m1).unknown == 1 + + m2 = BasicModel.construct(foo="foo", unknown={"foo_bar": True}) + assert m2.foo == "foo" + assert cast(Any, m2).unknown == {"foo_bar": True} + + assert model_dump(m2) == {"foo": "foo", "unknown": {"foo_bar": True}} + + +def test_strict_validation_unknown_fields() -> None: + class Model(BaseModel): + foo: str + + model = parse_obj(Model, dict(foo="hello!", user="Robert")) + assert model.foo == "hello!" + assert cast(Any, model).user == "Robert" + + assert model_dump(model) == {"foo": "hello!", "user": "Robert"} + + +def test_aliases() -> None: + class Model(BaseModel): + my_field: int = Field(alias="myField") + + m = Model.construct(myField=1) + assert m.my_field == 1 + + # mismatched types + m = Model.construct(myField={"hello": False}) + assert cast(Any, m.my_field) == {"hello": False} + + +def test_repr() -> None: + model = BasicModel(foo="bar") + assert str(model) == "BasicModel(foo='bar')" + assert repr(model) == "BasicModel(foo='bar')" + + +def test_repr_nested_model() -> None: + class Child(BaseModel): + name: str + age: int + + class Parent(BaseModel): + name: str + child: Child + + model = Parent(name="Robert", child=Child(name="Foo", age=5)) + assert str(model) == "Parent(name='Robert', child=Child(name='Foo', age=5))" + assert repr(model) == "Parent(name='Robert', child=Child(name='Foo', age=5))" + + +def test_optional_list() -> None: + class Submodel(BaseModel): + name: str + + class Model(BaseModel): + items: Optional[List[Submodel]] + + m = Model.construct(items=None) + assert m.items is None + + m = Model.construct(items=[]) + assert m.items == [] + + m = Model.construct(items=[{"name": "Robert"}]) + assert m.items is not None + assert len(m.items) == 1 + assert m.items[0].name == "Robert" + + +def test_nested_union_of_models() -> None: + class Submodel1(BaseModel): + bar: bool + + class Submodel2(BaseModel): + thing: str + + class Model(BaseModel): + foo: Union[Submodel1, Submodel2] + + m = Model.construct(foo={"thing": "hello"}) + assert isinstance(m.foo, Submodel2) + assert m.foo.thing == "hello" + + +def test_nested_union_of_mixed_types() -> None: + class Submodel1(BaseModel): + bar: bool + + class Model(BaseModel): + foo: Union[Submodel1, Literal[True], Literal["CARD_HOLDER"]] + + m = Model.construct(foo=True) + assert m.foo is True + + m = Model.construct(foo="CARD_HOLDER") + assert m.foo == "CARD_HOLDER" + + m = Model.construct(foo={"bar": False}) + assert isinstance(m.foo, Submodel1) + assert m.foo.bar is False + + +def test_nested_union_multiple_variants() -> None: + class Submodel1(BaseModel): + bar: bool + + class Submodel2(BaseModel): + thing: str + + class Submodel3(BaseModel): + foo: int + + class Model(BaseModel): + foo: Union[Submodel1, Submodel2, None, Submodel3] + + m = Model.construct(foo={"thing": "hello"}) + assert isinstance(m.foo, Submodel2) + assert m.foo.thing == "hello" + + m = Model.construct(foo=None) + assert m.foo is None + + m = Model.construct() + assert m.foo is None + + m = Model.construct(foo={"foo": "1"}) + assert isinstance(m.foo, Submodel3) + assert m.foo.foo == 1 + + +def test_nested_union_invalid_data() -> None: + class Submodel1(BaseModel): + level: int + + class Submodel2(BaseModel): + name: str + + class Model(BaseModel): + foo: Union[Submodel1, Submodel2] + + m = Model.construct(foo=True) + assert cast(bool, m.foo) is True + + m = Model.construct(foo={"name": 3}) + if PYDANTIC_V1: + assert isinstance(m.foo, Submodel2) + assert m.foo.name == "3" + else: + assert isinstance(m.foo, Submodel1) + assert m.foo.name == 3 # type: ignore + + +def test_list_of_unions() -> None: + class Submodel1(BaseModel): + level: int + + class Submodel2(BaseModel): + name: str + + class Model(BaseModel): + items: List[Union[Submodel1, Submodel2]] + + m = Model.construct(items=[{"level": 1}, {"name": "Robert"}]) + assert len(m.items) == 2 + assert isinstance(m.items[0], Submodel1) + assert m.items[0].level == 1 + assert isinstance(m.items[1], Submodel2) + assert m.items[1].name == "Robert" + + m = Model.construct(items=[{"level": -1}, 156]) + assert len(m.items) == 2 + assert isinstance(m.items[0], Submodel1) + assert m.items[0].level == -1 + assert cast(Any, m.items[1]) == 156 + + +def test_union_of_lists() -> None: + class SubModel1(BaseModel): + level: int + + class SubModel2(BaseModel): + name: str + + class Model(BaseModel): + items: Union[List[SubModel1], List[SubModel2]] + + # with one valid entry + m = Model.construct(items=[{"name": "Robert"}]) + assert len(m.items) == 1 + assert isinstance(m.items[0], SubModel2) + assert m.items[0].name == "Robert" + + # with two entries pointing to different types + m = Model.construct(items=[{"level": 1}, {"name": "Robert"}]) + assert len(m.items) == 2 + assert isinstance(m.items[0], SubModel1) + assert m.items[0].level == 1 + assert isinstance(m.items[1], SubModel1) + assert cast(Any, m.items[1]).name == "Robert" + + # with two entries pointing to *completely* different types + m = Model.construct(items=[{"level": -1}, 156]) + assert len(m.items) == 2 + assert isinstance(m.items[0], SubModel1) + assert m.items[0].level == -1 + assert cast(Any, m.items[1]) == 156 + + +def test_dict_of_union() -> None: + class SubModel1(BaseModel): + name: str + + class SubModel2(BaseModel): + foo: str + + class Model(BaseModel): + data: Dict[str, Union[SubModel1, SubModel2]] + + m = Model.construct(data={"hello": {"name": "there"}, "foo": {"foo": "bar"}}) + assert len(list(m.data.keys())) == 2 + assert isinstance(m.data["hello"], SubModel1) + assert m.data["hello"].name == "there" + assert isinstance(m.data["foo"], SubModel2) + assert m.data["foo"].foo == "bar" + + # TODO: test mismatched type + + +def test_double_nested_union() -> None: + class SubModel1(BaseModel): + name: str + + class SubModel2(BaseModel): + bar: str + + class Model(BaseModel): + data: Dict[str, List[Union[SubModel1, SubModel2]]] + + m = Model.construct(data={"foo": [{"bar": "baz"}, {"name": "Robert"}]}) + assert len(m.data["foo"]) == 2 + + entry1 = m.data["foo"][0] + assert isinstance(entry1, SubModel2) + assert entry1.bar == "baz" + + entry2 = m.data["foo"][1] + assert isinstance(entry2, SubModel1) + assert entry2.name == "Robert" + + # TODO: test mismatched type + + +def test_union_of_dict() -> None: + class SubModel1(BaseModel): + name: str + + class SubModel2(BaseModel): + foo: str + + class Model(BaseModel): + data: Union[Dict[str, SubModel1], Dict[str, SubModel2]] + + m = Model.construct(data={"hello": {"name": "there"}, "foo": {"foo": "bar"}}) + assert len(list(m.data.keys())) == 2 + assert isinstance(m.data["hello"], SubModel1) + assert m.data["hello"].name == "there" + assert isinstance(m.data["foo"], SubModel1) + assert cast(Any, m.data["foo"]).foo == "bar" + + +def test_iso8601_datetime() -> None: + class Model(BaseModel): + created_at: datetime + + expected = datetime(2019, 12, 27, 18, 11, 19, 117000, tzinfo=timezone.utc) + + if PYDANTIC_V1: + expected_json = '{"created_at": "2019-12-27T18:11:19.117000+00:00"}' + else: + expected_json = '{"created_at":"2019-12-27T18:11:19.117000Z"}' + + model = Model.construct(created_at="2019-12-27T18:11:19.117Z") + assert model.created_at == expected + assert model_json(model) == expected_json + + model = parse_obj(Model, dict(created_at="2019-12-27T18:11:19.117Z")) + assert model.created_at == expected + assert model_json(model) == expected_json + + +def test_does_not_coerce_int() -> None: + class Model(BaseModel): + bar: int + + assert Model.construct(bar=1).bar == 1 + assert Model.construct(bar=10.9).bar == 10.9 + assert Model.construct(bar="19").bar == "19" # type: ignore[comparison-overlap] + assert Model.construct(bar=False).bar is False + + +def test_int_to_float_safe_conversion() -> None: + class Model(BaseModel): + float_field: float + + m = Model.construct(float_field=10) + assert m.float_field == 10.0 + assert isinstance(m.float_field, float) + + m = Model.construct(float_field=10.12) + assert m.float_field == 10.12 + assert isinstance(m.float_field, float) + + # number too big + m = Model.construct(float_field=2**53 + 1) + assert m.float_field == 2**53 + 1 + assert isinstance(m.float_field, int) + + +def test_deprecated_alias() -> None: + class Model(BaseModel): + resource_id: str = Field(alias="model_id") + + @property + def model_id(self) -> str: + return self.resource_id + + m = Model.construct(model_id="id") + assert m.model_id == "id" + assert m.resource_id == "id" + assert m.resource_id is m.model_id + + m = parse_obj(Model, {"model_id": "id"}) + assert m.model_id == "id" + assert m.resource_id == "id" + assert m.resource_id is m.model_id + + +def test_omitted_fields() -> None: + class Model(BaseModel): + resource_id: Optional[str] = None + + m = Model.construct() + assert m.resource_id is None + assert "resource_id" not in m.model_fields_set + + m = Model.construct(resource_id=None) + assert m.resource_id is None + assert "resource_id" in m.model_fields_set + + m = Model.construct(resource_id="foo") + assert m.resource_id == "foo" + assert "resource_id" in m.model_fields_set + + +def test_to_dict() -> None: + class Model(BaseModel): + foo: Optional[str] = Field(alias="FOO", default=None) + + m = Model(FOO="hello") + assert m.to_dict() == {"FOO": "hello"} + assert m.to_dict(use_api_names=False) == {"foo": "hello"} + + m2 = Model() + assert m2.to_dict() == {} + assert m2.to_dict(exclude_unset=False) == {"FOO": None} + assert m2.to_dict(exclude_unset=False, exclude_none=True) == {} + assert m2.to_dict(exclude_unset=False, exclude_defaults=True) == {} + + m3 = Model(FOO=None) + assert m3.to_dict() == {"FOO": None} + assert m3.to_dict(exclude_none=True) == {} + assert m3.to_dict(exclude_defaults=True) == {} + + class Model2(BaseModel): + created_at: datetime + + time_str = "2024-03-21T11:39:01.275859" + m4 = Model2.construct(created_at=time_str) + assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)} + assert m4.to_dict(mode="json") == {"created_at": time_str} + + if PYDANTIC_V1: + with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): + m.to_dict(warnings=False) + + +def test_forwards_compat_model_dump_method() -> None: + class Model(BaseModel): + foo: Optional[str] = Field(alias="FOO", default=None) + + m = Model(FOO="hello") + assert m.model_dump() == {"foo": "hello"} + assert m.model_dump(include={"bar"}) == {} + assert m.model_dump(exclude={"foo"}) == {} + assert m.model_dump(by_alias=True) == {"FOO": "hello"} + + m2 = Model() + assert m2.model_dump() == {"foo": None} + assert m2.model_dump(exclude_unset=True) == {} + assert m2.model_dump(exclude_none=True) == {} + assert m2.model_dump(exclude_defaults=True) == {} + + m3 = Model(FOO=None) + assert m3.model_dump() == {"foo": None} + assert m3.model_dump(exclude_none=True) == {} + + if PYDANTIC_V1: + with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"): + m.model_dump(round_trip=True) + + with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): + m.model_dump(warnings=False) + + +def test_compat_method_no_error_for_warnings() -> None: + class Model(BaseModel): + foo: Optional[str] + + m = Model(foo="hello") + assert isinstance(model_dump(m, warnings=False), dict) + + +def test_to_json() -> None: + class Model(BaseModel): + foo: Optional[str] = Field(alias="FOO", default=None) + + m = Model(FOO="hello") + assert json.loads(m.to_json()) == {"FOO": "hello"} + assert json.loads(m.to_json(use_api_names=False)) == {"foo": "hello"} + + if PYDANTIC_V1: + assert m.to_json(indent=None) == '{"FOO": "hello"}' + else: + assert m.to_json(indent=None) == '{"FOO":"hello"}' + + m2 = Model() + assert json.loads(m2.to_json()) == {} + assert json.loads(m2.to_json(exclude_unset=False)) == {"FOO": None} + assert json.loads(m2.to_json(exclude_unset=False, exclude_none=True)) == {} + assert json.loads(m2.to_json(exclude_unset=False, exclude_defaults=True)) == {} + + m3 = Model(FOO=None) + assert json.loads(m3.to_json()) == {"FOO": None} + assert json.loads(m3.to_json(exclude_none=True)) == {} + + if PYDANTIC_V1: + with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): + m.to_json(warnings=False) + + +def test_forwards_compat_model_dump_json_method() -> None: + class Model(BaseModel): + foo: Optional[str] = Field(alias="FOO", default=None) + + m = Model(FOO="hello") + assert json.loads(m.model_dump_json()) == {"foo": "hello"} + assert json.loads(m.model_dump_json(include={"bar"})) == {} + assert json.loads(m.model_dump_json(include={"foo"})) == {"foo": "hello"} + assert json.loads(m.model_dump_json(by_alias=True)) == {"FOO": "hello"} + + assert m.model_dump_json(indent=2) == '{\n "foo": "hello"\n}' + + m2 = Model() + assert json.loads(m2.model_dump_json()) == {"foo": None} + assert json.loads(m2.model_dump_json(exclude_unset=True)) == {} + assert json.loads(m2.model_dump_json(exclude_none=True)) == {} + assert json.loads(m2.model_dump_json(exclude_defaults=True)) == {} + + m3 = Model(FOO=None) + assert json.loads(m3.model_dump_json()) == {"foo": None} + assert json.loads(m3.model_dump_json(exclude_none=True)) == {} + + if PYDANTIC_V1: + with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"): + m.model_dump_json(round_trip=True) + + with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): + m.model_dump_json(warnings=False) + + +def test_type_compat() -> None: + # our model type can be assigned to Pydantic's model type + + def takes_pydantic(model: pydantic.BaseModel) -> None: # noqa: ARG001 + ... + + class OurModel(BaseModel): + foo: Optional[str] = None + + takes_pydantic(OurModel()) + + +def test_annotated_types() -> None: + class Model(BaseModel): + value: str + + m = construct_type( + value={"value": "foo"}, + type_=cast(Any, Annotated[Model, "random metadata"]), + ) + assert isinstance(m, Model) + assert m.value == "foo" + + +def test_discriminated_unions_invalid_data() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "a", "data": 100}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, A) + assert m.type == "a" + if PYDANTIC_V1: + # pydantic v1 automatically converts inputs to strings + # if the expected type is a str + assert m.data == "100" + else: + assert m.data == 100 # type: ignore[comparison-overlap] + + +def test_discriminated_unions_unknown_variant() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + m = construct_type( + value={"type": "c", "data": None, "new_thing": "bar"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + + # just chooses the first variant + assert isinstance(m, A) + assert m.type == "c" # type: ignore[comparison-overlap] + assert m.data == None # type: ignore[unreachable] + assert m.new_thing == "bar" + + +def test_discriminated_unions_invalid_data_nested_unions() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + class C(BaseModel): + type: Literal["c"] + + data: bool + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "c", "data": "foo"}, + type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, C) + assert m.type == "c" + assert m.data == "foo" # type: ignore[comparison-overlap] + + +def test_discriminated_unions_with_aliases_invalid_data() -> None: + class A(BaseModel): + foo_type: Literal["a"] = Field(alias="type") + + data: str + + class B(BaseModel): + foo_type: Literal["b"] = Field(alias="type") + + data: int + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]), + ) + assert isinstance(m, B) + assert m.foo_type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "a", "data": 100}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]), + ) + assert isinstance(m, A) + assert m.foo_type == "a" + if PYDANTIC_V1: + # pydantic v1 automatically converts inputs to strings + # if the expected type is a str + assert m.data == "100" + else: + assert m.data == 100 # type: ignore[comparison-overlap] + + +def test_discriminated_unions_overlapping_discriminators_invalid_data() -> None: + class A(BaseModel): + type: Literal["a"] + + data: bool + + class B(BaseModel): + type: Literal["a"] + + data: int + + m = construct_type( + value={"type": "a", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "a" + assert m.data == "foo" # type: ignore[comparison-overlap] + + +def test_discriminated_unions_invalid_data_uses_cache() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + UnionType = cast(Any, Union[A, B]) + + assert not DISCRIMINATOR_CACHE.get(UnionType) + + m = construct_type( + value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + discriminator = DISCRIMINATOR_CACHE.get(UnionType) + assert discriminator is not None + + m = construct_type( + value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + # if the discriminator details object stays the same between invocations then + # we hit the cache + assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator + + +@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") +def test_type_alias_type() -> None: + Alias = TypeAliasType("Alias", str) # pyright: ignore + + class Model(BaseModel): + alias: Alias + union: Union[int, Alias] + + m = construct_type(value={"alias": "foo", "union": "bar"}, type_=Model) + assert isinstance(m, Model) + assert isinstance(m.alias, str) + assert m.alias == "foo" + assert isinstance(m.union, str) + assert m.union == "bar" + + +@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") +def test_field_named_cls() -> None: + class Model(BaseModel): + cls: str + + m = construct_type(value={"cls": "foo"}, type_=Model) + assert isinstance(m, Model) + assert isinstance(m.cls, str) + + +def test_discriminated_union_case() -> None: + class A(BaseModel): + type: Literal["a"] + + data: bool + + class B(BaseModel): + type: Literal["b"] + + data: List[Union[A, object]] + + class ModelA(BaseModel): + type: Literal["modelA"] + + data: int + + class ModelB(BaseModel): + type: Literal["modelB"] + + required: str + + data: Union[A, B] + + # when constructing ModelA | ModelB, value data doesn't match ModelB exactly - missing `required` + m = construct_type( + value={"type": "modelB", "data": {"type": "a", "data": True}}, + type_=cast(Any, Annotated[Union[ModelA, ModelB], PropertyInfo(discriminator="type")]), + ) + + assert isinstance(m, ModelB) + + +def test_nested_discriminated_union() -> None: + class InnerType1(BaseModel): + type: Literal["type_1"] + + class InnerModel(BaseModel): + inner_value: str + + class InnerType2(BaseModel): + type: Literal["type_2"] + some_inner_model: InnerModel + + class Type1(BaseModel): + base_type: Literal["base_type_1"] + value: Annotated[ + Union[ + InnerType1, + InnerType2, + ], + PropertyInfo(discriminator="type"), + ] + + class Type2(BaseModel): + base_type: Literal["base_type_2"] + + T = Annotated[ + Union[ + Type1, + Type2, + ], + PropertyInfo(discriminator="base_type"), + ] + + model = construct_type( + type_=T, + value={ + "base_type": "base_type_1", + "value": { + "type": "type_2", + }, + }, + ) + assert isinstance(model, Type1) + assert isinstance(model.value, InnerType2) + + +@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2 for now") +def test_extra_properties() -> None: + class Item(BaseModel): + prop: int + + class Model(BaseModel): + __pydantic_extra__: Dict[str, Item] = Field(init=False) # pyright: ignore[reportIncompatibleVariableOverride] + + other: str + + if TYPE_CHECKING: + + def __getattr__(self, attr: str) -> Item: ... + + model = construct_type( + type_=Model, + value={ + "a": {"prop": 1}, + "other": "foo", + }, + ) + assert isinstance(model, Model) + assert model.a.prop == 1 + assert isinstance(model.a, Item) + assert model.other == "foo" diff --git a/google/genai/tests/interactions/test_qs.py b/google/genai/tests/interactions/test_qs.py new file mode 100644 index 000000000..95ebb7124 --- /dev/null +++ b/google/genai/tests/interactions/test_qs.py @@ -0,0 +1,93 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, cast +from functools import partial +from urllib.parse import unquote + +import pytest + +from google.genai._interactions._qs import Querystring, stringify + + +def test_empty() -> None: + assert stringify({}) == "" + assert stringify({"a": {}}) == "" + assert stringify({"a": {"b": {"c": {}}}}) == "" + + +def test_basic() -> None: + assert stringify({"a": 1}) == "a=1" + assert stringify({"a": "b"}) == "a=b" + assert stringify({"a": True}) == "a=true" + assert stringify({"a": False}) == "a=false" + assert stringify({"a": 1.23456}) == "a=1.23456" + assert stringify({"a": None}) == "" + + +@pytest.mark.parametrize("method", ["class", "function"]) +def test_nested_dotted(method: str) -> None: + if method == "class": + serialise = Querystring(nested_format="dots").stringify + else: + serialise = partial(stringify, nested_format="dots") + + assert unquote(serialise({"a": {"b": "c"}})) == "a.b=c" + assert unquote(serialise({"a": {"b": "c", "d": "e", "f": "g"}})) == "a.b=c&a.d=e&a.f=g" + assert unquote(serialise({"a": {"b": {"c": {"d": "e"}}}})) == "a.b.c.d=e" + assert unquote(serialise({"a": {"b": True}})) == "a.b=true" + + +def test_nested_brackets() -> None: + assert unquote(stringify({"a": {"b": "c"}})) == "a[b]=c" + assert unquote(stringify({"a": {"b": "c", "d": "e", "f": "g"}})) == "a[b]=c&a[d]=e&a[f]=g" + assert unquote(stringify({"a": {"b": {"c": {"d": "e"}}}})) == "a[b][c][d]=e" + assert unquote(stringify({"a": {"b": True}})) == "a[b]=true" + + +@pytest.mark.parametrize("method", ["class", "function"]) +def test_array_comma(method: str) -> None: + if method == "class": + serialise = Querystring(array_format="comma").stringify + else: + serialise = partial(stringify, array_format="comma") + + assert unquote(serialise({"in": ["foo", "bar"]})) == "in=foo,bar" + assert unquote(serialise({"a": {"b": [True, False]}})) == "a[b]=true,false" + assert unquote(serialise({"a": {"b": [True, False, None, True]}})) == "a[b]=true,false,true" + + +def test_array_repeat() -> None: + assert unquote(stringify({"in": ["foo", "bar"]})) == "in=foo&in=bar" + assert unquote(stringify({"a": {"b": [True, False]}})) == "a[b]=true&a[b]=false" + assert unquote(stringify({"a": {"b": [True, False, None, True]}})) == "a[b]=true&a[b]=false&a[b]=true" + assert unquote(stringify({"in": ["foo", {"b": {"c": ["d", "e"]}}]})) == "in=foo&in[b][c]=d&in[b][c]=e" + + +@pytest.mark.parametrize("method", ["class", "function"]) +def test_array_brackets(method: str) -> None: + if method == "class": + serialise = Querystring(array_format="brackets").stringify + else: + serialise = partial(stringify, array_format="brackets") + + assert unquote(serialise({"in": ["foo", "bar"]})) == "in[]=foo&in[]=bar" + assert unquote(serialise({"a": {"b": [True, False]}})) == "a[b][]=true&a[b][]=false" + assert unquote(serialise({"a": {"b": [True, False, None, True]}})) == "a[b][]=true&a[b][]=false&a[b][]=true" + + +def test_unknown_array_format() -> None: + with pytest.raises(NotImplementedError, match="Unknown array_format value: foo, choose from comma, repeat"): + stringify({"a": ["foo", "bar"]}, array_format=cast(Any, "foo")) diff --git a/google/genai/tests/interactions/test_required_args.py b/google/genai/tests/interactions/test_required_args.py new file mode 100644 index 000000000..fc03d4131 --- /dev/null +++ b/google/genai/tests/interactions/test_required_args.py @@ -0,0 +1,126 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import pytest + +from google.genai._interactions._utils import required_args + + +def test_too_many_positional_params() -> None: + @required_args(["a"]) + def foo(a: str | None = None) -> str | None: + return a + + with pytest.raises(TypeError, match=r"foo\(\) takes 1 argument\(s\) but 2 were given"): + foo("a", "b") # type: ignore + + +def test_positional_param() -> None: + @required_args(["a"]) + def foo(a: str | None = None) -> str | None: + return a + + assert foo("a") == "a" + assert foo(None) is None + assert foo(a="b") == "b" + + with pytest.raises(TypeError, match="Missing required argument: 'a'"): + foo() + + +def test_keyword_only_param() -> None: + @required_args(["a"]) + def foo(*, a: str | None = None) -> str | None: + return a + + assert foo(a="a") == "a" + assert foo(a=None) is None + assert foo(a="b") == "b" + + with pytest.raises(TypeError, match="Missing required argument: 'a'"): + foo() + + +def test_multiple_params() -> None: + @required_args(["a", "b", "c"]) + def foo(a: str = "", *, b: str = "", c: str = "") -> str | None: + return f"{a} {b} {c}" + + assert foo(a="a", b="b", c="c") == "a b c" + + error_message = r"Missing required arguments.*" + + with pytest.raises(TypeError, match=error_message): + foo() + + with pytest.raises(TypeError, match=error_message): + foo(a="a") + + with pytest.raises(TypeError, match=error_message): + foo(b="b") + + with pytest.raises(TypeError, match=error_message): + foo(c="c") + + with pytest.raises(TypeError, match=r"Missing required argument: 'a'"): + foo(b="a", c="c") + + with pytest.raises(TypeError, match=r"Missing required argument: 'b'"): + foo("a", c="c") + + +def test_multiple_variants() -> None: + @required_args(["a"], ["b"]) + def foo(*, a: str | None = None, b: str | None = None) -> str | None: + return a if a is not None else b + + assert foo(a="foo") == "foo" + assert foo(b="bar") == "bar" + assert foo(a=None) is None + assert foo(b=None) is None + + # TODO: this error message could probably be improved + with pytest.raises( + TypeError, + match=r"Missing required arguments; Expected either \('a'\) or \('b'\) arguments to be given", + ): + foo() + + +def test_multiple_params_multiple_variants() -> None: + @required_args(["a", "b"], ["c"]) + def foo(*, a: str | None = None, b: str | None = None, c: str | None = None) -> str | None: + if a is not None: + return a + if b is not None: + return b + return c + + error_message = r"Missing required arguments; Expected either \('a' and 'b'\) or \('c'\) arguments to be given" + + with pytest.raises(TypeError, match=error_message): + foo(a="foo") + + with pytest.raises(TypeError, match=error_message): + foo(b="bar") + + with pytest.raises(TypeError, match=error_message): + foo() + + assert foo(a=None, b="bar") == "bar" + assert foo(c=None) is None + assert foo(c="foo") == "foo" diff --git a/google/genai/tests/interactions/test_response.py b/google/genai/tests/interactions/test_response.py new file mode 100644 index 000000000..27189e199 --- /dev/null +++ b/google/genai/tests/interactions/test_response.py @@ -0,0 +1,294 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +from typing import Any, List, Union, cast +from typing_extensions import Annotated + +import httpx +import pytest +import pydantic + +from google.genai._interactions import BaseModel, GeminiNextGenAPIClient, AsyncGeminiNextGenAPIClient +from google.genai._interactions._response import ( + APIResponse, + BaseAPIResponse, + AsyncAPIResponse, + BinaryAPIResponse, + AsyncBinaryAPIResponse, + extract_response_type, +) +from google.genai._interactions._streaming import Stream +from google.genai._interactions._base_client import FinalRequestOptions + + +class ConcreteBaseAPIResponse(APIResponse[bytes]): ... + + +class ConcreteAPIResponse(APIResponse[List[str]]): ... + + +class ConcreteAsyncAPIResponse(APIResponse[httpx.Response]): ... + + +def test_extract_response_type_direct_classes() -> None: + assert extract_response_type(BaseAPIResponse[str]) == str + assert extract_response_type(APIResponse[str]) == str + assert extract_response_type(AsyncAPIResponse[str]) == str + + +def test_extract_response_type_direct_class_missing_type_arg() -> None: + with pytest.raises( + RuntimeError, + match="Expected type to have a type argument at index 0 but it did not", + ): + extract_response_type(AsyncAPIResponse) + + +def test_extract_response_type_concrete_subclasses() -> None: + assert extract_response_type(ConcreteBaseAPIResponse) == bytes + assert extract_response_type(ConcreteAPIResponse) == List[str] + assert extract_response_type(ConcreteAsyncAPIResponse) == httpx.Response + + +def test_extract_response_type_binary_response() -> None: + assert extract_response_type(BinaryAPIResponse) == bytes + assert extract_response_type(AsyncBinaryAPIResponse) == bytes + + +class PydanticModel(pydantic.BaseModel): ... + + +def test_response_parse_mismatched_basemodel(client: GeminiNextGenAPIClient) -> None: + response = APIResponse( + raw=httpx.Response(200, content=b"foo"), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + with pytest.raises( + TypeError, + match="Pydantic models must subclass our base model type, e.g. `from google.genai._interactions import BaseModel`", + ): + response.parse(to=PydanticModel) + + +@pytest.mark.asyncio +async def test_async_response_parse_mismatched_basemodel(async_client: AsyncGeminiNextGenAPIClient) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=b"foo"), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + with pytest.raises( + TypeError, + match="Pydantic models must subclass our base model type, e.g. `from google.genai._interactions import BaseModel`", + ): + await response.parse(to=PydanticModel) + + +def test_response_parse_custom_stream(client: GeminiNextGenAPIClient) -> None: + response = APIResponse( + raw=httpx.Response(200, content=b"foo"), + client=client, + stream=True, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + stream = response.parse(to=Stream[int]) + assert stream._cast_to == int + + +@pytest.mark.asyncio +async def test_async_response_parse_custom_stream(async_client: AsyncGeminiNextGenAPIClient) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=b"foo"), + client=async_client, + stream=True, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + stream = await response.parse(to=Stream[int]) + assert stream._cast_to == int + + +class CustomModel(BaseModel): + foo: str + bar: int + + +def test_response_parse_custom_model(client: GeminiNextGenAPIClient) -> None: + response = APIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = response.parse(to=CustomModel) + assert obj.foo == "hello!" + assert obj.bar == 2 + + +@pytest.mark.asyncio +async def test_async_response_parse_custom_model(async_client: AsyncGeminiNextGenAPIClient) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = await response.parse(to=CustomModel) + assert obj.foo == "hello!" + assert obj.bar == 2 + + +def test_response_parse_annotated_type(client: GeminiNextGenAPIClient) -> None: + response = APIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = response.parse( + to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]), + ) + assert obj.foo == "hello!" + assert obj.bar == 2 + + +async def test_async_response_parse_annotated_type(async_client: AsyncGeminiNextGenAPIClient) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = await response.parse( + to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]), + ) + assert obj.foo == "hello!" + assert obj.bar == 2 + + +@pytest.mark.parametrize( + "content, expected", + [ + ("false", False), + ("true", True), + ("False", False), + ("True", True), + ("TrUe", True), + ("FalSe", False), + ], +) +def test_response_parse_bool(client: GeminiNextGenAPIClient, content: str, expected: bool) -> None: + response = APIResponse( + raw=httpx.Response(200, content=content), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + result = response.parse(to=bool) + assert result is expected + + +@pytest.mark.parametrize( + "content, expected", + [ + ("false", False), + ("true", True), + ("False", False), + ("True", True), + ("TrUe", True), + ("FalSe", False), + ], +) +async def test_async_response_parse_bool(client: AsyncGeminiNextGenAPIClient, content: str, expected: bool) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=content), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + result = await response.parse(to=bool) + assert result is expected + + +class OtherModel(BaseModel): + a: str + + +@pytest.mark.parametrize("client", [False], indirect=True) # loose validation +def test_response_parse_expect_model_union_non_json_content(client: GeminiNextGenAPIClient) -> None: + response = APIResponse( + raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = response.parse(to=cast(Any, Union[CustomModel, OtherModel])) + assert isinstance(obj, str) + assert obj == "foo" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_client", [False], indirect=True) # loose validation +async def test_async_response_parse_expect_model_union_non_json_content( + async_client: AsyncGeminiNextGenAPIClient, +) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = await response.parse(to=cast(Any, Union[CustomModel, OtherModel])) + assert isinstance(obj, str) + assert obj == "foo" diff --git a/google/genai/tests/interactions/test_streaming.py b/google/genai/tests/interactions/test_streaming.py new file mode 100644 index 000000000..add41f29a --- /dev/null +++ b/google/genai/tests/interactions/test_streaming.py @@ -0,0 +1,277 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +from typing import Iterator, AsyncIterator + +import httpx +import pytest + +from google.genai._interactions import GeminiNextGenAPIClient, AsyncGeminiNextGenAPIClient +from google.genai._interactions._streaming import Stream, AsyncStream, ServerSentEvent + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_basic(sync: bool, client: GeminiNextGenAPIClient, async_client: AsyncGeminiNextGenAPIClient) -> None: + def body() -> Iterator[bytes]: + yield b"event: completion\n" + yield b'data: {"foo":true}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "completion" + assert sse.json() == {"foo": True} + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_data_missing_event( + sync: bool, client: GeminiNextGenAPIClient, async_client: AsyncGeminiNextGenAPIClient +) -> None: + def body() -> Iterator[bytes]: + yield b'data: {"foo":true}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"foo": True} + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_event_missing_data( + sync: bool, client: GeminiNextGenAPIClient, async_client: AsyncGeminiNextGenAPIClient +) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.data == "" + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_events( + sync: bool, client: GeminiNextGenAPIClient, async_client: AsyncGeminiNextGenAPIClient +) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"\n" + yield b"event: completion\n" + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.data == "" + + sse = await iter_next(iterator) + assert sse.event == "completion" + assert sse.data == "" + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_events_with_data( + sync: bool, client: GeminiNextGenAPIClient, async_client: AsyncGeminiNextGenAPIClient +) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b'data: {"foo":true}\n' + yield b"\n" + yield b"event: completion\n" + yield b'data: {"bar":false}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": True} + + sse = await iter_next(iterator) + assert sse.event == "completion" + assert sse.json() == {"bar": False} + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_data_lines_with_empty_line( + sync: bool, client: GeminiNextGenAPIClient, async_client: AsyncGeminiNextGenAPIClient +) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"data: {\n" + yield b'data: "foo":\n' + yield b"data: \n" + yield b"data:\n" + yield b"data: true}\n" + yield b"\n\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": True} + assert sse.data == '{\n"foo":\n\n\ntrue}' + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_data_json_escaped_double_new_line( + sync: bool, client: GeminiNextGenAPIClient, async_client: AsyncGeminiNextGenAPIClient +) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b'data: {"foo": "my long\\n\\ncontent"}' + yield b"\n\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": "my long\n\ncontent"} + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_data_lines( + sync: bool, client: GeminiNextGenAPIClient, async_client: AsyncGeminiNextGenAPIClient +) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"data: {\n" + yield b'data: "foo":\n' + yield b"data: true}\n" + yield b"\n\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": True} + + await assert_empty_iter(iterator) + + +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_special_new_line_character( + sync: bool, + client: GeminiNextGenAPIClient, + async_client: AsyncGeminiNextGenAPIClient, +) -> None: + def body() -> Iterator[bytes]: + yield b'data: {"content":" culpa"}\n' + yield b"\n" + yield b'data: {"content":" \xe2\x80\xa8"}\n' + yield b"\n" + yield b'data: {"content":"foo"}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": " culpa"} + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": " 
"} + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": "foo"} + + await assert_empty_iter(iterator) + + +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multi_byte_character_multiple_chunks( + sync: bool, + client: GeminiNextGenAPIClient, + async_client: AsyncGeminiNextGenAPIClient, +) -> None: + def body() -> Iterator[bytes]: + yield b'data: {"content":"' + # bytes taken from the string 'известни' and arbitrarily split + # so that some multi-byte characters span multiple chunks + yield b"\xd0" + yield b"\xb8\xd0\xb7\xd0" + yield b"\xb2\xd0\xb5\xd1\x81\xd1\x82\xd0\xbd\xd0\xb8" + yield b'"}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": "известни"} + + +async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]: + for chunk in iter: + yield chunk + + +async def iter_next(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> ServerSentEvent: + if isinstance(iter, AsyncIterator): + return await iter.__anext__() + + return next(iter) + + +async def assert_empty_iter(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> None: + with pytest.raises((StopAsyncIteration, RuntimeError)): + await iter_next(iter) + + +def make_event_iterator( + content: Iterator[bytes], + *, + sync: bool, + client: GeminiNextGenAPIClient, + async_client: AsyncGeminiNextGenAPIClient, +) -> Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]: + if sync: + return Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))._iter_events() + + return AsyncStream( + cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content)) + )._iter_events() diff --git a/google/genai/tests/interactions/test_transform.py b/google/genai/tests/interactions/test_transform.py new file mode 100644 index 000000000..285584b10 --- /dev/null +++ b/google/genai/tests/interactions/test_transform.py @@ -0,0 +1,475 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import io +import pathlib +from typing import Any, Dict, List, Union, TypeVar, Iterable, Optional, cast +from datetime import date, datetime +from typing_extensions import Required, Annotated, TypedDict + +import pytest + +from google.genai._interactions._types import Base64FileInput, omit, not_given +from google.genai._interactions._utils import ( + PropertyInfo, + transform as _transform, + parse_datetime, + async_transform as _async_transform, +) +from google.genai._interactions._compat import PYDANTIC_V1 +from google.genai._interactions._models import BaseModel + +_T = TypeVar("_T") + +SAMPLE_FILE_PATH = pathlib.Path(__file__).parent.joinpath("sample_file.txt") + + +async def transform( + data: _T, + expected_type: object, + use_async: bool, +) -> _T: + if use_async: + return await _async_transform(data, expected_type=expected_type) + + return _transform(data, expected_type=expected_type) + + +parametrize = pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"]) + + +class Foo1(TypedDict): + foo_bar: Annotated[str, PropertyInfo(alias="fooBar")] + + +@parametrize +@pytest.mark.asyncio +async def test_top_level_alias(use_async: bool) -> None: + assert await transform({"foo_bar": "hello"}, expected_type=Foo1, use_async=use_async) == {"fooBar": "hello"} + + +class Foo2(TypedDict): + bar: Bar2 + + +class Bar2(TypedDict): + this_thing: Annotated[int, PropertyInfo(alias="this__thing")] + baz: Annotated[Baz2, PropertyInfo(alias="Baz")] + + +class Baz2(TypedDict): + my_baz: Annotated[str, PropertyInfo(alias="myBaz")] + + +@parametrize +@pytest.mark.asyncio +async def test_recursive_typeddict(use_async: bool) -> None: + assert await transform({"bar": {"this_thing": 1}}, Foo2, use_async) == {"bar": {"this__thing": 1}} + assert await transform({"bar": {"baz": {"my_baz": "foo"}}}, Foo2, use_async) == {"bar": {"Baz": {"myBaz": "foo"}}} + + +class Foo3(TypedDict): + things: List[Bar3] + + +class Bar3(TypedDict): + my_field: Annotated[str, PropertyInfo(alias="myField")] + + +@parametrize +@pytest.mark.asyncio +async def test_list_of_typeddict(use_async: bool) -> None: + result = await transform({"things": [{"my_field": "foo"}, {"my_field": "foo2"}]}, Foo3, use_async) + assert result == {"things": [{"myField": "foo"}, {"myField": "foo2"}]} + + +class Foo4(TypedDict): + foo: Union[Bar4, Baz4] + + +class Bar4(TypedDict): + foo_bar: Annotated[str, PropertyInfo(alias="fooBar")] + + +class Baz4(TypedDict): + foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] + + +@parametrize +@pytest.mark.asyncio +async def test_union_of_typeddict(use_async: bool) -> None: + assert await transform({"foo": {"foo_bar": "bar"}}, Foo4, use_async) == {"foo": {"fooBar": "bar"}} + assert await transform({"foo": {"foo_baz": "baz"}}, Foo4, use_async) == {"foo": {"fooBaz": "baz"}} + assert await transform({"foo": {"foo_baz": "baz", "foo_bar": "bar"}}, Foo4, use_async) == { + "foo": {"fooBaz": "baz", "fooBar": "bar"} + } + + +class Foo5(TypedDict): + foo: Annotated[Union[Bar4, List[Baz4]], PropertyInfo(alias="FOO")] + + +class Bar5(TypedDict): + foo_bar: Annotated[str, PropertyInfo(alias="fooBar")] + + +class Baz5(TypedDict): + foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] + + +@parametrize +@pytest.mark.asyncio +async def test_union_of_list(use_async: bool) -> None: + assert await transform({"foo": {"foo_bar": "bar"}}, Foo5, use_async) == {"FOO": {"fooBar": "bar"}} + assert await transform( + { + "foo": [ + {"foo_baz": "baz"}, + {"foo_baz": "baz"}, + ] + }, + Foo5, + use_async, + ) == {"FOO": [{"fooBaz": "baz"}, {"fooBaz": "baz"}]} + + +class Foo6(TypedDict): + bar: Annotated[str, PropertyInfo(alias="Bar")] + + +@parametrize +@pytest.mark.asyncio +async def test_includes_unknown_keys(use_async: bool) -> None: + assert await transform({"bar": "bar", "baz_": {"FOO": 1}}, Foo6, use_async) == { + "Bar": "bar", + "baz_": {"FOO": 1}, + } + + +class Foo7(TypedDict): + bar: Annotated[List[Bar7], PropertyInfo(alias="bAr")] + foo: Bar7 + + +class Bar7(TypedDict): + foo: str + + +@parametrize +@pytest.mark.asyncio +async def test_ignores_invalid_input(use_async: bool) -> None: + assert await transform({"bar": ""}, Foo7, use_async) == {"bAr": ""} + assert await transform({"foo": ""}, Foo7, use_async) == {"foo": ""} + + +class DatetimeDict(TypedDict, total=False): + foo: Annotated[datetime, PropertyInfo(format="iso8601")] + + bar: Annotated[Optional[datetime], PropertyInfo(format="iso8601")] + + required: Required[Annotated[Optional[datetime], PropertyInfo(format="iso8601")]] + + list_: Required[Annotated[Optional[List[datetime]], PropertyInfo(format="iso8601")]] + + union: Annotated[Union[int, datetime], PropertyInfo(format="iso8601")] + + +class DateDict(TypedDict, total=False): + foo: Annotated[date, PropertyInfo(format="iso8601")] + + +class DatetimeModel(BaseModel): + foo: datetime + + +class DateModel(BaseModel): + foo: Optional[date] + + +@parametrize +@pytest.mark.asyncio +async def test_iso8601_format(use_async: bool) -> None: + dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") + tz = "+00:00" if PYDANTIC_V1 else "Z" + assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] + assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz} # type: ignore[comparison-overlap] + + dt = dt.replace(tzinfo=None) + assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] + assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] + + assert await transform({"foo": None}, DateDict, use_async) == {"foo": None} # type: ignore[comparison-overlap] + assert await transform(DateModel(foo=None), Any, use_async) == {"foo": None} # type: ignore + assert await transform({"foo": date.fromisoformat("2023-02-23")}, DateDict, use_async) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap] + assert await transform(DateModel(foo=date.fromisoformat("2023-02-23")), DateDict, use_async) == { + "foo": "2023-02-23" + } # type: ignore[comparison-overlap] + + +@parametrize +@pytest.mark.asyncio +async def test_optional_iso8601_format(use_async: bool) -> None: + dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") + assert await transform({"bar": dt}, DatetimeDict, use_async) == {"bar": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] + + assert await transform({"bar": None}, DatetimeDict, use_async) == {"bar": None} + + +@parametrize +@pytest.mark.asyncio +async def test_required_iso8601_format(use_async: bool) -> None: + dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") + assert await transform({"required": dt}, DatetimeDict, use_async) == { + "required": "2023-02-23T14:16:36.337692+00:00" + } # type: ignore[comparison-overlap] + + assert await transform({"required": None}, DatetimeDict, use_async) == {"required": None} + + +@parametrize +@pytest.mark.asyncio +async def test_union_datetime(use_async: bool) -> None: + dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") + assert await transform({"union": dt}, DatetimeDict, use_async) == { # type: ignore[comparison-overlap] + "union": "2023-02-23T14:16:36.337692+00:00" + } + + assert await transform({"union": "foo"}, DatetimeDict, use_async) == {"union": "foo"} + + +@parametrize +@pytest.mark.asyncio +async def test_nested_list_iso6801_format(use_async: bool) -> None: + dt1 = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") + dt2 = parse_datetime("2022-01-15T06:34:23Z") + assert await transform({"list_": [dt1, dt2]}, DatetimeDict, use_async) == { # type: ignore[comparison-overlap] + "list_": ["2023-02-23T14:16:36.337692+00:00", "2022-01-15T06:34:23+00:00"] + } + + +@parametrize +@pytest.mark.asyncio +async def test_datetime_custom_format(use_async: bool) -> None: + dt = parse_datetime("2022-01-15T06:34:23Z") + + result = await transform(dt, Annotated[datetime, PropertyInfo(format="custom", format_template="%H")], use_async) + assert result == "06" # type: ignore[comparison-overlap] + + +class DateDictWithRequiredAlias(TypedDict, total=False): + required_prop: Required[Annotated[date, PropertyInfo(format="iso8601", alias="prop")]] + + +@parametrize +@pytest.mark.asyncio +async def test_datetime_with_alias(use_async: bool) -> None: + assert await transform({"required_prop": None}, DateDictWithRequiredAlias, use_async) == {"prop": None} # type: ignore[comparison-overlap] + assert await transform( + {"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias, use_async + ) == {"prop": "2023-02-23"} # type: ignore[comparison-overlap] + + +class MyModel(BaseModel): + foo: str + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_model_to_dictionary(use_async: bool) -> None: + assert cast(Any, await transform(MyModel(foo="hi!"), Any, use_async)) == {"foo": "hi!"} + assert cast(Any, await transform(MyModel.construct(foo="hi!"), Any, use_async)) == {"foo": "hi!"} + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_empty_model(use_async: bool) -> None: + assert cast(Any, await transform(MyModel.construct(), Any, use_async)) == {} + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_unknown_field(use_async: bool) -> None: + assert cast(Any, await transform(MyModel.construct(my_untyped_field=True), Any, use_async)) == { + "my_untyped_field": True + } + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_mismatched_types(use_async: bool) -> None: + model = MyModel.construct(foo=True) + if PYDANTIC_V1: + params = await transform(model, Any, use_async) + else: + with pytest.warns(UserWarning): + params = await transform(model, Any, use_async) + assert cast(Any, params) == {"foo": True} + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_mismatched_object_type(use_async: bool) -> None: + model = MyModel.construct(foo=MyModel.construct(hello="world")) + if PYDANTIC_V1: + params = await transform(model, Any, use_async) + else: + with pytest.warns(UserWarning): + params = await transform(model, Any, use_async) + assert cast(Any, params) == {"foo": {"hello": "world"}} + + +class ModelNestedObjects(BaseModel): + nested: MyModel + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_nested_objects(use_async: bool) -> None: + model = ModelNestedObjects.construct(nested={"foo": "stainless"}) + assert isinstance(model.nested, MyModel) + assert cast(Any, await transform(model, Any, use_async)) == {"nested": {"foo": "stainless"}} + + +class ModelWithDefaultField(BaseModel): + foo: str + with_none_default: Union[str, None] = None + with_str_default: str = "foo" + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_default_field(use_async: bool) -> None: + # should be excluded when defaults are used + model = ModelWithDefaultField.construct() + assert model.with_none_default is None + assert model.with_str_default == "foo" + assert cast(Any, await transform(model, Any, use_async)) == {} + + # should be included when the default value is explicitly given + model = ModelWithDefaultField.construct(with_none_default=None, with_str_default="foo") + assert model.with_none_default is None + assert model.with_str_default == "foo" + assert cast(Any, await transform(model, Any, use_async)) == {"with_none_default": None, "with_str_default": "foo"} + + # should be included when a non-default value is explicitly given + model = ModelWithDefaultField.construct(with_none_default="bar", with_str_default="baz") + assert model.with_none_default == "bar" + assert model.with_str_default == "baz" + assert cast(Any, await transform(model, Any, use_async)) == {"with_none_default": "bar", "with_str_default": "baz"} + + +class TypedDictIterableUnion(TypedDict): + foo: Annotated[Union[Bar8, Iterable[Baz8]], PropertyInfo(alias="FOO")] + + +class Bar8(TypedDict): + foo_bar: Annotated[str, PropertyInfo(alias="fooBar")] + + +class Baz8(TypedDict): + foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] + + +@parametrize +@pytest.mark.asyncio +async def test_iterable_of_dictionaries(use_async: bool) -> None: + assert await transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion, use_async) == { + "FOO": [{"fooBaz": "bar"}] + } + assert cast(Any, await transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion, use_async)) == { + "FOO": [{"fooBaz": "bar"}] + } + + def my_iter() -> Iterable[Baz8]: + yield {"foo_baz": "hello"} + yield {"foo_baz": "world"} + + assert await transform({"foo": my_iter()}, TypedDictIterableUnion, use_async) == { + "FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}] + } + + +@parametrize +@pytest.mark.asyncio +async def test_dictionary_items(use_async: bool) -> None: + class DictItems(TypedDict): + foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] + + assert await transform({"foo": {"foo_baz": "bar"}}, Dict[str, DictItems], use_async) == {"foo": {"fooBaz": "bar"}} + + +class TypedDictIterableUnionStr(TypedDict): + foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")] + + +@parametrize +@pytest.mark.asyncio +async def test_iterable_union_str(use_async: bool) -> None: + assert await transform({"foo": "bar"}, TypedDictIterableUnionStr, use_async) == {"FOO": "bar"} + assert cast(Any, await transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]], use_async)) == [ + {"fooBaz": "bar"} + ] + + +class TypedDictBase64Input(TypedDict): + foo: Annotated[Union[str, Base64FileInput], PropertyInfo(format="base64")] + + +@parametrize +@pytest.mark.asyncio +async def test_base64_file_input(use_async: bool) -> None: + # strings are left as-is + assert await transform({"foo": "bar"}, TypedDictBase64Input, use_async) == {"foo": "bar"} + + # pathlib.Path is automatically converted to base64 + assert await transform({"foo": SAMPLE_FILE_PATH}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQo=" + } # type: ignore[comparison-overlap] + + # io instances are automatically converted to base64 + assert await transform({"foo": io.StringIO("Hello, world!")}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQ==" + } # type: ignore[comparison-overlap] + assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQ==" + } # type: ignore[comparison-overlap] + + +@parametrize +@pytest.mark.asyncio +async def test_transform_skipping(use_async: bool) -> None: + # lists of ints are left as-is + data = [1, 2, 3] + assert await transform(data, List[int], use_async) is data + + # iterables of ints are converted to a list + data = iter([1, 2, 3]) + assert await transform(data, Iterable[int], use_async) == [1, 2, 3] + + +@parametrize +@pytest.mark.asyncio +async def test_strips_notgiven(use_async: bool) -> None: + assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} + assert await transform({"foo_bar": not_given}, Foo1, use_async) == {} + + +@parametrize +@pytest.mark.asyncio +async def test_strips_omit(use_async: bool) -> None: + assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} + assert await transform({"foo_bar": omit}, Foo1, use_async) == {} diff --git a/google/genai/tests/interactions/test_utils/test_datetime_parse.py b/google/genai/tests/interactions/test_utils/test_datetime_parse.py new file mode 100644 index 000000000..81fed9502 --- /dev/null +++ b/google/genai/tests/interactions/test_utils/test_datetime_parse.py @@ -0,0 +1,125 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Copied from https://github.com/pydantic/pydantic/blob/v1.10.22/tests/test_datetime_parse.py +with modifications so it works without pydantic v1 imports. +""" + +from typing import Type, Union +from datetime import date, datetime, timezone, timedelta + +import pytest + +from google.genai._interactions._utils import parse_date, parse_datetime + + +def create_tz(minutes: int) -> timezone: + return timezone(timedelta(minutes=minutes)) + + +@pytest.mark.parametrize( + "value,result", + [ + # Valid inputs + ("1494012444.883309", date(2017, 5, 5)), + (b"1494012444.883309", date(2017, 5, 5)), + (1_494_012_444.883_309, date(2017, 5, 5)), + ("1494012444", date(2017, 5, 5)), + (1_494_012_444, date(2017, 5, 5)), + (0, date(1970, 1, 1)), + ("2012-04-23", date(2012, 4, 23)), + (b"2012-04-23", date(2012, 4, 23)), + ("2012-4-9", date(2012, 4, 9)), + (date(2012, 4, 9), date(2012, 4, 9)), + (datetime(2012, 4, 9, 12, 15), date(2012, 4, 9)), + # Invalid inputs + ("x20120423", ValueError), + ("2012-04-56", ValueError), + (19_999_999_999, date(2603, 10, 11)), # just before watershed + (20_000_000_001, date(1970, 8, 20)), # just after watershed + (1_549_316_052, date(2019, 2, 4)), # nowish in s + (1_549_316_052_104, date(2019, 2, 4)), # nowish in ms + (1_549_316_052_104_324, date(2019, 2, 4)), # nowish in μs + (1_549_316_052_104_324_096, date(2019, 2, 4)), # nowish in ns + ("infinity", date(9999, 12, 31)), + ("inf", date(9999, 12, 31)), + (float("inf"), date(9999, 12, 31)), + ("infinity ", date(9999, 12, 31)), + (int("1" + "0" * 100), date(9999, 12, 31)), + (1e1000, date(9999, 12, 31)), + ("-infinity", date(1, 1, 1)), + ("-inf", date(1, 1, 1)), + ("nan", ValueError), + ], +) +def test_date_parsing(value: Union[str, bytes, int, float], result: Union[date, Type[Exception]]) -> None: + if type(result) == type and issubclass(result, Exception): # pyright: ignore[reportUnnecessaryIsInstance] + with pytest.raises(result): + parse_date(value) + else: + assert parse_date(value) == result + + +@pytest.mark.parametrize( + "value,result", + [ + # Valid inputs + # values in seconds + ("1494012444.883309", datetime(2017, 5, 5, 19, 27, 24, 883_309, tzinfo=timezone.utc)), + (1_494_012_444.883_309, datetime(2017, 5, 5, 19, 27, 24, 883_309, tzinfo=timezone.utc)), + ("1494012444", datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + (b"1494012444", datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + (1_494_012_444, datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + # values in ms + ("1494012444000.883309", datetime(2017, 5, 5, 19, 27, 24, 883, tzinfo=timezone.utc)), + ("-1494012444000.883309", datetime(1922, 8, 29, 4, 32, 35, 999117, tzinfo=timezone.utc)), + (1_494_012_444_000, datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + ("2012-04-23T09:15:00", datetime(2012, 4, 23, 9, 15)), + ("2012-4-9 4:8:16", datetime(2012, 4, 9, 4, 8, 16)), + ("2012-04-23T09:15:00Z", datetime(2012, 4, 23, 9, 15, 0, 0, timezone.utc)), + ("2012-4-9 4:8:16-0320", datetime(2012, 4, 9, 4, 8, 16, 0, create_tz(-200))), + ("2012-04-23T10:20:30.400+02:30", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(150))), + ("2012-04-23T10:20:30.400+02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(120))), + ("2012-04-23T10:20:30.400-02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(-120))), + (b"2012-04-23T10:20:30.400-02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(-120))), + (datetime(2017, 5, 5), datetime(2017, 5, 5)), + (0, datetime(1970, 1, 1, 0, 0, 0, tzinfo=timezone.utc)), + # Invalid inputs + ("x20120423091500", ValueError), + ("2012-04-56T09:15:90", ValueError), + ("2012-04-23T11:05:00-25:00", ValueError), + (19_999_999_999, datetime(2603, 10, 11, 11, 33, 19, tzinfo=timezone.utc)), # just before watershed + (20_000_000_001, datetime(1970, 8, 20, 11, 33, 20, 1000, tzinfo=timezone.utc)), # just after watershed + (1_549_316_052, datetime(2019, 2, 4, 21, 34, 12, 0, tzinfo=timezone.utc)), # nowish in s + (1_549_316_052_104, datetime(2019, 2, 4, 21, 34, 12, 104_000, tzinfo=timezone.utc)), # nowish in ms + (1_549_316_052_104_324, datetime(2019, 2, 4, 21, 34, 12, 104_324, tzinfo=timezone.utc)), # nowish in μs + (1_549_316_052_104_324_096, datetime(2019, 2, 4, 21, 34, 12, 104_324, tzinfo=timezone.utc)), # nowish in ns + ("infinity", datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("inf", datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("inf ", datetime(9999, 12, 31, 23, 59, 59, 999999)), + (1e50, datetime(9999, 12, 31, 23, 59, 59, 999999)), + (float("inf"), datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("-infinity", datetime(1, 1, 1, 0, 0)), + ("-inf", datetime(1, 1, 1, 0, 0)), + ("nan", ValueError), + ], +) +def test_datetime_parsing(value: Union[str, bytes, int, float], result: Union[datetime, Type[Exception]]) -> None: + if type(result) == type and issubclass(result, Exception): # pyright: ignore[reportUnnecessaryIsInstance] + with pytest.raises(result): + parse_datetime(value) + else: + assert parse_datetime(value) == result diff --git a/google/genai/tests/interactions/test_utils/test_json.py b/google/genai/tests/interactions/test_utils/test_json.py new file mode 100644 index 000000000..adf180c2b --- /dev/null +++ b/google/genai/tests/interactions/test_utils/test_json.py @@ -0,0 +1,141 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import datetime +from typing import Union + +import pydantic + +from google.genai._interactions import _compat +from google.genai._interactions._utils._json import openapi_dumps + + +class TestOpenapiDumps: + def test_basic(self) -> None: + data = {"key": "value", "number": 42} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"key":"value","number":42}' + + def test_datetime_serialization(self) -> None: + dt = datetime.datetime(2023, 1, 1, 12, 0, 0) + data = {"datetime": dt} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"datetime":"2023-01-01T12:00:00"}' + + def test_pydantic_model_serialization(self) -> None: + class User(pydantic.BaseModel): + first_name: str + last_name: str + age: int + + model_instance = User(first_name="John", last_name="Kramer", age=83) + data = {"model": model_instance} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"first_name":"John","last_name":"Kramer","age":83}}' + + def test_pydantic_model_with_default_values(self) -> None: + class User(pydantic.BaseModel): + name: str + role: str = "user" + active: bool = True + score: int = 0 + + model_instance = User(name="Alice") + data = {"model": model_instance} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"name":"Alice"}}' + + def test_pydantic_model_with_default_values_overridden(self) -> None: + class User(pydantic.BaseModel): + name: str + role: str = "user" + active: bool = True + + model_instance = User(name="Bob", role="admin", active=False) + data = {"model": model_instance} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"name":"Bob","role":"admin","active":false}}' + + def test_pydantic_model_with_alias(self) -> None: + class User(pydantic.BaseModel): + first_name: str = pydantic.Field(alias="firstName") + last_name: str = pydantic.Field(alias="lastName") + + model_instance = User(firstName="John", lastName="Doe") + data = {"model": model_instance} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"firstName":"John","lastName":"Doe"}}' + + def test_pydantic_model_with_alias_and_default(self) -> None: + class User(pydantic.BaseModel): + user_name: str = pydantic.Field(alias="userName") + user_role: str = pydantic.Field(default="member", alias="userRole") + is_active: bool = pydantic.Field(default=True, alias="isActive") + + model_instance = User(userName="charlie") + data = {"model": model_instance} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"userName":"charlie"}}' + + model_with_overrides = User(userName="diana", userRole="admin", isActive=False) + data = {"model": model_with_overrides} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"userName":"diana","userRole":"admin","isActive":false}}' + + def test_pydantic_model_with_nested_models_and_defaults(self) -> None: + class Address(pydantic.BaseModel): + street: str + city: str = "Unknown" + + class User(pydantic.BaseModel): + name: str + address: Address + verified: bool = False + + if _compat.PYDANTIC_V1: + # to handle forward references in Pydantic v1 + User.update_forward_refs(**locals()) # type: ignore[reportDeprecated] + + address = Address(street="123 Main St") + user = User(name="Diana", address=address) + data = {"user": user} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"user":{"name":"Diana","address":{"street":"123 Main St"}}}' + + address_with_city = Address(street="456 Oak Ave", city="Boston") + user_verified = User(name="Eve", address=address_with_city, verified=True) + data = {"user": user_verified} + json_bytes = openapi_dumps(data) + assert ( + json_bytes == b'{"user":{"name":"Eve","address":{"street":"456 Oak Ave","city":"Boston"},"verified":true}}' + ) + + def test_pydantic_model_with_optional_fields(self) -> None: + class User(pydantic.BaseModel): + name: str + email: Union[str, None] + phone: Union[str, None] + + model_with_none = User(name="Eve", email=None, phone=None) + data = {"model": model_with_none} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"name":"Eve","email":null,"phone":null}}' + + model_with_values = User(name="Frank", email="frank@example.com", phone=None) + data = {"model": model_with_values} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"name":"Frank","email":"frank@example.com","phone":null}}' diff --git a/google/genai/tests/interactions/test_utils/test_path.py b/google/genai/tests/interactions/test_utils/test_path.py new file mode 100644 index 000000000..eb273f8bd --- /dev/null +++ b/google/genai/tests/interactions/test_utils/test_path.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +from typing import Any + +import pytest + +from google.genai._interactions._utils._path import path_template + + +@pytest.mark.parametrize( + "template, kwargs, expected", + [ + ("/v1/{id}", dict(id="abc"), "/v1/abc"), + ("/v1/{a}/{b}", dict(a="x", b="y"), "/v1/x/y"), + ("/v1/{a}{b}/path/{c}?val={d}#{e}", dict(a="x", b="y", c="z", d="u", e="v"), "/v1/xy/path/z?val=u#v"), + ("/{w}/{w}", dict(w="echo"), "/echo/echo"), + ("/v1/static", {}, "/v1/static"), + ("", {}, ""), + ("/v1/?q={n}&count=10", dict(n=42), "/v1/?q=42&count=10"), + ("/v1/{v}", dict(v=None), "/v1/null"), + ("/v1/{v}", dict(v=True), "/v1/true"), + ("/v1/{v}", dict(v=False), "/v1/false"), + ("/v1/{v}", dict(v=".hidden"), "/v1/.hidden"), # dot prefix ok + ("/v1/{v}", dict(v="file.txt"), "/v1/file.txt"), # dot in middle ok + ("/v1/{v}", dict(v="..."), "/v1/..."), # triple dot ok + ("/v1/{a}{b}", dict(a=".", b="txt"), "/v1/.txt"), # dot var combining with adjacent to be ok + ("/items?q={v}#{f}", dict(v=".", f=".."), "/items?q=.#.."), # dots in query/fragment are fine + ( + "/v1/{a}?query={b}", + dict(a="../../other/endpoint", b="a&bad=true"), + "/v1/..%2F..%2Fother%2Fendpoint?query=a%26bad%3Dtrue", + ), + ("/v1/{val}", dict(val="a/b/c"), "/v1/a%2Fb%2Fc"), + ("/v1/{val}", dict(val="a/b/c?query=value"), "/v1/a%2Fb%2Fc%3Fquery=value"), + ("/v1/{val}", dict(val="a/b/c?query=value&bad=true"), "/v1/a%2Fb%2Fc%3Fquery=value&bad=true"), + ("/v1/{val}", dict(val="%20"), "/v1/%2520"), # escapes escape sequences in input + # Query: slash and ? are safe, # is not + ("/items?q={v}", dict(v="a/b"), "/items?q=a/b"), + ("/items?q={v}", dict(v="a?b"), "/items?q=a?b"), + ("/items?q={v}", dict(v="a#b"), "/items?q=a%23b"), + ("/items?q={v}", dict(v="a b"), "/items?q=a%20b"), + # Fragment: slash and ? are safe + ("/docs#{v}", dict(v="a/b"), "/docs#a/b"), + ("/docs#{v}", dict(v="a?b"), "/docs#a?b"), + # Path: slash, ? and # are all encoded + ("/v1/{v}", dict(v="a/b"), "/v1/a%2Fb"), + ("/v1/{v}", dict(v="a?b"), "/v1/a%3Fb"), + ("/v1/{v}", dict(v="a#b"), "/v1/a%23b"), + # same var encoded differently by component + ( + "/v1/{v}?q={v}#{v}", + dict(v="a/b?c#d"), + "/v1/a%2Fb%3Fc%23d?q=a/b?c%23d#a/b?c%23d", + ), + ("/v1/{val}", dict(val="x?admin=true"), "/v1/x%3Fadmin=true"), # query injection + ("/v1/{val}", dict(val="x#admin"), "/v1/x%23admin"), # fragment injection + ], +) +def test_interpolation(template: str, kwargs: dict[str, Any], expected: str) -> None: + assert path_template(template, **kwargs) == expected + + +def test_missing_kwarg_raises_key_error() -> None: + with pytest.raises(KeyError, match="org_id"): + path_template("/v1/{org_id}") + + +@pytest.mark.parametrize( + "template, kwargs", + [ + ("{a}/path", dict(a=".")), + ("{a}/path", dict(a="..")), + ("/v1/{a}", dict(a=".")), + ("/v1/{a}", dict(a="..")), + ("/v1/{a}/path", dict(a=".")), + ("/v1/{a}/path", dict(a="..")), + ("/v1/{a}{b}", dict(a=".", b=".")), # adjacent vars → ".." + ("/v1/{a}.", dict(a=".")), # var + static → ".." + ("/v1/{a}{b}", dict(a="", b=".")), # empty + dot → "." + ("/v1/%2e/{x}", dict(x="ok")), # encoded dot in static text + ("/v1/%2e./{x}", dict(x="ok")), # mixed encoded ".." in static + ("/v1/.%2E/{x}", dict(x="ok")), # mixed encoded ".." in static + ("/v1/{v}?q=1", dict(v="..")), + ("/v1/{v}#frag", dict(v="..")), + ], +) +def test_dot_segment_rejected(template: str, kwargs: dict[str, Any]) -> None: + with pytest.raises(ValueError, match="dot-segment"): + path_template(template, **kwargs) diff --git a/google/genai/tests/interactions/test_utils/test_proxy.py b/google/genai/tests/interactions/test_utils/test_proxy.py new file mode 100644 index 000000000..9e7871ed4 --- /dev/null +++ b/google/genai/tests/interactions/test_utils/test_proxy.py @@ -0,0 +1,49 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import operator +from typing import Any +from typing_extensions import override + +from google.genai._interactions._utils import LazyProxy + + +class RecursiveLazyProxy(LazyProxy[Any]): + @override + def __load__(self) -> Any: + return self + + def __call__(self, *_args: Any, **_kwds: Any) -> Any: + raise RuntimeError("This should never be called!") + + +def test_recursive_proxy() -> None: + proxy = RecursiveLazyProxy() + assert repr(proxy) == "RecursiveLazyProxy" + assert str(proxy) == "RecursiveLazyProxy" + assert dir(proxy) == [] + assert type(proxy).__name__ == "RecursiveLazyProxy" + assert type(operator.attrgetter("name.foo.bar.baz")(proxy)).__name__ == "RecursiveLazyProxy" + + +def test_isinstance_does_not_error() -> None: + class AlwaysErrorProxy(LazyProxy[Any]): + @override + def __load__(self) -> Any: + raise RuntimeError("Mocking missing dependency") + + proxy = AlwaysErrorProxy() + assert not isinstance(proxy, dict) + assert isinstance(proxy, LazyProxy) diff --git a/google/genai/tests/interactions/test_utils/test_typing.py b/google/genai/tests/interactions/test_utils/test_typing.py new file mode 100644 index 000000000..291132fb3 --- /dev/null +++ b/google/genai/tests/interactions/test_utils/test_typing.py @@ -0,0 +1,88 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +from typing import Generic, TypeVar, cast + +from google.genai._interactions._utils import extract_type_var_from_base + +_T = TypeVar("_T") +_T2 = TypeVar("_T2") +_T3 = TypeVar("_T3") + + +class BaseGeneric(Generic[_T]): ... + + +class SubclassGeneric(BaseGeneric[_T]): ... + + +class BaseGenericMultipleTypeArgs(Generic[_T, _T2, _T3]): ... + + +class SubclassGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T, _T2, _T3]): ... + + +class SubclassDifferentOrderGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T2, _T, _T3]): ... + + +def test_extract_type_var() -> None: + assert ( + extract_type_var_from_base( + BaseGeneric[int], + index=0, + generic_bases=cast("tuple[type, ...]", (BaseGeneric,)), + ) + == int + ) + + +def test_extract_type_var_generic_subclass() -> None: + assert ( + extract_type_var_from_base( + SubclassGeneric[int], + index=0, + generic_bases=cast("tuple[type, ...]", (BaseGeneric,)), + ) + == int + ) + + +def test_extract_type_var_multiple() -> None: + typ = BaseGenericMultipleTypeArgs[int, str, None] + + generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) + assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int + assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str + assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None) + + +def test_extract_type_var_generic_subclass_multiple() -> None: + typ = SubclassGenericMultipleTypeArgs[int, str, None] + + generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) + assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int + assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str + assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None) + + +def test_extract_type_var_generic_subclass_different_ordering_multiple() -> None: + typ = SubclassDifferentOrderGenericMultipleTypeArgs[int, str, None] + + generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) + assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int + assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str + assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None) diff --git a/google/genai/tests/interactions/utils.py b/google/genai/tests/interactions/utils.py new file mode 100644 index 000000000..4ae96e69b --- /dev/null +++ b/google/genai/tests/interactions/utils.py @@ -0,0 +1,182 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import os +import inspect +import traceback +import contextlib +from typing import Any, TypeVar, Iterator, Sequence, cast +from datetime import date, datetime +from typing_extensions import Literal, get_args, get_origin, assert_type + +from google.genai._interactions._types import Omit, NoneType +from google.genai._interactions._utils import ( + is_dict, + is_list, + is_list_type, + is_union_type, + extract_type_arg, + is_sequence_type, + is_annotated_type, + is_type_alias_type, +) +from google.genai._interactions._compat import PYDANTIC_V1, field_outer_type, get_model_fields +from google.genai._interactions._models import BaseModel + +BaseModelT = TypeVar("BaseModelT", bound=BaseModel) + + +def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool: + for name, field in get_model_fields(model).items(): + field_value = getattr(value, name) + if PYDANTIC_V1: + # in v1 nullability was structured differently + # https://docs.pydantic.dev/2.0/migration/#required-optional-and-nullable-fields + allow_none = getattr(field, "allow_none", False) + else: + allow_none = False + + assert_matches_type( + field_outer_type(field), + field_value, + path=[*path, name], + allow_none=allow_none, + ) + + return True + + +# Note: the `path` argument is only used to improve error messages when `--showlocals` is used +def assert_matches_type( + type_: Any, + value: object, + *, + path: list[str], + allow_none: bool = False, +) -> None: + if is_type_alias_type(type_): + type_ = type_.__value__ + + # unwrap `Annotated[T, ...]` -> `T` + if is_annotated_type(type_): + type_ = extract_type_arg(type_, 0) + + if allow_none and value is None: + return + + if type_ is None or type_ is NoneType: + assert value is None + return + + origin = get_origin(type_) or type_ + + if is_list_type(type_): + return _assert_list_type(type_, value) + + if is_sequence_type(type_): + assert isinstance(value, Sequence) + inner_type = get_args(type_)[0] + for entry in value: # type: ignore + assert_type(inner_type, entry) # type: ignore + return + + if origin == str: + assert isinstance(value, str) + elif origin == int: + assert isinstance(value, int) + elif origin == bool: + assert isinstance(value, bool) + elif origin == float: + assert isinstance(value, float) + elif origin == bytes: + assert isinstance(value, bytes) + elif origin == datetime: + assert isinstance(value, datetime) + elif origin == date: + assert isinstance(value, date) + elif origin == object: + # nothing to do here, the expected type is unknown + pass + elif origin == Literal: + assert value in get_args(type_) + elif origin == dict: + assert is_dict(value) + + args = get_args(type_) + key_type = args[0] + items_type = args[1] + + for key, item in value.items(): + assert_matches_type(key_type, key, path=[*path, ""]) + assert_matches_type(items_type, item, path=[*path, ""]) + elif is_union_type(type_): + variants = get_args(type_) + + try: + none_index = variants.index(type(None)) + except ValueError: + pass + else: + # special case Optional[T] for better error messages + if len(variants) == 2: + if value is None: + # valid + return + + return assert_matches_type(type_=variants[not none_index], value=value, path=path) + + for i, variant in enumerate(variants): + try: + assert_matches_type(variant, value, path=[*path, f"variant {i}"]) + return + except AssertionError: + traceback.print_exc() + continue + + raise AssertionError("Did not match any variants") + elif issubclass(origin, BaseModel): + assert isinstance(value, type_) + assert assert_matches_model(type_, cast(Any, value), path=path) + elif inspect.isclass(origin) and origin.__name__ == "HttpxBinaryResponseContent": + assert value.__class__.__name__ == "HttpxBinaryResponseContent" + else: + assert None, f"Unhandled field type: {type_}" + + +def _assert_list_type(type_: type[object], value: object) -> None: + assert is_list(value) + + inner_type = get_args(type_)[0] + for entry in value: + assert_type(inner_type, entry) # type: ignore + + +@contextlib.contextmanager +def update_env(**new_env: str | Omit) -> Iterator[None]: + old = os.environ.copy() + + try: + for name, value in new_env.items(): + if isinstance(value, Omit): + os.environ.pop(name, None) + else: + os.environ[name] = value + + yield None + finally: + os.environ.clear() + os.environ.update(old) diff --git a/google/genai/tests/interactions_manual/__init__.py b/google/genai/tests/interactions_manual/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/google/genai/tests/interactions_manual/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/google/genai/tests/interactions/test_auth.py b/google/genai/tests/interactions_manual/test_auth.py similarity index 100% rename from google/genai/tests/interactions/test_auth.py rename to google/genai/tests/interactions_manual/test_auth.py diff --git a/google/genai/tests/interactions/test_integration.py b/google/genai/tests/interactions_manual/test_integration.py similarity index 100% rename from google/genai/tests/interactions/test_integration.py rename to google/genai/tests/interactions_manual/test_integration.py diff --git a/google/genai/tests/interactions/test_paths.py b/google/genai/tests/interactions_manual/test_paths.py similarity index 100% rename from google/genai/tests/interactions/test_paths.py rename to google/genai/tests/interactions_manual/test_paths.py