Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public void testAPIKeySecurityScheme() {
AuthTestCase authTestCase = new AuthTestCase(
"http://agent.com/rpc",
"session-id",
APIKeySecurityScheme.API_KEY,
APIKeySecurityScheme.TYPE,
"secret-api-key",
new APIKeySecurityScheme(APIKeySecurityScheme.Location.HEADER, "x-api-key", "API Key authentication"),
"x-api-key",
Expand All @@ -91,7 +91,7 @@ public void testOAuth2SecurityScheme() {
AuthTestCase authTestCase = new AuthTestCase(
"http://agent.com/rpc",
"session-id",
OAuth2SecurityScheme.OAUTH2,
OAuth2SecurityScheme.TYPE,
"secret-oauth-access-token",
new OAuth2SecurityScheme(OAuthFlows.builder().build(), "OAuth2 authentication", null),
"Authorization",
Expand All @@ -105,7 +105,7 @@ public void testOidcSecurityScheme() {
AuthTestCase authTestCase = new AuthTestCase(
"http://agent.com/rpc",
"session-id",
OpenIdConnectSecurityScheme.OPENID_CONNECT,
OpenIdConnectSecurityScheme.TYPE,
"secret-oidc-id-token",
new OpenIdConnectSecurityScheme("http://provider.com/.well-known/openid-configuration", "OIDC authentication"),
"Authorization",
Expand Down
232 changes: 192 additions & 40 deletions jsonrpc-common/src/main/java/io/a2a/jsonrpc/common/json/JsonUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
import static io.a2a.spec.A2AErrorCodes.TASK_NOT_CANCELABLE_ERROR_CODE;
import static io.a2a.spec.A2AErrorCodes.TASK_NOT_FOUND_ERROR_CODE;
import static io.a2a.spec.A2AErrorCodes.UNSUPPORTED_OPERATION_ERROR_CODE;
import static io.a2a.spec.DataPart.DATA;
import static io.a2a.spec.FilePart.FILE;
import static io.a2a.spec.TextPart.TEXT;
import static java.lang.String.format;

import java.io.StringReader;
import java.lang.reflect.Type;
import java.time.OffsetDateTime;
import java.time.format.DateTimeFormatter;
import java.time.format.DateTimeParseException;
import java.util.Map;
import java.util.Set;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
Expand All @@ -29,21 +34,28 @@
import com.google.gson.stream.JsonReader;
import com.google.gson.stream.JsonToken;
import com.google.gson.stream.JsonWriter;

import io.a2a.spec.A2AError;
import io.a2a.spec.APIKeySecurityScheme;
import io.a2a.spec.ContentTypeNotSupportedError;
import io.a2a.spec.DataPart;
import io.a2a.spec.FileContent;
import io.a2a.spec.FilePart;
import io.a2a.spec.FileWithBytes;
import io.a2a.spec.FileWithUri;
import io.a2a.spec.HTTPAuthSecurityScheme;
import io.a2a.spec.InvalidAgentResponseError;
import io.a2a.spec.InvalidParamsError;
import io.a2a.spec.InvalidRequestError;
import io.a2a.spec.JSONParseError;
import io.a2a.spec.Message;
import io.a2a.spec.MethodNotFoundError;
import io.a2a.spec.MutualTLSSecurityScheme;
import io.a2a.spec.OAuth2SecurityScheme;
import io.a2a.spec.OpenIdConnectSecurityScheme;
import io.a2a.spec.Part;
import io.a2a.spec.PushNotificationNotSupportedError;
import io.a2a.spec.SecurityScheme;
import io.a2a.spec.StreamingEventKind;
import io.a2a.spec.Task;
import io.a2a.spec.TaskArtifactUpdateEvent;
Expand All @@ -53,6 +65,7 @@
import io.a2a.spec.TaskStatusUpdateEvent;
import io.a2a.spec.TextPart;
import io.a2a.spec.UnsupportedOperationError;

import org.jspecify.annotations.Nullable;

/**
Expand Down Expand Up @@ -83,6 +96,7 @@ private static GsonBuilder createBaseGsonBuilder() {
public static final Gson OBJECT_MAPPER = createBaseGsonBuilder()
.registerTypeHierarchyAdapter(Part.class, new PartTypeAdapter())
.registerTypeHierarchyAdapter(StreamingEventKind.class, new StreamingEventKindTypeAdapter())
.registerTypeHierarchyAdapter(SecurityScheme.class, new SecuritySchemeTypeAdapter())
.create();

/**
Expand Down Expand Up @@ -530,6 +544,8 @@ public void write(JsonWriter out, Message.Role value) throws java.io.IOException
*/
static class PartTypeAdapter extends TypeAdapter<Part<?>> {

private static final Set<String> VALID_KEYS = Set.of(TEXT, FILE, DATA);

// Create separate Gson instance without the Part adapter to avoid recursion
private final Gson delegateGson = createBaseGsonBuilder().create();

Expand All @@ -539,21 +555,20 @@ public void write(JsonWriter out, Part<?> value) throws java.io.IOException {
out.nullValue();
return;
}

// Write wrapper object with member name as discriminator
out.beginObject();

if (value instanceof TextPart textPart) {
// TextPart: { "text": "value" } - direct string value
out.name("text");
out.name(TEXT);
out.value(textPart.text());
} else if (value instanceof FilePart filePart) {
// FilePart: { "file": {...} }
out.name("file");
out.name(FILE);
delegateGson.toJson(filePart.file(), FileContent.class, out);
} else if (value instanceof DataPart dataPart) {
// DataPart: { "data": {...} }
out.name("data");
out.name(DATA);
delegateGson.toJson(dataPart.data(), Map.class, out);
} else {
throw new JsonSyntaxException("Unknown Part subclass: " + value.getClass().getName());
Expand All @@ -579,23 +594,27 @@ Part<?> read(JsonReader in) throws java.io.IOException {
com.google.gson.JsonObject jsonObject = jsonElement.getAsJsonObject();

// Check for member name discriminators (v1.0 protocol)
if (jsonObject.has("text")) {
// TextPart: { "text": "value" } - direct string value
return new TextPart(jsonObject.get("text").getAsString());
} else if (jsonObject.has("file")) {
// FilePart: { "file": {...} }
return new FilePart(delegateGson.fromJson(jsonObject.get("file"), FileContent.class));
} else if (jsonObject.has("data")) {
// DataPart: { "data": {...} }
@SuppressWarnings("unchecked")
Map<String, Object> dataMap = delegateGson.fromJson(
jsonObject.get("data"),
new TypeToken<Map<String, Object>>(){}.getType()
);
return new DataPart(dataMap);
} else {
throw new JsonSyntaxException("Part must have one of: text, file, data (found: " + jsonObject.keySet() + ")");
}
Set<String> keys = jsonObject.keySet();
if (keys.size() != 1) {
throw new JsonSyntaxException(format("Part object must have exactly one key, which must be one of: %s (found: %s)", VALID_KEYS, keys));
}

String discriminator = keys.iterator().next();

return switch (discriminator) {
case TEXT -> new TextPart(jsonObject.get(TEXT).getAsString());
case FILE -> new FilePart(delegateGson.fromJson(jsonObject.get(FILE), FileContent.class));
case DATA -> {
@SuppressWarnings("unchecked")
Map<String, Object> dataMap = delegateGson.fromJson(
jsonObject.get(DATA),
new TypeToken<Map<String, Object>>(){}.getType()
);
yield new DataPart(dataMap);
}
default ->
throw new JsonSyntaxException(format("Part must have one of: %s (found: %s)", VALID_KEYS, discriminator));
};
}
}

Expand Down Expand Up @@ -627,20 +646,10 @@ public void write(JsonWriter out, StreamingEventKind value) throws java.io.IOExc
out.nullValue();
return;
}

// Write wrapper object with member name as discriminator
out.beginObject();

Type type = switch (value.kind()) {
case Task.STREAMING_EVENT_ID -> Task.class;
case Message.STREAMING_EVENT_ID -> Message.class;
case TaskStatusUpdateEvent.STREAMING_EVENT_ID -> TaskStatusUpdateEvent.class;
case TaskArtifactUpdateEvent.STREAMING_EVENT_ID -> TaskArtifactUpdateEvent.class;
default -> throw new JsonSyntaxException("Unknown StreamingEventKind implementation: " + value.getClass().getName());
};

out.name(value.kind());
delegateGson.toJson(value, type, out);
delegateGson.toJson(value, value.getClass(), out);
out.endObject();
}

Expand Down Expand Up @@ -714,7 +723,9 @@ StreamingEventKind read(JsonReader in) throws java.io.IOException {
static class FileContentTypeAdapter extends TypeAdapter<FileContent> {

// Create separate Gson instance without the FileContent adapter to avoid recursion
private final Gson delegateGson = new Gson();
private final Gson delegateGson = new GsonBuilder()
.registerTypeAdapter(OffsetDateTime.class, new OffsetDateTimeTypeAdapter())
.create();

@Override
public void write(JsonWriter out, FileContent value) throws java.io.IOException {
Expand All @@ -723,13 +734,7 @@ public void write(JsonWriter out, FileContent value) throws java.io.IOException
return;
}
// Delegate to Gson's default serialization for the concrete type
if (value instanceof FileWithBytes fileWithBytes) {
delegateGson.toJson(fileWithBytes, FileWithBytes.class, out);
} else if (value instanceof FileWithUri fileWithUri) {
delegateGson.toJson(fileWithUri, FileWithUri.class, out);
} else {
throw new JsonSyntaxException("Unknown FileContent implementation: " + value.getClass().getName());
}
delegateGson.toJson(value, value.getClass(), out);
}

@Override
Expand Down Expand Up @@ -759,4 +764,151 @@ FileContent read(JsonReader in) throws java.io.IOException {
}
}

/**
* Gson TypeAdapter for serializing and deserializing {@link APIKeySecurityScheme.Location} enum.
* <p>
* This adapter ensures that Location enum values are serialized using their
* wire format string representation (e.g., "header") rather than
* the Java enum constant name (e.g., "HEADER").
* <p>
* For serialization, it uses {@link APIKeySecurityScheme.Location#asString()} to get the wire format.
* For deserialization, it uses {@link APIKeySecurityScheme.Location#fromString(String)} to parse the
* wire format back to the enum constant.
*
* @see APIKeySecurityScheme.Location
*/
static class APIKeyLocationTypeAdapter extends TypeAdapter<APIKeySecurityScheme.Location> {

@Override
public void write(JsonWriter out, APIKeySecurityScheme.Location value) throws java.io.IOException {
if (value == null) {
out.nullValue();
return;
}
out.value(value.asString());
}

@Override
public APIKeySecurityScheme.@Nullable Location read(JsonReader in) throws java.io.IOException {
if (in.peek() == JsonToken.NULL) {
in.nextNull();
return null;
}
String locationString = in.nextString();
try {
return APIKeySecurityScheme.Location.fromString(locationString);
} catch (IllegalArgumentException e) {
throw new JsonSyntaxException("Invalid APIKeySecurityScheme.Location: " + locationString, e);
}
}
}

/**
* Gson TypeAdapter for serializing and deserializing {@link SecurityScheme} and its implementations.
* <p>
* This adapter handles polymorphic deserialization for the sealed SecurityScheme interface,
* which permits five implementations:
* <ul>
* <li>{@link APIKeySecurityScheme} - API key authentication</li>
* <li>{@link HTTPAuthSecurityScheme} - HTTP authentication (basic or bearer)</li>
* <li>{@link OAuth2SecurityScheme} - OAuth 2.0 flows</li>
* <li>{@link OpenIdConnectSecurityScheme} - OpenID Connect discovery</li>
* <li>{@link MutualTLSSecurityScheme} - Client certificate authentication</li>
* </ul>
* <p>
* The adapter uses a wrapper object with the security scheme type as the discriminator field.
* Each SecurityScheme is serialized as a JSON object with a single field whose name identifies
* the security scheme type.
* <p>
* Serialization format examples:
* <pre>{@code
* // HTTPAuthSecurityScheme
* {
* "httpAuthSecurityScheme": {
* "scheme": "bearer",
* "bearerFormat": "JWT",
* "description": "..."
* }
* }
*
* // APIKeySecurityScheme
* {
* "apiKeySecurityScheme": {
* "location": "header",
* "name": "X-API-Key",
* "description": "..."
* }
* }
* }</pre>
*
* @see SecurityScheme
* @see APIKeySecurityScheme
* @see HTTPAuthSecurityScheme
* @see OAuth2SecurityScheme
* @see OpenIdConnectSecurityScheme
* @see MutualTLSSecurityScheme
*/
static class SecuritySchemeTypeAdapter extends TypeAdapter<SecurityScheme> {

private static final Set<String> VALID_KEYS = Set.of(APIKeySecurityScheme.TYPE,
HTTPAuthSecurityScheme.TYPE,
OAuth2SecurityScheme.TYPE,
OpenIdConnectSecurityScheme.TYPE,
MutualTLSSecurityScheme.TYPE);

// Create separate Gson instance without the SecurityScheme adapter to avoid recursion
// Register custom adapter for APIKeySecurityScheme.Location enum
private final Gson delegateGson = createBaseGsonBuilder()
.registerTypeAdapter(APIKeySecurityScheme.Location.class, new APIKeyLocationTypeAdapter())
.create();

@Override
public void write(JsonWriter out, SecurityScheme value) throws java.io.IOException {
if (value == null) {
out.nullValue();
return;
}

// Write wrapper object with member name as discriminator
out.beginObject();
out.name(value.type());
delegateGson.toJson(value, value.getClass(), out);
out.endObject();
}

@Override
public @Nullable
SecurityScheme read(JsonReader in) throws java.io.IOException {
if (in.peek() == JsonToken.NULL) {
in.nextNull();
return null;
}

// Read the JSON as a tree to inspect the member name discriminator
com.google.gson.JsonElement jsonElement = com.google.gson.JsonParser.parseReader(in);
if (!jsonElement.isJsonObject()) {
throw new JsonSyntaxException("SecurityScheme must be a JSON object");
}

com.google.gson.JsonObject jsonObject = jsonElement.getAsJsonObject();

// Check for member name discriminators
Set<String> keys = jsonObject.keySet();
if (keys.size() != 1) {
throw new JsonSyntaxException(format("A SecurityScheme object must have exactly one key, which must be one of: %s (found: %s)", VALID_KEYS, keys));
}

String discriminator = keys.iterator().next();
com.google.gson.JsonElement nestedObject = jsonObject.get(discriminator);

return switch (discriminator) {
case APIKeySecurityScheme.TYPE -> delegateGson.fromJson(nestedObject, APIKeySecurityScheme.class);
case HTTPAuthSecurityScheme.TYPE -> delegateGson.fromJson(nestedObject, HTTPAuthSecurityScheme.class);
case OAuth2SecurityScheme.TYPE -> delegateGson.fromJson(nestedObject, OAuth2SecurityScheme.class);
case OpenIdConnectSecurityScheme.TYPE -> delegateGson.fromJson(nestedObject, OpenIdConnectSecurityScheme.class);
case MutualTLSSecurityScheme.TYPE -> delegateGson.fromJson(nestedObject, MutualTLSSecurityScheme.class);
default -> throw new JsonSyntaxException(format("Unknown SecurityScheme type. Must be one of: %s (found: %s)", VALID_KEYS, discriminator));
};
}
}
}
Loading