Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 44 additions & 69 deletions a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ public static Optional<TextPart> toTextPart(io.a2a.spec.Part<?> part) {
}

/** Convert an A2A JSON part into a Google GenAI part representation. */
public static Optional<com.google.genai.types.Part> toGenaiPart(io.a2a.spec.Part<?> a2aPart) {
public static com.google.genai.types.Part toGenaiPart(io.a2a.spec.Part<?> a2aPart) {
if (a2aPart == null) {
return Optional.empty();
throw new IllegalArgumentException("A2A part cannot be null");
}

if (a2aPart instanceof TextPart textPart) {
return Optional.of(com.google.genai.types.Part.builder().text(textPart.getText()).build());
return com.google.genai.types.Part.builder().text(textPart.getText()).build();
}

if (a2aPart instanceof FilePart filePart) {
Expand All @@ -95,56 +95,41 @@ public static Optional<com.google.genai.types.Part> toGenaiPart(io.a2a.spec.Part
return convertDataPartToGenAiPart(dataPart);
}

logger.warn("Unsupported A2A part type: {}", a2aPart.getClass());
return Optional.empty();
throw new IllegalArgumentException("Unsupported A2A part type: " + a2aPart.getClass());
}

public static ImmutableList<com.google.genai.types.Part> toGenaiParts(
List<io.a2a.spec.Part<?>> a2aParts) {
return a2aParts.stream()
.map(PartConverter::toGenaiPart)
.flatMap(Optional::stream)
.collect(toImmutableList());
return a2aParts.stream().map(PartConverter::toGenaiPart).collect(toImmutableList());
}

private static Optional<com.google.genai.types.Part> convertFilePartToGenAiPart(
FilePart filePart) {
private static com.google.genai.types.Part convertFilePartToGenAiPart(FilePart filePart) {
FileContent fileContent = filePart.getFile();
if (fileContent instanceof FileWithUri fileWithUri) {
return Optional.of(
com.google.genai.types.Part.builder()
.fileData(
FileData.builder()
.fileUri(fileWithUri.uri())
.mimeType(fileWithUri.mimeType())
.build())
.build());
return com.google.genai.types.Part.builder()
.fileData(
FileData.builder()
.fileUri(fileWithUri.uri())
.mimeType(fileWithUri.mimeType())
.build())
.build();
}

if (fileContent instanceof FileWithBytes fileWithBytes) {
String bytesString = fileWithBytes.bytes();
if (bytesString == null) {
logger.warn("FileWithBytes missing byte content");
return Optional.empty();
}
try {
byte[] decoded = Base64.getDecoder().decode(bytesString);
return Optional.of(
com.google.genai.types.Part.builder()
.inlineData(Blob.builder().data(decoded).mimeType(fileWithBytes.mimeType()).build())
.build());
} catch (IllegalArgumentException e) {
logger.warn("Failed to decode base64 file content", e);
return Optional.empty();
throw new GenAiFieldMissingException("FileWithBytes missing byte content");
}
byte[] decoded = Base64.getDecoder().decode(bytesString);
return com.google.genai.types.Part.builder()
.inlineData(Blob.builder().data(decoded).mimeType(fileWithBytes.mimeType()).build())
.build();
}

logger.warn("Unsupported FilePart content: {}", fileContent.getClass());
return Optional.empty();
throw new IllegalArgumentException("Unsupported FilePart content: " + fileContent.getClass());
}

private static Optional<com.google.genai.types.Part> convertDataPartToGenAiPart(
DataPart dataPart) {
private static com.google.genai.types.Part convertDataPartToGenAiPart(DataPart dataPart) {
Map<String, Object> data =
Optional.ofNullable(dataPart.getData()).map(HashMap::new).orElseGet(HashMap::new);
Map<String, Object> metadata =
Expand All @@ -154,67 +139,57 @@ private static Optional<com.google.genai.types.Part> convertDataPartToGenAiPart(

if ((data.containsKey(NAME_KEY) && data.containsKey(ARGS_KEY))
|| metadataType.equals(A2ADataPartMetadataType.FUNCTION_CALL.getType())) {
String functionName = String.valueOf(data.getOrDefault(NAME_KEY, null));
String functionId = String.valueOf(data.getOrDefault(ID_KEY, null));
String functionName = String.valueOf(data.getOrDefault(NAME_KEY, ""));
String functionId = String.valueOf(data.getOrDefault(ID_KEY, ""));
Map<String, Object> args = coerceToMap(data.get(ARGS_KEY));
return Optional.of(
com.google.genai.types.Part.builder()
.functionCall(
FunctionCall.builder().name(functionName).id(functionId).args(args).build())
.build());
return com.google.genai.types.Part.builder()
.functionCall(FunctionCall.builder().name(functionName).id(functionId).args(args).build())
.build();
}

if ((data.containsKey(NAME_KEY) && data.containsKey(RESPONSE_KEY))
|| metadataType.equals(A2ADataPartMetadataType.FUNCTION_RESPONSE.getType())) {
String functionName = String.valueOf(data.getOrDefault(NAME_KEY, ""));
String functionId = String.valueOf(data.getOrDefault(ID_KEY, ""));
Map<String, Object> response = coerceToMap(data.get(RESPONSE_KEY));
return Optional.of(
com.google.genai.types.Part.builder()
.functionResponse(
FunctionResponse.builder()
.name(functionName)
.id(functionId)
.response(response)
.build())
.build());
return com.google.genai.types.Part.builder()
.functionResponse(
FunctionResponse.builder()
.name(functionName)
.id(functionId)
.response(response)
.build())
.build();
}

if ((data.containsKey(CODE_KEY) && data.containsKey(LANGUAGE_KEY))
|| metadataType.equals(A2ADataPartMetadataType.EXECUTABLE_CODE.getType())) {
String code = String.valueOf(data.getOrDefault(CODE_KEY, ""));
String language =
String.valueOf(
data.getOrDefault(LANGUAGE_KEY, Language.Known.LANGUAGE_UNSPECIFIED.toString())
.toString());
return Optional.of(
com.google.genai.types.Part.builder()
.executableCode(
ExecutableCode.builder().code(code).language(new Language(language)).build())
.build());
data.getOrDefault(LANGUAGE_KEY, Language.Known.LANGUAGE_UNSPECIFIED.toString()));
return com.google.genai.types.Part.builder()
.executableCode(
ExecutableCode.builder().code(code).language(new Language(language)).build())
.build();
}

if ((data.containsKey(OUTCOME_KEY) && data.containsKey(OUTPUT_KEY))
|| metadataType.equals(A2ADataPartMetadataType.CODE_EXECUTION_RESULT.getType())) {
String outcome =
String.valueOf(data.getOrDefault(OUTCOME_KEY, Outcome.Known.OUTCOME_OK).toString());
String output = String.valueOf(data.getOrDefault(OUTPUT_KEY, ""));
return Optional.of(
com.google.genai.types.Part.builder()
.codeExecutionResult(
CodeExecutionResult.builder()
.outcome(new Outcome(outcome))
.output(output)
.build())
.build());
return com.google.genai.types.Part.builder()
.codeExecutionResult(
CodeExecutionResult.builder().outcome(new Outcome(outcome)).output(output).build())
.build();
}

try {
String json = objectMapper.writeValueAsString(data);
return Optional.of(com.google.genai.types.Part.builder().text(json).build());
return com.google.genai.types.Part.builder().text(json).build();
} catch (JsonProcessingException e) {
logger.warn("Failed to serialize DataPart payload", e);
return Optional.empty();
throw new IllegalArgumentException("Failed to serialize DataPart payload", e);
}
}

Expand Down
123 changes: 95 additions & 28 deletions a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package com.google.adk.a2a.converters;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Streams.zip;

import com.google.adk.agents.InvocationContext;
import com.google.adk.events.Event;
Expand All @@ -29,13 +31,15 @@
import io.a2a.client.TaskEvent;
import io.a2a.client.TaskUpdateEvent;
import io.a2a.spec.Artifact;
import io.a2a.spec.DataPart;
import io.a2a.spec.Message;
import io.a2a.spec.Task;
import io.a2a.spec.TaskArtifactUpdateEvent;
import io.a2a.spec.TaskState;
import io.a2a.spec.TaskStatusUpdateEvent;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
Expand Down Expand Up @@ -70,6 +74,14 @@ public static Optional<Event> clientEventToEvent(
throw new IllegalArgumentException("Unsupported ClientEvent type: " + event.getClass());
}

private static boolean isPartial(Map<String, Object> metadata) {
if (metadata == null) {
return false;
}
return Objects.equals(
metadata.getOrDefault(PartConverter.A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, false), true);
}

/**
* Converts a A2A {@link TaskUpdateEvent} to an ADK {@link Event}, if applicable. Returns null if
* the event is not a final update for TaskArtifactUpdateEvent or if the message is empty for
Expand All @@ -85,7 +97,14 @@ private static Optional<Event> handleTaskUpdate(
boolean isAppend = Objects.equals(artifactEvent.isAppend(), true);
boolean isLastChunk = Objects.equals(artifactEvent.isLastChunk(), true);

if (isLastChunk && isPartial(artifactEvent.getMetadata())) {
return Optional.empty();
}

Event eventPart = artifactToEvent(artifactEvent.getArtifact(), context);
if (eventPart.content().flatMap(Content::parts).orElse(ImmutableList.of()).isEmpty()) {
return Optional.empty();
}
eventPart.setPartial(isAppend || !isLastChunk);
// append=true, lastChunk=false: emit as partial, update aggregation
// append=false, lastChunk=false: emit as partial, reset aggregation
Expand Down Expand Up @@ -115,26 +134,21 @@ private static Optional<Event> handleTaskUpdate(
.map(builder -> builder.turnComplete(true))
.map(builder -> builder.partial(false))
.map(Event.Builder::build);
} else {
return messageEvent;
}
return messageEvent;
}
throw new IllegalArgumentException(
"Unsupported TaskUpdateEvent type: " + updateEvent.getClass());
}

/** Converts an artifact to an ADK event. */
public static Event artifactToEvent(Artifact artifact, InvocationContext invocationContext) {
Message message =
new Message.Builder().role(Message.Role.AGENT).parts(artifact.parts()).build();
return messageToEvent(message, invocationContext);
}

/** Converts an A2A message back to ADK events. */
public static Event messageToEvent(Message message, InvocationContext invocationContext) {
return remoteAgentEventBuilder(invocationContext)
.content(fromModelParts(PartConverter.toGenaiParts(message.getParts())))
.build();
Event.Builder eventBuilder = remoteAgentEventBuilder(invocationContext);
ImmutableList<Part> genaiParts = PartConverter.toGenaiParts(artifact.parts());
eventBuilder
.content(fromModelParts(genaiParts))
.longRunningToolIds(getLongRunningToolIds(artifact.parts(), genaiParts));
return eventBuilder.build();
}

/** Converts an A2A message for a failed task to ADK event filling in the error message. */
Expand All @@ -147,6 +161,13 @@ public static Event messageToFailedEvent(Message message, InvocationContext invo
return builder.build();
}

/** Converts an A2A message back to ADK events. */
public static Event messageToEvent(Message message, InvocationContext invocationContext) {
return remoteAgentEventBuilder(invocationContext)
.content(fromModelParts(PartConverter.toGenaiParts(message.getParts())))
.build();
}

/**
* Converts an A2A message back to ADK events. For streaming task in pending state it sets the
* thought field to true, to mark them as thought updates.
Expand All @@ -168,25 +189,71 @@ public static Event messageToEvent(
* If none of these are present, an empty event is returned.
*/
public static Event taskToEvent(Task task, InvocationContext invocationContext) {
Message taskMessage = null;

if (!task.getArtifacts().isEmpty()) {
taskMessage =
new Message.Builder()
.messageId("")
.role(Message.Role.AGENT)
.parts(Iterables.getLast(task.getArtifacts()).parts())
.build();
} else if (task.getStatus().message() != null) {
taskMessage = task.getStatus().message();
} else if (!task.getHistory().isEmpty()) {
taskMessage = Iterables.getLast(task.getHistory());
ImmutableList.Builder<Part> genaiParts = ImmutableList.builder();
ImmutableSet.Builder<String> longRunningToolIds = ImmutableSet.builder();

for (Artifact artifact : task.getArtifacts()) {
ImmutableList<Part> converted = PartConverter.toGenaiParts(artifact.parts());
longRunningToolIds.addAll(getLongRunningToolIds(artifact.parts(), converted));
genaiParts.addAll(converted);
}

Event.Builder eventBuilder = remoteAgentEventBuilder(invocationContext);

if (task.getStatus().message() != null) {
ImmutableList<Part> msgParts =
PartConverter.toGenaiParts(task.getStatus().message().getParts());
longRunningToolIds.addAll(
getLongRunningToolIds(task.getStatus().message().getParts(), msgParts));
if (task.getStatus().state() == TaskState.FAILED
&& msgParts.size() == 1
&& msgParts.get(0).text().isPresent()) {
eventBuilder.errorMessage(msgParts.get(0).text().get());
} else {
genaiParts.addAll(msgParts);
}
}

if (taskMessage != null) {
return messageToEvent(taskMessage, invocationContext);
ImmutableList<Part> finalParts = genaiParts.build();
boolean isFinal =
task.getStatus().state().isFinal() || task.getStatus().state() == TaskState.INPUT_REQUIRED;

if (finalParts.isEmpty() && !isFinal) {
return emptyEvent(invocationContext);
}
return emptyEvent(invocationContext);
if (!finalParts.isEmpty()) {
eventBuilder.content(fromModelParts(finalParts));
}
if (task.getStatus().state() == TaskState.INPUT_REQUIRED) {
eventBuilder.longRunningToolIds(longRunningToolIds.build());
}
eventBuilder.turnComplete(isFinal);
return eventBuilder.build();
}

private static ImmutableSet<String> getLongRunningToolIds(
List<io.a2a.spec.Part<?>> parts, List<Part> convertedParts) {
return zip(
parts.stream(),
convertedParts.stream(),
(part, convertedPart) -> {
if (!(part instanceof DataPart dataPart)) {
return Optional.<String>empty();
}
Object isLongRunning =
dataPart
.getMetadata()
.get(PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY);
if (!Objects.equals(isLongRunning, true)) {
return Optional.<String>empty();
}
if (convertedPart.functionCall().isEmpty()) {
return Optional.<String>empty();
}
return convertedPart.functionCall().get().id();
})
.flatMap(Optional::stream)
.collect(toImmutableSet());
}

private static Event emptyEvent(InvocationContext invocationContext) {
Expand Down
Loading
Loading