diff --git a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java index 3983b08a5..ab10a6521 100644 --- a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java +++ b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java @@ -22,6 +22,7 @@ import com.google.adk.models.LlmResponse; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import java.net.URI; import java.util.ArrayList; @@ -32,6 +33,8 @@ import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptions; @@ -318,11 +321,27 @@ public LlmResponse toLlmResponse(ChatResponse chatResponse, boolean isStreaming) boolean isPartial = isStreaming && isPartialResponse(assistantMessage); boolean isTurnComplete = !isStreaming || isTurnCompleteResponse(chatResponse); - return LlmResponse.builder() - .content(content) - .partial(isPartial) - .turnComplete(isTurnComplete) - .build(); + LlmResponse.Builder responseBuilder = + LlmResponse.builder().content(content).partial(isPartial).turnComplete(isTurnComplete); + + if (chatResponse.getMetadata() != null + && chatResponse.getMetadata().getUsage() != null + && !(chatResponse.getMetadata().getUsage() instanceof EmptyUsage)) { + Usage springUsage = chatResponse.getMetadata().getUsage(); + + GenerateContentResponseUsageMetadata adkUsage = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(nullSafeInt(springUsage.getPromptTokens())) + .candidatesTokenCount(nullSafeInt(springUsage.getCompletionTokens())) + .totalTokenCount(nullSafeInt(springUsage.getTotalTokens())) + .build(); + responseBuilder.usageMetadata(adkUsage); + } + return responseBuilder.build(); + } + + private int nullSafeInt(Integer value) { + return value != null ? value.intValue() : 0; } /** Determines if an assistant message represents a partial response in streaming. */ diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java index b861a71f2..513b61179 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java @@ -16,7 +16,8 @@ package com.google.adk.models.springai; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.models.LlmRequest; @@ -33,6 +34,8 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; @@ -237,6 +240,70 @@ void testToLlmResponseFromChatResponseWithToolCalls() { assertThat(functionCallPart.functionCall().get().id()).contains("call_123"); } + @Test + void testUsageMetadataShouldBeEmptyWhenSpringAiMetadataIsNull() { + MessageConverter converter = new MessageConverter(new ObjectMapper()); + AssistantMessage assistantMessage = new AssistantMessage("intermediate chunk"); + Generation generation = new Generation(assistantMessage); + + ChatResponse chatResponse = new ChatResponse(List.of(generation), null); + + LlmResponse llmResponse = converter.toLlmResponse(chatResponse, true); + + assertTrue( + llmResponse.usageMetadata().isEmpty(), + "Expected usageMetadata to be empty for intermediate stream chunks lacking metadata"); + } + + @Test + void testUsageMetadataShouldBeEmptyWhenSpringAiUsageIsNull() { + MessageConverter converter = new MessageConverter(new ObjectMapper()); + AssistantMessage assistantMessage = new AssistantMessage("intermediate chunk"); + Generation generation = new Generation(assistantMessage); + + ChatResponseMetadata metadata = ChatResponseMetadata.builder().id("resp-no-usage").build(); + + ChatResponse chatResponse = new ChatResponse(List.of(generation), metadata); + + LlmResponse llmResponse = converter.toLlmResponse(chatResponse, true); + + assertTrue( + llmResponse.usageMetadata().isEmpty(), + "Expected usageMetadata to be empty when metadata exists but usage is null"); + } + + @Test + void testUsageMetadataShouldDefaultToZeroWhenSpringAiTokensAreNull() { + MessageConverter converter = new MessageConverter(new ObjectMapper()); + AssistantMessage assistantMessage = new AssistantMessage("final chunk"); + Generation generation = new Generation(assistantMessage); + + // Anonymous implementation to simulate incomplete provider data where some token counts are + // null + DefaultUsage incompleteUsage = new DefaultUsage(null, null, 42); + ChatResponseMetadata metadata = + ChatResponseMetadata.builder().id("resp-partial-tokens").usage(incompleteUsage).build(); + + ChatResponse chatResponse = new ChatResponse(List.of(generation), metadata); + + LlmResponse llmResponse = converter.toLlmResponse(chatResponse, false); + + assertTrue(llmResponse.usageMetadata().isPresent(), "Expected usageMetadata to be present"); + + assertEquals( + 0, + llmResponse.usageMetadata().get().promptTokenCount().orElse(-1), + "Null prompt tokens should default to 0"); + assertEquals( + 0, + llmResponse.usageMetadata().get().candidatesTokenCount().orElse(-1), + "Null completion tokens should default to 0"); + assertEquals( + 42, + llmResponse.usageMetadata().get().totalTokenCount().orElse(-1), + "Total tokens should be mapped correctly"); + } + @Test void testToolCallIdPreservedInConversion() { // Create AssistantMessage with tool call including ID