@@ -15,6 +15,7 @@ import (
1515 "github.com/coder/aibridge/mcp"
1616 "github.com/google/uuid"
1717 mcplib "github.com/mark3labs/mcp-go/mcp"
18+ "github.com/tidwall/sjson"
1819
1920 "cdr.dev/slog"
2021)
@@ -390,8 +391,7 @@ newStream:
390391 }
391392
392393 // Overwrite response identifier since proxy obscures injected tool call invocations.
393- event .Message .ID = i .ID ().String ()
394- payload , err := i .marshal (event )
394+ payload , err := i .marshalEvent (event )
395395 if err != nil {
396396 logger .Warn (ctx , "failed to marshal event" , slog .Error (err ), slog .F ("event" , event .RawJSON ()))
397397 lastErr = fmt .Errorf ("marshal event: %w" , err )
@@ -474,6 +474,20 @@ newStream:
474474 return interceptionErr
475475}
476476
477+ func (s * AnthropicMessagesStreamingInterception ) marshalEvent (event anthropic.MessageStreamEventUnion ) ([]byte , error ) {
478+ sj , err := sjson .Set (event .RawJSON (), "message.id" , s .ID ().String ())
479+ if err != nil {
480+ return nil , fmt .Errorf ("marshal event id failed: %w" , err )
481+ }
482+
483+ sj , err = sjson .Set (sj , "usage.output_tokens" , event .Usage .OutputTokens )
484+ if err != nil {
485+ return nil , fmt .Errorf ("marshal event usage failed: %w" , err )
486+ }
487+
488+ return s .encodeForStream ([]byte (sj ), event .Type ), nil
489+ }
490+
477491func (s * AnthropicMessagesStreamingInterception ) marshal (payload any ) ([]byte , error ) {
478492 data , err := json .Marshal (payload )
479493 if err != nil {
0 commit comments