From 5b3872897a68d00f536b94bb11bb580c45e8cd55 Mon Sep 17 00:00:00 2001 From: Ayush Agrawal Date: Wed, 11 Mar 2026 14:26:08 -0700 Subject: [PATCH] feat: Allow custom endpoints for authentication with Vertex AI in Java PiperOrigin-RevId: 882197861 --- src/main/java/com/google/genai/ApiClient.java | 124 +++++++++++++++--- src/main/java/com/google/genai/AsyncLive.java | 44 +++++-- .../com/google/genai/types/HttpOptions.java | 45 +++++++ .../com/google/genai/types/ResourceScope.java | 112 ++++++++++++++++ .../java/com/google/genai/AsyncLiveTest.java | 121 +++++++++++++++++ .../java/com/google/genai/ClientTest.java | 32 +++++ .../com/google/genai/HttpApiClientTest.java | 94 ++++++++++++- 7 files changed, 534 insertions(+), 38 deletions(-) create mode 100644 src/main/java/com/google/genai/types/ResourceScope.java create mode 100644 src/test/java/com/google/genai/AsyncLiveTest.java diff --git a/src/main/java/com/google/genai/ApiClient.java b/src/main/java/com/google/genai/ApiClient.java index c5ace77458d..29f482a38dc 100644 --- a/src/main/java/com/google/genai/ApiClient.java +++ b/src/main/java/com/google/genai/ApiClient.java @@ -32,6 +32,7 @@ import com.google.genai.types.HttpRetryOptions; import com.google.genai.types.ProxyOptions; import com.google.genai.types.ProxyType; +import com.google.genai.types.ResourceScope.Known; import java.io.IOException; import java.net.InetSocketAddress; import java.net.Proxy; @@ -72,6 +73,7 @@ public abstract class ApiClient implements AutoCloseable { HttpOptions httpOptions; final boolean vertexAI; final Optional clientOptions; + final Optional customBaseUrl; // For Google AI APIs final Optional apiKey; // For Vertex AI APIs @@ -102,8 +104,12 @@ protected ApiClient( this.credentials = Optional.empty(); this.vertexAI = false; this.clientOptions = clientOptions; + this.customBaseUrl = + customHttpOptions.flatMap(HttpOptions::baseUrl).map(url -> url.replaceAll("/$", "")); - this.httpOptions = defaultHttpOptions(/* vertexAI= */ false, this.location, this.apiKey); + this.httpOptions = + defaultHttpOptions( + /* vertexAI= */ false, this.location, this.apiKey, this.customBaseUrl, Optional.empty()); if (customHttpOptions.isPresent()) { this.httpOptions = mergeHttpOptions(customHttpOptions.get()); @@ -149,6 +155,9 @@ protected ApiClient( boolean hasProject = project != null && project.isPresent(); boolean hasLocation = location != null && location.isPresent(); + Optional customBaseUrl = + customHttpOptions.flatMap(HttpOptions::baseUrl).map(url -> url.replaceAll("/$", "")); + // Validate constructor arguments combinations. if (hasProject && hasApiKey) { throw new IllegalArgumentException( @@ -197,19 +206,41 @@ protected ApiClient( apiKeyValue = null; } - if (locationValue == null && apiKeyValue == null) { + if (locationValue == null && apiKeyValue == null && !customBaseUrl.isPresent()) { + locationValue = "global"; + } else if (locationValue == null + && apiKeyValue == null + && customBaseUrl.isPresent() + && customBaseUrl.get().endsWith(".googleapis.com")) { locationValue = "global"; } + boolean hasSufficientAuth = + (projectValue != null && locationValue != null) || apiKeyValue != null; + if (!hasSufficientAuth && !customBaseUrl.isPresent()) { + throw new IllegalArgumentException( + "Authentication is not set up. Please provide either a project and location, or an API" + + " key, or a custom base URL."); + } + + boolean hasConstructorAuth = (hasProject && hasLocation) || hasApiKey; + HttpOptions.Builder initHttpOptionsBuilder = HttpOptions.builder(); + if (customBaseUrl.isPresent() && !hasConstructorAuth) { + initHttpOptionsBuilder.baseUrl(customBaseUrl.get()); + projectValue = null; + locationValue = null; + } else if (apiKeyValue != null + || (locationValue != null && locationValue.equals("global") && !customBaseUrl.isPresent())) { + initHttpOptionsBuilder.baseUrl("https://aiplatform.googleapis.com"); + } else if (locationValue != null && !customBaseUrl.isPresent()) { + initHttpOptionsBuilder.baseUrl( + String.format("https://%s-aiplatform.googleapis.com", locationValue)); + } + this.apiKey = Optional.ofNullable(apiKeyValue); this.project = Optional.ofNullable(projectValue); this.location = Optional.ofNullable(locationValue); - - // Validate that either project and location or API key is set. - if (!(this.project.isPresent() || this.apiKey.isPresent())) { - throw new IllegalArgumentException( - "For Vertex AI APIs, either project or API key must be set."); - } + this.customBaseUrl = customBaseUrl; // Only set credentials if using project/location. this.credentials = @@ -219,7 +250,13 @@ protected ApiClient( this.clientOptions = clientOptions; - this.httpOptions = defaultHttpOptions(/* vertexAI= */ true, this.location, this.apiKey); + this.httpOptions = + defaultHttpOptions( + /* vertexAI= */ true, + this.location, + this.apiKey, + this.customBaseUrl, + initHttpOptionsBuilder.build().baseUrl()); if (customHttpOptions.isPresent()) { this.httpOptions = mergeHttpOptions(customHttpOptions.get()); @@ -324,19 +361,17 @@ protected Request buildRequest( String requestJson, Optional requestHttpOptions) { String capitalizedHttpMethod = Ascii.toUpperCase(httpMethod); - boolean queryBaseModel = - capitalizedHttpMethod.equals("GET") && path.startsWith("publishers/google/models"); - if (this.vertexAI() - && !this.apiKey.isPresent() - && !path.startsWith("projects/") - && !queryBaseModel) { + HttpOptions mergedHttpOptions = mergeHttpOptions(requestHttpOptions.orElse(null)); + + boolean prependProjectLocation = + shouldPrependVertexProjectPath(capitalizedHttpMethod, path, mergedHttpOptions); + + if (prependProjectLocation) { path = String.format("projects/%s/locations/%s/", this.project.get(), this.location.get()) + path; } - HttpOptions mergedHttpOptions = mergeHttpOptions(requestHttpOptions.orElse(null)); - String requestUrl; String baseUrl = @@ -408,6 +443,7 @@ protected Request buildRequest( byte[] requestBytes, Optional requestHttpOptions) { HttpOptions mergedHttpOptions = mergeHttpOptions(requestHttpOptions.orElse(null)); + if (httpMethod.equalsIgnoreCase("POST")) { RequestBody body = RequestBody.create(requestBytes, MediaType.get("application/octet-stream")); @@ -437,9 +473,8 @@ private void setHeaders(Request.Builder request, HttpOptions requestHttpOptions) if (apiKey.isPresent()) { // Sets API key for Gemini Developer API or Vertex AI Express mode request.header("x-goog-api-key", apiKey.get()); - } else { - GoogleCredentials cred = - credentials.orElseThrow(() -> new IllegalStateException("credentials is required")); + } else if (credentials.isPresent()) { + GoogleCredentials cred = credentials.get(); try { cred.refreshIfExpired(); } catch (IOException e) { @@ -451,6 +486,8 @@ private void setHeaders(Request.Builder request, HttpOptions requestHttpOptions) if (cred.getQuotaProjectId() != null) { request.header("x-goog-user-project", cred.getQuotaProjectId()); } + } else if (!customBaseUrl.isPresent()) { + throw new IllegalStateException("credentials is required"); } } @@ -504,11 +541,21 @@ public boolean vertexAI() { return apiKey.orElse(null); } + /** Returns the custom base URL if provided. */ + public @Nullable String customBaseUrl() { + return customBaseUrl.orElse(null); + } + /** Returns the HttpClient for API calls. */ public OkHttpClient httpClient() { return httpClient; } + /** Returns the GoogleCredentials for Vertex AI APIs. */ + public @Nullable GoogleCredentials credentials() { + return credentials.orElse(null); + } + /** Returns the HTTP options for API calls. */ public HttpOptions httpOptions() { return httpOptions; @@ -577,6 +624,7 @@ HttpOptions mergeHttpOptions(HttpOptions httpOptionsToApply) { HttpOptions.Builder mergedHttpOptionsBuilder = this.httpOptions.toBuilder(); httpOptionsToApply.baseUrl().ifPresent(mergedHttpOptionsBuilder::baseUrl); + httpOptionsToApply.baseUrlResourceScope().ifPresent(mergedHttpOptionsBuilder::baseUrlResourceScope); httpOptionsToApply.apiVersion().ifPresent(mergedHttpOptionsBuilder::apiVersion); httpOptionsToApply.timeout().ifPresent(mergedHttpOptionsBuilder::timeout); httpOptionsToApply.extraBody().ifPresent(mergedHttpOptionsBuilder::extraBody); @@ -602,7 +650,11 @@ HttpOptions mergeHttpOptions(HttpOptions httpOptionsToApply) { } static HttpOptions defaultHttpOptions( - boolean vertexAI, Optional location, Optional apiKey) { + boolean vertexAI, + Optional location, + Optional apiKey, + Optional customBaseUrl, + Optional initBaseUrl) { ImmutableMap.Builder defaultHeaders = ImmutableMap.builder(); defaultHeaders.put("Content-Type", "application/json"); defaultHeaders.put("user-agent", libraryVersion()); @@ -618,6 +670,12 @@ static HttpOptions defaultHttpOptions( vertexBaseUrl.orElseGet(() -> defaultEnvironmentVariables.get("vertexBaseUrl")); if (defaultBaseUrl != null) { defaultHttpOptionsBuilder.baseUrl(defaultBaseUrl); + } else if (initBaseUrl.isPresent()) { + defaultHttpOptionsBuilder.baseUrl(initBaseUrl.get()); + } else if (customBaseUrl.isPresent() + && !customBaseUrl.get().endsWith(".googleapis.com") + && !(location.isPresent() || apiKey.isPresent())) { + defaultHttpOptionsBuilder.baseUrl(customBaseUrl.get()); } else if (apiKey.isPresent() || location.get().equalsIgnoreCase("global")) { defaultHttpOptionsBuilder.baseUrl("https://aiplatform.googleapis.com"); } else { @@ -638,6 +696,30 @@ static HttpOptions defaultHttpOptions( return defaultHttpOptionsBuilder.build(); } + boolean shouldPrependVertexProjectPath( + String httpMethod, String path, HttpOptions httpOptions) { + if (httpOptions.baseUrlResourceScope().isPresent() + && httpOptions.baseUrl().isPresent() + && httpOptions.baseUrlResourceScope().get().knownEnum() == Known.COLLECTION) { + return false; + } + if (this.apiKey.isPresent()) { + return false; + } + if (!this.vertexAI) { + return false; + } + if (path.startsWith("projects/")) { + return false; + } + // These paths are used by Vertex's models.get and models.list calls. For base models Vertex + // does not accept a project/location prefix (for tuned model the prefix is required). + if (httpMethod.equals("GET") && path.startsWith("publishers/google/models")) { + return false; + } + return true; + } + GoogleCredentials defaultCredentials() { try { return GoogleCredentials.getApplicationDefault() diff --git a/src/main/java/com/google/genai/AsyncLive.java b/src/main/java/com/google/genai/AsyncLive.java index 1681b39f7e6..0c1b577d297 100644 --- a/src/main/java/com/google/genai/AsyncLive.java +++ b/src/main/java/com/google/genai/AsyncLive.java @@ -87,7 +87,7 @@ public CompletableFuture connect(String model, LiveConnectConfig c /** Gets the URI for the websocket connection. */ private URI getWebSocketUri() { - String baseUrl = apiClient.httpOptions.baseUrl().orElse(null); + String baseUrl = apiClient.httpOptions().baseUrl().orElse(null); if (baseUrl == null) { throw new IllegalArgumentException("No base URL provided in the client."); } @@ -102,13 +102,20 @@ private URI getWebSocketUri() { baseUri.getFragment()) .toString(); + boolean hasStandardAuth = + (apiClient.project() != null && apiClient.location() != null) + || apiClient.apiKey() != null; + if (apiClient.customBaseUrl() != null && !hasStandardAuth) { + return new URI(wsBaseUrl); + } + if (!apiClient.vertexAI()) { String method; - if (apiClient.apiKey().startsWith("auth_tokens/")) { + if (apiClient.apiKey() != null && apiClient.apiKey().startsWith("auth_tokens/")) { logger.warning( "Warning: Ephemeral token support is experimental and may change in future" + " versions."); - if (!apiClient.httpOptions.apiVersion().orElse("v1beta").equals("v1alpha")) { + if (!apiClient.httpOptions().apiVersion().orElse("v1beta").equals("v1alpha")) { logger.warning( "Warning: The SDK's ephemeral token support is in v1alpha only. Please use client" + " = Client.builder().httpOptions(HttpOptions.builder().apiVersion(\"v1alpha\").build()).build()" @@ -121,12 +128,12 @@ private URI getWebSocketUri() { return new URI( String.format( "%s/ws/google.ai.generativelanguage.%s.GenerativeService.%s", - wsBaseUrl, apiClient.httpOptions.apiVersion().orElse("v1beta"), method)); + wsBaseUrl, apiClient.httpOptions().apiVersion().orElse("v1beta"), method)); } else { return new URI( String.format( "%s/ws/google.cloud.aiplatform.%s.LlmBidiService/BidiGenerateContent", - wsBaseUrl, apiClient.httpOptions.apiVersion().orElse("v1beta1"))); + wsBaseUrl, apiClient.httpOptions().apiVersion().orElse("v1beta1"))); } } catch (URISyntaxException e) { throw new IllegalStateException("Failed to parse URL.", e); @@ -136,16 +143,22 @@ private URI getWebSocketUri() { /** Gets the headers for the websocket connection. */ private Map getWebSocketHeaders() { Map headers = new HashMap<>(); - apiClient.httpOptions.headers().ifPresent(headers::putAll); + apiClient.httpOptions().headers().ifPresent(headers::putAll); if (apiClient.vertexAI()) { - try { - GoogleCredentials credentials = - apiClient.credentials.orElseGet(() -> apiClient.defaultCredentials()); - credentials.refreshIfExpired(); - headers.put("Authorization", "Bearer " + credentials.getAccessToken().getTokenValue()); - } catch (IOException e) { - throw new GenAiIOException("Failed to refresh credentials for Vertex AI.", e); + if (apiClient.credentials() != null) { + try { + GoogleCredentials credentials = apiClient.credentials(); + credentials.refreshIfExpired(); + headers.put("Authorization", "Bearer " + credentials.getAccessToken().getTokenValue()); + if (credentials.getQuotaProjectId() != null) { + headers.put("x-goog-user-project", credentials.getQuotaProjectId()); + } + } catch (IOException e) { + throw new GenAiIOException("Failed to refresh credentials for Vertex AI.", e); + } + } else if (apiClient.apiKey() != null) { + headers.put("x-goog-api-key", apiClient.apiKey()); } } else { @Nullable String apiKey = apiClient.apiKey(); @@ -165,7 +178,10 @@ private String getSetupRequest(String model, LiveConnectConfig config) { String transformedModel = Transformers.tModel(apiClient, model); // Vertex requires the full resource path for the model. - if (apiClient.vertexAI() && transformedModel.startsWith("publishers/")) { + if (apiClient.vertexAI() + && transformedModel.startsWith("publishers/") + && apiClient.project() != null + && apiClient.location() != null) { model = String.format( "projects/%s/locations/%s/%s", diff --git a/src/main/java/com/google/genai/types/HttpOptions.java b/src/main/java/com/google/genai/types/HttpOptions.java index c9dad40d4aa..91e6c0266b8 100644 --- a/src/main/java/com/google/genai/types/HttpOptions.java +++ b/src/main/java/com/google/genai/types/HttpOptions.java @@ -35,6 +35,10 @@ public abstract class HttpOptions extends JsonSerializable { @JsonProperty("baseUrl") public abstract Optional baseUrl(); + /** The resource scope used to constructing the resource name when base_url is set */ + @JsonProperty("baseUrlResourceScope") + public abstract Optional baseUrlResourceScope(); + /** Specifies the version of the API to use. */ @JsonProperty("apiVersion") public abstract Optional apiVersion(); @@ -95,6 +99,47 @@ public Builder clearBaseUrl() { return baseUrl(Optional.empty()); } + /** + * Setter for baseUrlResourceScope. + * + *

baseUrlResourceScope: The resource scope used to constructing the resource name when + * base_url is set + */ + @JsonProperty("baseUrlResourceScope") + public abstract Builder baseUrlResourceScope(ResourceScope baseUrlResourceScope); + + @ExcludeFromGeneratedCoverageReport + abstract Builder baseUrlResourceScope(Optional baseUrlResourceScope); + + /** Clears the value of baseUrlResourceScope field. */ + @ExcludeFromGeneratedCoverageReport + @CanIgnoreReturnValue + public Builder clearBaseUrlResourceScope() { + return baseUrlResourceScope(Optional.empty()); + } + + /** + * Setter for baseUrlResourceScope given a known enum. + * + *

baseUrlResourceScope: The resource scope used to constructing the resource name when + * base_url is set + */ + @CanIgnoreReturnValue + public Builder baseUrlResourceScope(ResourceScope.Known knownType) { + return baseUrlResourceScope(new ResourceScope(knownType)); + } + + /** + * Setter for baseUrlResourceScope given a string. + * + *

baseUrlResourceScope: The resource scope used to constructing the resource name when + * base_url is set + */ + @CanIgnoreReturnValue + public Builder baseUrlResourceScope(String baseUrlResourceScope) { + return baseUrlResourceScope(new ResourceScope(baseUrlResourceScope)); + } + /** * Setter for apiVersion. * diff --git a/src/main/java/com/google/genai/types/ResourceScope.java b/src/main/java/com/google/genai/types/ResourceScope.java new file mode 100644 index 00000000000..9e2f64bcb1e --- /dev/null +++ b/src/main/java/com/google/genai/types/ResourceScope.java @@ -0,0 +1,112 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Auto-generated code. Do not edit. + +package com.google.genai.types; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; +import com.google.common.base.Ascii; +import java.util.Objects; + +/** Resource scope. */ +public class ResourceScope { + + /** Enum representing the known values for ResourceScope. */ + public enum Known { + /** + * When setting base_url, this value configures resource scope to be the collection. The + * resource name will not include api version, project, or location. For example, if base_url is + * set to "https://aiplatform.googleapis.com", then the resource name for a Model would be + * "https://aiplatform.googleapis.com/publishers/google/models/gemini-3-pro-preview + */ + COLLECTION, + + RESOURCE_SCOPE_UNSPECIFIED + } + + private Known resourceScopeEnum; + private final String value; + + @JsonCreator + public ResourceScope(String value) { + this.value = value; + for (Known resourceScopeEnum : Known.values()) { + if (Ascii.equalsIgnoreCase(resourceScopeEnum.toString(), value)) { + this.resourceScopeEnum = resourceScopeEnum; + break; + } + } + if (this.resourceScopeEnum == null) { + this.resourceScopeEnum = Known.RESOURCE_SCOPE_UNSPECIFIED; + } + } + + public ResourceScope(Known knownValue) { + this.resourceScopeEnum = knownValue; + this.value = knownValue.toString(); + } + + @ExcludeFromGeneratedCoverageReport + @Override + @JsonValue + public String toString() { + return this.value; + } + + @ExcludeFromGeneratedCoverageReport + @SuppressWarnings("PatternMatchingInstanceof") + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null) { + return false; + } + + if (!(o instanceof ResourceScope)) { + return false; + } + + ResourceScope other = (ResourceScope) o; + + if (this.resourceScopeEnum != Known.RESOURCE_SCOPE_UNSPECIFIED + && other.resourceScopeEnum != Known.RESOURCE_SCOPE_UNSPECIFIED) { + return this.resourceScopeEnum == other.resourceScopeEnum; + } else if (this.resourceScopeEnum == Known.RESOURCE_SCOPE_UNSPECIFIED + && other.resourceScopeEnum == Known.RESOURCE_SCOPE_UNSPECIFIED) { + return this.value.equals(other.value); + } + return false; + } + + @ExcludeFromGeneratedCoverageReport + @Override + public int hashCode() { + if (this.resourceScopeEnum != Known.RESOURCE_SCOPE_UNSPECIFIED) { + return this.resourceScopeEnum.hashCode(); + } else { + return Objects.hashCode(this.value); + } + } + + @ExcludeFromGeneratedCoverageReport + public Known knownEnum() { + return this.resourceScopeEnum; + } +} diff --git a/src/test/java/com/google/genai/AsyncLiveTest.java b/src/test/java/com/google/genai/AsyncLiveTest.java new file mode 100644 index 00000000000..01907b1df5f --- /dev/null +++ b/src/test/java/com/google/genai/AsyncLiveTest.java @@ -0,0 +1,121 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.genai; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.genai.types.HttpOptions; +import java.lang.reflect.Method; +import java.net.URI; +import java.util.Map; +import java.util.Optional; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class AsyncLiveTest { + private ApiClient apiClient; + private AsyncLive asyncLive; + + @BeforeEach + public void setUp() { + apiClient = mock(ApiClient.class); + asyncLive = new AsyncLive(apiClient); + } + + @Test + public void testGetWebSocketUri_VertexCustomBaseUrlNoAuth() throws Exception { + when(apiClient.vertexAI()).thenReturn(true); + when(apiClient.customBaseUrl()).thenReturn("https://my-custom-endpoint.com"); + when(apiClient.project()).thenReturn(null); + when(apiClient.location()).thenReturn(null); + when(apiClient.apiKey()).thenReturn(null); + when(apiClient.httpOptions()).thenReturn(HttpOptions.builder().baseUrl("https://my-custom-endpoint.com").build()); + + Method getWebSocketUri = AsyncLive.class.getDeclaredMethod("getWebSocketUri"); + getWebSocketUri.setAccessible(true); + URI uri = (URI) getWebSocketUri.invoke(asyncLive); + + assertEquals("wss://my-custom-endpoint.com", uri.toString()); + } + + @Test + public void testGetWebSocketHeaders_VertexQuotaProject() throws Exception { + GoogleCredentials credentials = mock(GoogleCredentials.class); + when(credentials.getAccessToken()).thenReturn(AccessToken.newBuilder().setTokenValue("test-token").build()); + when(credentials.getQuotaProjectId()).thenReturn("test-quota-project"); + + when(apiClient.vertexAI()).thenReturn(true); + when(apiClient.credentials()).thenReturn(credentials); + when(apiClient.httpOptions()).thenReturn(HttpOptions.builder().build()); + + Method getWebSocketHeaders = AsyncLive.class.getDeclaredMethod("getWebSocketHeaders"); + getWebSocketHeaders.setAccessible(true); + Map headers = (Map) getWebSocketHeaders.invoke(asyncLive); + + assertEquals("Bearer test-token", headers.get("Authorization")); + assertEquals("test-quota-project", headers.get("x-goog-user-project")); + } + + @Test + public void testGetWebSocketHeaders_VertexApiKey() throws Exception { + when(apiClient.vertexAI()).thenReturn(true); + when(apiClient.apiKey()).thenReturn("vertex-api-key"); + when(apiClient.credentials()).thenReturn(null); + when(apiClient.httpOptions()).thenReturn(HttpOptions.builder().build()); + + Method getWebSocketHeaders = AsyncLive.class.getDeclaredMethod("getWebSocketHeaders"); + getWebSocketHeaders.setAccessible(true); + Map headers = (Map) getWebSocketHeaders.invoke(asyncLive); + + assertEquals("vertex-api-key", headers.get("x-goog-api-key")); + } + + @Test + public void testGetWebSocketUri_GoogleAiEphemeralToken() throws Exception { + when(apiClient.vertexAI()).thenReturn(false); + when(apiClient.apiKey()).thenReturn("auth_tokens/ephemeral-token"); + when(apiClient.httpOptions()).thenReturn(HttpOptions.builder() + .baseUrl("https://generativelanguage.googleapis.com") + .apiVersion("v1alpha") + .build()); + + Method getWebSocketUri = AsyncLive.class.getDeclaredMethod("getWebSocketUri"); + getWebSocketUri.setAccessible(true); + URI uri = (URI) getWebSocketUri.invoke(asyncLive); + + assertTrue(uri.toString().contains("BidiGenerateContentConstrained")); + assertTrue(uri.toString().contains("v1alpha")); + } + + @Test + public void testGetWebSocketHeaders_GoogleAiEphemeralToken() throws Exception { + when(apiClient.vertexAI()).thenReturn(false); + when(apiClient.apiKey()).thenReturn("auth_tokens/ephemeral-token"); + when(apiClient.httpOptions()).thenReturn(HttpOptions.builder().build()); + + Method getWebSocketHeaders = AsyncLive.class.getDeclaredMethod("getWebSocketHeaders"); + getWebSocketHeaders.setAccessible(true); + Map headers = (Map) getWebSocketHeaders.invoke(asyncLive); + + assertEquals("Token auth_tokens/ephemeral-token", headers.get("Authorization")); + } +} diff --git a/src/test/java/com/google/genai/ClientTest.java b/src/test/java/com/google/genai/ClientTest.java index b84e8026314..6373ac1ff84 100644 --- a/src/test/java/com/google/genai/ClientTest.java +++ b/src/test/java/com/google/genai/ClientTest.java @@ -188,4 +188,36 @@ public void testSetDefaultBaseUrls() { // Reset the base URLs after the test. Client.setDefaultBaseUrls(Optional.empty(), Optional.empty()); } + + @Test + public void testInitClientFromBuilder_globalLocation() { + // Act + Client client = + Client.builder() + .project(PROJECT) + .location("global") + .credentials(CREDENTIALS) + .vertexAI(true) + .build(); + + // Assert + assertEquals("global", client.location()); + assertTrue(client.vertexAI()); + assertEquals("https://aiplatform.googleapis.com", client.baseUrl().orElse(null)); + } + + @Test + public void testInitClientFromBuilder_globalLocationWithCustomBaseUrl() { + // Act + Client client = + Client.builder() + .httpOptions(HttpOptions.builder().baseUrl("https://my-endpoint.com").build()) + .vertexAI(true) + .build(); + + // Assert + assertEquals(null, client.location()); + assertTrue(client.vertexAI()); + assertEquals("https://my-endpoint.com", client.baseUrl().orElse(null)); + } } diff --git a/src/test/java/com/google/genai/HttpApiClientTest.java b/src/test/java/com/google/genai/HttpApiClientTest.java index 63616accb9a..a9741556867 100644 --- a/src/test/java/com/google/genai/HttpApiClientTest.java +++ b/src/test/java/com/google/genai/HttpApiClientTest.java @@ -74,9 +74,11 @@ public class HttpApiClientTest { private static final String PROXY_USERNAME = "user"; private static final String PROXY_PASSWORD = "pass"; private static final HttpOptions defaultHttpOptionsMLDev = - HttpApiClient.defaultHttpOptions(false, Optional.empty(), Optional.of(API_KEY)); + HttpApiClient.defaultHttpOptions( + false, Optional.empty(), Optional.of(API_KEY), Optional.empty(), Optional.empty()); private static final HttpOptions defaultHttpOptionsVertex = - HttpApiClient.defaultHttpOptions(true, Optional.of(LOCATION), Optional.empty()); + HttpApiClient.defaultHttpOptions( + true, Optional.of(LOCATION), Optional.empty(), Optional.empty(), Optional.empty()); private static final Optional REQUEST_HTTP_OPTIONS = Optional.of( HttpOptions.builder() @@ -831,7 +833,8 @@ public void testHttpClientVertexWithNoApiKeyProject_throwsException() throws Exc Optional.empty(), Optional.empty())); assertEquals( - "For Vertex AI APIs, either project or API key must be set.", + "Authentication is not set up. Please provide either a project and location, or an API" + + " key, or a custom base URL.", exception.getMessage()); } @@ -1483,6 +1486,91 @@ public void testNoDefaultLocationWhenUsingApiKeyOnlyMode( assertTrue(client.vertexAI()); } + @Test + public void testRequestWithResourceScopeCollection() throws Exception { + HttpOptions httpOptions = + HttpOptions.builder() + .baseUrl("https://my-endpoint.com") + .baseUrlResourceScope(com.google.genai.types.ResourceScope.Known.COLLECTION) + .build(); + HttpApiClient client = + new HttpApiClient( + Optional.empty(), + Optional.of(PROJECT), + Optional.of(LOCATION), + Optional.of(CREDENTIALS), + Optional.of(httpOptions), + Optional.empty()); + + String path = "models/my-model"; + Request request = client.buildRequest("GET", path, (String) null, Optional.empty()); + + assertEquals("https://my-endpoint.com/v1beta1/models/my-model", request.url().toString()); + } + + @Test + public void testRequestWithVertexAiAndApiKey_noPrepend() throws Exception { + HttpApiClient client = + new HttpApiClient( + Optional.of(API_KEY), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + + String path = "models/my-model"; + Request request = client.buildRequest("GET", path, (String) null, Optional.empty()); + + // Should NOT prepend projects/... + assertTrue(request.url().toString().contains("v1beta1/models/my-model")); + assertFalse(request.url().toString().contains("projects/")); + } + + @Test + public void testRequestWithVertexAiAndPathStartingWithPublishers_noPrepend() throws Exception { + HttpApiClient client = + new HttpApiClient( + Optional.empty(), + Optional.of(PROJECT), + Optional.of(LOCATION), + Optional.of(CREDENTIALS), + Optional.empty(), + Optional.empty()); + + String path = "publishers/google/models/gemini-pro"; + Request request = client.buildRequest("GET", path, (String) null, Optional.empty()); + + // Should NOT prepend projects/... for publishers/google/models + assertEquals( + String.format( + "https://%s-aiplatform.googleapis.com/v1beta1/publishers/google/models/gemini-pro", + LOCATION), + request.url().toString()); + } + + @Test + public void testRequestWithVertexAiAndPathStartingWithPublishers_Post_shouldPrepend() throws Exception { + HttpApiClient client = + new HttpApiClient( + Optional.empty(), + Optional.of(PROJECT), + Optional.of(LOCATION), + Optional.of(CREDENTIALS), + Optional.empty(), + Optional.empty()); + + String path = "publishers/google/models/gemini-pro:predict"; + Request request = client.buildRequest("POST", path, "{}", Optional.empty()); + + // Should prepend projects/... for publishers/google/models if POST + assertEquals( + String.format( + "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/publishers/google/models/gemini-pro:predict", + LOCATION, PROJECT, LOCATION), + request.url().toString()); + } + @Test public void testCloseClient() { HttpApiClient client =