diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 8f30f5946..e13b7aba0 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -735,7 +735,8 @@ private Flowable buildPostprocessingEvents( Event modelResponseEvent = buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse); - if (modelResponseEvent.functionCalls().isEmpty()) { + if (modelResponseEvent.functionCalls().isEmpty() + || modelResponseEvent.partial().orElse(false)) { return processorEvents.concatWith(Flowable.just(modelResponseEvent)); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index 876f3a206..773713b24 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -16,6 +16,7 @@ package com.google.adk.flows.llmflows; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import com.fasterxml.jackson.core.JsonProcessingException; @@ -57,13 +58,6 @@ public Single processRequest( } LlmAgent llmAgent = (LlmAgent) context.agent(); - String modelName; - try { - modelName = llmAgent.resolvedModel().modelName().orElse(""); - } catch (IllegalStateException e) { - modelName = ""; - } - ImmutableList sessionEvents; synchronized (context.session().events()) { sessionEvents = ImmutableList.copyOf(context.session().events()); @@ -75,17 +69,13 @@ public Single processRequest( request.toBuilder() .contents( getCurrentTurnContents( - context.branch().orElse(null), - sessionEvents, - context.agent().name(), - modelName)) + context.branch().orElse(null), sessionEvents, context.agent().name())) .build(), ImmutableList.of())); } ImmutableList contents = - getContents( - context.branch().orElse(null), sessionEvents, context.agent().name(), modelName); + getContents(context.branch().orElse(null), sessionEvents, context.agent().name()); return Single.just( RequestProcessor.RequestProcessingResult.create( @@ -94,19 +84,19 @@ public Single processRequest( /** Gets contents for the current turn only (no conversation history). */ private ImmutableList getCurrentTurnContents( - @Nullable String currentBranch, List events, String agentName, String modelName) { + @Nullable String currentBranch, List events, String agentName) { // Find the latest event that starts the current turn and process from there. for (int i = events.size() - 1; i >= 0; i--) { Event event = events.get(i); if (event.author().equals("user") || isOtherAgentReply(agentName, event)) { - return getContents(currentBranch, events.subList(i, events.size()), agentName, modelName); + return getContents(currentBranch, events.subList(i, events.size()), agentName); } } return ImmutableList.of(); } private ImmutableList getContents( - @Nullable String currentBranch, List events, String agentName, String modelName) { + @Nullable String currentBranch, List events, String agentName) { List filteredEvents = new ArrayList<>(); boolean hasCompactEvent = false; @@ -148,7 +138,7 @@ private ImmutableList getContents( } List resultEvents = rearrangeEventsForLatestFunctionResponse(filteredEvents); - resultEvents = rearrangeEventsForAsyncFunctionResponsesInHistory(resultEvents, modelName); + resultEvents = rearrangeEventsForAsyncFunctionResponsesInHistory(resultEvents); return resultEvents.stream() .map(Event::content) @@ -564,8 +554,7 @@ private static List rearrangeEventsForLatestFunctionResponse(List return resultEvents; } - private static List rearrangeEventsForAsyncFunctionResponsesInHistory( - List events, String modelName) { + private static List rearrangeEventsForAsyncFunctionResponsesInHistory(List events) { Map functionCallIdToResponseEventIndex = new HashMap<>(); for (int i = 0; i < events.size(); i++) { final int index = i; @@ -592,11 +581,6 @@ private static List rearrangeEventsForAsyncFunctionResponsesInHistory( List resultEvents = new ArrayList<>(); // Keep track of response events already added to avoid duplicates when merging Set processedResponseIndices = new HashSet<>(); - List responseEventsBuffer = new ArrayList<>(); - - // Gemini 3 requires function calls to be grouped first and only then function responses: - // FC1 FC2 FR1 FR2 - boolean shouldBufferResponseEvents = modelName.contains("gemini-3"); for (int i = 0; i < events.size(); i++) { Event event = events.get(i); @@ -641,47 +625,21 @@ private static List rearrangeEventsForAsyncFunctionResponsesInHistory( for (int index : sortedIndices) { if (processedResponseIndices.add(index)) { // Add index and check if it was newly added - responseEventsBuffer.add(events.get(index)); responseEventsToAdd.add(events.get(index)); } } - if (!shouldBufferResponseEvents) { - if (responseEventsToAdd.size() == 1) { - resultEvents.add(responseEventsToAdd.get(0)); - } else if (responseEventsToAdd.size() > 1) { - resultEvents.add(mergeFunctionResponseEvents(responseEventsToAdd)); - } + if (responseEventsToAdd.size() == 1) { + resultEvents.add(responseEventsToAdd.get(0)); + } else if (responseEventsToAdd.size() > 1) { + resultEvents.add(mergeFunctionResponseEvents(responseEventsToAdd)); } } } else { - // gemini-3 specific part: buffer response events - if (shouldBufferResponseEvents) { - if (!responseEventsBuffer.isEmpty()) { - if (responseEventsBuffer.size() == 1) { - resultEvents.add(responseEventsBuffer.get(0)); - } else { - resultEvents.add(mergeFunctionResponseEvents(responseEventsBuffer)); - } - responseEventsBuffer.clear(); - } - } resultEvents.add(event); } } - // gemini-3 specific part: buffer response events - if (shouldBufferResponseEvents) { - if (!responseEventsBuffer.isEmpty()) { - if (responseEventsBuffer.size() == 1) { - resultEvents.add(responseEventsBuffer.get(0)); - } else { - resultEvents.add(mergeFunctionResponseEvents(responseEventsBuffer)); - } - responseEventsBuffer.clear(); - } - } - return resultEvents; } @@ -702,9 +660,8 @@ private static List rearrangeEventsForAsyncFunctionResponsesInHistory( * appended to the part list of the initial function response event. */ private static Event mergeFunctionResponseEvents(List functionResponseEvents) { - if (functionResponseEvents.isEmpty()) { - throw new IllegalArgumentException("At least one functionResponse event is required."); - } + checkArgument( + !functionResponseEvents.isEmpty(), "At least one functionResponse event is required."); if (functionResponseEvents.size() == 1) { return functionResponseEvents.get(0); } @@ -719,10 +676,9 @@ private static Event mergeFunctionResponseEvents(List functionResponseEve .parts() .orElseThrow(() -> new IllegalArgumentException("Base event content must have parts.")); - if (baseParts.isEmpty()) { - throw new IllegalArgumentException( - "There should be at least one functionResponse part in the base event."); - } + checkArgument( + !baseParts.isEmpty(), + "There should be at least one functionResponse part in the base event."); List partsInMergedEvent = new ArrayList<>(baseParts); Map partIndicesInMergedEvent = new HashMap<>(); diff --git a/core/src/main/java/com/google/adk/models/Gemini.java b/core/src/main/java/com/google/adk/models/Gemini.java index 6f145e1de..d76ca75f0 100644 --- a/core/src/main/java/com/google/adk/models/Gemini.java +++ b/core/src/main/java/com/google/adk/models/Gemini.java @@ -19,6 +19,7 @@ import static com.google.common.base.StandardSystemProperty.JAVA_VERSION; import com.google.adk.Version; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.Client; @@ -26,17 +27,22 @@ import com.google.genai.types.Candidate; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; +import com.google.genai.types.FunctionCall; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.GenerateContentResponse; import com.google.genai.types.HttpOptions; import com.google.genai.types.LiveConnectConfig; import com.google.genai.types.Part; +import com.google.genai.types.PartialArg; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Objects; -import java.util.Optional; +import java.util.UUID; import java.util.concurrent.CompletableFuture; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -223,24 +229,9 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre effectiveModelName, llmRequest.contents(), config); return Flowable.defer( - () -> - processRawResponses( - Flowable.fromFuture(streamFuture).flatMapIterable(iterable -> iterable))) - .filter( - llmResponse -> - llmResponse - .content() - .flatMap(Content::parts) - .map( - parts -> - !parts.isEmpty() - && parts.stream() - .anyMatch( - p -> - p.functionCall().isPresent() - || p.functionResponse().isPresent() - || p.text().isPresent())) - .orElse(false)); + () -> + processRawResponses( + Flowable.fromFuture(streamFuture).flatMapIterable(iterable -> iterable))); } else { logger.debug("Sending generateContent request to model {}", effectiveModelName); return Flowable.fromFuture( @@ -253,109 +244,7 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre } static Flowable processRawResponses(Flowable rawResponses) { - final StringBuilder accumulatedText = new StringBuilder(); - final StringBuilder accumulatedThoughtText = new StringBuilder(); - // Array to bypass final local variable reassignment in lambda. - final GenerateContentResponse[] lastRawResponseHolder = {null}; - return rawResponses - .concatMap( - rawResponse -> { - lastRawResponseHolder[0] = rawResponse; - logger.trace("Raw streaming response: {}", rawResponse); - - List responsesToEmit = new ArrayList<>(); - LlmResponse currentProcessedLlmResponse = LlmResponse.create(rawResponse); - Optional part = GeminiUtil.getPart0FromLlmResponse(currentProcessedLlmResponse); - String currentTextChunk = part.flatMap(Part::text).orElse(""); - - if (!currentTextChunk.isBlank()) { - if (part.get().thought().orElse(false)) { - accumulatedThoughtText.append(currentTextChunk); - responsesToEmit.add( - thinkingResponseFromText(currentTextChunk).toBuilder() - .usageMetadata(currentProcessedLlmResponse.usageMetadata().orElse(null)) - .partial(true) - .build()); - } else { - accumulatedText.append(currentTextChunk); - responsesToEmit.add( - responseFromText(currentTextChunk).toBuilder() - .usageMetadata(currentProcessedLlmResponse.usageMetadata().orElse(null)) - .partial(true) - .build()); - } - } else { - if (accumulatedThoughtText.length() > 0 - && GeminiUtil.shouldEmitAccumulatedText(currentProcessedLlmResponse)) { - LlmResponse aggregatedThoughtResponse = - thinkingResponseFromText(accumulatedThoughtText.toString()); - responsesToEmit.add(aggregatedThoughtResponse); - accumulatedThoughtText.setLength(0); - } - if (accumulatedText.length() > 0 - && GeminiUtil.shouldEmitAccumulatedText(currentProcessedLlmResponse)) { - LlmResponse aggregatedTextResponse = responseFromText(accumulatedText.toString()); - responsesToEmit.add(aggregatedTextResponse); - accumulatedText.setLength(0); - } - responsesToEmit.add(currentProcessedLlmResponse); - } - logger.debug("Responses to emit: {}", responsesToEmit); - return Flowable.fromIterable(responsesToEmit); - }) - .concatWith( - Flowable.defer( - () -> { - GenerateContentResponse finalRawResp = lastRawResponseHolder[0]; - if (finalRawResp == null) { - return Flowable.empty(); - } - boolean isStop = - finalRawResp - .candidates() - .flatMap(candidates -> candidates.stream().findFirst()) - .flatMap(Candidate::finishReason) - .map(finishReason -> finishReason.knownEnum() == FinishReason.Known.STOP) - .orElse(false); - - if (isStop) { - List finalResponses = new ArrayList<>(); - if (accumulatedThoughtText.length() > 0) { - finalResponses.add( - thinkingResponseFromText(accumulatedThoughtText.toString()).toBuilder() - .usageMetadata( - accumulatedText.length() > 0 - ? null - : finalRawResp.usageMetadata().orElse(null)) - .build()); - } - if (accumulatedText.length() > 0) { - finalResponses.add( - responseFromText(accumulatedText.toString()).toBuilder() - .usageMetadata(finalRawResp.usageMetadata().orElse(null)) - .build()); - } - - return Flowable.fromIterable(finalResponses); - } - return Flowable.empty(); - })); - } - - private static LlmResponse responseFromText(String accumulatedText) { - return LlmResponse.builder() - .content(Content.builder().role("model").parts(Part.fromText(accumulatedText)).build()) - .build(); - } - - private static LlmResponse thinkingResponseFromText(String accumulatedThoughtText) { - return LlmResponse.builder() - .content( - Content.builder() - .role("model") - .parts(Part.fromText(accumulatedThoughtText).toBuilder().thought(true).build()) - .build()) - .build(); + return Flowable.defer(() -> new StreamingResponseAggregator().process(rawResponses)); } @Override @@ -372,4 +261,342 @@ public BaseLlmConnection connect(LlmRequest llmRequest) { return new GeminiLlmConnection(apiClient, effectiveModelName, liveConnectConfig); } + + private static final class StreamingResponseAggregator { + private final List accumulatedSequence = new ArrayList<>(); + private final StringBuilder currentTextBuffer = new StringBuilder(); + // Always reassigned in accumulateParts() before it is read; the initializer is never observed. + private boolean currentTextIsThought = false; + private byte[] currentThoughtSignature = null; + private GenerateContentResponse lastRawResponse = null; + + // Streaming function-call accumulation state. When the model streams a function call across + // multiple chunks (via partialArgs/willContinue), its arguments are accumulated here and a + // single complete function-call part is flushed to accumulatedSequence once it completes. + private String currentFcName = null; + private Map currentFcArgs = new LinkedHashMap<>(); + private String currentFcId = null; + + /** + * Processes a stream of raw responses, emitting partial and aggregated {@link LlmResponse}s. + */ + private Flowable process(Flowable rawResponses) { + return rawResponses + .concatMap(this::processRawResponse) + .concatWith(Flowable.defer(this::processFinalResponse)); + } + + /** + * Processes a single raw streaming chunk, accumulating parts and emitting intermediate + * responses. + */ + private Flowable processRawResponse(GenerateContentResponse rawResponse) { + lastRawResponse = rawResponse; + logger.trace("Raw streaming response: {}", rawResponse); + + LlmResponse currentProcessedLlmResponse = LlmResponse.create(rawResponse); + List parts = + currentProcessedLlmResponse.content().flatMap(Content::parts).orElse(ImmutableList.of()); + + // Assign an ID to every function-call part up front, mirroring ADK Python's + // StreamingResponseAggregator: the same ID is reused in the partial and final responses so + // consumers can correlate them. + List partsWithIds = ensureFunctionCallIds(parts); + + if (accumulateParts(partsWithIds)) { + // partsWithIds is non-empty here, so the chunk's content (and its role) is present. Rebuild + // the partial content from the parts-with-IDs so its FC ID matches the final event. + Content.Builder rebuilt = Content.builder().parts(partsWithIds); + currentProcessedLlmResponse.content().flatMap(Content::role).ifPresent(rebuilt::role); + return Flowable.just( + currentProcessedLlmResponse.toBuilder().content(rebuilt.build()).partial(true).build()); + } + + // If the chunk has no text or function calls (e.g. metadata-only or empty), we suppress it + // during streaming so it doesn't emit an empty partial response. + // Exception: If this is a standalone empty chunk in an otherwise completely empty stream + // (and not a STOP chunk), we emit it directly as a non-partial empty response. + if (!isStop(currentProcessedLlmResponse) + && accumulatedSequence.isEmpty() + && currentTextBuffer.isEmpty()) { + return Flowable.just(currentProcessedLlmResponse.toBuilder().partial(null).build()); + } + + return Flowable.empty(); + } + + /** + * Returns a list of parts where every function-call part has a non-empty ID. If a part's + * function call already has an ID, the original part is preserved; otherwise a new part with a + * client-generated ID is substituted. Non-FC parts are passed through unchanged. + */ + private static List ensureFunctionCallIds(List parts) { + List result = new ArrayList<>(parts.size()); + for (Part part : parts) { + if (part.functionCall().isPresent()) { + FunctionCall fc = part.functionCall().get(); + if (fc.id().map(String::isEmpty).orElse(true)) { + FunctionCall withId = fc.toBuilder().id(generateClientFunctionCallId()).build(); + result.add(part.toBuilder().functionCall(withId).build()); + continue; + } + } + result.add(part); + } + return result; + } + + /** + * Generates a unique client-side function-call ID. Format matches {@code + * com.google.adk.flows.llmflows.Functions#generateClientFunctionCallId()} so downstream code + * that already sees IDs with the {@code "adk-"} prefix continues to work. + */ + private static String generateClientFunctionCallId() { + return "adk-" + UUID.randomUUID(); + } + + /** + * Accumulates text and function calls from incoming parts. Function-call parts passed to this + * method are expected to already have IDs (see {@link #ensureFunctionCallIds}). + * + * @return true if any text or function call was present, false otherwise. + */ + private boolean accumulateParts(List parts) { + boolean hasTextOrFc = false; + for (Part part : parts) { + part.thoughtSignature().ifPresent(sig -> currentThoughtSignature = sig); + String text = part.text().orElse(""); + if (!text.isEmpty()) { + hasTextOrFc = true; + boolean isThought = part.thought().orElse(false); + // Immediately flush the active text buffer to preserve the exact interleaved blocks of + // text/thoughts. + if (!currentTextBuffer.isEmpty() && isThought != currentTextIsThought) { + flushTextBufferToSequence(); + } + if (currentTextBuffer.isEmpty()) { + currentTextIsThought = isThought; + } + currentTextBuffer.append(text); + } + if (part.functionCall().isPresent()) { + hasTextOrFc = true; + processFunctionCallPart(part); + } + } + return hasTextOrFc; + } + + /** + * Processes a function-call part, mirroring ADK Python's {@code _process_function_call_part}. A + * function call whose arguments are streamed across chunks (it carries {@code partialArgs} or + * {@code willContinue=true}) is accumulated and flushed as a single complete part once it + * finishes; a complete (non-streaming) function call is appended directly. + */ + private void processFunctionCallPart(Part part) { + FunctionCall fc = part.functionCall().get(); + boolean streaming = + fc.partialArgs().map(args -> !args.isEmpty()).orElse(false) + || fc.willContinue().orElse(false); + if (streaming) { + // Capture the thought signature from the first chunk that carries one. + if (part.thoughtSignature().isPresent() && currentThoughtSignature == null) { + currentThoughtSignature = part.thoughtSignature().get(); + } + processStreamingFunctionCall(fc); + } else if (fc.name().filter(name -> !name.isEmpty()).isPresent()) { + // Complete function call. Skip empty calls, which are only streaming end markers. The part + // already has an ID assigned by ensureFunctionCallIds. + flushTextBufferToSequence(); + accumulatedSequence.add(part); + } + } + + /** + * Accumulates one chunk of a streamed function call, mirroring ADK Python's {@code + * _process_streaming_function_call}: merges the function name/ID and each {@code partialArg} + * (by JSONPath) into {@link #currentFcArgs}, then flushes the completed call once {@code + * willContinue} is no longer set. + */ + private void processStreamingFunctionCall(FunctionCall fc) { + fc.name().filter(name -> !name.isEmpty()).ifPresent(name -> currentFcName = name); + // Use the first ID seen (the model's, if provided, otherwise a generated one) for the whole + // call so the partial and final events correlate. + if (currentFcId == null) { + currentFcId = + fc.id().filter(id -> !id.isEmpty()).orElseGet(() -> generateClientFunctionCallId()); + } + for (PartialArg partialArg : fc.partialArgs().orElse(ImmutableList.of())) { + String jsonPath = partialArg.jsonPath().orElse(""); + if (jsonPath.isEmpty()) { + continue; + } + applyPartialArg(partialArg, jsonPath); + } + if (!fc.willContinue().orElse(false)) { + flushTextBufferToSequence(); + flushFunctionCallToSequence(); + } + } + + /** + * Applies a single {@link PartialArg} to {@link #currentFcArgs} at {@code jsonPath}, mirroring + * ADK Python's {@code _get_value_from_partial_arg}: string chunks are appended to any existing + * string at the path, while number/bool/null values overwrite. + */ + private void applyPartialArg(PartialArg partialArg, String jsonPath) { + if (partialArg.stringValue().isPresent()) { + Object existing = getValueByJsonPath(jsonPath); + String chunk = partialArg.stringValue().get(); + setValueByJsonPath(jsonPath, existing instanceof String s ? s + chunk : chunk); + } else if (partialArg.numberValue().isPresent()) { + setValueByJsonPath(jsonPath, partialArg.numberValue().get()); + } else if (partialArg.boolValue().isPresent()) { + setValueByJsonPath(jsonPath, partialArg.boolValue().get()); + } else if (partialArg.nullValue().isPresent()) { + setValueByJsonPath(jsonPath, null); + } + } + + /** + * Returns the value currently stored at {@code jsonPath} in {@link #currentFcArgs}, or null. + */ + private @Nullable Object getValueByJsonPath(String jsonPath) { + Object current = currentFcArgs; + for (String key : splitJsonPath(jsonPath)) { + if (current instanceof Map map && map.containsKey(key)) { + current = map.get(key); + } else { + return null; + } + } + return current; + } + + /** + * Sets {@code value} at {@code jsonPath} in {@link #currentFcArgs}, creating maps as needed. + */ + @SuppressWarnings("unchecked") + private void setValueByJsonPath(String jsonPath, Object value) { + String[] keys = splitJsonPath(jsonPath); + Map current = currentFcArgs; + for (int i = 0; i < keys.length - 1; i++) { + Object next = current.get(keys[i]); + if (!(next instanceof Map)) { + next = new LinkedHashMap<>(); + current.put(keys[i], next); + } + current = (Map) next; + } + current.put(keys[keys.length - 1], value); + } + + /** Splits a JSONPath such as {@code "$.location.city"} into its component keys. */ + private static String[] splitJsonPath(String jsonPath) { + String path = jsonPath.startsWith("$.") ? jsonPath.substring(2) : jsonPath; + return path.split("\\."); + } + + /** + * Flushes the accumulated streamed function call (if any) to {@link #accumulatedSequence} as a + * single complete part, mirroring ADK Python's {@code _flush_function_call_to_sequence}. + */ + private void flushFunctionCallToSequence() { + if (currentFcName == null) { + return; + } + FunctionCall.Builder fcBuilder = + FunctionCall.builder().name(currentFcName).args(new LinkedHashMap<>(currentFcArgs)); + if (currentFcId != null) { + fcBuilder.id(currentFcId); + } + Part.Builder partBuilder = Part.builder().functionCall(fcBuilder.build()); + if (currentThoughtSignature != null) { + partBuilder.thoughtSignature(currentThoughtSignature); + } + accumulatedSequence.add(partBuilder.build()); + currentFcName = null; + currentFcArgs = new LinkedHashMap<>(); + currentFcId = null; + currentThoughtSignature = null; + } + + /** Flushes any accumulated text or thought content in the buffer as a new {@link Part}. */ + private void flushTextBufferToSequence() { + if (!currentTextBuffer.isEmpty()) { + Part.Builder partBuilder = + Part.builder().text(currentTextBuffer.toString()).thought(currentTextIsThought); + if (currentThoughtSignature != null) { + partBuilder.thoughtSignature(currentThoughtSignature); + currentThoughtSignature = null; + } + accumulatedSequence.add(partBuilder.build()); + currentTextBuffer.setLength(0); + currentTextIsThought = false; + } + } + + /** + * Emits the final aggregated, non-partial response with all accumulated parts (thoughts, text, + * function calls). Mirrors ADK Python's {@code StreamingResponseAggregator.close()}: emitted + * even without a finish reason so accumulated content is never dropped; a non-STOP finish + * reason is surfaced as an error. + */ + private Flowable processFinalResponse() { + if (lastRawResponse == null) { + return Flowable.empty(); + } + LlmResponse currentResponse = LlmResponse.create(lastRawResponse); + + flushTextBufferToSequence(); + // Flush any in-progress streamed function call whose stream ended before completing. + flushFunctionCallToSequence(); + + // Nothing accumulated and no finish reason: any empty/metadata chunk already streamed, skip. + boolean hasFinishReason = currentResponse.finishReason().isPresent(); + if (accumulatedSequence.isEmpty() && !hasFinishReason) { + return Flowable.empty(); + } + + LlmResponse.Builder finalResponseBuilder = currentResponse.toBuilder().partial(null); + if (hasFinishReason && !isStop(currentResponse)) { + finalResponseBuilder.errorCode(currentResponse.finishReason().get()); + lastRawResponse + .candidates() + .filter(candidates -> !candidates.isEmpty()) + .map(candidates -> candidates.get(0)) + .flatMap(Candidate::finishMessage) + .ifPresent(finalResponseBuilder::errorMessage); + } + + if (accumulatedSequence.isEmpty()) { + return Flowable.just(finalResponseBuilder.build()); + } + + // If the final chunk carries a thoughtSignature (e.g. from a preceding function call or + // thought), attach it to the last accumulated part in the sequence. + GeminiUtil.getPart0FromLlmResponse(currentResponse) + .flatMap(Part::thoughtSignature) + .ifPresent( + signature -> { + int targetIndex = accumulatedSequence.size() - 1; + Part targetPart = accumulatedSequence.get(targetIndex); + accumulatedSequence.set( + targetIndex, targetPart.toBuilder().thoughtSignature(signature).build()); + }); + + return Flowable.just( + finalResponseBuilder + .content(Content.builder().role("model").parts(accumulatedSequence).build()) + .build()); + } + + /** Checks whether the response finish reason indicates the stream has finished with STOP. */ + private static boolean isStop(LlmResponse response) { + return response + .finishReason() + .map(reason -> reason.knownEnum() == FinishReason.Known.STOP) + .orElse(false); + } + } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index eb5e4d1f2..7cacc525c 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -16,6 +16,8 @@ package com.google.adk.runner; +import static com.google.common.base.Preconditions.checkArgument; + import com.google.adk.agents.ActiveStreamingTool; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.ContextCacheConfig; @@ -353,9 +355,7 @@ private Single appendNewMessageToSession( InvocationContext invocationContext, boolean saveInputBlobsAsArtifacts, @Nullable Map stateDelta) { - if (newMessage.parts().isEmpty()) { - throw new IllegalArgumentException("No parts in the new_message."); - } + checkArgument(newMessage.parts().isPresent(), "No parts in the new_message."); Completable saveArtifactsFlow = Completable.complete(); if (this.artifactService != null && saveInputBlobsAsArtifacts) { @@ -609,29 +609,37 @@ private Flowable runAgentWithUpdatedSession( .agent() .runAsync(contextWithUpdatedSession) .concatMap( - agentEvent -> - this.sessionService - .appendEvent(updatedSession, agentEvent) - // Release (or fail) BaseLlmFlow's wait for this step; the Runner stays the - // sole appendEvent caller (see PersistBarrier). - .doOnSuccess( - unusedEvent -> - PersistBarrier.markPersisted( - contextWithUpdatedSession, agentEvent.id())) - .doOnError( - error -> - PersistBarrier.markFailed( - contextWithUpdatedSession, agentEvent.id(), error)) - .flatMap( - registeredEvent -> { - // TODO: remove this hack after deprecating runAsync with Session. - copySessionStates(updatedSession, initialContext.session()); - return contextWithUpdatedSession - .pluginManager() - .onEventCallback(contextWithUpdatedSession, registeredEvent) - .defaultIfEmpty(registeredEvent); - }) - .toFlowable()); + agentEvent -> { + // Mirror ADK Python (runners.py): partial events are streamed to the caller but + // never persisted, so managed session services (e.g. VertexAiSessionService) do + // not store a duplicate of the function call/text that the final aggregated event + // already carries. Nothing to persist, so resolve the barrier immediately. + Single persistStep = + agentEvent.partial().orElse(false) + ? Single.just(agentEvent) + : this.sessionService.appendEvent(updatedSession, agentEvent); + return persistStep + // Release (or fail) BaseLlmFlow's wait for this step; the Runner stays the + // sole appendEvent caller (see PersistBarrier). + .doOnSuccess( + unusedEvent -> + PersistBarrier.markPersisted( + contextWithUpdatedSession, agentEvent.id())) + .doOnError( + error -> + PersistBarrier.markFailed( + contextWithUpdatedSession, agentEvent.id(), error)) + .flatMap( + registeredEvent -> { + // TODO: remove this hack after deprecating runAsync with Session. + copySessionStates(updatedSession, initialContext.session()); + return contextWithUpdatedSession + .pluginManager() + .onEventCallback(contextWithUpdatedSession, registeredEvent) + .defaultIfEmpty(registeredEvent); + }) + .toFlowable(); + }); // If beforeRunCallback returns content, emit it and skip agent Context capturedContext = Context.current(); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index a58a206d9..8dd4f40c0 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -44,6 +44,7 @@ import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; +import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; @@ -221,6 +222,103 @@ public void run_withLongRunningFunctionCall_returnsCorrectEventsWithLongRunningT assertThat(events.get(2).content()).hasValue(secondContent); } + @Test + public void run_withPartialFunctionCall_doesNotExecuteTool() { + Content partialContent = + Content.fromParts(Part.fromFunctionCall("my_function", ImmutableMap.of("arg1", "value1"))); + LlmResponse partialResponse = + LlmResponse.builder().content(partialContent).partial(true).build(); + TestLlm testLlm = createTestLlm(partialResponse); + ImmutableMap testResponse = + ImmutableMap.of("response", "response for my_function"); + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .tools(ImmutableList.of(new TestTool("my_function", testResponse))) + .build()); + BaseLlmFlow baseLlmFlow = + createBaseLlmFlow( + /* requestProcessors= */ ImmutableList.of(), + /* responseProcessors= */ ImmutableList.of(), + /* maxSteps= */ Optional.of(1)); + + List events = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + assertThat(events.get(0).partial()).hasValue(true); + assertThat(events.get(0).functionCalls()).hasSize(1); + } + + // End-to-end: when the Gemini aggregator emits a partial event and a final aggregated event for + // the same function call, both must share the same function-call ID so consumers can correlate + // them. Mirrors ADK Python's progressive SSE contract. Simulates the post-aggregator stream (FC + // ID pre-populated, same Part reused across both events). + @Test + public void run_streamingFunctionCallWithPrePopulatedId_partialAndFinalShareFunctionCallId() { + // The aggregator (Gemini.processRawResponses) pre-populates the function call ID. Both the + // partial event and the final aggregated event reference the same Part with the same ID. + Part fcPartWithId = + Part.builder() + .functionCall( + FunctionCall.builder() + .id("adk-fixed-id-for-test") + .name("my_function") + .args(ImmutableMap.of("arg1", "value1")) + .build()) + .build(); + Content fcContent = Content.builder().role("model").parts(fcPartWithId).build(); + LlmResponse partialResponse = LlmResponse.builder().content(fcContent).partial(true).build(); + LlmResponse aggregatedResponse = + LlmResponse.builder().content(fcContent).partial(false).build(); + Content secondContent = + Content.fromParts(Part.fromText("LLM response after function response")); + TestLlm testLlm = + createTestLlm( + // First LLM call: SSE-style stream with partial + aggregated FC events. + Flowable.just(partialResponse, aggregatedResponse), + // Second LLM call: final text response after the tool executes. + Flowable.just(createLlmResponse(secondContent))); + ImmutableMap testResponse = + ImmutableMap.of("response", "response for my_function"); + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .tools(ImmutableList.of(new TestTool("my_function", testResponse))) + .build()); + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + List events = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + Event partialFcEvent = null; + Event aggregatedFcEvent = null; + int totalFunctionResponses = 0; + for (Event e : events) { + if (!e.functionCalls().isEmpty()) { + if (e.partial().orElse(false)) { + partialFcEvent = e; + } else { + aggregatedFcEvent = e; + } + } + totalFunctionResponses += e.functionResponses().size(); + } + + // Tool executes exactly once (only the non-partial event triggers execution). + assertThat(totalFunctionResponses).isEqualTo(1); + + // Both events carry the function call (this matches ADK Python's progressive SSE behavior). + assertThat(partialFcEvent).isNotNull(); + assertThat(aggregatedFcEvent).isNotNull(); + assertThat(partialFcEvent.functionCalls()).hasSize(1); + assertThat(aggregatedFcEvent.functionCalls()).hasSize(1); + + // The FC IDs must match so consumers can correlate/dedupe. + String partialId = partialFcEvent.functionCalls().get(0).id().orElseThrow(); + String aggregatedId = aggregatedFcEvent.functionCalls().get(0).id().orElseThrow(); + assertThat(partialId).isEqualTo(aggregatedId); + assertThat(partialId).isEqualTo("adk-fixed-id-for-test"); + } + @Test public void run_withRequestProcessor_doesNotModifyRequest() { Content content = Content.fromParts(Part.fromText("LLM response")); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 1e6267dde..fc2b6e1e3 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -550,7 +550,7 @@ public void processRequest_sequentialFCFR_returnsOriginalList() { } @Test - public void rearrangeHistory_gemini3interleavedFCFR_groupsFcAndFr() { + public void rearrangeHistory_gemini3SequentialCalls_preservesInterleavedOrder() { Event u1 = createUserEvent("u1", "Query"); Event fc1 = createFunctionCallEvent("fc1", "tool1", "call1"); Event fr1 = createFunctionResponseEvent("fr1", "tool1", "call1"); @@ -561,16 +561,66 @@ public void rearrangeHistory_gemini3interleavedFCFR_groupsFcAndFr() { List result = runContentsProcessorWithModelName(inputEvents, "gemini-3-flash-exp"); - assertThat(result).hasSize(4); - assertThat(result.get(0)).isEqualTo(u1.content().get()); - assertThat(result.get(1)).isEqualTo(fc1.content().get()); - assertThat(result.get(2)).isEqualTo(fc2.content().get()); - Content mergedContent = result.get(3); + assertThat(result).isEqualTo(eventsToContents(inputEvents)); + } + + @Test + public void rearrangeHistory_sequentialCalls_preservesInterleavedOrder() { + Event u1 = createUserEvent("u1", "Query"); + Event fc1 = createFunctionCallEvent("fc1", "tool1", "call1"); + Event fr1 = createFunctionResponseEvent("fr1", "tool1", "call1"); + Event fc2 = createFunctionCallEvent("fc2", "tool2", "call2"); + Event fr2 = createFunctionResponseEvent("fr2", "tool2", "call2"); + + ImmutableList inputEvents = ImmutableList.of(u1, fc1, fr1, fc2, fr2); + + List result = runContentsProcessor(inputEvents); + + assertThat(result).isEqualTo(eventsToContents(inputEvents)); + } + + @Test + public void rearrangeHistory_parallelCallsSeparateResponseEvents_mergesResponses() { + Event fcEvent = createParallelFunctionCallEvent("fc1", "tool1", "call1", "tool2", "call2"); + Event frEvent1 = createFunctionResponseEvent("fr1", "tool1", "call1"); + Event frEvent2 = createFunctionResponseEvent("fr2", "tool2", "call2"); + ImmutableList inputEvents = + ImmutableList.of(createUserEvent("u1", "Query"), fcEvent, frEvent1, frEvent2); + + List result = runContentsProcessorWithModelName(inputEvents, "gemini-3-flash-exp"); + + assertThat(result).hasSize(3); // u1, fc1, merged_fr + assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); + assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); + Content mergedContent = result.get(2); + assertThat(mergedContent.parts().get()).hasSize(2); + assertThat(mergedContent.parts().get().get(0).functionResponse().get().name()) + .hasValue("tool1"); + assertThat(mergedContent.parts().get().get(1).functionResponse().get().name()) + .hasValue("tool2"); + } + + @Test + public void rearrangeHistory_parallelCallsSeparateResponseEventsInHistory_mergesResponses() { + Event fcEvent = createParallelFunctionCallEvent("fc1", "tool1", "call1", "tool2", "call2"); + Event frEvent1 = createFunctionResponseEvent("fr1", "tool1", "call1"); + Event frEvent2 = createFunctionResponseEvent("fr2", "tool2", "call2"); + Event u2 = createUserEvent("u2", "Second Query"); + ImmutableList inputEvents = + ImmutableList.of(createUserEvent("u1", "Query"), fcEvent, frEvent1, frEvent2, u2); + + List result = runContentsProcessor(inputEvents); + + assertThat(result).hasSize(4); // u1, fc1, merged_fr, u2 + assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); + assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); + Content mergedContent = result.get(2); assertThat(mergedContent.parts().get()).hasSize(2); assertThat(mergedContent.parts().get().get(0).functionResponse().get().name()) .hasValue("tool1"); assertThat(mergedContent.parts().get().get(1).functionResponse().get().name()) .hasValue("tool2"); + assertThat(result.get(3)).isEqualTo(inputEvents.get(4).content().get()); } @Test diff --git a/core/src/test/java/com/google/adk/models/GeminiTest.java b/core/src/test/java/com/google/adk/models/GeminiTest.java index c230f5f68..62044d362 100644 --- a/core/src/test/java/com/google/adk/models/GeminiTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiTest.java @@ -16,14 +16,19 @@ package com.google.adk.models; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; import com.google.genai.types.Candidate; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; +import com.google.genai.types.FunctionCall; import com.google.genai.types.GenerateContentResponse; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; +import com.google.genai.types.PartialArg; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.functions.Predicate; import io.reactivex.rxjava3.subscribers.TestSubscriber; @@ -42,8 +47,12 @@ public void processRawResponses_withTextChunks_emitsPartialResponses() { Flowable llmResponses = Gemini.processRawResponses(rawResponses); + // No finish reason: the accumulated text is still emitted as a final aggregated response. assertLlmResponses( - llmResponses, isPartialTextResponse("Hello"), isPartialTextResponse(" world")); + llmResponses, + isPartialTextResponse("Hello"), + isPartialTextResponse(" world"), + isFinalTextResponse("Hello world")); } @Test @@ -59,8 +68,304 @@ public void processRawResponses_withTextChunks_emitsPartialResponses() { assertLlmResponses( llmResponses, isPartialTextResponse("Thinking..."), - isFinalTextResponse("Thinking..."), - isFunctionCallResponse()); + isPartialFunctionCallResponse("test_function"), + isFinalTextAndFunctionCallResponseWithNoUsageMetadata("Thinking...", "test_function")); + } + + @Test + public void processRawResponses_chunkWithBothTextAndFunctionCall_emitsPartialWithBoth() { + GenerateContentResponse chunkWithBoth = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts( + Part.fromText("Here is the call:"), + Part.fromFunctionCall("my_tool", ImmutableMap.of())) + .build()) + .build()) + .build(); + + Flowable rawResponses = Flowable.just(chunkWithBoth); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextAndFunctionCallResponse("Here is the call:", "my_tool"), + isFinalTextAndFunctionCallResponseWithNoUsageMetadata("Here is the call:", "my_tool")); + } + + @Test + public void processRawResponses_streamingFunctionCallsAndStop_emitsPartialsThenFinalAggregated() { + Part fc1 = Part.fromFunctionCall("tool1", ImmutableMap.of("arg1", "val1")); + Part fc2 = Part.fromFunctionCall("tool2", ImmutableMap.of("arg2", "val2")); + GenerateContentResponse fc2WithStop = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(fc2).build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .build(); + Flowable rawResponses = Flowable.just(toResponse(fc1), fc2WithStop); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialFunctionCallResponse("tool1"), + isPartialFunctionCallResponse("tool2"), + isFinalAggregatedFunctionCallResponse("tool1", "tool2")); + } + + // Mirrors ADK Python's test_streaming_fc_generates_consistent_id_across_chunks: a function call + // arriving without an ID gets one client-side ID, reused in both the partial and final events so + // consumers can correlate them (and distinct calls get distinct IDs). + @Test + public void + processRawResponses_streamingFunctionCallsAndStop_partialAndFinalShareFunctionCallId() { + Part fc1 = Part.fromFunctionCall("tool1", ImmutableMap.of("arg1", "val1")); + Part fc2 = Part.fromFunctionCall("tool2", ImmutableMap.of("arg2", "val2")); + GenerateContentResponse fc2WithStop = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(fc2).build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .build(); + Flowable rawResponses = Flowable.just(toResponse(fc1), fc2WithStop); + + ImmutableList responses = + ImmutableList.copyOf(Gemini.processRawResponses(rawResponses).blockingIterable()); + + // 3 responses: partial(tool1), partial(tool2), final(tool1+tool2). + assertThat(responses).hasSize(3); + + LlmResponse partial1 = responses.get(0); + LlmResponse partial2 = responses.get(1); + LlmResponse finalAgg = responses.get(2); + + String partial1Id = functionCallId(partial1, 0); + String partial2Id = functionCallId(partial2, 0); + String final1Id = functionCallId(finalAgg, 0); + String final2Id = functionCallId(finalAgg, 1); + + // Tool1's ID matches between its partial event and its position in the final aggregated event. + assertThat(partial1Id).isEqualTo(final1Id); + // Tool2's ID matches between its partial event and its position in the final aggregated event. + assertThat(partial2Id).isEqualTo(final2Id); + // The two distinct calls have distinct IDs. + assertThat(partial1Id).isNotEqualTo(partial2Id); + } + + // Mirrors ADK Python's test_non_streaming_fc_generates_id_when_empty: a function call without an + // ID gets a client-side "adk-"-prefixed ID (the prefix lets downstream code strip client IDs + // before replaying to the model), shared by the partial and final events. + @Test + public void processRawResponses_functionCallWithoutId_generatesAdkPrefixedId() { + GenerateContentResponse fcWithStop = + toResponse( + Candidate.builder() + .content( + Content.builder() + .parts(Part.fromFunctionCall("my_tool", ImmutableMap.of("x", "1"))) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()); + + ImmutableList responses = + ImmutableList.copyOf( + Gemini.processRawResponses(Flowable.just(fcWithStop)).blockingIterable()); + + // partial(my_tool) + final(my_tool). + assertThat(responses).hasSize(2); + String partialId = functionCallId(responses.get(0), 0); + String finalId = functionCallId(responses.get(1), 0); + assertThat(partialId).startsWith("adk-"); + assertThat(finalId).startsWith("adk-"); + assertThat(partialId).isEqualTo(finalId); + // A complete (non-streaming) call keeps its arguments verbatim in the final event. + FunctionCall finalCall = + Iterables.getLast(responses).content().get().parts().get().get(0).functionCall().get(); + assertThat(finalCall.args().get()).containsExactly("x", "1"); + } + + // Mirrors ADK Python's streaming_utils test_non_streaming_fc_preserves_llm_assigned_id: when the + // model itself supplies a function-call ID, the aggregator must preserve it (rather than + // overwriting it with a generated "adk-" ID) in both the partial and final events. + @Test + public void processRawResponses_functionCallWithModelProvidedId_preservesId() { + Part fcWithId = + Part.builder() + .functionCall( + FunctionCall.builder() + .id("model-assigned-id") + .name("my_tool") + .args(ImmutableMap.of("x", "1")) + .build()) + .build(); + GenerateContentResponse fcWithStop = + toResponse( + Candidate.builder() + .content(Content.builder().parts(fcWithId).build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()); + + ImmutableList responses = + ImmutableList.copyOf( + Gemini.processRawResponses(Flowable.just(fcWithStop)).blockingIterable()); + + // partial(my_tool) + final(my_tool), both keeping the model-supplied ID. + assertThat(responses).hasSize(2); + assertThat(functionCallId(responses.get(0), 0)).isEqualTo("model-assigned-id"); + assertThat(functionCallId(responses.get(1), 0)).isEqualTo("model-assigned-id"); + } + + // Mirrors ADK Python's streaming_utils streamed-function-call handling: when the model streams a + // single function call across chunks via partialArgs/willContinue, the arguments are accumulated + // (string chunks concatenated by JSONPath) and emitted as ONE complete call in the final + // aggregated response, rather than one (incomplete) call per chunk. + @Test + public void processRawResponses_streamingFunctionCallArgs_mergesIntoSingleFinalCall() { + GenerateContentResponse chunk1 = + toResponse( + functionCallPart(FunctionCall.builder().name("getWeather").willContinue(true).build())); + GenerateContentResponse chunk2 = + toResponse( + functionCallPart( + FunctionCall.builder() + .partialArgs(PartialArg.builder().jsonPath("$.city").stringValue("Kra").build()) + .willContinue(true) + .build())); + GenerateContentResponse chunk3 = + toResponse( + Candidate.builder() + .content( + Content.builder() + .parts( + functionCallPart( + FunctionCall.builder() + .partialArgs( + PartialArg.builder() + .jsonPath("$.city") + .stringValue("kow") + .build()) + .willContinue(false) + .build())) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()); + + ImmutableList responses = + ImmutableList.copyOf( + Gemini.processRawResponses(Flowable.just(chunk1, chunk2, chunk3)).blockingIterable()); + + // The final aggregated response carries exactly one complete getWeather(city="Krakow") call. + LlmResponse finalResponse = Iterables.getLast(responses); + assertThat(finalResponse.partial().orElse(false)).isFalse(); + assertThat(finalResponse.content().get().parts().get()).hasSize(1); + FunctionCall finalCall = + finalResponse.content().get().parts().get().get(0).functionCall().get(); + assertThat(finalCall.name()).hasValue("getWeather"); + assertThat(finalCall.args().get()).containsExactly("city", "Krakow"); + // The call's ID (generated on the first chunk) is reused on the final event. + assertThat(finalCall.id()).hasValue(functionCallId(responses.get(0), 0)); + } + + // Streamed function-call arguments may target nested JSONPaths and non-string values; the + // aggregator must build the nested structure, mirroring ADK Python's _set_value_by_json_path. + @Test + public void processRawResponses_streamingFunctionCallArgs_buildsNestedArgs() { + GenerateContentResponse chunk1 = + toResponse( + functionCallPart( + FunctionCall.builder() + .name("book") + .partialArgs( + PartialArg.builder() + .jsonPath("$.location.city") + .stringValue("Paris") + .build()) + .willContinue(true) + .build())); + GenerateContentResponse chunk2 = + toResponse( + Candidate.builder() + .content( + Content.builder() + .parts( + functionCallPart( + FunctionCall.builder() + .partialArgs( + PartialArg.builder() + .jsonPath("$.guests") + .numberValue(2.0) + .build()) + .willContinue(false) + .build())) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()); + + ImmutableList responses = + ImmutableList.copyOf( + Gemini.processRawResponses(Flowable.just(chunk1, chunk2)).blockingIterable()); + + LlmResponse finalResponse = Iterables.getLast(responses); + FunctionCall finalCall = + finalResponse.content().get().parts().get().get(0).functionCall().get(); + assertThat(finalCall.name()).hasValue("book"); + assertThat(finalCall.args().get()) + .containsExactly("location", ImmutableMap.of("city", "Paris"), "guests", 2.0); + } + + // Two streamed function calls back-to-back must not bleed arguments into each other: a completed + // call's accumulated-args state is reset before the next one starts. + @Test + public void processRawResponses_twoStreamingFunctionCalls_keepArgsSeparate() { + GenerateContentResponse call1 = + toResponse( + functionCallPart( + FunctionCall.builder() + .name("first") + .partialArgs(PartialArg.builder().jsonPath("$.a").stringValue("1").build()) + .willContinue(false) + .build())); + GenerateContentResponse call2 = + toResponse( + Candidate.builder() + .content( + Content.builder() + .parts( + functionCallPart( + FunctionCall.builder() + .name("second") + .partialArgs( + PartialArg.builder() + .jsonPath("$.b") + .stringValue("2") + .build()) + .willContinue(false) + .build())) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()); + + ImmutableList responses = + ImmutableList.copyOf( + Gemini.processRawResponses(Flowable.just(call1, call2)).blockingIterable()); + + LlmResponse finalResponse = Iterables.getLast(responses); + assertThat(finalResponse.content().get().parts().get()).hasSize(2); + FunctionCall first = finalResponse.content().get().parts().get().get(0).functionCall().get(); + FunctionCall second = finalResponse.content().get().parts().get().get(1).functionCall().get(); + assertThat(first.name()).hasValue("first"); + assertThat(first.args().get()).containsExactly("a", "1"); + assertThat(second.name()).hasValue("second"); + assertThat(second.args().get()).containsExactly("b", "2"); } @Test @@ -98,7 +403,7 @@ public void processRawResponses_singleEmptyResponse_emitsOneEmptyResponse() { } @Test - public void processRawResponses_finishReasonNotStop_doesNotEmitFinalAccumulatedText() { + public void processRawResponses_finishReasonNotStop_emitsFinalWithErrorCode() { Flowable rawResponses = Flowable.just( toResponseWithText("Hello"), @@ -106,22 +411,50 @@ public void processRawResponses_finishReasonNotStop_doesNotEmitFinalAccumulatedT Flowable llmResponses = Gemini.processRawResponses(rawResponses); + // Mirrors ADK Python: a non-STOP finish still yields the aggregated final response, with the + // finish reason surfaced as an error code. assertLlmResponses( - llmResponses, isPartialTextResponse("Hello"), isPartialTextResponse(" world")); + llmResponses, + isPartialTextResponse("Hello"), + isPartialTextResponse(" world"), + isFinalTextResponseWithErrorCode("Hello world", FinishReason.Known.MAX_TOKENS)); } @Test - public void processRawResponses_textThenEmpty_emitsPartialTextThenFullTextAndEmpty() { + public void + processRawResponses_finishReasonNotStopWithMessage_finalResponseIncludesErrorMessage() { + GenerateContentResponse truncatedResponse = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(Part.fromText(" world")).build()) + .finishReason(new FinishReason(FinishReason.Known.MAX_TOKENS)) + .finishMessage("Output truncated due to token limit.") + .build()) + .build(); Flowable rawResponses = - Flowable.just(toResponseWithText("Thinking..."), GenerateContentResponse.builder().build()); + Flowable.just(toResponseWithText("Hello"), truncatedResponse); Flowable llmResponses = Gemini.processRawResponses(rawResponses); + // A non-STOP finish surfaces the candidate's finishMessage as the response errorMessage. assertLlmResponses( llmResponses, - isPartialTextResponse("Thinking..."), - isFinalTextResponse("Thinking..."), - isEmptyResponse()); + isPartialTextResponse("Hello"), + isPartialTextResponse(" world"), + isFinalTextResponseWithErrorCodeAndMessage( + "Hello world", FinishReason.Known.MAX_TOKENS, "Output truncated due to token limit.")); + } + + @Test + public void processRawResponses_textThenEmpty_emitsPartialTextThenFullText() { + Flowable rawResponses = + Flowable.just(toResponseWithText("Thinking..."), GenerateContentResponse.builder().build()); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, isPartialTextResponse("Thinking..."), isFinalTextResponse("Thinking...")); } @Test @@ -137,7 +470,8 @@ public void processRawResponses_withTextChunks_partialResponsesIncludeUsageMetad assertLlmResponses( llmResponses, isPartialTextResponseWithUsageMetadata("Hello", metadata1), - isPartialTextResponseWithUsageMetadata(" world", metadata2)); + isPartialTextResponseWithUsageMetadata(" world", metadata2), + isFinalTextResponseWithUsageMetadata("Hello world", metadata2)); } @Test @@ -157,6 +491,27 @@ public void processRawResponses_textAndStopReason_finalResponseIncludesUsageMeta isFinalTextResponseWithUsageMetadata("Hello world", metadata)); } + @Test + public void + processRawResponses_textThenEmptyStopWithUsageMetadata_finalResponseIncludesUsageMetadata() { + GenerateContentResponseUsageMetadata metadata = createUsageMetadata(10, 20, 30); + GenerateContentResponse stopResponse = + GenerateContentResponse.builder() + .candidates( + Candidate.builder().finishReason(new FinishReason(FinishReason.Known.STOP)).build()) + .usageMetadata(metadata) + .build(); + Flowable rawResponses = + Flowable.just(toResponseWithText("Hello"), stopResponse); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextResponse("Hello"), + isFinalTextResponseWithUsageMetadata("Hello", metadata)); + } + @Test public void processRawResponses_thoughtChunksAndStop_includeUsageMetadata() { GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); @@ -190,12 +545,277 @@ public void processRawResponses_thoughtAndTextWithStop_onlyFinalTextIncludesUsag llmResponses, isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), isPartialTextResponseWithUsageMetadata("Answer", metadata2), - isFinalThoughtResponseWithNoUsageMetadata("Thinking"), - isFinalTextResponseWithUsageMetadata("Answer", metadata2)); + isFinalThoughtAndTextResponseWithUsageMetadata("Thinking", "Answer", metadata2)); } - // Helper methods for assertions + @Test + public void + processRawResponses_interleavedThoughtAndTextWithStop_separatelyAggregatesThoughtAndText() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 5, 10); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata3 = createUsageMetadata(10, 15, 25); + GenerateContentResponseUsageMetadata metadata4 = createUsageMetadata(10, 20, 30); + Flowable rawResponses = + Flowable.just( + toResponseWithThoughtText("Thinking 1", metadata1), + toResponseWithText("Answer 1", metadata2), + toResponseWithThoughtText(" Thinking 2", metadata3), + toResponseWithText(" Answer 2", FinishReason.Known.STOP, metadata4)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking 1", metadata1), + isPartialTextResponseWithUsageMetadata("Answer 1", metadata2), + isPartialThoughtResponseWithUsageMetadata(" Thinking 2", metadata3), + isPartialTextResponseWithUsageMetadata(" Answer 2", metadata4), + isFinalInterleavedThoughtAndTextResponseWithUsageMetadata( + "Thinking 1", "Answer 1", " Thinking 2", " Answer 2", metadata4)); + } + + @Test + public void + processRawResponses_textAndFunctionCallWithStop_onlyFinalFunctionCallIncludesUsageMetadata() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 5, 10); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(10, 20, 30); + Part fcPart = Part.fromFunctionCall("my_tool", ImmutableMap.of()); + GenerateContentResponse stopResponse = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(fcPart).build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .usageMetadata(metadata2) + .build(); + Flowable rawResponses = + Flowable.just(toResponseWithText("Answer", metadata1), stopResponse); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextResponseWithUsageMetadata("Answer", metadata1), + isPartialFunctionCallResponse("my_tool"), + isFinalTextAndFunctionCallResponseWithUsageMetadata("Answer", metadata2, "my_tool")); + } + + @Test + public void + processRawResponses_thoughtThenEmptyWithSignatureAndStop_flushesThoughtWithSignature() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25); + GenerateContentResponse chunk1 = toResponseWithThoughtText("Thinking", metadata1); + GenerateContentResponse chunk2 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts( + Part.builder() + .thought(true) + .thoughtSignature("sig".getBytes(UTF_8)) + .build()) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .usageMetadata(metadata2) + .build(); + Flowable rawResponses = Flowable.just(chunk1, chunk2); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), + isFinalThoughtResponseWithUsageMetadataAndSignature("Thinking", metadata2, "sig")); + } + + @Test + public void + processRawResponses_thoughtWithSignatureThenTextAndStop_flushesThoughtWithSignature() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25); + GenerateContentResponse chunk1 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts( + Part.builder() + .text("Thinking") + .thought(true) + .thoughtSignature("sig".getBytes(UTF_8)) + .build()) + .build()) + .build()) + .usageMetadata(metadata1) + .build(); + GenerateContentResponse chunk2 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts(Part.builder().text("Hello").thought(false).build()) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .usageMetadata(metadata2) + .build(); + Flowable rawResponses = Flowable.just(chunk1, chunk2); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), + isPartialTextResponseWithUsageMetadata("Hello", metadata2), + isFinalThoughtAndTextResponseWithUsageMetadataAndSignature( + "Thinking", "Hello", metadata2, "sig")); + } + + @Test + public void + processRawResponses_thoughtThenFunctionCallWithSignatureAndStop_attachesSignatureToFunctionCall() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25); + GenerateContentResponse chunk1 = toResponseWithThoughtText("Thinking", metadata1); + GenerateContentResponse chunk2 = + toResponse(Part.fromFunctionCall("my_tool", ImmutableMap.of())); + GenerateContentResponse chunk3 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts( + Part.builder() + .thought(true) + .thoughtSignature("sig".getBytes(UTF_8)) + .build()) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .usageMetadata(metadata2) + .build(); + Flowable rawResponses = Flowable.just(chunk1, chunk2, chunk3); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), + isPartialFunctionCallResponse("my_tool"), + isFinalThoughtAndFunctionCallResponseWithUsageMetadataAndSignature( + "Thinking", metadata2, "sig", "my_tool")); + } + + @Test + public void processRawResponses_emptyPartsThenSignature_doesNotThrowException() { + GenerateContentResponseUsageMetadata metadata = createUsageMetadata(5, 10, 15); + GenerateContentResponse chunk1 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(ImmutableList.of()).build()) + .build()) + .build(); + GenerateContentResponse chunk2 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts( + Part.builder() + .thought(true) + .thoughtSignature("sig".getBytes(UTF_8)) + .build()) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .usageMetadata(metadata) + .build(); + Flowable rawResponses = Flowable.just(chunk1, chunk2); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + assertLlmResponses( + llmResponses, + isEmptyResponse(), + isFinalThoughtResponseWithUsageMetadataAndSignature("", metadata, "sig")); + } + + @Test + public void functionCallThenEmptyTextWithStop_emitsPartialThenFinalAggregatedFunctionCall() { + Flowable rawResponses = + Flowable.just( + toResponse(Part.fromFunctionCall("test_function", ImmutableMap.of())), + toResponseWithText("", FinishReason.Known.STOP)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialFunctionCallResponse("test_function"), + isFinalAggregatedFunctionCallResponse("test_function")); + } + + @Test + public void functionCallThenEmptyTextWithUsageMetadata_emitsFinalAggregatedWithUsageMetadata() { + GenerateContentResponseUsageMetadata metadata = createUsageMetadata(5, 10, 15); + Flowable rawResponses = + Flowable.just( + toResponse(Part.fromFunctionCall("test_function", ImmutableMap.of())), + toResponseWithText("", FinishReason.Known.STOP, metadata)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialFunctionCallResponse("test_function"), + isFinalAggregatedFunctionCallResponseWithUsageMetadata(metadata, "test_function")); + } + + @Test + public void functionCallThenEmptyText_doesNotEmitExtraEmptyResponse() { + Flowable rawResponses = + Flowable.just( + toResponse(Part.fromFunctionCall("test_function", ImmutableMap.of())), + toResponseWithText("")); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + // The trailing empty-text chunk adds no empty response; the function call is still aggregated + // into a final response even without a finish reason. + assertLlmResponses( + llmResponses, + isPartialFunctionCallResponse("test_function"), + isFinalAggregatedFunctionCallResponse("test_function")); + } + + @Test + public void textThenFunctionCallThenEmptyTextWithStop_emitsTextThenFunctionCalls() { + Flowable rawResponses = + Flowable.just( + toResponseWithText("Thinking..."), + toResponse(Part.fromFunctionCall("test_function", ImmutableMap.of())), + toResponseWithText("", FinishReason.Known.STOP)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextResponse("Thinking..."), + isPartialFunctionCallResponse("test_function"), + isFinalTextAndFunctionCallResponseWithNoUsageMetadata("Thinking...", "test_function")); + } + + // Helper methods for assertions private void assertLlmResponses( Flowable llmResponses, Predicate... predicates) { TestSubscriber testSubscriber = llmResponses.test(); @@ -207,6 +827,17 @@ private void assertLlmResponses( testSubscriber.assertNoErrors(); } + /** Returns the function-call ID of the part at {@code partIndex} in the response's content. */ + private static String functionCallId(LlmResponse response, int partIndex) { + return response + .content() + .flatMap(Content::parts) + .map(parts -> parts.get(partIndex)) + .flatMap(Part::functionCall) + .flatMap(FunctionCall::id) + .orElseThrow(); + } + private static Predicate isPartialTextResponse(String expectedText) { return response -> { assertThat(response.partial()).hasValue(true); @@ -218,23 +849,88 @@ private static Predicate isPartialTextResponse(String expectedText) private static Predicate isFinalTextResponse(String expectedText) { return response -> { - assertThat(response.partial()).isEmpty(); + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + return true; + }; + } + + private static Predicate isFinalTextResponseWithErrorCode( + String expectedText, FinishReason.Known expectedErrorCode) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) .isEqualTo(expectedText); + assertThat(response.errorCode().map(FinishReason::knownEnum)).hasValue(expectedErrorCode); return true; }; } - private static Predicate isFunctionCallResponse() { + private static Predicate isFinalTextResponseWithErrorCodeAndMessage( + String expectedText, FinishReason.Known expectedErrorCode, String expectedErrorMessage) { return response -> { - assertThat(response.content().get().parts().get().get(0).functionCall()).isNotNull(); + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(response.errorCode().map(FinishReason::knownEnum)).hasValue(expectedErrorCode); + assertThat(response.errorMessage()).hasValue(expectedErrorMessage); + return true; + }; + } + + private static Predicate isPartialFunctionCallResponse(String expectedToolName) { + return response -> { + assertThat(response.partial()).hasValue(true); + assertThat(response.content().get().parts().get()).hasSize(1); + assertThat(response.content().get().parts().get().get(0).functionCall().get().name()) + .hasValue(expectedToolName); + return true; + }; + } + + private static Predicate isPartialTextAndFunctionCallResponse( + String expectedText, String expectedToolName) { + return response -> { + assertThat(response.partial()).hasValue(true); + assertThat(response.content().get().parts().get()).hasSize(2); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedText); + assertThat(response.content().get().parts().get().get(1).functionCall().get().name()) + .hasValue(expectedToolName); + return true; + }; + } + + private static Predicate isFinalAggregatedFunctionCallResponse( + String... expectedToolNames) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(expectedToolNames.length); + for (int i = 0; i < expectedToolNames.length; i++) { + assertThat(response.content().get().parts().get().get(i).functionCall().get().name()) + .hasValue(expectedToolNames[i]); + } + return true; + }; + } + + private static Predicate isFinalAggregatedFunctionCallResponseWithUsageMetadata( + GenerateContentResponseUsageMetadata expectedMetadata, String... expectedToolNames) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(expectedToolNames.length); + for (int i = 0; i < expectedToolNames.length; i++) { + assertThat(response.content().get().parts().get().get(i).functionCall().get().name()) + .hasValue(expectedToolNames[i]); + } + assertThat(response.usageMetadata()).hasValue(expectedMetadata); return true; }; } private static Predicate isEmptyResponse() { return response -> { - assertThat(response.partial()).isEmpty(); + assertThat(response.partial().orElse(false)).isFalse(); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) .isEmpty(); return true; @@ -268,7 +964,7 @@ private static Predicate isPartialThoughtResponseWithUsageMetadata( private static Predicate isFinalTextResponseWithUsageMetadata( String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { return response -> { - assertThat(response.partial()).isEmpty(); + assertThat(response.partial().orElse(false)).isFalse(); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) .isEqualTo(expectedText); assertThat(response.usageMetadata()).hasValue(expectedMetadata); @@ -279,7 +975,7 @@ private static Predicate isFinalTextResponseWithUsageMetadata( private static Predicate isFinalThoughtResponseWithUsageMetadata( String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { return response -> { - assertThat(response.partial()).isEmpty(); + assertThat(response.partial().orElse(false)).isFalse(); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) .isEqualTo(expectedText); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) @@ -289,21 +985,143 @@ private static Predicate isFinalThoughtResponseWithUsageMetadata( }; } - private static Predicate isFinalThoughtResponseWithNoUsageMetadata( - String expectedText) { + private static Predicate isFinalThoughtResponseWithUsageMetadataAndSignature( + String expectedText, + GenerateContentResponseUsageMetadata expectedMetadata, + String expectedSignature) { return response -> { - assertThat(response.partial()).isEmpty(); + assertThat(response.partial().orElse(false)).isFalse(); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) .isEqualTo(expectedText); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) .isTrue(); + assertThat( + GeminiUtil.getPart0FromLlmResponse(response) + .flatMap(Part::thoughtSignature) + .orElse(new byte[0])) + .isEqualTo(expectedSignature.getBytes(UTF_8)); + + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalThoughtAndTextResponseWithUsageMetadata( + String expectedThought, + String expectedText, + GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(2); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedThought); + assertThat(response.content().get().parts().get().get(0).thought()).hasValue(true); + assertThat(response.content().get().parts().get().get(1).text()).hasValue(expectedText); + assertThat(response.content().get().parts().get().get(1).thought()).hasValue(false); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalThoughtAndTextResponseWithUsageMetadataAndSignature( + String expectedThought, + String expectedText, + GenerateContentResponseUsageMetadata expectedMetadata, + String expectedSignature) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(2); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedThought); + assertThat(response.content().get().parts().get().get(0).thought()).hasValue(true); + assertThat( + response.content().get().parts().get().get(0).thoughtSignature().orElse(new byte[0])) + .isEqualTo(expectedSignature.getBytes(UTF_8)); + assertThat(response.content().get().parts().get().get(1).text()).hasValue(expectedText); + assertThat(response.content().get().parts().get().get(1).thought()).hasValue(false); + // The signature belongs only to the thought part; it must not leak onto the following text + // part (the aggregator resets the buffered signature after each flush). + assertThat(response.content().get().parts().get().get(1).thoughtSignature()).isEmpty(); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalInterleavedThoughtAndTextResponseWithUsageMetadata( + String expectedThought1, + String expectedText1, + String expectedThought2, + String expectedText2, + GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(4); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedThought1); + assertThat(response.content().get().parts().get().get(0).thought()).hasValue(true); + assertThat(response.content().get().parts().get().get(1).text()).hasValue(expectedText1); + assertThat(response.content().get().parts().get().get(1).thought()).hasValue(false); + assertThat(response.content().get().parts().get().get(2).text()).hasValue(expectedThought2); + assertThat(response.content().get().parts().get().get(2).thought()).hasValue(true); + assertThat(response.content().get().parts().get().get(3).text()).hasValue(expectedText2); + assertThat(response.content().get().parts().get().get(3).thought()).hasValue(false); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalTextAndFunctionCallResponseWithUsageMetadata( + String expectedText, + GenerateContentResponseUsageMetadata expectedMetadata, + String... expectedToolNames) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(expectedToolNames.length + 1); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedText); + for (int i = 0; i < expectedToolNames.length; i++) { + assertThat(response.content().get().parts().get().get(i + 1).functionCall().get().name()) + .hasValue(expectedToolNames[i]); + } + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalTextAndFunctionCallResponseWithNoUsageMetadata( + String expectedText, String... expectedToolNames) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(expectedToolNames.length + 1); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedText); + for (int i = 0; i < expectedToolNames.length; i++) { + assertThat(response.content().get().parts().get().get(i + 1).functionCall().get().name()) + .hasValue(expectedToolNames[i]); + } assertThat(response.usageMetadata()).isEmpty(); return true; }; } - // Helper methods to create responses for testing + private static Predicate + isFinalThoughtAndFunctionCallResponseWithUsageMetadataAndSignature( + String expectedThought, + GenerateContentResponseUsageMetadata expectedMetadata, + String expectedSignature, + String... expectedToolNames) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(expectedToolNames.length + 1); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedThought); + assertThat(response.content().get().parts().get().get(0).thought()).hasValue(true); + for (int i = 0; i < expectedToolNames.length; i++) { + Part part = response.content().get().parts().get().get(i + 1); + assertThat(part.functionCall().get().name()).hasValue(expectedToolNames[i]); + assertThat(part.thoughtSignature().orElse(new byte[0])) + .isEqualTo(expectedSignature.getBytes(UTF_8)); + } + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + // Helper methods to create responses for testing private GenerateContentResponse toResponseWithText(String text) { return toResponse(Part.fromText(text)); } @@ -316,14 +1134,6 @@ private GenerateContentResponse toResponseWithText(String text, FinishReason.Kno .build()); } - private GenerateContentResponse toResponse(Part part) { - return toResponse(Candidate.builder().content(Content.builder().parts(part).build()).build()); - } - - private GenerateContentResponse toResponse(Candidate candidate) { - return GenerateContentResponse.builder().candidates(candidate).build(); - } - private GenerateContentResponse toResponseWithText( String text, GenerateContentResponseUsageMetadata usageMetadata) { return GenerateContentResponse.builder() @@ -349,6 +1159,18 @@ private GenerateContentResponse toResponseWithText( .build(); } + private static Part functionCallPart(FunctionCall functionCall) { + return Part.builder().functionCall(functionCall).build(); + } + + private GenerateContentResponse toResponse(Part part) { + return toResponse(Candidate.builder().content(Content.builder().parts(part).build()).build()); + } + + private GenerateContentResponse toResponse(Candidate candidate) { + return GenerateContentResponse.builder().candidates(candidate).build(); + } + private GenerateContentResponse toResponseWithThoughtText( String text, GenerateContentResponseUsageMetadata usageMetadata) { Part thoughtPart = Part.fromText(text).toBuilder().thought(true).build(); diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 5b1e7b7f0..ae1e0ee74 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -90,6 +90,7 @@ import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -98,6 +99,7 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import org.jspecify.annotations.Nullable; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -2599,4 +2601,82 @@ public void runner_executesSaveArtifactFlow() { // agent was run assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); } + + @Test + public void runAsync_partialEvent_streamedButNotPassedToSessionService() { + // The model streams a partial event followed by the final aggregated event in one turn. + LlmResponse partialResponse = + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("partial")).build()) + .partial(true) + .build(); + LlmResponse finalResponse = + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("final")).build()) + .build(); + TestLlm testLlm = new TestLlm(() -> Flowable.just(partialResponse, finalResponse)); + LlmAgent agent = createTestAgent(testLlm); + RecordingSessionService sessionService = new RecordingSessionService(); + Runner runner = + Runner.builder() + .app(App.builder().name("test").rootAgent(agent).build()) + .sessionService(sessionService) + .build(); + Session session = sessionService.createSession("test", "user").blockingGet(); + + List events = + runner.runAsync("user", session.id(), createContent("hi")).toList().blockingGet(); + + // The partial event is still streamed to the caller. + assertThat(events.stream().anyMatch(event -> event.partial().orElse(false))).isTrue(); + // Mirroring ADK Python's Runner, partial events are never handed to the session service, so + // managed services (e.g. VertexAiSessionService) cannot persist duplicates. + assertThat(sessionService.appendedEvents.stream().anyMatch(e -> e.partial().orElse(false))) + .isFalse(); + } + + /** A session service that records every event passed to {@code appendEvent} for assertions. */ + private static final class RecordingSessionService implements BaseSessionService { + private final InMemorySessionService delegate = new InMemorySessionService(); + final List appendedEvents = Collections.synchronizedList(new ArrayList<>()); + + @Override + public Single appendEvent(Session session, Event event) { + appendedEvents.add(event); + return delegate.appendEvent(session, event); + } + + // BaseSessionService's only abstract createSession overload is deprecated, so implementing and + // delegating to it is unavoidable. + @SuppressWarnings("deprecation") + @Override + public Single createSession( + String appName, + String userId, + @Nullable ConcurrentMap state, + @Nullable String sessionId) { + return delegate.createSession(appName, userId, state, sessionId); + } + + @Override + public Maybe getSession( + String appName, String userId, String sessionId, Optional config) { + return delegate.getSession(appName, userId, sessionId, config); + } + + @Override + public Single listSessions(String appName, String userId) { + return delegate.listSessions(appName, userId); + } + + @Override + public Completable deleteSession(String appName, String userId, String sessionId) { + return delegate.deleteSession(appName, userId, sessionId); + } + + @Override + public Single listEvents(String appName, String userId, String sessionId) { + return delegate.listEvents(appName, userId, sessionId); + } + } }