Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.Metrics;
using System.IO;
using System.Linq;
using System.Net.Http;
Expand Down Expand Up @@ -469,6 +470,64 @@ public async Task ShouldHandleStreamingThoughtPartsAsync()
Assert.Equal(" 42.", thirdMessage.Content);
}

[Fact]
public async Task ShouldEmitUsageMetricsOnceForStreamingResponseAsync()
{
// Arrange
var streamingResponse = """
data: {"candidates": [{"content": {"parts": [{"text": "One"}], "role": "model"}, "index": 0}], "usageMetadata": {"promptTokenCount": 101, "candidatesTokenCount": 202, "totalTokenCount": 303}}

data: {"candidates": [{"content": {"parts": [{"text": " two"}], "role": "model"}, "index": 0}], "usageMetadata": {"promptTokenCount": 101, "candidatesTokenCount": 202, "totalTokenCount": 303}}

data: {"candidates": [{"content": {"parts": [{"text": " three"}], "role": "model"}, "index": 0}], "usageMetadata": {"promptTokenCount": 101, "candidatesTokenCount": 202, "totalTokenCount": 303}}

data: {"candidates": [{"content": {"parts": [{"text": " four"}], "role": "model"}, "index": 0}], "usageMetadata": {"promptTokenCount": 101, "candidatesTokenCount": 202, "totalTokenCount": 303}}

data: {"candidates": [{"content": {"parts": [{"text": " five"}], "role": "model"}, "finishReason": "STOP", "index": 0}], "usageMetadata": {"promptTokenCount": 101, "candidatesTokenCount": 202, "totalTokenCount": 303}}

data: {"candidates": [{"content": {"parts": [{"text": ""}], "role": "model"}, "finishReason": "STOP", "index": 0}]}

""";

this._messageHandlerStub.ResponseToReturn.Content = new StringContent(streamingResponse);

var measurements = new Dictionary<string, List<int>>
{
["Microsoft.SemanticKernel.Connectors.Google.tokens.prompt"] = [],
["Microsoft.SemanticKernel.Connectors.Google.tokens.completion"] = [],
["Microsoft.SemanticKernel.Connectors.Google.tokens.total"] = [],
};

using MeterListener listener = new();
listener.InstrumentPublished = (instrument, listener) =>
{
if (measurements.ContainsKey(instrument.Name))
{
listener.EnableMeasurementEvents(instrument);
}
};
listener.SetMeasurementEventCallback<int>((instrument, measurement, tags, state) =>
{
if (measurements.TryGetValue(instrument.Name, out var instrumentMeasurements))
{
instrumentMeasurements.Add(measurement);
}
});
listener.Start();

var client = this.CreateChatCompletionClient();
var chatHistory = CreateSampleChatHistory();

// Act
var messages = await client.StreamGenerateChatMessageAsync(chatHistory).ToListAsync();

// Assert
Assert.Equal(6, messages.Count);
Assert.Equal(1, measurements["Microsoft.SemanticKernel.Connectors.Google.tokens.prompt"].Count(measurement => measurement == 101));
Assert.Equal(1, measurements["Microsoft.SemanticKernel.Connectors.Google.tokens.completion"].Count(measurement => measurement == 202));
Assert.Equal(1, measurements["Microsoft.SemanticKernel.Connectors.Google.tokens.total"].Count(measurement => measurement == 303));
}

private static ChatHistory CreateSampleChatHistory()
{
var chatHistory = new ChatHistory();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -772,13 +772,26 @@ private async IAsyncEnumerable<GeminiChatMessageContent> ProcessChatResponseStre
Stream responseStream,
[EnumeratorCancellation] CancellationToken ct)
{
List<GeminiChatMessageContent>? lastChatMessageContentsWithUsage = null;

await foreach (var response in this.ParseResponseStreamAsync(responseStream, ct: ct).ConfigureAwait(false))
{
foreach (var messageContent in this.ProcessChatResponse(response))
var chatMessageContents = this.CreateChatMessageContents(response);
if (HasTokenUsage(chatMessageContents))
{
lastChatMessageContentsWithUsage = chatMessageContents;
}

foreach (var messageContent in chatMessageContents)
{
yield return messageContent;
}
}

if (lastChatMessageContentsWithUsage is { } chatMessageContentsWithUsage)
{
this.LogUsage(chatMessageContentsWithUsage);
}
}

private async IAsyncEnumerable<GeminiResponse> ParseResponseStreamAsync(
Expand All @@ -793,13 +806,21 @@ private async IAsyncEnumerable<GeminiResponse> ParseResponseStreamAsync(

private List<GeminiChatMessageContent> ProcessChatResponse(GeminiResponse geminiResponse)
{
ValidateGeminiResponse(geminiResponse);

var chatMessageContents = this.GetChatMessageContentsFromResponse(geminiResponse);
var chatMessageContents = this.CreateChatMessageContents(geminiResponse);
this.LogUsage(chatMessageContents);
return chatMessageContents;
}

private List<GeminiChatMessageContent> CreateChatMessageContents(GeminiResponse geminiResponse)
{
ValidateGeminiResponse(geminiResponse);

return this.GetChatMessageContentsFromResponse(geminiResponse);
}

private static bool HasTokenUsage(List<GeminiChatMessageContent> chatMessageContents)
=> chatMessageContents.FirstOrDefault()?.Metadata is { TotalTokenCount: > 0 };

private static void ValidateGeminiResponse(GeminiResponse geminiResponse)
{
if (geminiResponse.PromptFeedback?.BlockReason is not null)
Expand All @@ -811,7 +832,7 @@ private static void ValidateGeminiResponse(GeminiResponse geminiResponse)

private void LogUsage(List<GeminiChatMessageContent> chatMessageContents)
{
GeminiMetadata? metadata = chatMessageContents[0].Metadata;
GeminiMetadata? metadata = chatMessageContents.FirstOrDefault()?.Metadata;

if (metadata is null || metadata.TotalTokenCount <= 0)
{
Expand Down
Loading