From 09dd9bb4a93eaa4320579e93fdac89260ad52983 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 25 Feb 2026 14:48:51 -0800 Subject: [PATCH] Feat: add support for using provided credentials in non-Vertex mode. PiperOrigin-RevId: 875347183 --- google/genai/_api_client.py | 53 +++--- .../client/test_client_initialization.py | 161 ++++++++++++++++++ 2 files changed, 195 insertions(+), 19 deletions(-) diff --git a/google/genai/_api_client.py b/google/genai/_api_client.py index 2dfc0366b..9f33de151 100644 --- a/google/genai/_api_client.py +++ b/google/genai/_api_client.py @@ -702,7 +702,21 @@ def __init__( ) self._http_options.api_version = 'v1beta1' else: # Implicit initialization or missing arguments. - if not self.api_key: + if env_api_key and api_key: + # Explicit api_key takes precedence over implicit api_key. + logger.info( + 'The client initializer api_key argument takes ' + 'precedence over the API key from the environment variable.' + ) + if credentials: + if env_api_key: + logger.info( + 'The user `credentials` argument will take precedence over the' + ' api key from the environment variables.' + ) + self.api_key = None + + if not self.api_key and not credentials: raise ValueError( 'No API key was provided. Please pass a valid API key. Learn how to' ' create an API key at' @@ -1175,20 +1189,21 @@ def _request_once( stream: bool = False, ) -> HttpResponse: data: Optional[Union[str, bytes]] = None - # If using proj/location, fetch ADC - if self.vertexai and (self.project or self.location): + + uses_vertex_creds = self.vertexai and (self.project or self.location) + uses_mldev_creds = not self.vertexai and self._credentials + if (uses_vertex_creds or uses_mldev_creds): http_request.headers['Authorization'] = f'Bearer {self._access_token()}' if self._credentials and self._credentials.quota_project_id: http_request.headers['x-goog-user-project'] = ( self._credentials.quota_project_id ) - data = json.dumps(http_request.data) if http_request.data else None - else: - if http_request.data: - if not isinstance(http_request.data, bytes): - data = json.dumps(http_request.data) if http_request.data else None - else: - data = http_request.data + + if http_request.data: + if not isinstance(http_request.data, bytes): + data = json.dumps(http_request.data) + else: + data = http_request.data if stream: httpx_request = self._httpx_client.build_request( @@ -1241,8 +1256,9 @@ async def _async_request_once( ) -> HttpResponse: data: Optional[Union[str, bytes]] = None - # If using proj/location, fetch ADC - if self.vertexai and (self.project or self.location): + uses_vertex_creds = self.vertexai and (self.project or self.location) + uses_mldev_creds = not self.vertexai and self._credentials + if (uses_vertex_creds or uses_mldev_creds): http_request.headers['Authorization'] = ( f'Bearer {await self._async_access_token()}' ) @@ -1250,13 +1266,12 @@ async def _async_request_once( http_request.headers['x-goog-user-project'] = ( self._credentials.quota_project_id ) - data = json.dumps(http_request.data) if http_request.data else None - else: - if http_request.data: - if not isinstance(http_request.data, bytes): - data = json.dumps(http_request.data) if http_request.data else None - else: - data = http_request.data + + if http_request.data: + if not isinstance(http_request.data, bytes): + data = json.dumps(http_request.data) + else: + data = http_request.data if stream: if self._use_aiohttp(): diff --git a/google/genai/tests/client/test_client_initialization.py b/google/genai/tests/client/test_client_initialization.py index 7b0136044..9daf50d38 100644 --- a/google/genai/tests/client/test_client_initialization.py +++ b/google/genai/tests/client/test_client_initialization.py @@ -47,6 +47,28 @@ ) +class FakeCredentials(credentials.Credentials): + def __init__(self, token="fake_token", expired=False, quota_project_id=None): + super().__init__() + self.token = token + self._expired = expired + self._quota_project_id = quota_project_id + self.refresh_count = 0 + + @property + def expired(self): + return self._expired + + @property + def quota_project_id(self): + return self._quota_project_id + + def refresh(self, request): + self.refresh_count += 1 + self.token = "refreshed_token" + self._expired = False + + @pytest.fixture(autouse=True) def reset_has_aiohttp(): yield @@ -1721,3 +1743,142 @@ async def test_get_aiohttp_session(): assert initial_session is not None session = await client._api_client._get_aiohttp_session() assert session is initial_session + + +def test_missing_api_key_and_credentials(monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "") + with pytest.raises(ValueError, match="No API key was provided"): + Client() + + +auth_precedence_test_cases = [ + # client_args, env_vars, expected_headers + ( + {"credentials": FakeCredentials()}, + {"GOOGLE_API_KEY": "env_api_key"}, + {"Authorization": "Bearer fake_token"} + ), + ( + {"credentials": FakeCredentials(quota_project_id="quota-proj")}, + {"GOOGLE_API_KEY": "env_api_key"}, + { + "Authorization": "Bearer fake_token", + "x-goog-user-project": "quota-proj" + } + ), + ( + {"api_key": "test_api_key"}, + {"GOOGLE_API_KEY": "env_api_key"}, + {"x-goog-api-key": "test_api_key"} + ), + ( + {}, + {"GOOGLE_API_KEY": "env_api_key"}, + {"x-goog-api-key": "env_api_key"} + ), +] + + +@pytest.mark.parametrize( + ["client_kwargs", "env_vars", "expected_headers"], + auth_precedence_test_cases, +) +@mock.patch.object(httpx.Client, "send", autospec=True) +def test_auth_precedence_mldev(mock_send, monkeypatch, client_kwargs, env_vars, expected_headers): + for key, value in env_vars.items(): + monkeypatch.setenv(key, value) + + client = Client(**client_kwargs) + mock_send.return_value = httpx.Response( + status_code=200, + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]} + ) + client.models.generate_content(model="test", contents="hello?") + mock_send.assert_called_once() + request = mock_send.call_args[0][1] + + for key, value in expected_headers.items(): + assert key in request.headers + assert request.headers[key] == value + + if "Authorization" in expected_headers: + assert "x-goog-api-key" not in request.headers + if "x-goog-api-key" in expected_headers: + assert "Authorization" not in request.headers + if "x-goog-user-project" not in expected_headers: + assert "x-goog-user-project" not in request.headers + assert request.content == b'{"contents": [{"parts": [{"text": "hello?"}], "role": "user"}]}' + +@pytest.mark.parametrize( + ["client_kwargs", "env_vars", "expected_headers"], + auth_precedence_test_cases, +) +@pytest.mark.asyncio +@mock.patch.object(httpx.AsyncClient, "send", autospec=True) +async def test_async_auth_precedence_mldev(mock_send, monkeypatch, client_kwargs, env_vars, expected_headers): + for key, value in env_vars.items(): + monkeypatch.setenv(key, value) + + client = Client(**client_kwargs) + mock_send.return_value = httpx.Response( + status_code=200, + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]} + ) + await client.aio.models.generate_content(model="test", contents="hello?") + mock_send.assert_called_once() + request = mock_send.call_args[0][1] + + for key, value in expected_headers.items(): + assert key in request.headers + assert request.headers[key] == value + + if "Authorization" in expected_headers: + assert "x-goog-api-key" not in request.headers + if "x-goog-api-key" in expected_headers: + assert "Authorization" not in request.headers + if "x-goog-user-project" not in expected_headers: + assert "x-goog-user-project" not in request.headers + assert request.content == b'{"contents": [{"parts": [{"text": "hello?"}], "role": "user"}]}' + + +@pytest.mark.asyncio +async def test_both_credentials_mldev(): + with pytest.raises(ValueError, match="mutually exclusive"): + creds = FakeCredentials(expired=True) + client = Client(credentials=creds, api_key="test-api-key") + + +@mock.patch.object(httpx.Client, "send", autospec=True) +def test_refresh_credentials_mldev(mock_send): + creds = FakeCredentials(expired=True) + client = Client(credentials=creds) + mock_send.return_value = httpx.Response( + status_code=200, + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]}, + ) + client.models.generate_content(model="test", contents="hello?") + mock_send.assert_called_once() + request = mock_send.call_args[0][1] + assert "Authorization" in request.headers + assert request.headers["Authorization"] == "Bearer refreshed_token" + assert "x-goog-api-key" not in request.headers + assert creds.refresh_count == 1 + + +@requires_aiohttp +@pytest.mark.asyncio +@mock.patch.object(httpx.AsyncClient, "send", autospec=True) +async def test_async_refresh_credentials_mldev(mock_send): + creds = FakeCredentials(expired=True) + client = Client(credentials=creds) + mock_send.return_value = httpx.Response( + status_code=200, + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]}, + ) + await client.aio.models.generate_content(model="test", contents="hello?") + mock_send.assert_called_once() + request = mock_send.call_args[0][1] + assert "Authorization" in request.headers + assert request.headers["Authorization"] == "Bearer refreshed_token" + assert "x-goog-api-key" not in request.headers + assert creds.refresh_count == 1