diff --git a/.github/workflows/pr-commit-check.yml b/.github/workflows/pr-commit-check.yml index ec6644311..1e31e42f3 100644 --- a/.github/workflows/pr-commit-check.yml +++ b/.github/workflows/pr-commit-check.yml @@ -21,7 +21,7 @@ jobs: # Step 1: Check out the code # This action checks out your repository under $GITHUB_WORKSPACE, so your workflow can access it. - name: Checkout Code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: # We need to fetch all commits to accurately count them. # '0' means fetch all history for all branches and tags. diff --git a/.github/workflows/validation.yml b/.github/workflows/validation.yml index d9035a579..65e66f8fd 100644 --- a/.github/workflows/validation.yml +++ b/.github/workflows/validation.yml @@ -20,7 +20,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Java ${{ matrix.java-version }} uses: actions/setup-java@v4 diff --git a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java index 021786162..ccb662b7c 100644 --- a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java +++ b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java @@ -181,6 +181,25 @@ public RemoteA2AAgent build() { } } + private Message.Builder newA2AMessage(Message.Role role, List> parts) { + return new Message.Builder().messageId(UUID.randomUUID().toString()).role(role).parts(parts); + } + + private Message prepareMessage(InvocationContext invocationContext) { + Event userCall = EventConverter.findUserFunctionCall(invocationContext.session().events()); + if (userCall != null) { + ImmutableList> parts = + EventConverter.contentToParts(userCall.content(), userCall.partial().orElse(false)); + return newA2AMessage(Message.Role.USER, parts) + .taskId(EventConverter.taskId(userCall)) + .contextId(EventConverter.contextId(userCall)) + .build(); + } + return newA2AMessage( + Message.Role.USER, EventConverter.messagePartsFromContext(invocationContext)) + .build(); + } + @Override protected Flowable runAsyncImpl(InvocationContext invocationContext) { // Construct A2A Message from the last ADK event @@ -191,14 +210,7 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { return Flowable.empty(); } - Optional a2aMessageOpt = EventConverter.convertEventsToA2AMessage(invocationContext); - - if (a2aMessageOpt.isEmpty()) { - logger.warn("Failed to convert event to A2A message."); - return Flowable.empty(); - } - - Message originalMessage = a2aMessageOpt.get(); + Message originalMessage = prepareMessage(invocationContext); String requestJson = serializeMessageToJson(originalMessage); return Flowable.create( diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java index d823e3817..71573070e 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java @@ -18,66 +18,107 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; import com.google.genai.types.Content; -import io.a2a.spec.Message; +import com.google.genai.types.FunctionResponse; import io.a2a.spec.Part; import java.util.Collection; +import java.util.List; import java.util.Optional; import java.util.UUID; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.jspecify.annotations.Nullable; /** Converter for ADK Events to A2A Messages. */ public final class EventConverter { - private static final Logger logger = LoggerFactory.getLogger(EventConverter.class); + public static final String ADK_TASK_ID_KEY = "adk_task_id"; + public static final String ADK_CONTEXT_ID_KEY = "adk_context_id"; private EventConverter() {} /** - * Converts an ADK InvocationContext to an A2A Message. + * Returns the task ID from the event. * - *

It combines all the events in the session, plus the user content, converted into A2A Parts, - * into a single A2A Message. + *

Task ID is stored in the event's custom metadata with the key {@link #ADK_TASK_ID_KEY}. * - *

If the context has no events, or no suitable content to build the message, an empty optional - * is returned. - * - * @param context The ADK InvocationContext to convert. - * @return The converted A2A Message. + * @param event The event to get the task ID from. + * @return The task ID, or an empty string if not found. */ - public static Optional convertEventsToA2AMessage(InvocationContext context) { - if (context.session().events().isEmpty()) { - logger.warn("No events in session, cannot convert to A2A message."); - return Optional.empty(); - } - - ImmutableList.Builder> partsBuilder = ImmutableList.builder(); + public static String taskId(Event event) { + return metadataValue(event, ADK_TASK_ID_KEY); + } - context - .session() - .events() - .forEach( - event -> - partsBuilder.addAll( - contentToParts(event.content(), event.partial().orElse(false)))); - partsBuilder.addAll(contentToParts(context.userContent(), false)); + /** + * Returns the context ID from the event. + * + *

Context ID is stored in the event's custom metadata with the key {@link + * #ADK_CONTEXT_ID_KEY}. + * + * @param event The event to get the context ID from. + * @return The context ID, or an empty string if not found. + */ + public static String contextId(Event event) { + return metadataValue(event, ADK_CONTEXT_ID_KEY); + } - ImmutableList> parts = partsBuilder.build(); + /** + * Returns the last user function call event from the list of events. + * + * @param events The list of events to find the user function call event from. + * @return The user function call event, or null if not found. + */ + public static @Nullable Event findUserFunctionCall(List events) { + Event candidate = Iterables.getLast(events); + if (!candidate.author().equals("user")) { + return null; + } + FunctionResponse functionResponse = findUserFunctionResponse(candidate); + if (functionResponse == null || functionResponse.id().isEmpty()) { + return null; + } + for (int i = events.size() - 2; i >= 0; i--) { + Event event = events.get(i); + if (isUserFunctionCall(event, functionResponse.id().get())) { + return event; + } + } + return null; + } - if (parts.isEmpty()) { - logger.warn("No suitable content found to build A2A request message."); - return Optional.empty(); + private static @Nullable FunctionResponse findUserFunctionResponse(Event candidate) { + if (candidate.content().isEmpty() || candidate.content().get().parts().isEmpty()) { + return null; } + return candidate.content().get().parts().get().stream() + .filter(part -> part.functionResponse().isPresent()) + .findFirst() + .map(part -> part.functionResponse().get()) + .orElse(null); + } - return Optional.of( - new Message.Builder() - .messageId(UUID.randomUUID().toString()) - .parts(parts) - .role(Message.Role.USER) - .build()); + private static boolean isUserFunctionCall(Event event, String functionResponseId) { + if (event.content().isEmpty()) { + return false; + } + return event.content().get().parts().get().stream() + .anyMatch( + part -> + part.functionCall().isPresent() + && part.functionCall() + .get() + .id() + .map(id -> id.equals(functionResponseId)) + .orElse(false)); } + /** + * Converts a GenAI Content object to a list of A2A Parts. + * + * @param content The GenAI Content object to convert. + * @param isPartial Whether the content is partial. + * @return A list of A2A Parts. + */ public static ImmutableList> contentToParts( Optional content, boolean isPartial) { return content.flatMap(Content::parts).stream() @@ -85,4 +126,80 @@ public static ImmutableList> contentToParts( .map(part -> PartConverter.fromGenaiPart(part, isPartial)) .collect(toImmutableList()); } + + /** + * Returns the parts from the context events that should be sent to the agent. + * + *

All session events from the previous remote agent response (or the beginning of the session + * in case of the first agent invocation) are included into the A2A message. Events from other + * agents are presented as user messages and rephased as if a user was telling what happened in + * the session up to the point. + * + * @param context The invocation context to get the parts from. + * @return A list of A2A Parts. + */ + public static ImmutableList> messagePartsFromContext(InvocationContext context) { + if (context.session().events().isEmpty()) { + return ImmutableList.of(); + } + List events = context.session().events(); + int lastResponseIndex = -1; + String contextId = ""; + for (int i = events.size() - 1; i >= 0; i--) { + Event event = events.get(i); + if (event.author().equals(context.agent().name())) { + lastResponseIndex = i; + contextId = contextId(event); + break; + } + } + ImmutableList.Builder> partsBuilder = ImmutableList.builder(); + for (int i = lastResponseIndex + 1; i < events.size(); i++) { + Event event = events.get(i); + if (!event.author().equals("user") && !event.author().equals(context.agent().name())) { + event = presentAsUserMessage(event, contextId); + } + contentToParts(event.content(), event.partial().orElse(false)).forEach(partsBuilder::add); + } + return partsBuilder.build(); + } + + private static Event presentAsUserMessage(Event event, String contextId) { + Event.Builder userEvent = + new Event.Builder().id(UUID.randomUUID().toString()).invocationId(contextId).author("user"); + ImmutableList parts = + event.content().flatMap(Content::parts).stream() + .flatMap(Collection::stream) + // convert only non-thought parts to user message parts, skip thought parts as they are + // not meant to be shown to the user + .filter(part -> !part.thought().orElse(false)) + .map(part -> PartConverter.remoteCallAsUserPart(event.author(), part)) + .collect(toImmutableList()); + if (parts.isEmpty()) { + return userEvent.build(); + } + com.google.genai.types.Part forContext = + com.google.genai.types.Part.builder().text("For context:").build(); + return userEvent + .content( + Content.builder() + .parts( + ImmutableList.builder() + .add(forContext) + .addAll(parts) + .build()) + .build()) + .build(); + } + + private static String metadataValue(Event event, String key) { + if (event.customMetadata().isEmpty()) { + return ""; + } + return event.customMetadata().get().stream() + .filter(m -> m.key().map(k -> k.equals(key)).orElse(false)) + .findFirst() + .flatMap(m -> m.stringValue()) + .orElse(""); + } } diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index 96ef66bc8..36af6cc8b 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -384,6 +384,50 @@ private static FilePart filePartToA2A(Part part, ImmutableMap.BuilderEvents are rephrased as if a user was telling what happened in the session up to the point. + * E.g. + * + *

{@code
+   * For context:
+   * User said: Now help me with Z
+   * Agent A said: Agent B can help you with it!
+   * Agent B said: Agent C might know better.*
+   * }
+ * + * @param author The author of the part. + * @param part The part to convert. + * @return The converted part. + */ + public static Part remoteCallAsUserPart(String author, Part part) { + if (part.text().isPresent()) { + String partText = String.format("[%s] said: %s", author, part.text().get()); + return Part.builder().text(partText).build(); + } else if (part.functionCall().isPresent()) { + FunctionCall functionCall = part.functionCall().get(); + String partText = + String.format( + "[%s] called tool %s with parameters: %s", + author, + functionCall.name().orElse(""), + functionCall.args().orElse(ImmutableMap.of())); + return Part.builder().text(partText).build(); + } else if (part.functionResponse().isPresent()) { + FunctionResponse functionResponse = part.functionResponse().get(); + String partText = + String.format( + "[%s] %s tool returned result: %s", + author, + functionResponse.name().orElse(""), + functionResponse.response().orElse(ImmutableMap.of())); + return Part.builder().text(partText).build(); + } else { + return part; + } + } + @SuppressWarnings("unchecked") // safe conversion from objectMapper.readValue private static Map coerceToMap(Object value) { if (value == null) { diff --git a/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java index e75da64ba..b1ffa248a 100644 --- a/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java @@ -412,10 +412,11 @@ public void runAsync_constructsRequestWithHistory() { .sendMessage(messageCaptor.capture(), any(List.class), any(Consumer.class), any()); Message message = messageCaptor.getValue(); assertThat(message.getRole()).isEqualTo(Message.Role.USER); - assertThat(message.getParts()).hasSize(3); + assertThat(message.getParts()).hasSize(4); assertThat(((TextPart) message.getParts().get(0)).getText()).isEqualTo("hello"); - assertThat(((TextPart) message.getParts().get(1)).getText()).isEqualTo("hi"); - assertThat(((TextPart) message.getParts().get(2)).getText()).isEqualTo("how are you?"); + assertThat(((TextPart) message.getParts().get(1)).getText()).isEqualTo("For context:"); + assertThat(((TextPart) message.getParts().get(2)).getText()).isEqualTo("[model] said: hi"); + assertThat(((TextPart) message.getParts().get(3)).getText()).isEqualTo("how are you?"); } @Test diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/EventConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/EventConverterTest.java index 8d460c457..207019199 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/EventConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/EventConverterTest.java @@ -4,23 +4,17 @@ import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; -import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.events.Event; -import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; -import io.a2a.spec.DataPart; -import io.a2a.spec.Message; import io.a2a.spec.TextPart; import io.reactivex.rxjava3.core.Flowable; -import java.util.ArrayList; -import java.util.List; import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; @@ -30,101 +24,180 @@ public final class EventConverterTest { @Test - public void convertEventsToA2AMessage_preservesFunctionCallAndResponseParts() { - // Arrange session events: user text, function call, function response. - Part userTextPart = Part.builder().text("Roll a die").build(); - Event userEvent = + public void testTaskId() { + Event e = + Event.builder() + .customMetadata( + ImmutableList.of( + CustomMetadata.builder() + .key(EventConverter.ADK_TASK_ID_KEY) + .stringValue("task-123") + .build())) + .build(); + assertThat(EventConverter.taskId(e)).isEqualTo("task-123"); + } + + @Test + public void testTaskId_empty() { + Event e = Event.builder().build(); + assertThat(EventConverter.taskId(e)).isEmpty(); + } + + @Test + public void testContextId() { + Event e = + Event.builder() + .customMetadata( + ImmutableList.of( + CustomMetadata.builder() + .key(EventConverter.ADK_CONTEXT_ID_KEY) + .stringValue("context-456") + .build())) + .build(); + assertThat(EventConverter.contextId(e)).isEqualTo("context-456"); + } + + @Test + public void testContextId_empty() { + Event e = Event.builder().build(); + assertThat(EventConverter.contextId(e)).isEmpty(); + } + + @Test + public void testFindUserFunctionCall_success() { + Event agentEvent = Event.builder().author("agent").build(); + FunctionCall fc = FunctionCall.builder().name("my-func").id("fc-id").build(); + Event userEventWithCall = Event.builder() - .id("event-user") .author("user") - .content(Content.builder().role("user").parts(ImmutableList.of(userTextPart)).build()) + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().functionCall(fc).build())) + .build()) .build(); - Part functionCallPart = - Part.builder() - .functionCall( - FunctionCall.builder() - .name("roll_die") - .id("adk-call-1") - .args(ImmutableMap.of("sides", 6)) + FunctionResponse fr = FunctionResponse.builder().name("my-func").id("fc-id").build(); + Event userEventWithResponse = + Event.builder() + .author("user") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().functionResponse(fr).build())) .build()) .build(); - Event callEvent = + + ImmutableList events = + ImmutableList.of(userEventWithCall, agentEvent, userEventWithResponse); + assertThat(EventConverter.findUserFunctionCall(events)).isEqualTo(userEventWithCall); + } + + @Test + public void testFindUserFunctionCall_noMatchingCall() { + Event agentEvent = Event.builder().author("agent").build(); + FunctionCall fc = FunctionCall.builder().name("my-func").id("other-id").build(); + Event userEventWithCall = Event.builder() - .id("event-call") - .author("root_agent") + .author("user") .content( Content.builder() - .role("assistant") - .parts(ImmutableList.of(functionCallPart)) + .parts(ImmutableList.of(Part.builder().functionCall(fc).build())) .build()) .build(); - Part functionResponsePart = - Part.builder() - .functionResponse( - FunctionResponse.builder() - .name("roll_die") - .id("adk-call-1") - .response(ImmutableMap.of("result", 3)) + FunctionResponse fr = FunctionResponse.builder().name("my-func").id("fc-id").build(); + Event userEventWithResponse = + Event.builder() + .author("user") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().functionResponse(fr).build())) + .build()) + .build(); + + ImmutableList events = + ImmutableList.of(userEventWithCall, agentEvent, userEventWithResponse); + assertThat(EventConverter.findUserFunctionCall(events)).isNull(); + } + + @Test + public void testFindUserFunctionCall_lastEventNotUser() { + Event agentEvent = Event.builder().author("agent").build(); + FunctionCall fc = FunctionCall.builder().name("my-func").id("fc-id").build(); + Event userEventWithCall = + Event.builder() + .author("user") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().functionCall(fc).build())) .build()) .build(); - Event responseEvent = + FunctionResponse fr = FunctionResponse.builder().name("my-func").id("fc-id").build(); + // Last event is not a user event, so should return null. + Event agentEventWithResponse = Event.builder() - .id("event-response") - .author("roll_agent") + .author("agent") .content( Content.builder() - .role("tool") - .parts(ImmutableList.of(functionResponsePart)) + .parts(ImmutableList.of(Part.builder().functionResponse(fr).build())) .build()) .build(); - List events = new ArrayList<>(ImmutableList.of(userEvent, callEvent, responseEvent)); - Session session = - Session.builder("session-1").appName("demo").userId("user").events(events).build(); + ImmutableList events = + ImmutableList.of(userEventWithCall, agentEvent, agentEventWithResponse); - InvocationContext context = + assertThat(EventConverter.findUserFunctionCall(events)).isNull(); + } + + @Test + public void testContentToParts() { + Part textPart = Part.builder().text("hello").build(); + Content content = Content.builder().parts(ImmutableList.of(textPart)).build(); + ImmutableList> list = + EventConverter.contentToParts(Optional.of(content), false); + assertThat(list).hasSize(1); + assertThat(((TextPart) list.get(0)).getText()).isEqualTo("hello"); + } + + @Test + public void testMessagePartsFromContext() { + Session session = + Session.builder("session1") + .events( + ImmutableList.of( + Event.builder() + .author("user") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().text("hello").build())) + .build()) + .build(), + Event.builder() + .author("test_agent") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().text("hi").build())) + .build()) + .build(), + Event.builder() + .author("other_agent") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().text("hey").build())) + .build()) + .build())) + .build(); + BaseAgent agent = new TestAgent(); + InvocationContext ctx = InvocationContext.builder() - .sessionService(new InMemorySessionService()) - .artifactService(new InMemoryArtifactService()) - .pluginManager(new PluginManager()) - .invocationId("invocation-1") - .agent(new TestAgent()) .session(session) - .userContent( - Content.builder().role("user").parts(ImmutableList.of(userTextPart)).build()) - .endInvocation(false) + .sessionService(new InMemorySessionService()) + .agent(agent) .build(); + ImmutableList> parts = EventConverter.messagePartsFromContext(ctx); - // Act - Optional maybeMessage = EventConverter.convertEventsToA2AMessage(context); - - // Assert - assertThat(maybeMessage).isPresent(); - Message message = maybeMessage.get(); - assertThat(message.getParts()).hasSize(4); - assertThat(message.getParts().get(0)).isInstanceOf(TextPart.class); - assertThat(message.getParts().get(1)).isInstanceOf(DataPart.class); - assertThat(message.getParts().get(2)).isInstanceOf(DataPart.class); - assertThat(message.getParts().get(3)).isInstanceOf(TextPart.class); - - DataPart callDataPart = (DataPart) message.getParts().get(1); - assertThat(callDataPart.getMetadata().get(PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY)) - .isEqualTo(A2ADataPartMetadataType.FUNCTION_CALL.getType()); - assertThat(callDataPart.getData()).containsEntry("name", "roll_die"); - assertThat(callDataPart.getData()).containsEntry("id", "adk-call-1"); - assertThat(callDataPart.getData()).containsEntry("args", ImmutableMap.of("sides", 6)); - - DataPart responseDataPart = (DataPart) message.getParts().get(2); - assertThat(responseDataPart.getMetadata().get(PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY)) - .isEqualTo(A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); - assertThat(responseDataPart.getData()).containsEntry("name", "roll_die"); - assertThat(responseDataPart.getData()).containsEntry("id", "adk-call-1"); - assertThat(responseDataPart.getData()).containsEntry("response", ImmutableMap.of("result", 3)); - - TextPart lastTextPart = (TextPart) message.getParts().get(3); - assertThat(lastTextPart.getText()).isEqualTo("Roll a die"); + assertThat(parts).hasSize(2); + assertThat(((TextPart) parts.get(0)).getText()).isEqualTo("For context:"); + assertThat(((TextPart) parts.get(1)).getText()).isEqualTo("[other_agent] said: hey"); } private static final class TestAgent extends BaseAgent { diff --git a/core/src/main/java/com/google/adk/agents/CallbackContext.java b/core/src/main/java/com/google/adk/agents/CallbackContext.java index a29783769..da5b0d794 100644 --- a/core/src/main/java/com/google/adk/agents/CallbackContext.java +++ b/core/src/main/java/com/google/adk/agents/CallbackContext.java @@ -19,12 +19,12 @@ import com.google.adk.artifacts.ListArtifactsResponse; import com.google.adk.events.EventActions; import com.google.adk.sessions.State; +import com.google.common.base.Preconditions; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; -import java.util.Optional; /** The context of various callbacks for an agent invocation. */ public class CallbackContext extends ReadonlyContext { @@ -94,22 +94,19 @@ public Single> listArtifacts() { /** Loads the latest version of an artifact from the service. */ public Maybe loadArtifact(String filename) { - return loadArtifact(filename, Optional.empty()); + checkArtifactServiceInitialized(); + return invocationContext + .artifactService() + .loadArtifact( + invocationContext.appName(), + invocationContext.userId(), + invocationContext.session().id(), + filename); } /** Loads a specific version of an artifact from the service. */ public Maybe loadArtifact(String filename, int version) { - return loadArtifact(filename, Optional.of(version)); - } - - /** - * @deprecated Use {@link #loadArtifact(String)} or {@link #loadArtifact(String, int)} instead. - */ - @Deprecated - public Maybe loadArtifact(String filename, Optional version) { - if (invocationContext.artifactService() == null) { - throw new IllegalStateException("Artifact service is not initialized."); - } + checkArtifactServiceInitialized(); return invocationContext .artifactService() .loadArtifact( @@ -120,6 +117,11 @@ public Maybe loadArtifact(String filename, Optional version) { version); } + private void checkArtifactServiceInitialized() { + Preconditions.checkState( + invocationContext.artifactService() != null, "Artifact service is not initialized."); + } + /** * Saves an artifact and records it as a delta for the current session. * diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index bbed217f4..d326d8154 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -757,14 +757,6 @@ public Single> canonicalGlobalInstruction(ReadonlyCon throw new IllegalStateException("Unknown Instruction subtype: " + globalInstruction.getClass()); } - /** - * @deprecated Use {@link #canonicalTools(ReadonlyContext)} instead. - */ - @Deprecated - public Flowable canonicalTools(Optional context) { - return canonicalTools(context.orElse(null)); - } - /** * Constructs the list of tools for this agent based on the {@link #tools} field. * diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 79066b213..ab5f6567a 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -190,20 +190,23 @@ private Flowable callLlm( context, llmRequestBuilder, eventForCallbackUsage, exception) .switchIfEmpty(Single.error(exception)) .toFlowable()) - .doOnNext( - llmResp -> - Tracing.traceCallLlm( - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp)) .doOnError( error -> { Span span = Span.current(); span.setStatus(StatusCode.ERROR, error.getMessage()); span.recordException(error); }) - .compose(Tracing.trace("call_llm").setParent(spanContext)) + .compose( + Tracing.trace("call_llm") + .setParent(spanContext) + .onSuccess( + (span, llmResp) -> + Tracing.traceCallLlm( + span, + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp))) .concatMap( llmResp -> handleAfterModelCallback(context, llmResp, eventForCallbackUsage) diff --git a/core/src/main/java/com/google/adk/sessions/SessionUtils.java b/core/src/main/java/com/google/adk/sessions/SessionUtils.java index 7a795be4c..1aeca98c9 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionUtils.java +++ b/core/src/main/java/com/google/adk/sessions/SessionUtils.java @@ -24,6 +24,7 @@ import java.util.Base64; import java.util.List; import java.util.Optional; +import org.jspecify.annotations.Nullable; /** Utility functions for session service. */ public final class SessionUtils { @@ -53,7 +54,7 @@ public static Content encodeContent(Content content) { encodedParts.add(part); } } - return toContent(encodedParts, content.role()); + return toContent(encodedParts, content.role().orElse(null)); } /** Decodes Base64-encoded inline blobs in content. */ @@ -79,13 +80,15 @@ public static Content decodeContent(Content content) { decodedParts.add(part); } } - return toContent(decodedParts, content.role()); + return toContent(decodedParts, content.role().orElse(null)); } /** Builds content from parts and optional role. */ - private static Content toContent(List parts, Optional role) { + private static Content toContent(List parts, @Nullable String role) { Content.Builder contentBuilder = Content.builder().parts(parts); - role.ifPresent(contentBuilder::role); + if (role != null) { + contentBuilder.role(role); + } return contentBuilder.build(); } } diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 7f338fdcf..35bf3cc96 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -54,6 +54,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; import org.reactivestreams.Publisher; @@ -292,58 +293,49 @@ private static Map buildLlmRequestForTrace(LlmRequest llmRequest * @param llmResponse The LLM response object. */ public static void traceCallLlm( + Span span, InvocationContext invocationContext, String eventId, LlmRequest llmRequest, LlmResponse llmResponse) { - traceWithSpan( - "traceCallLlm", - span -> { - span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); - llmRequest - .model() - .ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - - setInvocationAttributes(span, invocationContext, eventId); - - setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); - setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); - - llmRequest - .config() - .ifPresent( - config -> { - config - .topP() - .ifPresent( - topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); - config - .maxOutputTokens() - .ifPresent( - maxTokens -> - span.setAttribute( - GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); - }); - llmResponse - .usageMetadata() - .ifPresent( - usage -> { - usage - .promptTokenCount() - .ifPresent( - tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); - usage - .candidatesTokenCount() - .ifPresent( - tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); - }); - llmResponse - .finishReason() - .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) - .ifPresent( - reason -> - span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); - }); + span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); + + setInvocationAttributes(span, invocationContext, eventId); + + setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); + setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + + llmRequest + .config() + .ifPresent( + config -> { + config + .topP() + .ifPresent(topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); + config + .maxOutputTokens() + .ifPresent( + maxTokens -> + span.setAttribute(GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + }); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage + .promptTokenCount() + .ifPresent(tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); + usage + .candidatesTokenCount() + .ifPresent( + tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + }); + llmResponse + .finishReason() + .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) + .ifPresent( + reason -> span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); } /** @@ -455,6 +447,7 @@ public static final class TracerProvider private final String spanName; private Context explicitParentContext; private final List> spanConfigurers = new ArrayList<>(); + private BiConsumer onSuccessConsumer; private TracerProvider(String spanName) { this.spanName = spanName; @@ -474,6 +467,16 @@ public TracerProvider setParent(Context parentContext) { return this; } + /** + * Registers a callback to be executed with the span and the result item when the stream emits a + * success value. + */ + @CanIgnoreReturnValue + public TracerProvider onSuccess(BiConsumer consumer) { + this.onSuccessConsumer = consumer; + return this; + } + private Context getParentContext() { return explicitParentContext != null ? explicitParentContext : Context.current(); } @@ -504,7 +507,11 @@ public Publisher apply(Flowable upstream) { return Flowable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + Flowable pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + if (onSuccessConsumer != null) { + pipeline = pipeline.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + return pipeline.doFinally(lifecycle::end); }); } @@ -513,7 +520,11 @@ public SingleSource apply(Single upstream) { return Single.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + Single pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + if (onSuccessConsumer != null) { + pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + return pipeline.doFinally(lifecycle::end); }); } @@ -522,7 +533,11 @@ public MaybeSource apply(Maybe upstream) { return Maybe.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + Maybe pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + if (onSuccessConsumer != null) { + pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + return pipeline.doFinally(lifecycle::end); }); } diff --git a/core/src/main/java/com/google/adk/tools/LoadArtifactsTool.java b/core/src/main/java/com/google/adk/tools/LoadArtifactsTool.java index c5ae8af37..399079af5 100644 --- a/core/src/main/java/com/google/adk/tools/LoadArtifactsTool.java +++ b/core/src/main/java/com/google/adk/tools/LoadArtifactsTool.java @@ -169,7 +169,7 @@ private Completable loadAndAppendIndividualArtifact( LlmRequest.Builder llmRequestBuilder, ToolContext toolContext, String artifactName) { return toolContext - .loadArtifact(artifactName, Optional.empty()) + .loadArtifact(artifactName) .flatMapCompletable( actualArtifact -> Completable.fromAction( diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 9439fe718..f809193cf 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -503,7 +503,7 @@ public void testTraceCallLlm() { .totalTokenCount(30) .build()) .build(); - Tracing.traceCallLlm(buildInvocationContext(), "event-1", llmRequest, llmResponse); + Tracing.traceCallLlm(span, buildInvocationContext(), "event-1", llmRequest, llmResponse); } finally { span.end(); } diff --git a/core/src/test/java/com/google/adk/tools/LoadArtifactsToolTest.java b/core/src/test/java/com/google/adk/tools/LoadArtifactsToolTest.java index 89014175d..5ed7a1f40 100644 --- a/core/src/test/java/com/google/adk/tools/LoadArtifactsToolTest.java +++ b/core/src/test/java/com/google/adk/tools/LoadArtifactsToolTest.java @@ -3,7 +3,6 @@ import static com.google.common.truth.Truth.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -163,12 +162,11 @@ public void processLlmRequest_artifactsInContext_withLoadArtifactsFunctionCall_l Part loadedArtifactPart = Part.fromText("This is the content of doc1.txt"); ToolContext spiedToolContext = spy(ToolContext.builder(mockInvocationContext).build()); when(spiedToolContext.listArtifacts()).thenReturn(Single.just(availableArtifacts)); - when(spiedToolContext.loadArtifact(eq("doc1.txt"), eq(Optional.empty()))) - .thenReturn(Maybe.just(loadedArtifactPart)); + when(spiedToolContext.loadArtifact("doc1.txt")).thenReturn(Maybe.just(loadedArtifactPart)); loadArtifactsTool.processLlmRequest(llmRequestBuilder, spiedToolContext).blockingAwait(); - verify(spiedToolContext).loadArtifact(eq("doc1.txt"), eq(Optional.empty())); + verify(spiedToolContext).loadArtifact("doc1.txt"); LlmRequest finalRequest = llmRequestBuilder.build(); List finalContents = finalRequest.contents();