Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/google/adk/auth/auth_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import model_validator


class BaseModelWithConfig(BaseModel):
Expand Down Expand Up @@ -150,6 +151,21 @@ class ServiceAccount(BaseModelWithConfig):
service_account_credential: Optional[ServiceAccountCredential] = None
scopes: List[str]
use_default_credential: Optional[bool] = False
token_kind: Literal["access_token", "id_token"] = "access_token"
audience: Optional[str] = None

@model_validator(mode="before")
@classmethod
def _validate_before(cls, data: Any) -> Any:
if isinstance(data, dict):
token_kind = data.get("token_kind", "access_token")
audience = data.get("audience")
if token_kind == "id_token" and not audience:
raise ValueError(
"service_account.audience is required when"
" service_account.token_kind='id_token'"
)
return data


class AuthCredentialTypes(str, Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import google.auth
from google.auth.transport.requests import Request
from google.oauth2 import id_token as google_id_token
from google.oauth2 import service_account
import google.oauth2.credentials

Expand Down Expand Up @@ -73,27 +74,55 @@ def exchange_credential(
)

try:
if auth_credential.service_account.use_default_credential:
credentials, project_id = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
quota_project_id = (
getattr(credentials, "quota_project_id", None) or project_id
)
config = auth_credential.service_account
token_kind = getattr(config, "token_kind", "access_token")
request = Request()

quota_project_id = None
token = None

if token_kind == "id_token":
audience = getattr(config, "audience", None)
if config.use_default_credential:
token = google_id_token.fetch_id_token(request, audience)
else:
if config.service_account_credential is None:
raise ValueError("service_account_credential is required when use_default_credential is False")

Comment on lines +89 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check for service_account_credential is redundant. A check at the beginning of the exchange_credential method (lines 62-69) already ensures service_account_credential is provided when use_default_credential is false. This check also raises a more specific AuthCredentialMissingError. Please remove this redundant check and the extra newline.

id_creds = (
service_account.IDTokenCredentials.from_service_account_info(
config.service_account_credential.model_dump(),
target_audience=audience,
)
)
id_creds.refresh(request)
token = id_creds.token
else:
config = auth_credential.service_account
credentials = service_account.Credentials.from_service_account_info(
config.service_account_credential.model_dump(), scopes=config.scopes
)
quota_project_id = None

credentials.refresh(Request())
if config.use_default_credential:
credentials, project_id = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
quota_project_id = (
getattr(credentials, "quota_project_id", None) or project_id
)
else:
if config.service_account_credential is None:
raise ValueError("service_account_credential is required when use_default_credential is False")
Comment on lines +109 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the id_token path, this check for service_account_credential is redundant. The check at the beginning of the method already covers this case. Please remove this redundant check for consistency and to avoid confusion.


credentials = service_account.Credentials.from_service_account_info(
config.service_account_credential.model_dump(),
scopes=config.scopes,
)
quota_project_id = None

credentials.refresh(Request())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency and potential performance benefits, it's better to reuse the request object that was instantiated on line 79. The google-auth library is designed to allow reusing transport sessions. Please change this to credentials.refresh(request).

Suggested change
credentials.refresh(Request())
credentials.refresh(request)

token = credentials.token

updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token=credentials.token),
credentials=HttpCredentials(token=token),
additional_headers={
"x-goog-user-project": quota_project_id,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
import google.auth
import google.oauth2.id_token
import pytest


Expand Down Expand Up @@ -218,3 +219,123 @@ def test_exchange_credential_exchange_failure(
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
assert "Failed to exchange service account token" in str(exc_info.value)
mock_from_service_account_info.assert_called_once()


def test_exchange_credential_use_default_credential_id_token_success(
service_account_exchanger, auth_scheme, monkeypatch
):
"""Test successful exchange using ADC with an ID token (OIDC) for a target audience."""
mock_google_auth_default = MagicMock()
monkeypatch.setattr(google.auth, "default", mock_google_auth_default)

mock_fetch_id_token = MagicMock(return_value="mock_id_token")
monkeypatch.setattr(
google.oauth2.id_token,
"fetch_id_token",
mock_fetch_id_token,
)
Comment on lines +232 to +236
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This monkeypatch.setattr call is redundant. The subsequent patch on lines 237-240 targets the fetch_id_token function where it's used within the service_account_exchanger module, which is the correct and sufficient way to patch it for this test. You can remove this patch.

monkeypatch.setattr(
"google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger.google_id_token.fetch_id_token",
mock_fetch_id_token,
)

auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
use_default_credential=True,
scopes=[
"https://www.googleapis.com/auth/cloud-platform"
], # unused in id_token mode, but required by model today
token_kind="id_token",
audience="https://my-service-abc.a.run.app",
),
)

result = service_account_exchanger.exchange_credential(
auth_scheme, auth_credential
)

assert result.auth_type == AuthCredentialTypes.HTTP
assert result.http.scheme == "bearer"
assert result.http.credentials.token == "mock_id_token"
assert not result.http.additional_headers

mock_fetch_id_token.assert_called_once()
# Can we test this?
# mock_fetch_id_token.assert_called_once_with(ANY_REQUEST_OBJECT, "https://my-service-abc.a.run.app")
mock_google_auth_default.assert_not_called()


def test_exchange_credential_service_account_id_token_success(
service_account_exchanger, auth_scheme, monkeypatch
):
"""Test successful exchange using SA JSON key with an ID token (OIDC) for a target audience."""
mock_id_creds = MagicMock()
mock_id_creds.token = "mock_id_token"
mock_id_creds.refresh = MagicMock()

mock_from_info = MagicMock(return_value=mock_id_creds)

# Patch IDTokenCredentials factory (NOT Credentials.from_service_account_info)
target_path = (
"google.adk.tools.openapi_tool.auth.credential_exchangers."
"service_account_exchanger.service_account.IDTokenCredentials."
"from_service_account_info"
)
monkeypatch.setattr(target_path, mock_from_info)

auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
service_account_credential=ServiceAccountCredential(
type_="service_account",
project_id="your_project_id",
private_key_id="your_private_key_id",
private_key="-----BEGIN PRIVATE KEY-----...",
client_email="...@....iam.gserviceaccount.com",
client_id="your_client_id",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url=(
"https://www.googleapis.com/oauth2/v1/certs"
),
client_x509_cert_url=(
"https://www.googleapis.com/robot/v1/metadata/x509/..."
),
universe_domain="googleapis.com",
),
scopes=[
"https://www.googleapis.com/auth/cloud-platform"
], # unused in id_token mode but required today
token_kind="id_token",
audience="https://my-service-abc.a.run.app",
),
)

result = service_account_exchanger.exchange_credential(
auth_scheme, auth_credential
)

assert result.auth_type == AuthCredentialTypes.HTTP
assert result.http.scheme == "bearer"
assert result.http.credentials.token == "mock_id_token"
assert not result.http.additional_headers

# Verify we used the IDTokenCredentials path with the correct target_audience
mock_from_info.assert_called_once()
_, kwargs = mock_from_info.call_args
assert kwargs["target_audience"] == "https://my-service-abc.a.run.app"

mock_id_creds.refresh.assert_called_once()


def test_service_account_id_token_requires_audience():
"""ServiceAccount validation: id_token requires audience."""
with pytest.raises(ValueError) as exc_info:
ServiceAccount(
use_default_credential=True,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
token_kind="id_token",
audience=None,
)
assert "audience" in str(exc_info.value)
Loading