diff --git a/bridge_integration_test.go b/bridge_integration_test.go index f8ef57e..7fb7885 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -605,7 +605,7 @@ func TestSimple(t *testing.T) { // multiple messages in response to a single request. id, err := tc.getResponseIDFunc(streaming, resp) require.NoError(t, err, "failed to retrieve response ID") - require.Nil(t, uuid.Validate(id), "id is not a UUID") + require.Nilf(t, uuid.Validate(id), "%s is not a valid UUID", id) require.GreaterOrEqual(t, len(recorderClient.tokenUsages), 1) require.Equal(t, recorderClient.tokenUsages[0].MsgID, tc.expectedMsgID) diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index ef8aabd..5f2393a 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -15,6 +15,7 @@ import ( "github.com/coder/aibridge/mcp" "github.com/google/uuid" mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/tidwall/sjson" "cdr.dev/slog" ) @@ -390,8 +391,7 @@ newStream: } // Overwrite response identifier since proxy obscures injected tool call invocations. - event.Message.ID = i.ID().String() - payload, err := i.marshal(event) + payload, err := i.marshalEvent(event) if err != nil { logger.Warn(ctx, "failed to marshal event", slog.Error(err), slog.F("event", event.RawJSON())) lastErr = fmt.Errorf("marshal event: %w", err) @@ -474,6 +474,20 @@ newStream: return interceptionErr } +func (s *AnthropicMessagesStreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) { + sj, err := sjson.Set(event.RawJSON(), "message.id", s.ID().String()) + if err != nil { + return nil, fmt.Errorf("marshal event id failed: %w", err) + } + + sj, err = sjson.Set(sj, "usage.output_tokens", event.Usage.OutputTokens) + if err != nil { + return nil, fmt.Errorf("marshal event usage failed: %w", err) + } + + return s.encodeForStream([]byte(sj), event.Type), nil +} + func (s *AnthropicMessagesStreamingInterception) marshal(payload any) ([]byte, error) { data, err := json.Marshal(payload) if err != nil {