diff --git a/core/src/main/java/com/google/adk/models/Gemini.java b/core/src/main/java/com/google/adk/models/Gemini.java index 74cf78b98..6f145e1de 100644 --- a/core/src/main/java/com/google/adk/models/Gemini.java +++ b/core/src/main/java/com/google/adk/models/Gemini.java @@ -239,7 +239,7 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre p -> p.functionCall().isPresent() || p.functionResponse().isPresent() - || p.text().map(t -> !t.isBlank()).orElse(false))) + || p.text().isPresent())) .orElse(false)); } else { logger.debug("Sending generateContent request to model {}", effectiveModelName); @@ -272,11 +272,17 @@ static Flowable processRawResponses(Flowable 0 @@ -316,11 +322,20 @@ static Flowable processRawResponses(Flowable finalResponses = new ArrayList<>(); if (accumulatedThoughtText.length() > 0) { finalResponses.add( - thinkingResponseFromText(accumulatedThoughtText.toString())); + thinkingResponseFromText(accumulatedThoughtText.toString()).toBuilder() + .usageMetadata( + accumulatedText.length() > 0 + ? null + : finalRawResp.usageMetadata().orElse(null)) + .build()); } if (accumulatedText.length() > 0) { - finalResponses.add(responseFromText(accumulatedText.toString())); + finalResponses.add( + responseFromText(accumulatedText.toString()).toBuilder() + .usageMetadata(finalRawResp.usageMetadata().orElse(null)) + .build()); } + return Flowable.fromIterable(finalResponses); } return Flowable.empty(); diff --git a/core/src/test/java/com/google/adk/models/GeminiTest.java b/core/src/test/java/com/google/adk/models/GeminiTest.java index 07dd675e5..c230f5f68 100644 --- a/core/src/test/java/com/google/adk/models/GeminiTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiTest.java @@ -22,6 +22,7 @@ import com.google.genai.types.Content; import com.google.genai.types.FinishReason; import com.google.genai.types.GenerateContentResponse; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.functions.Predicate; @@ -123,6 +124,76 @@ public void processRawResponses_textThenEmpty_emitsPartialTextThenFullTextAndEmp isEmptyResponse()); } + @Test + public void processRawResponses_withTextChunks_partialResponsesIncludeUsageMetadata() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25); + Flowable rawResponses = + Flowable.just( + toResponseWithText("Hello", metadata1), toResponseWithText(" world", metadata2)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextResponseWithUsageMetadata("Hello", metadata1), + isPartialTextResponseWithUsageMetadata(" world", metadata2)); + } + + @Test + public void processRawResponses_textAndStopReason_finalResponseIncludesUsageMetadata() { + GenerateContentResponseUsageMetadata metadata = createUsageMetadata(10, 20, 30); + Flowable rawResponses = + Flowable.just( + toResponseWithText("Hello"), + toResponseWithText(" world", FinishReason.Known.STOP, metadata)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextResponse("Hello"), + isPartialTextResponseWithUsageMetadata(" world", metadata), + isFinalTextResponseWithUsageMetadata("Hello world", metadata)); + } + + @Test + public void processRawResponses_thoughtChunksAndStop_includeUsageMetadata() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25); + Flowable rawResponses = + Flowable.just( + toResponseWithThoughtText("Thinking", metadata1), + toResponseWithThoughtText(" deeply", FinishReason.Known.STOP, metadata2)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), + isPartialThoughtResponseWithUsageMetadata(" deeply", metadata2), + isFinalThoughtResponseWithUsageMetadata("Thinking deeply", metadata2)); + } + + @Test + public void processRawResponses_thoughtAndTextWithStop_onlyFinalTextIncludesUsageMetadata() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 5, 10); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(10, 20, 30); + Flowable rawResponses = + Flowable.just( + toResponseWithThoughtText("Thinking", metadata1), + toResponseWithText("Answer", FinishReason.Known.STOP, metadata2)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), + isPartialTextResponseWithUsageMetadata("Answer", metadata2), + isFinalThoughtResponseWithNoUsageMetadata("Thinking"), + isFinalTextResponseWithUsageMetadata("Answer", metadata2)); + } + // Helper methods for assertions private void assertLlmResponses( @@ -170,6 +241,67 @@ private static Predicate isEmptyResponse() { }; } + private static Predicate isPartialTextResponseWithUsageMetadata( + String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial()).hasValue(true); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isPartialThoughtResponseWithUsageMetadata( + String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial()).hasValue(true); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) + .isTrue(); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalTextResponseWithUsageMetadata( + String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial()).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalThoughtResponseWithUsageMetadata( + String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial()).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) + .isTrue(); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalThoughtResponseWithNoUsageMetadata( + String expectedText) { + return response -> { + assertThat(response.partial()).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) + .isTrue(); + assertThat(response.usageMetadata()).isEmpty(); + return true; + }; + } + // Helper methods to create responses for testing private GenerateContentResponse toResponseWithText(String text) { @@ -191,4 +323,63 @@ private GenerateContentResponse toResponse(Part part) { private GenerateContentResponse toResponse(Candidate candidate) { return GenerateContentResponse.builder().candidates(candidate).build(); } + + private GenerateContentResponse toResponseWithText( + String text, GenerateContentResponseUsageMetadata usageMetadata) { + return GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(Part.fromText(text)).build()) + .build()) + .usageMetadata(usageMetadata) + .build(); + } + + private GenerateContentResponse toResponseWithText( + String text, + FinishReason.Known finishReason, + GenerateContentResponseUsageMetadata usageMetadata) { + return GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(Part.fromText(text)).build()) + .finishReason(new FinishReason(finishReason)) + .build()) + .usageMetadata(usageMetadata) + .build(); + } + + private GenerateContentResponse toResponseWithThoughtText( + String text, GenerateContentResponseUsageMetadata usageMetadata) { + Part thoughtPart = Part.fromText(text).toBuilder().thought(true).build(); + return GenerateContentResponse.builder() + .candidates( + Candidate.builder().content(Content.builder().parts(thoughtPart).build()).build()) + .usageMetadata(usageMetadata) + .build(); + } + + private GenerateContentResponse toResponseWithThoughtText( + String text, + FinishReason.Known finishReason, + GenerateContentResponseUsageMetadata usageMetadata) { + Part thoughtPart = Part.fromText(text).toBuilder().thought(true).build(); + return GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(thoughtPart).build()) + .finishReason(new FinishReason(finishReason)) + .build()) + .usageMetadata(usageMetadata) + .build(); + } + + private static GenerateContentResponseUsageMetadata createUsageMetadata( + int promptTokens, int candidateTokens, int totalTokens) { + return GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(promptTokens) + .candidatesTokenCount(candidateTokens) + .totalTokenCount(totalTokens) + .build(); + } }