Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 40 additions & 54 deletions core/src/main/java/com/google/adk/sessions/ApiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

package com.google.adk.sessions;

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.StandardSystemProperty.JAVA_VERSION;

import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.base.Ascii;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import com.google.genai.errors.GenAiIOException;
import com.google.genai.types.HttpOptions;
Expand All @@ -35,83 +35,69 @@
abstract class ApiClient {
OkHttpClient httpClient;
// For Google AI APIs
final Optional<String> apiKey;
final @Nullable String apiKey;
// For Vertex AI APIs
final Optional<String> project;
final Optional<String> location;
final Optional<GoogleCredentials> credentials;
final @Nullable String project;
final @Nullable String location;
final @Nullable GoogleCredentials credentials;
HttpOptions httpOptions;
final boolean vertexAI;

/** Constructs an ApiClient for Google AI APIs. */
ApiClient(Optional<String> apiKey, Optional<HttpOptions> customHttpOptions) {
checkNotNull(apiKey, "API Key cannot be null");
checkNotNull(customHttpOptions, "customHttpOptions cannot be null");
ApiClient(@Nullable String apiKey, @Nullable HttpOptions customHttpOptions) {

try {
this.apiKey = Optional.of(apiKey.orElseGet(() -> System.getenv("GOOGLE_API_KEY")));
} catch (NullPointerException e) {
this.apiKey = apiKey != null ? apiKey : System.getenv("GOOGLE_API_KEY");

if (Strings.isNullOrEmpty(this.apiKey)) {
throw new IllegalArgumentException(
"API key must either be provided or set in the environment variable" + " GOOGLE_API_KEY.",
e);
"API key must either be provided or set in the environment variable"
+ " GOOGLE_API_KEY.");
}

this.project = Optional.empty();
this.location = Optional.empty();
this.credentials = Optional.empty();
this.project = null;
this.location = null;
this.credentials = null;
this.vertexAI = false;

this.httpOptions = defaultHttpOptions(/* vertexAI= */ false, this.location);

if (customHttpOptions.isPresent()) {
applyHttpOptions(customHttpOptions.get());
if (customHttpOptions != null) {
applyHttpOptions(customHttpOptions);
}

this.httpClient = createHttpClient(httpOptions.timeout().orElse(null));
}

ApiClient(
Optional<String> project,
Optional<String> location,
Optional<GoogleCredentials> credentials,
Optional<HttpOptions> customHttpOptions) {
checkNotNull(project, "project cannot be null");
checkNotNull(location, "location cannot be null");
checkNotNull(credentials, "credentials cannot be null");
checkNotNull(customHttpOptions, "customHttpOptions cannot be null");
@Nullable String project,
@Nullable String location,
@Nullable GoogleCredentials credentials,
@Nullable HttpOptions customHttpOptions) {

try {
this.project = Optional.of(project.orElseGet(() -> System.getenv("GOOGLE_CLOUD_PROJECT")));
} catch (NullPointerException e) {
this.project = project != null ? project : System.getenv("GOOGLE_CLOUD_PROJECT");

if (Strings.isNullOrEmpty(this.project)) {
throw new IllegalArgumentException(
"Project must either be provided or set in the environment variable"
+ " GOOGLE_CLOUD_PROJECT.",
e);
}
if (this.project.get().isEmpty()) {
throw new IllegalArgumentException("Project must not be empty.");
+ " GOOGLE_CLOUD_PROJECT.");
}

try {
this.location = Optional.of(location.orElse(System.getenv("GOOGLE_CLOUD_LOCATION")));
} catch (NullPointerException e) {
this.location = location != null ? location : System.getenv("GOOGLE_CLOUD_LOCATION");

if (Strings.isNullOrEmpty(this.location)) {
throw new IllegalArgumentException(
"Location must either be provided or set in the environment variable"
+ " GOOGLE_CLOUD_LOCATION.",
e);
}
if (this.location.get().isEmpty()) {
throw new IllegalArgumentException("Location must not be empty.");
+ " GOOGLE_CLOUD_LOCATION.");
}

this.credentials = Optional.of(credentials.orElseGet(this::defaultCredentials));
this.credentials = credentials != null ? credentials : defaultCredentials();

this.httpOptions = defaultHttpOptions(/* vertexAI= */ true, this.location);

if (customHttpOptions.isPresent()) {
applyHttpOptions(customHttpOptions.get());
if (customHttpOptions != null) {
applyHttpOptions(customHttpOptions);
}
this.apiKey = Optional.empty();
this.apiKey = null;
this.vertexAI = true;
this.httpClient = createHttpClient(httpOptions.timeout().orElse(null));
}
Expand Down Expand Up @@ -142,17 +128,17 @@ public boolean vertexAI() {

/** Returns the project ID for Vertex AI APIs. */
public @Nullable String project() {
return project.orElse(null);
return project;
}

/** Returns the location for Vertex AI APIs. */
public @Nullable String location() {
return location.orElse(null);
return location;
}

/** Returns the API key for Google AI APIs. */
public @Nullable String apiKey() {
return apiKey.orElse(null);
return apiKey;
}

/** Returns the HttpClient for API calls. */
Expand Down Expand Up @@ -192,7 +178,7 @@ private void applyHttpOptions(HttpOptions httpOptionsToApply) {
this.httpOptions = mergedHttpOptionsBuilder.build();
}

static HttpOptions defaultHttpOptions(boolean vertexAI, Optional<String> location) {
static HttpOptions defaultHttpOptions(boolean vertexAI, @Nullable String location) {
ImmutableMap.Builder<String, String> defaultHeaders = ImmutableMap.builder();
defaultHeaders
.put("Content-Type", "application/json")
Expand All @@ -202,14 +188,14 @@ static HttpOptions defaultHttpOptions(boolean vertexAI, Optional<String> locatio
HttpOptions.Builder defaultHttpOptionsBuilder =
HttpOptions.builder().headers(defaultHeaders.buildOrThrow());

if (vertexAI && location.isPresent()) {
if (vertexAI && location != null) {
defaultHttpOptionsBuilder
.baseUrl(
Ascii.equalsIgnoreCase(location.get(), "global")
Ascii.equalsIgnoreCase(location, "global")
? "https://aiplatform.googleapis.com"
: String.format("https://%s-aiplatform.googleapis.com", location.get()))
: String.format("https://%s-aiplatform.googleapis.com", location))
.apiVersion("v1beta1");
} else if (vertexAI && location.isEmpty()) {
} else if (vertexAI && Strings.isNullOrEmpty(location)) {
throw new IllegalArgumentException("Location must be provided for Vertex AI APIs.");
} else {
defaultHttpOptionsBuilder
Expand Down
25 changes: 12 additions & 13 deletions core/src/main/java/com/google/adk/sessions/HttpApiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,34 @@

import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.base.Ascii;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.genai.errors.GenAiIOException;
import com.google.genai.types.HttpOptions;
import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.jspecify.annotations.Nullable;

/** Base client for the HTTP APIs. */
public class HttpApiClient extends ApiClient {
public static final MediaType MEDIA_TYPE_APPLICATION_JSON =
MediaType.parse("application/json; charset=utf-8");

/** Constructs an ApiClient for Google AI APIs. */
HttpApiClient(Optional<String> apiKey, Optional<HttpOptions> httpOptions) {
HttpApiClient(@Nullable String apiKey, @Nullable HttpOptions httpOptions) {
super(apiKey, httpOptions);
}

/** Constructs an ApiClient for Vertex AI APIs. */
HttpApiClient(
Optional<String> project,
Optional<String> location,
Optional<GoogleCredentials> credentials,
Optional<HttpOptions> httpOptions) {
@Nullable String project,
@Nullable String location,
@Nullable GoogleCredentials credentials,
@Nullable HttpOptions httpOptions) {
super(project, location, credentials, httpOptions);
}

Expand All @@ -54,9 +55,7 @@ public ApiResponse request(String httpMethod, String path, String requestJson) {
boolean queryBaseModel =
Ascii.equalsIgnoreCase(httpMethod, "GET") && path.startsWith("publishers/google/models/");
if (this.vertexAI() && !path.startsWith("projects/") && !queryBaseModel) {
path =
String.format("projects/%s/locations/%s/", this.project.get(), this.location.get())
+ path;
path = String.format("projects/%s/locations/%s/", this.project, this.location) + path;
}
String requestUrl =
String.format(
Expand Down Expand Up @@ -85,11 +84,11 @@ private void setHeaders(Request.Builder requestBuilder) {
requestBuilder.header(header.getKey(), header.getValue());
}

if (apiKey.isPresent()) {
requestBuilder.header("x-goog-api-key", apiKey.get());
if (apiKey != null) {
requestBuilder.header("x-goog-api-key", apiKey);
} else {
GoogleCredentials cred =
credentials.orElseThrow(() -> new IllegalStateException("credentials is required"));
Preconditions.checkState(credentials != null, "credentials is required");
GoogleCredentials cred = credentials;
try {
cred.refreshIfExpired();
} catch (IOException e) {
Expand Down
11 changes: 4 additions & 7 deletions core/src/main/java/com/google/adk/sessions/VertexAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeoutException;
import javax.annotation.Nullable;
import okhttp3.ResponseBody;
Expand All @@ -37,17 +36,15 @@ final class VertexAiClient {
}

VertexAiClient() {
this.apiClient =
new HttpApiClient(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
this.apiClient = new HttpApiClient((String) null, null, null, null);
}

VertexAiClient(
String project,
String location,
Optional<GoogleCredentials> credentials,
Optional<HttpOptions> httpOptions) {
this.apiClient =
new HttpApiClient(Optional.of(project), Optional.of(location), credentials, httpOptions);
@Nullable GoogleCredentials credentials,
@Nullable HttpOptions httpOptions) {
this.apiClient = new HttpApiClient(project, location, credentials, httpOptions);
}

Maybe<JsonNode> createSession(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import java.util.concurrent.ConcurrentMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.annotation.Nullable;
import org.jspecify.annotations.Nullable;

/** Connects to the managed Vertex AI Session Service. */
// TODO: Use the genai HttpApiClient and ApiResponse methods once they are public.
Expand All @@ -65,8 +65,8 @@ public VertexAiSessionService() {
public VertexAiSessionService(
String project,
String location,
Optional<GoogleCredentials> credentials,
Optional<HttpOptions> httpOptions) {
@Nullable GoogleCredentials credentials,
@Nullable HttpOptions httpOptions) {
this.client = new VertexAiClient(project, location, credentials, httpOptions);
}

Expand Down
Loading