Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,21 @@ def __init__(
)
self._http_options.api_version = 'v1beta1'
else: # Implicit initialization or missing arguments.
if not self.api_key:
if env_api_key and api_key:
# Explicit api_key takes precedence over implicit api_key.
logger.info(
'The client initializer api_key argument takes '
'precedence over the API key from the environment variable.'
)
if credentials:
if env_api_key:
logger.info(
'The user `credentials` argument will take precedence over the'
' api key from the environment variables.'
)
self.api_key = None

if not self.api_key and not credentials:
raise ValueError(
'No API key was provided. Please pass a valid API key. Learn how to'
' create an API key at'
Expand Down Expand Up @@ -1175,20 +1189,21 @@ def _request_once(
stream: bool = False,
) -> HttpResponse:
data: Optional[Union[str, bytes]] = None
# If using proj/location, fetch ADC
if self.vertexai and (self.project or self.location):

uses_vertex_creds = self.vertexai and (self.project or self.location)
uses_mldev_creds = not self.vertexai and self._credentials
if (uses_vertex_creds or uses_mldev_creds):
http_request.headers['Authorization'] = f'Bearer {self._access_token()}'
if self._credentials and self._credentials.quota_project_id:
http_request.headers['x-goog-user-project'] = (
self._credentials.quota_project_id
)
data = json.dumps(http_request.data) if http_request.data else None
else:
if http_request.data:
if not isinstance(http_request.data, bytes):
data = json.dumps(http_request.data) if http_request.data else None
else:
data = http_request.data

if http_request.data:
if not isinstance(http_request.data, bytes):
data = json.dumps(http_request.data)
else:
data = http_request.data

if stream:
httpx_request = self._httpx_client.build_request(
Expand Down Expand Up @@ -1241,22 +1256,22 @@ async def _async_request_once(
) -> HttpResponse:
data: Optional[Union[str, bytes]] = None

# If using proj/location, fetch ADC
if self.vertexai and (self.project or self.location):
uses_vertex_creds = self.vertexai and (self.project or self.location)
uses_mldev_creds = not self.vertexai and self._credentials
if (uses_vertex_creds or uses_mldev_creds):
http_request.headers['Authorization'] = (
f'Bearer {await self._async_access_token()}'
)
if self._credentials and self._credentials.quota_project_id:
http_request.headers['x-goog-user-project'] = (
self._credentials.quota_project_id
)
data = json.dumps(http_request.data) if http_request.data else None
else:
if http_request.data:
if not isinstance(http_request.data, bytes):
data = json.dumps(http_request.data) if http_request.data else None
else:
data = http_request.data

if http_request.data:
if not isinstance(http_request.data, bytes):
data = json.dumps(http_request.data)
else:
data = http_request.data

if stream:
if self._use_aiohttp():
Expand Down
161 changes: 161 additions & 0 deletions google/genai/tests/client/test_client_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,28 @@
)


class FakeCredentials(credentials.Credentials):
def __init__(self, token="fake_token", expired=False, quota_project_id=None):
super().__init__()
self.token = token
self._expired = expired
self._quota_project_id = quota_project_id
self.refresh_count = 0

@property
def expired(self):
return self._expired

@property
def quota_project_id(self):
return self._quota_project_id

def refresh(self, request):
self.refresh_count += 1
self.token = "refreshed_token"
self._expired = False


@pytest.fixture(autouse=True)
def reset_has_aiohttp():
yield
Expand Down Expand Up @@ -1721,3 +1743,142 @@ async def test_get_aiohttp_session():
assert initial_session is not None
session = await client._api_client._get_aiohttp_session()
assert session is initial_session


def test_missing_api_key_and_credentials(monkeypatch):
monkeypatch.setenv("GOOGLE_API_KEY", "")
with pytest.raises(ValueError, match="No API key was provided"):
Client()


auth_precedence_test_cases = [
# client_args, env_vars, expected_headers
(
{"credentials": FakeCredentials()},
{"GOOGLE_API_KEY": "env_api_key"},
{"Authorization": "Bearer fake_token"}
),
(
{"credentials": FakeCredentials(quota_project_id="quota-proj")},
{"GOOGLE_API_KEY": "env_api_key"},
{
"Authorization": "Bearer fake_token",
"x-goog-user-project": "quota-proj"
}
),
(
{"api_key": "test_api_key"},
{"GOOGLE_API_KEY": "env_api_key"},
{"x-goog-api-key": "test_api_key"}
),
(
{},
{"GOOGLE_API_KEY": "env_api_key"},
{"x-goog-api-key": "env_api_key"}
),
]


@pytest.mark.parametrize(
["client_kwargs", "env_vars", "expected_headers"],
auth_precedence_test_cases,
)
@mock.patch.object(httpx.Client, "send", autospec=True)
def test_auth_precedence_mldev(mock_send, monkeypatch, client_kwargs, env_vars, expected_headers):
for key, value in env_vars.items():
monkeypatch.setenv(key, value)

client = Client(**client_kwargs)
mock_send.return_value = httpx.Response(
status_code=200,
json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]}
)
client.models.generate_content(model="test", contents="hello?")
mock_send.assert_called_once()
request = mock_send.call_args[0][1]

for key, value in expected_headers.items():
assert key in request.headers
assert request.headers[key] == value

if "Authorization" in expected_headers:
assert "x-goog-api-key" not in request.headers
if "x-goog-api-key" in expected_headers:
assert "Authorization" not in request.headers
if "x-goog-user-project" not in expected_headers:
assert "x-goog-user-project" not in request.headers
assert request.content == b'{"contents": [{"parts": [{"text": "hello?"}], "role": "user"}]}'

@pytest.mark.parametrize(
["client_kwargs", "env_vars", "expected_headers"],
auth_precedence_test_cases,
)
@pytest.mark.asyncio
@mock.patch.object(httpx.AsyncClient, "send", autospec=True)
async def test_async_auth_precedence_mldev(mock_send, monkeypatch, client_kwargs, env_vars, expected_headers):
for key, value in env_vars.items():
monkeypatch.setenv(key, value)

client = Client(**client_kwargs)
mock_send.return_value = httpx.Response(
status_code=200,
json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]}
)
await client.aio.models.generate_content(model="test", contents="hello?")
mock_send.assert_called_once()
request = mock_send.call_args[0][1]

for key, value in expected_headers.items():
assert key in request.headers
assert request.headers[key] == value

if "Authorization" in expected_headers:
assert "x-goog-api-key" not in request.headers
if "x-goog-api-key" in expected_headers:
assert "Authorization" not in request.headers
if "x-goog-user-project" not in expected_headers:
assert "x-goog-user-project" not in request.headers
assert request.content == b'{"contents": [{"parts": [{"text": "hello?"}], "role": "user"}]}'


@pytest.mark.asyncio
async def test_both_credentials_mldev():
with pytest.raises(ValueError, match="mutually exclusive"):
creds = FakeCredentials(expired=True)
client = Client(credentials=creds, api_key="test-api-key")


@mock.patch.object(httpx.Client, "send", autospec=True)
def test_refresh_credentials_mldev(mock_send):
creds = FakeCredentials(expired=True)
client = Client(credentials=creds)
mock_send.return_value = httpx.Response(
status_code=200,
json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]},
)
client.models.generate_content(model="test", contents="hello?")
mock_send.assert_called_once()
request = mock_send.call_args[0][1]
assert "Authorization" in request.headers
assert request.headers["Authorization"] == "Bearer refreshed_token"
assert "x-goog-api-key" not in request.headers
assert creds.refresh_count == 1


@requires_aiohttp
@pytest.mark.asyncio
@mock.patch.object(httpx.AsyncClient, "send", autospec=True)
async def test_async_refresh_credentials_mldev(mock_send):
creds = FakeCredentials(expired=True)
client = Client(credentials=creds)
mock_send.return_value = httpx.Response(
status_code=200,
json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]},
)
await client.aio.models.generate_content(model="test", contents="hello?")
mock_send.assert_called_once()
request = mock_send.call_args[0][1]
assert "Authorization" in request.headers
assert request.headers["Authorization"] == "Bearer refreshed_token"
assert "x-goog-api-key" not in request.headers
assert creds.refresh_count == 1