Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions core/src/main/java/com/google/adk/models/Gemini.java
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stre
p ->
p.functionCall().isPresent()
|| p.functionResponse().isPresent()
|| p.text().map(t -> !t.isBlank()).orElse(false)))
|| p.text().isPresent()))
.orElse(false));
} else {
logger.debug("Sending generateContent request to model {}", effectiveModelName);
Expand Down Expand Up @@ -272,11 +272,17 @@ static Flowable<LlmResponse> processRawResponses(Flowable<GenerateContentRespons
if (part.get().thought().orElse(false)) {
accumulatedThoughtText.append(currentTextChunk);
responsesToEmit.add(
thinkingResponseFromText(currentTextChunk).toBuilder().partial(true).build());
thinkingResponseFromText(currentTextChunk).toBuilder()
.usageMetadata(currentProcessedLlmResponse.usageMetadata().orElse(null))
.partial(true)
.build());
} else {
accumulatedText.append(currentTextChunk);
responsesToEmit.add(
responseFromText(currentTextChunk).toBuilder().partial(true).build());
responseFromText(currentTextChunk).toBuilder()
.usageMetadata(currentProcessedLlmResponse.usageMetadata().orElse(null))
.partial(true)
.build());
}
} else {
if (accumulatedThoughtText.length() > 0
Expand Down Expand Up @@ -316,11 +322,20 @@ static Flowable<LlmResponse> processRawResponses(Flowable<GenerateContentRespons
List<LlmResponse> finalResponses = new ArrayList<>();
if (accumulatedThoughtText.length() > 0) {
finalResponses.add(
thinkingResponseFromText(accumulatedThoughtText.toString()));
thinkingResponseFromText(accumulatedThoughtText.toString()).toBuilder()
.usageMetadata(
accumulatedText.length() > 0
? null
: finalRawResp.usageMetadata().orElse(null))
.build());
}
if (accumulatedText.length() > 0) {
finalResponses.add(responseFromText(accumulatedText.toString()));
finalResponses.add(
responseFromText(accumulatedText.toString()).toBuilder()
.usageMetadata(finalRawResp.usageMetadata().orElse(null))
.build());
}

return Flowable.fromIterable(finalResponses);
}
return Flowable.empty();
Expand Down
191 changes: 191 additions & 0 deletions core/src/test/java/com/google/adk/models/GeminiTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.genai.types.Content;
import com.google.genai.types.FinishReason;
import com.google.genai.types.GenerateContentResponse;
import com.google.genai.types.GenerateContentResponseUsageMetadata;
import com.google.genai.types.Part;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.functions.Predicate;
Expand Down Expand Up @@ -123,6 +124,76 @@ public void processRawResponses_textThenEmpty_emitsPartialTextThenFullTextAndEmp
isEmptyResponse());
}

@Test
public void processRawResponses_withTextChunks_partialResponsesIncludeUsageMetadata() {
GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15);
GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25);
Flowable<GenerateContentResponse> rawResponses =
Flowable.just(
toResponseWithText("Hello", metadata1), toResponseWithText(" world", metadata2));

Flowable<LlmResponse> llmResponses = Gemini.processRawResponses(rawResponses);

assertLlmResponses(
llmResponses,
isPartialTextResponseWithUsageMetadata("Hello", metadata1),
isPartialTextResponseWithUsageMetadata(" world", metadata2));
}

@Test
public void processRawResponses_textAndStopReason_finalResponseIncludesUsageMetadata() {
GenerateContentResponseUsageMetadata metadata = createUsageMetadata(10, 20, 30);
Flowable<GenerateContentResponse> rawResponses =
Flowable.just(
toResponseWithText("Hello"),
toResponseWithText(" world", FinishReason.Known.STOP, metadata));

Flowable<LlmResponse> llmResponses = Gemini.processRawResponses(rawResponses);

assertLlmResponses(
llmResponses,
isPartialTextResponse("Hello"),
isPartialTextResponseWithUsageMetadata(" world", metadata),
isFinalTextResponseWithUsageMetadata("Hello world", metadata));
}

@Test
public void processRawResponses_thoughtChunksAndStop_includeUsageMetadata() {
GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15);
GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25);
Flowable<GenerateContentResponse> rawResponses =
Flowable.just(
toResponseWithThoughtText("Thinking", metadata1),
toResponseWithThoughtText(" deeply", FinishReason.Known.STOP, metadata2));

Flowable<LlmResponse> llmResponses = Gemini.processRawResponses(rawResponses);

assertLlmResponses(
llmResponses,
isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1),
isPartialThoughtResponseWithUsageMetadata(" deeply", metadata2),
isFinalThoughtResponseWithUsageMetadata("Thinking deeply", metadata2));
}

@Test
public void processRawResponses_thoughtAndTextWithStop_onlyFinalTextIncludesUsageMetadata() {
GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 5, 10);
GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(10, 20, 30);
Flowable<GenerateContentResponse> rawResponses =
Flowable.just(
toResponseWithThoughtText("Thinking", metadata1),
toResponseWithText("Answer", FinishReason.Known.STOP, metadata2));

Flowable<LlmResponse> llmResponses = Gemini.processRawResponses(rawResponses);

assertLlmResponses(
llmResponses,
isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1),
isPartialTextResponseWithUsageMetadata("Answer", metadata2),
isFinalThoughtResponseWithNoUsageMetadata("Thinking"),
isFinalTextResponseWithUsageMetadata("Answer", metadata2));
}

// Helper methods for assertions

private void assertLlmResponses(
Expand Down Expand Up @@ -170,6 +241,67 @@ private static Predicate<LlmResponse> isEmptyResponse() {
};
}

private static Predicate<LlmResponse> isPartialTextResponseWithUsageMetadata(
String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) {
return response -> {
assertThat(response.partial()).hasValue(true);
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
.isEqualTo(expectedText);
assertThat(response.usageMetadata()).hasValue(expectedMetadata);
return true;
};
}

private static Predicate<LlmResponse> isPartialThoughtResponseWithUsageMetadata(
String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) {
return response -> {
assertThat(response.partial()).hasValue(true);
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
.isEqualTo(expectedText);
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false))
.isTrue();
assertThat(response.usageMetadata()).hasValue(expectedMetadata);
return true;
};
}

private static Predicate<LlmResponse> isFinalTextResponseWithUsageMetadata(
String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) {
return response -> {
assertThat(response.partial()).isEmpty();
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
.isEqualTo(expectedText);
assertThat(response.usageMetadata()).hasValue(expectedMetadata);
return true;
};
}

private static Predicate<LlmResponse> isFinalThoughtResponseWithUsageMetadata(
String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) {
return response -> {
assertThat(response.partial()).isEmpty();
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
.isEqualTo(expectedText);
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false))
.isTrue();
assertThat(response.usageMetadata()).hasValue(expectedMetadata);
return true;
};
}

private static Predicate<LlmResponse> isFinalThoughtResponseWithNoUsageMetadata(
String expectedText) {
return response -> {
assertThat(response.partial()).isEmpty();
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
.isEqualTo(expectedText);
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false))
.isTrue();
assertThat(response.usageMetadata()).isEmpty();
return true;
};
}

// Helper methods to create responses for testing

private GenerateContentResponse toResponseWithText(String text) {
Expand All @@ -191,4 +323,63 @@ private GenerateContentResponse toResponse(Part part) {
private GenerateContentResponse toResponse(Candidate candidate) {
return GenerateContentResponse.builder().candidates(candidate).build();
}

private GenerateContentResponse toResponseWithText(
String text, GenerateContentResponseUsageMetadata usageMetadata) {
return GenerateContentResponse.builder()
.candidates(
Candidate.builder()
.content(Content.builder().parts(Part.fromText(text)).build())
.build())
.usageMetadata(usageMetadata)
.build();
}

private GenerateContentResponse toResponseWithText(
String text,
FinishReason.Known finishReason,
GenerateContentResponseUsageMetadata usageMetadata) {
return GenerateContentResponse.builder()
.candidates(
Candidate.builder()
.content(Content.builder().parts(Part.fromText(text)).build())
.finishReason(new FinishReason(finishReason))
.build())
.usageMetadata(usageMetadata)
.build();
}

private GenerateContentResponse toResponseWithThoughtText(
String text, GenerateContentResponseUsageMetadata usageMetadata) {
Part thoughtPart = Part.fromText(text).toBuilder().thought(true).build();
return GenerateContentResponse.builder()
.candidates(
Candidate.builder().content(Content.builder().parts(thoughtPart).build()).build())
.usageMetadata(usageMetadata)
.build();
}

private GenerateContentResponse toResponseWithThoughtText(
String text,
FinishReason.Known finishReason,
GenerateContentResponseUsageMetadata usageMetadata) {
Part thoughtPart = Part.fromText(text).toBuilder().thought(true).build();
return GenerateContentResponse.builder()
.candidates(
Candidate.builder()
.content(Content.builder().parts(thoughtPart).build())
.finishReason(new FinishReason(finishReason))
.build())
.usageMetadata(usageMetadata)
.build();
}

private static GenerateContentResponseUsageMetadata createUsageMetadata(
int promptTokens, int candidateTokens, int totalTokens) {
return GenerateContentResponseUsageMetadata.builder()
.promptTokenCount(promptTokens)
.candidatesTokenCount(candidateTokens)
.totalTokenCount(totalTokens)
.build();
}
}