Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 103 additions & 21 deletions src/main/java/com/google/genai/ApiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.genai.types.HttpRetryOptions;
import com.google.genai.types.ProxyOptions;
import com.google.genai.types.ProxyType;
import com.google.genai.types.ResourceScope.Known;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Proxy;
Expand Down Expand Up @@ -72,6 +73,7 @@ public abstract class ApiClient implements AutoCloseable {
HttpOptions httpOptions;
final boolean vertexAI;
final Optional<ClientOptions> clientOptions;
final Optional<String> customBaseUrl;
// For Google AI APIs
final Optional<String> apiKey;
// For Vertex AI APIs
Expand Down Expand Up @@ -102,8 +104,12 @@ protected ApiClient(
this.credentials = Optional.empty();
this.vertexAI = false;
this.clientOptions = clientOptions;
this.customBaseUrl =
customHttpOptions.flatMap(HttpOptions::baseUrl).map(url -> url.replaceAll("/$", ""));

this.httpOptions = defaultHttpOptions(/* vertexAI= */ false, this.location, this.apiKey);
this.httpOptions =
defaultHttpOptions(
/* vertexAI= */ false, this.location, this.apiKey, this.customBaseUrl, Optional.empty());

if (customHttpOptions.isPresent()) {
this.httpOptions = mergeHttpOptions(customHttpOptions.get());
Expand Down Expand Up @@ -149,6 +155,9 @@ protected ApiClient(
boolean hasProject = project != null && project.isPresent();
boolean hasLocation = location != null && location.isPresent();

Optional<String> customBaseUrl =
customHttpOptions.flatMap(HttpOptions::baseUrl).map(url -> url.replaceAll("/$", ""));

// Validate constructor arguments combinations.
if (hasProject && hasApiKey) {
throw new IllegalArgumentException(
Expand Down Expand Up @@ -197,19 +206,41 @@ protected ApiClient(
apiKeyValue = null;
}

if (locationValue == null && apiKeyValue == null) {
if (locationValue == null && apiKeyValue == null && !customBaseUrl.isPresent()) {
locationValue = "global";
} else if (locationValue == null
&& apiKeyValue == null
&& customBaseUrl.isPresent()
&& customBaseUrl.get().endsWith(".googleapis.com")) {
locationValue = "global";
}

boolean hasSufficientAuth =
(projectValue != null && locationValue != null) || apiKeyValue != null;
if (!hasSufficientAuth && !customBaseUrl.isPresent()) {
throw new IllegalArgumentException(
"Authentication is not set up. Please provide either a project and location, or an API"
+ " key, or a custom base URL.");
}

boolean hasConstructorAuth = (hasProject && hasLocation) || hasApiKey;
HttpOptions.Builder initHttpOptionsBuilder = HttpOptions.builder();
if (customBaseUrl.isPresent() && !hasConstructorAuth) {
initHttpOptionsBuilder.baseUrl(customBaseUrl.get());
projectValue = null;
locationValue = null;
} else if (apiKeyValue != null
|| (locationValue != null && locationValue.equals("global") && !customBaseUrl.isPresent())) {
initHttpOptionsBuilder.baseUrl("https://aiplatform.googleapis.com");
} else if (locationValue != null && !customBaseUrl.isPresent()) {
initHttpOptionsBuilder.baseUrl(
String.format("https://%s-aiplatform.googleapis.com", locationValue));
}

this.apiKey = Optional.ofNullable(apiKeyValue);
this.project = Optional.ofNullable(projectValue);
this.location = Optional.ofNullable(locationValue);

// Validate that either project and location or API key is set.
if (!(this.project.isPresent() || this.apiKey.isPresent())) {
throw new IllegalArgumentException(
"For Vertex AI APIs, either project or API key must be set.");
}
this.customBaseUrl = customBaseUrl;

// Only set credentials if using project/location.
this.credentials =
Expand All @@ -219,7 +250,13 @@ protected ApiClient(

this.clientOptions = clientOptions;

this.httpOptions = defaultHttpOptions(/* vertexAI= */ true, this.location, this.apiKey);
this.httpOptions =
defaultHttpOptions(
/* vertexAI= */ true,
this.location,
this.apiKey,
this.customBaseUrl,
initHttpOptionsBuilder.build().baseUrl());

if (customHttpOptions.isPresent()) {
this.httpOptions = mergeHttpOptions(customHttpOptions.get());
Expand Down Expand Up @@ -324,19 +361,17 @@ protected Request buildRequest(
String requestJson,
Optional<HttpOptions> requestHttpOptions) {
String capitalizedHttpMethod = Ascii.toUpperCase(httpMethod);
boolean queryBaseModel =
capitalizedHttpMethod.equals("GET") && path.startsWith("publishers/google/models");
if (this.vertexAI()
&& !this.apiKey.isPresent()
&& !path.startsWith("projects/")
&& !queryBaseModel) {
HttpOptions mergedHttpOptions = mergeHttpOptions(requestHttpOptions.orElse(null));

boolean prependProjectLocation =
shouldPrependVertexProjectPath(capitalizedHttpMethod, path, mergedHttpOptions);

if (prependProjectLocation) {
path =
String.format("projects/%s/locations/%s/", this.project.get(), this.location.get())
+ path;
}

HttpOptions mergedHttpOptions = mergeHttpOptions(requestHttpOptions.orElse(null));

String requestUrl;

String baseUrl =
Expand Down Expand Up @@ -408,6 +443,7 @@ protected Request buildRequest(
byte[] requestBytes,
Optional<HttpOptions> requestHttpOptions) {
HttpOptions mergedHttpOptions = mergeHttpOptions(requestHttpOptions.orElse(null));

if (httpMethod.equalsIgnoreCase("POST")) {
RequestBody body =
RequestBody.create(requestBytes, MediaType.get("application/octet-stream"));
Expand Down Expand Up @@ -437,9 +473,8 @@ private void setHeaders(Request.Builder request, HttpOptions requestHttpOptions)
if (apiKey.isPresent()) {
// Sets API key for Gemini Developer API or Vertex AI Express mode
request.header("x-goog-api-key", apiKey.get());
} else {
GoogleCredentials cred =
credentials.orElseThrow(() -> new IllegalStateException("credentials is required"));
} else if (credentials.isPresent()) {
GoogleCredentials cred = credentials.get();
try {
cred.refreshIfExpired();
} catch (IOException e) {
Expand All @@ -451,6 +486,8 @@ private void setHeaders(Request.Builder request, HttpOptions requestHttpOptions)
if (cred.getQuotaProjectId() != null) {
request.header("x-goog-user-project", cred.getQuotaProjectId());
}
} else if (!customBaseUrl.isPresent()) {
throw new IllegalStateException("credentials is required");
}
}

Expand Down Expand Up @@ -504,11 +541,21 @@ public boolean vertexAI() {
return apiKey.orElse(null);
}

/** Returns the custom base URL if provided. */
public @Nullable String customBaseUrl() {
return customBaseUrl.orElse(null);
}

/** Returns the HttpClient for API calls. */
public OkHttpClient httpClient() {
return httpClient;
}

/** Returns the GoogleCredentials for Vertex AI APIs. */
public @Nullable GoogleCredentials credentials() {
return credentials.orElse(null);
}

/** Returns the HTTP options for API calls. */
public HttpOptions httpOptions() {
return httpOptions;
Expand Down Expand Up @@ -577,6 +624,7 @@ HttpOptions mergeHttpOptions(HttpOptions httpOptionsToApply) {
HttpOptions.Builder mergedHttpOptionsBuilder = this.httpOptions.toBuilder();

httpOptionsToApply.baseUrl().ifPresent(mergedHttpOptionsBuilder::baseUrl);
httpOptionsToApply.baseUrlResourceScope().ifPresent(mergedHttpOptionsBuilder::baseUrlResourceScope);
httpOptionsToApply.apiVersion().ifPresent(mergedHttpOptionsBuilder::apiVersion);
httpOptionsToApply.timeout().ifPresent(mergedHttpOptionsBuilder::timeout);
httpOptionsToApply.extraBody().ifPresent(mergedHttpOptionsBuilder::extraBody);
Expand All @@ -602,7 +650,11 @@ HttpOptions mergeHttpOptions(HttpOptions httpOptionsToApply) {
}

static HttpOptions defaultHttpOptions(
boolean vertexAI, Optional<String> location, Optional<String> apiKey) {
boolean vertexAI,
Optional<String> location,
Optional<String> apiKey,
Optional<String> customBaseUrl,
Optional<String> initBaseUrl) {
ImmutableMap.Builder<String, String> defaultHeaders = ImmutableMap.builder();
defaultHeaders.put("Content-Type", "application/json");
defaultHeaders.put("user-agent", libraryVersion());
Expand All @@ -618,6 +670,12 @@ static HttpOptions defaultHttpOptions(
vertexBaseUrl.orElseGet(() -> defaultEnvironmentVariables.get("vertexBaseUrl"));
if (defaultBaseUrl != null) {
defaultHttpOptionsBuilder.baseUrl(defaultBaseUrl);
} else if (initBaseUrl.isPresent()) {
defaultHttpOptionsBuilder.baseUrl(initBaseUrl.get());
} else if (customBaseUrl.isPresent()
&& !customBaseUrl.get().endsWith(".googleapis.com")
&& !(location.isPresent() || apiKey.isPresent())) {
defaultHttpOptionsBuilder.baseUrl(customBaseUrl.get());
} else if (apiKey.isPresent() || location.get().equalsIgnoreCase("global")) {
defaultHttpOptionsBuilder.baseUrl("https://aiplatform.googleapis.com");
} else {
Expand All @@ -638,6 +696,30 @@ static HttpOptions defaultHttpOptions(
return defaultHttpOptionsBuilder.build();
}

boolean shouldPrependVertexProjectPath(
String httpMethod, String path, HttpOptions httpOptions) {
if (httpOptions.baseUrlResourceScope().isPresent()
&& httpOptions.baseUrl().isPresent()
&& httpOptions.baseUrlResourceScope().get().knownEnum() == Known.COLLECTION) {
return false;
}
if (this.apiKey.isPresent()) {
return false;
}
if (!this.vertexAI) {
return false;
}
if (path.startsWith("projects/")) {
return false;
}
// These paths are used by Vertex's models.get and models.list calls. For base models Vertex
// does not accept a project/location prefix (for tuned model the prefix is required).
if (httpMethod.equals("GET") && path.startsWith("publishers/google/models")) {
return false;
}
return true;
}

GoogleCredentials defaultCredentials() {
try {
return GoogleCredentials.getApplicationDefault()
Expand Down
44 changes: 30 additions & 14 deletions src/main/java/com/google/genai/AsyncLive.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public CompletableFuture<AsyncSession> connect(String model, LiveConnectConfig c

/** Gets the URI for the websocket connection. */
private URI getWebSocketUri() {
String baseUrl = apiClient.httpOptions.baseUrl().orElse(null);
String baseUrl = apiClient.httpOptions().baseUrl().orElse(null);
if (baseUrl == null) {
throw new IllegalArgumentException("No base URL provided in the client.");
}
Expand All @@ -102,13 +102,20 @@ private URI getWebSocketUri() {
baseUri.getFragment())
.toString();

boolean hasStandardAuth =
(apiClient.project() != null && apiClient.location() != null)
|| apiClient.apiKey() != null;
if (apiClient.customBaseUrl() != null && !hasStandardAuth) {
return new URI(wsBaseUrl);
}

if (!apiClient.vertexAI()) {
String method;
if (apiClient.apiKey().startsWith("auth_tokens/")) {
if (apiClient.apiKey() != null && apiClient.apiKey().startsWith("auth_tokens/")) {
logger.warning(
"Warning: Ephemeral token support is experimental and may change in future"
+ " versions.");
if (!apiClient.httpOptions.apiVersion().orElse("v1beta").equals("v1alpha")) {
if (!apiClient.httpOptions().apiVersion().orElse("v1beta").equals("v1alpha")) {
logger.warning(
"Warning: The SDK's ephemeral token support is in v1alpha only. Please use client"
+ " = Client.builder().httpOptions(HttpOptions.builder().apiVersion(\"v1alpha\").build()).build()"
Expand All @@ -121,12 +128,12 @@ private URI getWebSocketUri() {
return new URI(
String.format(
"%s/ws/google.ai.generativelanguage.%s.GenerativeService.%s",
wsBaseUrl, apiClient.httpOptions.apiVersion().orElse("v1beta"), method));
wsBaseUrl, apiClient.httpOptions().apiVersion().orElse("v1beta"), method));
} else {
return new URI(
String.format(
"%s/ws/google.cloud.aiplatform.%s.LlmBidiService/BidiGenerateContent",
wsBaseUrl, apiClient.httpOptions.apiVersion().orElse("v1beta1")));
wsBaseUrl, apiClient.httpOptions().apiVersion().orElse("v1beta1")));
}
} catch (URISyntaxException e) {
throw new IllegalStateException("Failed to parse URL.", e);
Expand All @@ -136,16 +143,22 @@ private URI getWebSocketUri() {
/** Gets the headers for the websocket connection. */
private Map<String, String> getWebSocketHeaders() {
Map<String, String> headers = new HashMap<>();
apiClient.httpOptions.headers().ifPresent(headers::putAll);
apiClient.httpOptions().headers().ifPresent(headers::putAll);

if (apiClient.vertexAI()) {
try {
GoogleCredentials credentials =
apiClient.credentials.orElseGet(() -> apiClient.defaultCredentials());
credentials.refreshIfExpired();
headers.put("Authorization", "Bearer " + credentials.getAccessToken().getTokenValue());
} catch (IOException e) {
throw new GenAiIOException("Failed to refresh credentials for Vertex AI.", e);
if (apiClient.credentials() != null) {
try {
GoogleCredentials credentials = apiClient.credentials();
credentials.refreshIfExpired();
headers.put("Authorization", "Bearer " + credentials.getAccessToken().getTokenValue());
if (credentials.getQuotaProjectId() != null) {
headers.put("x-goog-user-project", credentials.getQuotaProjectId());
}
} catch (IOException e) {
throw new GenAiIOException("Failed to refresh credentials for Vertex AI.", e);
}
} else if (apiClient.apiKey() != null) {
headers.put("x-goog-api-key", apiClient.apiKey());
}
} else {
@Nullable String apiKey = apiClient.apiKey();
Expand All @@ -165,7 +178,10 @@ private String getSetupRequest(String model, LiveConnectConfig config) {

String transformedModel = Transformers.tModel(apiClient, model);
// Vertex requires the full resource path for the model.
if (apiClient.vertexAI() && transformedModel.startsWith("publishers/")) {
if (apiClient.vertexAI()
&& transformedModel.startsWith("publishers/")
&& apiClient.project() != null
&& apiClient.location() != null) {
model =
String.format(
"projects/%s/locations/%s/%s",
Expand Down
45 changes: 45 additions & 0 deletions src/main/java/com/google/genai/types/HttpOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ public abstract class HttpOptions extends JsonSerializable {
@JsonProperty("baseUrl")
public abstract Optional<String> baseUrl();

/** The resource scope used to constructing the resource name when base_url is set */
@JsonProperty("baseUrlResourceScope")
public abstract Optional<ResourceScope> baseUrlResourceScope();

/** Specifies the version of the API to use. */
@JsonProperty("apiVersion")
public abstract Optional<String> apiVersion();
Expand Down Expand Up @@ -95,6 +99,47 @@ public Builder clearBaseUrl() {
return baseUrl(Optional.empty());
}

/**
* Setter for baseUrlResourceScope.
*
* <p>baseUrlResourceScope: The resource scope used to constructing the resource name when
* base_url is set
*/
@JsonProperty("baseUrlResourceScope")
public abstract Builder baseUrlResourceScope(ResourceScope baseUrlResourceScope);

@ExcludeFromGeneratedCoverageReport
abstract Builder baseUrlResourceScope(Optional<ResourceScope> baseUrlResourceScope);

/** Clears the value of baseUrlResourceScope field. */
@ExcludeFromGeneratedCoverageReport
@CanIgnoreReturnValue
public Builder clearBaseUrlResourceScope() {
return baseUrlResourceScope(Optional.empty());
}

/**
* Setter for baseUrlResourceScope given a known enum.
*
* <p>baseUrlResourceScope: The resource scope used to constructing the resource name when
* base_url is set
*/
@CanIgnoreReturnValue
public Builder baseUrlResourceScope(ResourceScope.Known knownType) {
return baseUrlResourceScope(new ResourceScope(knownType));
}

/**
* Setter for baseUrlResourceScope given a string.
*
* <p>baseUrlResourceScope: The resource scope used to constructing the resource name when
* base_url is set
*/
@CanIgnoreReturnValue
public Builder baseUrlResourceScope(String baseUrlResourceScope) {
return baseUrlResourceScope(new ResourceScope(baseUrlResourceScope));
}

/**
* Setter for apiVersion.
*
Expand Down
Loading
Loading