Skip to content

Commit d62a713

Browse files
authored
fix: do not use sdk marshaling; zero values suck (#70)
Signed-off-by: Danny Kopping <danny@coder.com>
1 parent 86ddf9f commit d62a713

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

bridge_integration_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ func TestSimple(t *testing.T) {
605605
// multiple messages in response to a single request.
606606
id, err := tc.getResponseIDFunc(streaming, resp)
607607
require.NoError(t, err, "failed to retrieve response ID")
608-
require.Nil(t, uuid.Validate(id), "id is not a UUID")
608+
require.Nilf(t, uuid.Validate(id), "%s is not a valid UUID", id)
609609

610610
require.GreaterOrEqual(t, len(recorderClient.tokenUsages), 1)
611611
require.Equal(t, recorderClient.tokenUsages[0].MsgID, tc.expectedMsgID)

intercept_anthropic_messages_streaming.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/coder/aibridge/mcp"
1616
"github.com/google/uuid"
1717
mcplib "github.com/mark3labs/mcp-go/mcp"
18+
"github.com/tidwall/sjson"
1819

1920
"cdr.dev/slog"
2021
)
@@ -390,8 +391,7 @@ newStream:
390391
}
391392

392393
// Overwrite response identifier since proxy obscures injected tool call invocations.
393-
event.Message.ID = i.ID().String()
394-
payload, err := i.marshal(event)
394+
payload, err := i.marshalEvent(event)
395395
if err != nil {
396396
logger.Warn(ctx, "failed to marshal event", slog.Error(err), slog.F("event", event.RawJSON()))
397397
lastErr = fmt.Errorf("marshal event: %w", err)
@@ -474,6 +474,20 @@ newStream:
474474
return interceptionErr
475475
}
476476

477+
func (s *AnthropicMessagesStreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) {
478+
sj, err := sjson.Set(event.RawJSON(), "message.id", s.ID().String())
479+
if err != nil {
480+
return nil, fmt.Errorf("marshal event id failed: %w", err)
481+
}
482+
483+
sj, err = sjson.Set(sj, "usage.output_tokens", event.Usage.OutputTokens)
484+
if err != nil {
485+
return nil, fmt.Errorf("marshal event usage failed: %w", err)
486+
}
487+
488+
return s.encodeForStream([]byte(sj), event.Type), nil
489+
}
490+
477491
func (s *AnthropicMessagesStreamingInterception) marshal(payload any) ([]byte, error) {
478492
data, err := json.Marshal(payload)
479493
if err != nil {

0 commit comments

Comments
 (0)