diff --git a/google/genai/_api_client.py b/google/genai/_api_client.py index 2dfc0366b..9f33de151 100644 --- a/google/genai/_api_client.py +++ b/google/genai/_api_client.py @@ -702,7 +702,21 @@ def __init__( ) self._http_options.api_version = 'v1beta1' else: # Implicit initialization or missing arguments. - if not self.api_key: + if env_api_key and api_key: + # Explicit api_key takes precedence over implicit api_key. + logger.info( + 'The client initializer api_key argument takes ' + 'precedence over the API key from the environment variable.' + ) + if credentials: + if env_api_key: + logger.info( + 'The user `credentials` argument will take precedence over the' + ' api key from the environment variables.' + ) + self.api_key = None + + if not self.api_key and not credentials: raise ValueError( 'No API key was provided. Please pass a valid API key. Learn how to' ' create an API key at' @@ -1175,20 +1189,21 @@ def _request_once( stream: bool = False, ) -> HttpResponse: data: Optional[Union[str, bytes]] = None - # If using proj/location, fetch ADC - if self.vertexai and (self.project or self.location): + + uses_vertex_creds = self.vertexai and (self.project or self.location) + uses_mldev_creds = not self.vertexai and self._credentials + if (uses_vertex_creds or uses_mldev_creds): http_request.headers['Authorization'] = f'Bearer {self._access_token()}' if self._credentials and self._credentials.quota_project_id: http_request.headers['x-goog-user-project'] = ( self._credentials.quota_project_id ) - data = json.dumps(http_request.data) if http_request.data else None - else: - if http_request.data: - if not isinstance(http_request.data, bytes): - data = json.dumps(http_request.data) if http_request.data else None - else: - data = http_request.data + + if http_request.data: + if not isinstance(http_request.data, bytes): + data = json.dumps(http_request.data) + else: + data = http_request.data if stream: httpx_request = self._httpx_client.build_request( @@ -1241,8 +1256,9 @@ async def _async_request_once( ) -> HttpResponse: data: Optional[Union[str, bytes]] = None - # If using proj/location, fetch ADC - if self.vertexai and (self.project or self.location): + uses_vertex_creds = self.vertexai and (self.project or self.location) + uses_mldev_creds = not self.vertexai and self._credentials + if (uses_vertex_creds or uses_mldev_creds): http_request.headers['Authorization'] = ( f'Bearer {await self._async_access_token()}' ) @@ -1250,13 +1266,12 @@ async def _async_request_once( http_request.headers['x-goog-user-project'] = ( self._credentials.quota_project_id ) - data = json.dumps(http_request.data) if http_request.data else None - else: - if http_request.data: - if not isinstance(http_request.data, bytes): - data = json.dumps(http_request.data) if http_request.data else None - else: - data = http_request.data + + if http_request.data: + if not isinstance(http_request.data, bytes): + data = json.dumps(http_request.data) + else: + data = http_request.data if stream: if self._use_aiohttp(): diff --git a/google/genai/_extra_utils.py b/google/genai/_extra_utils.py index 129c05f7d..19e1bee31 100644 --- a/google/genai/_extra_utils.py +++ b/google/genai/_extra_utils.py @@ -16,16 +16,20 @@ """Extra utils depending on types that are shared between sync and async modules.""" import asyncio +from collections.abc import Callable, MutableMapping import inspect import io import logging import sys import typing -from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin +from typing import Any, Optional, Union, get_args, get_origin import mimetypes import os import pydantic +import google.auth.transport.requests + + from . import _common from . import _mcp_utils from . import _transformers as t @@ -677,3 +681,18 @@ def prepare_resumable_upload( http_options.headers = {} http_options.headers['X-Goog-Upload-File-Name'] = os.path.basename(file) return http_options, size_bytes, mime_type + + +async def _maybe_update_and_insert_auth_token( + headers:MutableMapping[str, str], + creds: google.auth.credentials.Credentials) -> None: + # Refresh credentials to ensure token is valid + if not (creds.token and creds.valid): + try: + auth_req = google.auth.transport.requests.Request() # type: ignore[no-untyped-call] + await asyncio.to_thread(creds.refresh, auth_req) + except Exception as e: + raise ConnectionError(f"Failed to refresh credentials") from e + + if not headers.get('Authorization'): + headers['Authorization'] = f'Bearer {creds.token}' diff --git a/google/genai/_interactions/_client.py b/google/genai/_interactions/_client.py index 64bc99587..d751df1be 100644 --- a/google/genai/_interactions/_client.py +++ b/google/genai/_interactions/_client.py @@ -178,7 +178,7 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: @override def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - if not self.client_adapter or not self.client_adapter.is_vertex_ai(): + if not self.client_adapter: return options headers = options.headers or {} @@ -400,7 +400,7 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: @override async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - if not self.client_adapter or not self.client_adapter.is_vertex_ai(): + if not self.client_adapter: return options headers = options.headers or {} diff --git a/google/genai/client.py b/google/genai/client.py index cec5cef4c..2ebf26e47 100644 --- a/google/genai/client.py +++ b/google/genai/client.py @@ -174,14 +174,12 @@ def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient: # uSDk expects ms, nextgen uses a httpx Timeout -> expects seconds. timeout=http_opts.timeout / 1000 if http_opts.timeout else None, max_retries=max_retries, - client_adapter=AsyncGeminiNextGenAPIClientAdapter(self._api_client) + client_adapter=AsyncGeminiNextGenAPIClientAdapter(self._api_client), ) - client = self._nextgen_client_instance - if self._api_client.vertexai: - client._is_vertex = True - client._vertex_project = self._api_client.project - client._vertex_location = self._api_client.location + self._nextgen_client_instance._is_vertex = self._api_client.vertexai or False + self._nextgen_client_instance._vertex_project = self._api_client.project + self._nextgen_client_instance._vertex_location = self._api_client.location return self._nextgen_client_instance @@ -525,11 +523,9 @@ def _nextgen_client(self) -> GeminiNextGenAPIClient: client_adapter=GeminiNextGenAPIClientAdapter(self._api_client), ) - client = self._nextgen_client_instance - if self._api_client.vertexai: - client._is_vertex = True - client._vertex_project = self._api_client.project - client._vertex_location = self._api_client.location + self._nextgen_client_instance._is_vertex = self._api_client.vertexai or False + self._nextgen_client_instance._vertex_project = self._api_client.project + self._nextgen_client_instance._vertex_location = self._api_client.location return self._nextgen_client_instance diff --git a/google/genai/live.py b/google/genai/live.py index 93953a02f..08a2f3d78 100644 --- a/google/genai/live.py +++ b/google/genai/live.py @@ -29,6 +29,7 @@ import websockets from . import _api_module +from . import _extra_utils from . import _common from . import _live_converters as live_converters from . import _mcp_utils @@ -946,17 +947,103 @@ async def connect( base_url = self._api_client._websocket_base_url() if isinstance(base_url, bytes): base_url = base_url.decode('utf-8') - transformed_model = t.t_model(self._api_client, model) # type: ignore parameter_model = await _t_live_connect_config(self._api_client, config) - if self._api_client.api_key and not self._api_client.vertexai: - version = self._api_client._http_options.api_version - api_key = self._api_client.api_key - method = 'BidiGenerateContent' - original_headers = self._api_client._http_options.headers - headers = original_headers.copy() if original_headers is not None else {} + if self._api_client.vertexai: + uri, headers, request = await self._prepare_connection_vertex( + base_url=base_url, model=model, parameter_model=parameter_model + ) + else: + uri, headers, request = await self._prepare_connection_mldev( + base_url=base_url, model=model, parameter_model=parameter_model + ) + + if parameter_model.tools and _mcp_utils.has_mcp_tool_usage( + parameter_model.tools + ): + if headers is None: + headers = {} + _mcp_utils.set_mcp_usage_header(headers) + + async with ws_connect( + uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx + ) as ws: + await ws.send(request) + try: + # websockets 14.0+ + raw_response = await ws.recv(decode=False) + except TypeError: + raw_response = await ws.recv() # type: ignore[assignment] + except ConnectionClosed as e: + if e.rcvd: + code = e.rcvd.code + reason = e.rcvd.reason + else: + code = 1006 + reason = 'Abnormal closure.' + errors.APIError.raise_error(code, reason, None) + if raw_response: + try: + response = json.loads(raw_response) + except json.decoder.JSONDecodeError as e: + raise ValueError(f'Failed to parse response: {raw_response!r}') from e + else: + response = {} + + if self._api_client.vertexai: + response_dict = live_converters._LiveServerMessage_from_vertex(response) + else: + response_dict = response + + setup_response = types.LiveServerMessage._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + if setup_response.setup_complete: + session_id = setup_response.setup_complete.session_id + else: + session_id = None + yield AsyncSession( + api_client=self._api_client, + websocket=ws, + session_id=session_id, + ) + + async def _prepare_connection_mldev( + self, *, + base_url: str, + model: str, + parameter_model: types.LiveConnectConfig, + ) -> tuple[str, _common.StringDict, str]: + """Prepares live connection parameters for the MLDev API. + + Constructs the WebSocket URI, headers, and request body necessary + to establish a connection with the MLDev backend. + + Args: + base_url: The base URL for the WebSocket connection. + model: The name of the model to use. + parameter_model: Configuration parameters for the connection. + + Returns: + A tuple containing: + - uri: The WebSocket connection URI. + - headers: A dictionary of headers for the connection. + - request: The JSON-serialized request body. + + Raises: + ValueError: If an API key is not provided. + """ + transformed_model = t.t_model(self._api_client, model) # type: ignore + version = self._api_client._http_options.api_version + method = 'BidiGenerateContent' + original_headers = self._api_client._http_options.headers + headers = original_headers.copy() if original_headers is not None else {} + + if api_key := self._api_client.api_key: if api_key.startswith('auth_tokens/'): + method = 'BidiGenerateContentConstrained' + headers['Authorization'] = f'Token {api_key}' warnings.warn( message=( "The SDK's ephemeral token support is experimental, and may" @@ -964,8 +1051,6 @@ async def connect( ), category=errors.ExperimentalWarning, ) - method = 'BidiGenerateContentConstrained' - headers['Authorization'] = f'Token {api_key}' if version != 'v1alpha': warnings.warn( message=( @@ -976,47 +1061,68 @@ async def connect( ), category=errors.ExperimentalWarning, ) - uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}' - - request_dict = _common.convert_to_dict( - live_converters._LiveConnectParameters_to_mldev( - api_client=self._api_client, - from_object=types.LiveConnectParameters( - model=transformed_model, - config=parameter_model, - ).model_dump(exclude_none=True), - ) - ) + elif creds := self._api_client._credentials: + await _extra_utils._maybe_update_and_insert_auth_token(headers, creds) + else: + # this shouldn't happen. + raise ValueError('Genai live connection requires credentials or API key provided.') + + uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}' + + request_dict = _common.convert_to_dict( + live_converters._LiveConnectParameters_to_mldev( + api_client=self._api_client, + from_object=types.LiveConnectParameters( + model=transformed_model, + config=parameter_model, + ).model_dump(exclude_none=True), + ) + ) + del request_dict['config'] + request_dict = _common.encode_unserializable_types(request_dict) - del request_dict['config'] - request_dict = _common.encode_unserializable_types(request_dict) - setv(request_dict, ['setup', 'model'], transformed_model) + setv(request_dict, ['setup', 'model'], transformed_model) - request = json.dumps(request_dict) - elif self._api_client.api_key and self._api_client.vertexai: - # Headers already contains api key for express mode. - api_key = self._api_client.api_key - version = self._api_client._http_options.api_version - uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent' - original_headers = self._api_client._http_options.headers - headers = original_headers.copy() if original_headers is not None else {} - - request_dict = _common.convert_to_dict( - live_converters._LiveConnectParameters_to_vertex( - api_client=self._api_client, - from_object=types.LiveConnectParameters( - model=transformed_model, - config=parameter_model, - ).model_dump(exclude_none=True), - ) - ) - del request_dict['config'] - request_dict = _common.encode_unserializable_types(request_dict) - setv(request_dict, ['setup', 'model'], transformed_model) + return uri, headers, json.dumps(request_dict) - request = json.dumps(request_dict) + + async def _prepare_connection_vertex( + self, *, + base_url: str, + model: str, + parameter_model: types.LiveConnectConfig, + ) -> tuple[str, _common.StringDict, str]: + """Prepares live connection parameters for the Vertex AI API. + + Constructs the WebSocket URI, headers, and request body necessary + to establish a connection with the Vertex AI backend. Handles + authentication using either an API key or default credentials. + + Args: + base_url: The base URL for the WebSocket connection. + model: The name of the model to use. + parameter_model: Configuration parameters for the connection. + + Returns: + A tuple containing: + - uri: The WebSocket connection URI. + - headers: A dictionary of headers for the connection. + - request: The JSON-serialized request body. + + Raises: + ValueError: If project and location are not provided when + default credentials are used. + """ + transformed_model = t.t_model(self._api_client, model) # type: ignore + version = self._api_client._http_options.api_version + original_headers = self._api_client._http_options.headers + headers = ( + original_headers.copy() if original_headers is not None else {} + ) + if api_key := self._api_client.api_key: + # Headers already contains api key + uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent' else: - version = self._api_client._http_options.api_version has_sufficient_auth = ( self._api_client.project and self._api_client.location ) @@ -1044,19 +1150,8 @@ async def connect( creds = self._api_client._credentials # creds.valid is False, and creds.token is None # Need to refresh credentials to populate those - if not (creds.token and creds.valid): - if requests is None: - raise ValueError('The requests module is required to refresh google-auth credentials. Please install with `pip install google-auth[requests]`') - auth_req = requests.Request() # type: ignore - creds.refresh(auth_req) # type: ignore[no-untyped-call] - bearer_token = creds.token + await _extra_utils._maybe_update_and_insert_auth_token(headers, creds) - original_headers = self._api_client._http_options.headers - headers = ( - original_headers.copy() if original_headers is not None else {} - ) - if not headers.get('Authorization'): - headers['Authorization'] = f'Bearer {bearer_token}' location = self._api_client.location project = self._api_client.project @@ -1064,17 +1159,23 @@ async def connect( transformed_model = ( f'projects/{project}/locations/{location}/' + transformed_model ) - request_dict = _common.convert_to_dict( - live_converters._LiveConnectParameters_to_vertex( - api_client=self._api_client, - from_object=types.LiveConnectParameters( - model=transformed_model, - config=parameter_model, - ).model_dump(exclude_none=True), - ) - ) - del request_dict['config'] - request_dict = _common.encode_unserializable_types(request_dict) + + request_dict = _common.convert_to_dict( + live_converters._LiveConnectParameters_to_vertex( + api_client=self._api_client, + from_object=types.LiveConnectParameters( + model=transformed_model, + config=parameter_model, + ).model_dump(exclude_none=True), + ) + ) + del request_dict['config'] + request_dict = _common.encode_unserializable_types(request_dict) + + if api_key is None: + # Refactor note: I'm surprised the two paths are different, you'd have + # to test every model to be sure. The goal of this refactor is to not + # change any behavior so leaving it as is. if ( getv( request_dict, ['setup', 'generationConfig', 'responseModalities'] @@ -1087,57 +1188,7 @@ async def connect( ['AUDIO'], ) - request = json.dumps(request_dict) - - if parameter_model.tools and _mcp_utils.has_mcp_tool_usage( - parameter_model.tools - ): - if headers is None: - headers = {} - _mcp_utils.set_mcp_usage_header(headers) - - async with ws_connect( - uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx - ) as ws: - await ws.send(request) - try: - # websockets 14.0+ - raw_response = await ws.recv(decode=False) - except TypeError: - raw_response = await ws.recv() # type: ignore[assignment] - except ConnectionClosed as e: - if e.rcvd: - code = e.rcvd.code - reason = e.rcvd.reason - else: - code = 1006 - reason = 'Abnormal closure.' - errors.APIError.raise_error(code, reason, None) - if raw_response: - try: - response = json.loads(raw_response) - except json.decoder.JSONDecodeError: - raise ValueError(f'Failed to parse response: {raw_response!r}') - else: - response = {} - - if self._api_client.vertexai: - response_dict = live_converters._LiveServerMessage_from_vertex(response) - else: - response_dict = response - - setup_response = types.LiveServerMessage._from_response( - response=response_dict, kwargs=parameter_model.model_dump() - ) - if setup_response.setup_complete: - session_id = setup_response.setup_complete.session_id - else: - session_id = None - yield AsyncSession( - api_client=self._api_client, - websocket=ws, - session_id=session_id, - ) + return uri, headers, json.dumps(request_dict) async def _t_live_connect_config( diff --git a/google/genai/live_music.py b/google/genai/live_music.py index 8730e08f9..cdc2111c8 100644 --- a/google/genai/live_music.py +++ b/google/genai/live_music.py @@ -22,6 +22,7 @@ import websockets from . import _api_module +from . import _extra_utils from . import _common from . import _live_converters as live_converters from . import _transformers as t @@ -169,31 +170,42 @@ class AsyncLiveMusic(_api_module.BaseModule): @contextlib.asynccontextmanager async def connect(self, *, model: str) -> AsyncIterator[AsyncMusicSession]: """[Experimental] Connect to the live music server.""" + if self._api_client.vertexai: + raise NotImplementedError('Live music generation is not supported in Vertex AI.') + base_url = self._api_client._websocket_base_url() if isinstance(base_url, bytes): base_url = base_url.decode('utf-8') transformed_model = t.t_model(self._api_client, model) + version = self._api_client._http_options.api_version + original_headers = self._api_client._http_options.headers + headers = original_headers.copy() if original_headers is not None else {} + if self._api_client.api_key: - api_key = self._api_client.api_key - version = self._api_client._http_options.api_version - uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic?key={api_key}' - headers = self._api_client._http_options.headers - - # Only mldev supported - request_dict = _common.convert_to_dict( - live_converters._LiveMusicConnectParameters_to_mldev( - from_object=types.LiveMusicConnectParameters( - model=transformed_model, - ).model_dump(exclude_none=True) - ) - ) + # API key is already included in headers. + pass + elif creds := self._api_client._credentials: + await _extra_utils._maybe_update_and_insert_auth_token(headers, creds) + else: + # This shouldn't happen. + raise ValueError('Genai live music connection requires credentials or API key provided.') + + uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic' + + # Only mldev supported + request_dict = _common.convert_to_dict( + live_converters._LiveMusicConnectParameters_to_mldev( + from_object=types.LiveMusicConnectParameters( + model=transformed_model, + ).model_dump(exclude_none=True) + ) + ) - setv(request_dict, ['setup', 'model'], transformed_model) + setv(request_dict, ['setup', 'model'], transformed_model) + + request = json.dumps(request_dict) - request = json.dumps(request_dict) - else: - raise NotImplementedError('Live music generation is not supported in Vertex AI.') try: async with connect(uri, additional_headers=headers) as ws: diff --git a/google/genai/tests/client/test_client_initialization.py b/google/genai/tests/client/test_client_initialization.py index 7b0136044..9daf50d38 100644 --- a/google/genai/tests/client/test_client_initialization.py +++ b/google/genai/tests/client/test_client_initialization.py @@ -47,6 +47,28 @@ ) +class FakeCredentials(credentials.Credentials): + def __init__(self, token="fake_token", expired=False, quota_project_id=None): + super().__init__() + self.token = token + self._expired = expired + self._quota_project_id = quota_project_id + self.refresh_count = 0 + + @property + def expired(self): + return self._expired + + @property + def quota_project_id(self): + return self._quota_project_id + + def refresh(self, request): + self.refresh_count += 1 + self.token = "refreshed_token" + self._expired = False + + @pytest.fixture(autouse=True) def reset_has_aiohttp(): yield @@ -1721,3 +1743,142 @@ async def test_get_aiohttp_session(): assert initial_session is not None session = await client._api_client._get_aiohttp_session() assert session is initial_session + + +def test_missing_api_key_and_credentials(monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "") + with pytest.raises(ValueError, match="No API key was provided"): + Client() + + +auth_precedence_test_cases = [ + # client_args, env_vars, expected_headers + ( + {"credentials": FakeCredentials()}, + {"GOOGLE_API_KEY": "env_api_key"}, + {"Authorization": "Bearer fake_token"} + ), + ( + {"credentials": FakeCredentials(quota_project_id="quota-proj")}, + {"GOOGLE_API_KEY": "env_api_key"}, + { + "Authorization": "Bearer fake_token", + "x-goog-user-project": "quota-proj" + } + ), + ( + {"api_key": "test_api_key"}, + {"GOOGLE_API_KEY": "env_api_key"}, + {"x-goog-api-key": "test_api_key"} + ), + ( + {}, + {"GOOGLE_API_KEY": "env_api_key"}, + {"x-goog-api-key": "env_api_key"} + ), +] + + +@pytest.mark.parametrize( + ["client_kwargs", "env_vars", "expected_headers"], + auth_precedence_test_cases, +) +@mock.patch.object(httpx.Client, "send", autospec=True) +def test_auth_precedence_mldev(mock_send, monkeypatch, client_kwargs, env_vars, expected_headers): + for key, value in env_vars.items(): + monkeypatch.setenv(key, value) + + client = Client(**client_kwargs) + mock_send.return_value = httpx.Response( + status_code=200, + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]} + ) + client.models.generate_content(model="test", contents="hello?") + mock_send.assert_called_once() + request = mock_send.call_args[0][1] + + for key, value in expected_headers.items(): + assert key in request.headers + assert request.headers[key] == value + + if "Authorization" in expected_headers: + assert "x-goog-api-key" not in request.headers + if "x-goog-api-key" in expected_headers: + assert "Authorization" not in request.headers + if "x-goog-user-project" not in expected_headers: + assert "x-goog-user-project" not in request.headers + assert request.content == b'{"contents": [{"parts": [{"text": "hello?"}], "role": "user"}]}' + +@pytest.mark.parametrize( + ["client_kwargs", "env_vars", "expected_headers"], + auth_precedence_test_cases, +) +@pytest.mark.asyncio +@mock.patch.object(httpx.AsyncClient, "send", autospec=True) +async def test_async_auth_precedence_mldev(mock_send, monkeypatch, client_kwargs, env_vars, expected_headers): + for key, value in env_vars.items(): + monkeypatch.setenv(key, value) + + client = Client(**client_kwargs) + mock_send.return_value = httpx.Response( + status_code=200, + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]} + ) + await client.aio.models.generate_content(model="test", contents="hello?") + mock_send.assert_called_once() + request = mock_send.call_args[0][1] + + for key, value in expected_headers.items(): + assert key in request.headers + assert request.headers[key] == value + + if "Authorization" in expected_headers: + assert "x-goog-api-key" not in request.headers + if "x-goog-api-key" in expected_headers: + assert "Authorization" not in request.headers + if "x-goog-user-project" not in expected_headers: + assert "x-goog-user-project" not in request.headers + assert request.content == b'{"contents": [{"parts": [{"text": "hello?"}], "role": "user"}]}' + + +@pytest.mark.asyncio +async def test_both_credentials_mldev(): + with pytest.raises(ValueError, match="mutually exclusive"): + creds = FakeCredentials(expired=True) + client = Client(credentials=creds, api_key="test-api-key") + + +@mock.patch.object(httpx.Client, "send", autospec=True) +def test_refresh_credentials_mldev(mock_send): + creds = FakeCredentials(expired=True) + client = Client(credentials=creds) + mock_send.return_value = httpx.Response( + status_code=200, + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]}, + ) + client.models.generate_content(model="test", contents="hello?") + mock_send.assert_called_once() + request = mock_send.call_args[0][1] + assert "Authorization" in request.headers + assert request.headers["Authorization"] == "Bearer refreshed_token" + assert "x-goog-api-key" not in request.headers + assert creds.refresh_count == 1 + + +@requires_aiohttp +@pytest.mark.asyncio +@mock.patch.object(httpx.AsyncClient, "send", autospec=True) +async def test_async_refresh_credentials_mldev(mock_send): + creds = FakeCredentials(expired=True) + client = Client(credentials=creds) + mock_send.return_value = httpx.Response( + status_code=200, + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]}, + ) + await client.aio.models.generate_content(model="test", contents="hello?") + mock_send.assert_called_once() + request = mock_send.call_args[0][1] + assert "Authorization" in request.headers + assert request.headers["Authorization"] == "Bearer refreshed_token" + assert "x-goog-api-key" not in request.headers + assert creds.refresh_count == 1 diff --git a/google/genai/tests/interactions/test_auth.py b/google/genai/tests/interactions/test_auth.py index 5c92a225e..ff5f5ce4d 100644 --- a/google/genai/tests/interactions/test_auth.py +++ b/google/genai/tests/interactions/test_auth.py @@ -474,3 +474,64 @@ async def test_async_interactions_vertex_extra_headers_override(): assert headers['x-goog-api-key'] == 'manual-key' assert 'authorization' not in headers mock_access_token.assert_not_called() + +def test_interactions_mldev_auth_header(): + from ..._api_client import BaseApiClient + from httpx import Client as HTTPClient + + creds = mock.Mock() + creds.quota_project_id = "test-quota-project" + client = Client(vertexai=False, credentials=creds) + + with ( + mock.patch.object( + BaseApiClient, "_access_token", return_value='fake-mldev-token' + ) as mock_access_token, + mock.patch.object( + HTTPClient, "send", + return_value=mock.Mock(), + ) as mock_send, + ): + + client.interactions.create( + model='gemini-3-flash-preview', + input='Hello', + ) + + mock_send.assert_called_once() + mock_access_token.assert_called_once() + args, kwargs = mock_send.call_args + headers = args[0].headers + assert headers['authorization'] == 'Bearer fake-mldev-token' + assert headers['x-goog-user-project'] == 'test-quota-project' + +@pytest.mark.asyncio +async def test_async_interactions_mldev_auth_header(): + from ..._api_client import BaseApiClient + from ..._api_client import AsyncHttpxClient + + creds = mock.Mock() + creds.quota_project_id = "test-quota-project" + client = Client(vertexai=False, credentials=creds) + + with ( + mock.patch.object( + BaseApiClient, "_async_access_token", return_value='fake-mldev-token' + ) as mock_access_token, + mock.patch.object( + AsyncHttpxClient, "send", + return_value=mock.Mock(), + ) as mock_send, + ): + + await client.aio.interactions.create( + model='gemini-3-flash-preview', + input='Hello', + ) + + mock_send.assert_called_once() + mock_access_token.assert_called_once() + args, kwargs = mock_send.call_args + headers = args[0].headers + assert headers['authorization'] == 'Bearer fake-mldev-token' + assert headers['x-goog-user-project'] == 'test-quota-project' diff --git a/google/genai/tests/live/test_live.py b/google/genai/tests/live/test_live.py index c0b60e595..1af139272 100644 --- a/google/genai/tests/live/test_live.py +++ b/google/genai/tests/live/test_live.py @@ -29,6 +29,8 @@ import warnings import certifi +import google.auth +from google.auth.transport import requests from google.oauth2.credentials import Credentials import pytest from websockets import client @@ -40,6 +42,9 @@ from ... import client as gl_client from ... import live from ... import types +from ... import _extra_utils +from google.auth import credentials + try: import aiohttp AIOHTTP_NOT_INSTALLED = False @@ -85,6 +90,23 @@ }] +class FakeCredentials(Credentials): + def __init__(self, token='fake_token', valid=True): + super().__init__(token='placeholder') + self.token = token + self._valid = valid + self.refresh_called = False + + def refresh(self, request): + self.token = 'refreshed_token' + self._valid = True + self.refresh_called = True + + @property + def valid(self): + return self._valid + + def get_current_weather(location: str, unit: str): """Get the current weather in a city.""" return 15 if unit == 'C' else 59 @@ -2141,3 +2163,120 @@ async def mock_connect(uri, additional_headers=None, **kwargs): assert capture['headers']['x-goog-api-key'] == 'TEST_API_KEY' assert 'BidiGenerateContent' in capture['uri'] + +@pytest.mark.asyncio +async def test_prepare_connection_vertex_with_api_key(mock_websocket): + # Test the branch where api_key is present in vertexai + client = Client(vertexai=True, api_key="test_api_key") + capture = {} + + @contextlib.asynccontextmanager + async def mock_ws_connect(uri, additional_headers=None, **kwargs): + capture['uri'] = uri + capture['headers'] = additional_headers + yield mock_websocket + + with patch.object(live, 'ws_connect', new=mock_ws_connect): + live_module = client.aio.live + async with live_module.connect(model='test-model'): + pass + + headers = capture['headers'] + uri = capture['uri'] + assert 'x-goog-api-key' in headers + assert headers['x-goog-api-key'] == "test_api_key" + # Authorization header should not be added by this method if api_key is used + assert 'Authorization' not in headers + assert "BidiGenerateContent" in uri + + +@pytest.mark.asyncio +async def test_prepare_connection_vertex_refresh_creds(mock_websocket): + # Test the branch where credentials need refreshing + fake_creds = FakeCredentials(token=None, valid=False) + capture = {} + + @contextlib.asynccontextmanager + async def mock_ws_connect(uri, additional_headers=None, **kwargs): + capture['uri'] = uri + capture['headers'] = additional_headers + yield mock_websocket + + with ( + patch.object(google.auth, 'default', return_value=(fake_creds, "test-project")), + patch.object(requests, 'Request', return_value=Mock()), + patch.object(live, 'ws_connect', new=mock_ws_connect) + ): + client = Client(vertexai=True, project="test-project", + location="us-central1") + live_module = client.aio.live + async with live_module.connect(model='test-model'): + pass + + headers = capture['headers'] + uri = capture['uri'] + assert fake_creds.refresh_called + assert 'Authorization' in headers + assert headers['Authorization'] == f'Bearer refreshed_token' + assert "BidiGenerateContent" in uri + + +@pytest.mark.asyncio +async def test_async_live_connect_with_api_key(mock_websocket): + client = api_client.BaseApiClient(api_key='test_api_key') + async_live = live.AsyncLive(client) + capture = {} + + @contextlib.asynccontextmanager + async def mock_connect(uri, additional_headers=None, **kwargs): + capture['headers'] = additional_headers + yield mock_websocket + + with mock.patch.object(live, 'ws_connect', new=mock_connect): + async with async_live.connect(model='models/test-model'): + pass + + assert 'headers' in capture + headers = capture['headers'] + + assert headers['x-goog-api-key'] == 'test_api_key' + + assert 'Authorization' not in headers + +@pytest.mark.parametrize( + "creds, existing_headers, expected_auth, expect_refresh", + [ + (FakeCredentials(), {}, 'Bearer fake_token', False), + (FakeCredentials(valid=False), {}, 'Bearer refreshed_token', True), + (FakeCredentials(token=None, valid=False), {}, 'Bearer refreshed_token', True), + (FakeCredentials(token='existing_token', valid=True), {}, 'Bearer existing_token', False), + (FakeCredentials(token='new_token', valid=True), {'Authorization': 'Bearer old_token'}, 'Bearer old_token', False), + ], +) +@pytest.mark.asyncio +async def test_async_live_connect_with_credentials( + mock_websocket, creds, existing_headers, expected_auth, expect_refresh +): + client = api_client.BaseApiClient(credentials=creds) + if existing_headers: + client._http_options.headers = existing_headers + async_live = live.AsyncLive(client) + capture = {} + + @contextlib.asynccontextmanager + async def mock_connect(uri, additional_headers=None, **kwargs): + capture['headers'] = additional_headers + yield mock_websocket + + with ( + mock.patch.object(live, 'ws_connect', new=mock_connect), + mock.patch.object(requests, 'Request', autospec=True) + ): + async with async_live.connect(model='models/test-model'): + pass + + assert 'headers' in capture + headers = capture['headers'] + assert headers.get('Authorization') == expected_auth + assert 'x-goog-api-key' not in headers + assert creds.refresh_called == expect_refresh diff --git a/google/genai/tests/live/test_live_music.py b/google/genai/tests/live/test_live_music.py index f51f8248d..b7a81509d 100644 --- a/google/genai/tests/live/test_live_music.py +++ b/google/genai/tests/live/test_live_music.py @@ -24,6 +24,8 @@ from unittest.mock import patch import warnings +import google.auth +import google.auth.transport.requests from google.oauth2.credentials import Credentials import pytest from websockets import client @@ -36,6 +38,7 @@ from ... import live_music from ... import types from .. import pytest_helper + try: import aiohttp AIOHTTP_NOT_INSTALLED = False @@ -49,10 +52,27 @@ ) -def mock_api_client(vertexai=False, credentials=None): +class FakeCredentials(Credentials): + def __init__(self, token='fake_token', valid=True): + super().__init__(token='placeholder') + self.token = token + self._valid = valid + self.refresh_called = False + + def refresh(self, request): + self.token = 'refreshed_token' + self._valid = True + self.refresh_called = True + + @property + def valid(self): + return self._valid + + +def mock_api_client(vertexai=False, credentials=None, api_key='TEST_API_KEY'): api_client = mock.MagicMock(spec=gl_client.BaseApiClient) if not vertexai: - api_client.api_key = 'TEST_API_KEY' + api_client.api_key = api_key api_client.location = None api_client.project = None else: @@ -67,9 +87,24 @@ def mock_api_client(vertexai=False, credentials=None): ) # Ensure headers exist api_client.vertexai = vertexai api_client._api_client = api_client + api_client._websocket_base_url = lambda: 'wss://test.com' return api_client +@pytest.fixture(autouse=True) +def clear_env(monkeypatch): + monkeypatch.delenv("GOOGLE_API_KEY", raising=False) + monkeypatch.delenv("GEMINI_API_KEY", raising=False) + monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False) + monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False) + monkeypatch.delenv("GOOGLE_GENAI_USE_VERTEXAI", raising=False) + monkeypatch.delenv("GOOGLE_GENAI_CLIENT_MODE", raising=False) + monkeypatch.delenv("GOOGLE_GENAI_REPLAY_ID", raising=False) + monkeypatch.delenv("GOOGLE_GENAI_REPLAYS_DIRECTORY", raising=False) + monkeypatch.delenv("GOOGLE_GEMINI_BASE_URL", raising=False) + monkeypatch.delenv("GOOGLE_VERTEX_BASE_URL", raising=False) + + @pytest.fixture def mock_websocket(): websocket = AsyncMock(spec=client.ClientConnection) @@ -115,14 +150,15 @@ async def get_connect_message(api_client, model): mock_google_auth_default.return_value = (mock_creds, None) @contextlib.asynccontextmanager - async def mock_connect(uri, additional_headers=None): + async def mock_connect(uri, additional_headers=None, **kwargs): yield mock_ws - @patch('google.auth.default', new=mock_google_auth_default) - @patch.object(live_music, 'connect', new=mock_connect) - async def _test_connect(): - live_module = live.AsyncLive(api_client) - async with live_module.music.connect( + with ( + patch.object(google.auth, 'default', new=mock_google_auth_default), + patch.object(live_music, 'connect', new=mock_connect) + ): + live_module = live_music.AsyncLiveMusic(api_client) + async with live_module.connect( model=model, ): pass @@ -130,7 +166,6 @@ async def _test_connect(): mock_ws.send.assert_called_once() return json.loads(mock_ws.send.call_args[0][0]) - return await _test_connect() def test_mldev_from_env(monkeypatch): @@ -142,6 +177,7 @@ def test_mldev_from_env(monkeypatch): assert not client.aio.live.music._api_client.vertexai assert client.aio.live.music._api_client.api_key == api_key assert isinstance(client.aio.live._api_client, api_client.BaseApiClient) + assert client.aio.live._api_client._http_options.headers['x-goog-api-key'] == api_key @requires_aiohttp @@ -360,3 +396,70 @@ async def test_setup_to_api(vertexai): else: expected_result['setup']['model'] = 'models/test_model' assert result == expected_result + +@pytest.mark.asyncio +async def test_connect_with_api_key(mock_websocket): + client = Client(api_key='TEST_API_KEY', http_options={'api_version': 'v1test'}) + client._api_client._websocket_base_url = lambda: 'wss://test.com' + live_module = client.aio.live.music + capture = {} + + @contextlib.asynccontextmanager + async def mock_connect(uri, additional_headers=None, **kwargs): + capture['uri'] = uri + capture['headers'] = additional_headers + yield mock_websocket + + with patch.object(live_music, 'connect', new=mock_connect): + async with live_module.connect(model='test-model'): + pass + + assert capture['uri'] == 'wss://test.com/ws/google.ai.generativelanguage.v1test.GenerativeService.BidiGenerateMusic' + assert capture['headers']['x-goog-api-key'] == 'TEST_API_KEY' + assert 'Authorization' not in capture['headers'] + +@pytest.mark.parametrize( + "creds, existing_headers, expected_auth, expect_refresh", + [ + (FakeCredentials(), {}, 'Bearer fake_token', False), + (FakeCredentials(valid=False), {}, 'Bearer refreshed_token', True), + (FakeCredentials(token=None, valid=False), {}, 'Bearer refreshed_token', True), + (FakeCredentials(token='existing_token', valid=True), {}, 'Bearer existing_token', False), + (FakeCredentials(token='new_token', valid=True), {'Authorization': 'Bearer old_token'}, 'Bearer old_token', False), + ], +) +@pytest.mark.asyncio +async def test_connect_with_credentials( + mock_websocket, creds, existing_headers, expected_auth, expect_refresh +): + client = api_client.BaseApiClient(credentials=creds, http_options={'api_version': 'v1test'}) + if existing_headers: + client._http_options.headers = existing_headers + client._websocket_base_url = lambda: 'wss://test.com' + live_module = live_music.AsyncLiveMusic(client) + capture = {} + + @contextlib.asynccontextmanager + async def mock_connect(uri, additional_headers=None, **kwargs): + capture['uri'] = uri + capture['headers'] = additional_headers + yield mock_websocket + + with patch.object(live_music, 'connect', new=mock_connect): + with patch.object(google.auth.transport.requests, 'Request', autospec=True): + async with live_module.connect(model='test-model'): + pass + + assert capture['uri'] == 'wss://test.com/ws/google.ai.generativelanguage.v1test.GenerativeService.BidiGenerateMusic' + headers = capture['headers'] + assert headers.get('Authorization') == expected_auth + assert 'x-goog-api-key' not in headers + assert creds.refresh_called == expect_refresh + +@pytest.mark.asyncio +async def test_connect_vertex_unsupported(mock_websocket): + client = Client(vertexai=True, project='test', location='us-central1') + live_module = client.aio.live.music + with pytest.raises(NotImplementedError): + async with live_module.connect(model='test-model'): + pass