diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs index 90588f80d66e..896f0a111491 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.Metrics; using System.IO; using System.Linq; using System.Net.Http; @@ -469,6 +470,64 @@ public async Task ShouldHandleStreamingThoughtPartsAsync() Assert.Equal(" 42.", thirdMessage.Content); } + [Fact] + public async Task ShouldEmitUsageMetricsOnceForStreamingResponseAsync() + { + // Arrange + var streamingResponse = """ + data: {"candidates": [{"content": {"parts": [{"text": "One"}], "role": "model"}, "index": 0}], "usageMetadata": {"promptTokenCount": 101, "candidatesTokenCount": 202, "totalTokenCount": 303}} + + data: {"candidates": [{"content": {"parts": [{"text": " two"}], "role": "model"}, "index": 0}], "usageMetadata": {"promptTokenCount": 101, "candidatesTokenCount": 202, "totalTokenCount": 303}} + + data: {"candidates": [{"content": {"parts": [{"text": " three"}], "role": "model"}, "index": 0}], "usageMetadata": {"promptTokenCount": 101, "candidatesTokenCount": 202, "totalTokenCount": 303}} + + data: {"candidates": [{"content": {"parts": [{"text": " four"}], "role": "model"}, "index": 0}], "usageMetadata": {"promptTokenCount": 101, "candidatesTokenCount": 202, "totalTokenCount": 303}} + + data: {"candidates": [{"content": {"parts": [{"text": " five"}], "role": "model"}, "finishReason": "STOP", "index": 0}], "usageMetadata": {"promptTokenCount": 101, "candidatesTokenCount": 202, "totalTokenCount": 303}} + + data: {"candidates": [{"content": {"parts": [{"text": ""}], "role": "model"}, "finishReason": "STOP", "index": 0}]} + + """; + + this._messageHandlerStub.ResponseToReturn.Content = new StringContent(streamingResponse); + + var measurements = new Dictionary> + { + ["Microsoft.SemanticKernel.Connectors.Google.tokens.prompt"] = [], + ["Microsoft.SemanticKernel.Connectors.Google.tokens.completion"] = [], + ["Microsoft.SemanticKernel.Connectors.Google.tokens.total"] = [], + }; + + using MeterListener listener = new(); + listener.InstrumentPublished = (instrument, listener) => + { + if (measurements.ContainsKey(instrument.Name)) + { + listener.EnableMeasurementEvents(instrument); + } + }; + listener.SetMeasurementEventCallback((instrument, measurement, tags, state) => + { + if (measurements.TryGetValue(instrument.Name, out var instrumentMeasurements)) + { + instrumentMeasurements.Add(measurement); + } + }); + listener.Start(); + + var client = this.CreateChatCompletionClient(); + var chatHistory = CreateSampleChatHistory(); + + // Act + var messages = await client.StreamGenerateChatMessageAsync(chatHistory).ToListAsync(); + + // Assert + Assert.Equal(6, messages.Count); + Assert.Equal(1, measurements["Microsoft.SemanticKernel.Connectors.Google.tokens.prompt"].Count(measurement => measurement == 101)); + Assert.Equal(1, measurements["Microsoft.SemanticKernel.Connectors.Google.tokens.completion"].Count(measurement => measurement == 202)); + Assert.Equal(1, measurements["Microsoft.SemanticKernel.Connectors.Google.tokens.total"].Count(measurement => measurement == 303)); + } + private static ChatHistory CreateSampleChatHistory() { var chatHistory = new ChatHistory(); diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index 3c3501622b74..29e2ac22f01d 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -772,13 +772,26 @@ private async IAsyncEnumerable ProcessChatResponseStre Stream responseStream, [EnumeratorCancellation] CancellationToken ct) { + List? lastChatMessageContentsWithUsage = null; + await foreach (var response in this.ParseResponseStreamAsync(responseStream, ct: ct).ConfigureAwait(false)) { - foreach (var messageContent in this.ProcessChatResponse(response)) + var chatMessageContents = this.CreateChatMessageContents(response); + if (HasTokenUsage(chatMessageContents)) + { + lastChatMessageContentsWithUsage = chatMessageContents; + } + + foreach (var messageContent in chatMessageContents) { yield return messageContent; } } + + if (lastChatMessageContentsWithUsage is { } chatMessageContentsWithUsage) + { + this.LogUsage(chatMessageContentsWithUsage); + } } private async IAsyncEnumerable ParseResponseStreamAsync( @@ -793,13 +806,21 @@ private async IAsyncEnumerable ParseResponseStreamAsync( private List ProcessChatResponse(GeminiResponse geminiResponse) { - ValidateGeminiResponse(geminiResponse); - - var chatMessageContents = this.GetChatMessageContentsFromResponse(geminiResponse); + var chatMessageContents = this.CreateChatMessageContents(geminiResponse); this.LogUsage(chatMessageContents); return chatMessageContents; } + private List CreateChatMessageContents(GeminiResponse geminiResponse) + { + ValidateGeminiResponse(geminiResponse); + + return this.GetChatMessageContentsFromResponse(geminiResponse); + } + + private static bool HasTokenUsage(List chatMessageContents) + => chatMessageContents.FirstOrDefault()?.Metadata is { TotalTokenCount: > 0 }; + private static void ValidateGeminiResponse(GeminiResponse geminiResponse) { if (geminiResponse.PromptFeedback?.BlockReason is not null) @@ -811,7 +832,7 @@ private static void ValidateGeminiResponse(GeminiResponse geminiResponse) private void LogUsage(List chatMessageContents) { - GeminiMetadata? metadata = chatMessageContents[0].Metadata; + GeminiMetadata? metadata = chatMessageContents.FirstOrDefault()?.Metadata; if (metadata is null || metadata.TotalTokenCount <= 0) {