Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,21 @@ def __init__(
)
self._http_options.api_version = 'v1beta1'
else: # Implicit initialization or missing arguments.
if not self.api_key:
if env_api_key and api_key:
# Explicit api_key takes precedence over implicit api_key.
logger.info(
'The client initializer api_key argument takes '
'precedence over the API key from the environment variable.'
)
if credentials:
if env_api_key:
logger.info(
'The user `credentials` argument will take precedence over the'
' api key from the environment variables.'
)
self.api_key = None

if not self.api_key and not credentials:
raise ValueError(
'No API key was provided. Please pass a valid API key. Learn how to'
' create an API key at'
Expand Down Expand Up @@ -1175,20 +1189,21 @@ def _request_once(
stream: bool = False,
) -> HttpResponse:
data: Optional[Union[str, bytes]] = None
# If using proj/location, fetch ADC
if self.vertexai and (self.project or self.location):

uses_vertex_creds = self.vertexai and (self.project or self.location)
uses_mldev_creds = not self.vertexai and self._credentials
if (uses_vertex_creds or uses_mldev_creds):
http_request.headers['Authorization'] = f'Bearer {self._access_token()}'
if self._credentials and self._credentials.quota_project_id:
http_request.headers['x-goog-user-project'] = (
self._credentials.quota_project_id
)
data = json.dumps(http_request.data) if http_request.data else None
else:
if http_request.data:
if not isinstance(http_request.data, bytes):
data = json.dumps(http_request.data) if http_request.data else None
else:
data = http_request.data

if http_request.data:
if not isinstance(http_request.data, bytes):
data = json.dumps(http_request.data)
else:
data = http_request.data

if stream:
httpx_request = self._httpx_client.build_request(
Expand Down Expand Up @@ -1241,22 +1256,22 @@ async def _async_request_once(
) -> HttpResponse:
data: Optional[Union[str, bytes]] = None

# If using proj/location, fetch ADC
if self.vertexai and (self.project or self.location):
uses_vertex_creds = self.vertexai and (self.project or self.location)
uses_mldev_creds = not self.vertexai and self._credentials
if (uses_vertex_creds or uses_mldev_creds):
http_request.headers['Authorization'] = (
f'Bearer {await self._async_access_token()}'
)
if self._credentials and self._credentials.quota_project_id:
http_request.headers['x-goog-user-project'] = (
self._credentials.quota_project_id
)
data = json.dumps(http_request.data) if http_request.data else None
else:
if http_request.data:
if not isinstance(http_request.data, bytes):
data = json.dumps(http_request.data) if http_request.data else None
else:
data = http_request.data

if http_request.data:
if not isinstance(http_request.data, bytes):
data = json.dumps(http_request.data)
else:
data = http_request.data

if stream:
if self._use_aiohttp():
Expand Down
21 changes: 20 additions & 1 deletion google/genai/_extra_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@
"""Extra utils depending on types that are shared between sync and async modules."""

import asyncio
from collections.abc import Callable, MutableMapping
import inspect
import io
import logging
import sys
import typing
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
from typing import Any, Optional, Union, get_args, get_origin
import mimetypes
import os
import pydantic

import google.auth.transport.requests


from . import _common
from . import _mcp_utils
from . import _transformers as t
Expand Down Expand Up @@ -677,3 +681,18 @@ def prepare_resumable_upload(
http_options.headers = {}
http_options.headers['X-Goog-Upload-File-Name'] = os.path.basename(file)
return http_options, size_bytes, mime_type


async def _maybe_update_and_insert_auth_token(
headers:MutableMapping[str, str],
creds: google.auth.credentials.Credentials) -> None:
# Refresh credentials to ensure token is valid
if not (creds.token and creds.valid):
try:
auth_req = google.auth.transport.requests.Request() # type: ignore[no-untyped-call]
await asyncio.to_thread(creds.refresh, auth_req)
except Exception as e:
raise ConnectionError(f"Failed to refresh credentials") from e

if not headers.get('Authorization'):
headers['Authorization'] = f'Bearer {creds.token}'
4 changes: 2 additions & 2 deletions google/genai/_interactions/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:

@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
if not self.client_adapter or not self.client_adapter.is_vertex_ai():
if not self.client_adapter:
return options

headers = options.headers or {}
Expand Down Expand Up @@ -400,7 +400,7 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
if not self.client_adapter or not self.client_adapter.is_vertex_ai():
if not self.client_adapter:
return options

headers = options.headers or {}
Expand Down
18 changes: 7 additions & 11 deletions google/genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,12 @@ def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient:
# uSDk expects ms, nextgen uses a httpx Timeout -> expects seconds.
timeout=http_opts.timeout / 1000 if http_opts.timeout else None,
max_retries=max_retries,
client_adapter=AsyncGeminiNextGenAPIClientAdapter(self._api_client)
client_adapter=AsyncGeminiNextGenAPIClientAdapter(self._api_client),
)

client = self._nextgen_client_instance
if self._api_client.vertexai:
client._is_vertex = True
client._vertex_project = self._api_client.project
client._vertex_location = self._api_client.location
self._nextgen_client_instance._is_vertex = self._api_client.vertexai or False
self._nextgen_client_instance._vertex_project = self._api_client.project
self._nextgen_client_instance._vertex_location = self._api_client.location

return self._nextgen_client_instance

Expand Down Expand Up @@ -525,11 +523,9 @@ def _nextgen_client(self) -> GeminiNextGenAPIClient:
client_adapter=GeminiNextGenAPIClientAdapter(self._api_client),
)

client = self._nextgen_client_instance
if self._api_client.vertexai:
client._is_vertex = True
client._vertex_project = self._api_client.project
client._vertex_location = self._api_client.location
self._nextgen_client_instance._is_vertex = self._api_client.vertexai or False
self._nextgen_client_instance._vertex_project = self._api_client.project
self._nextgen_client_instance._vertex_location = self._api_client.location

return self._nextgen_client_instance

Expand Down
Loading