From 5f8e36ac0685a84d22c6ba7b273dabff12710921 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Wed, 19 Nov 2025 17:31:20 +0000 Subject: [PATCH 01/16] feat: add interception tracing --- aibtrace/aibtrace.go | 44 ++ bridge.go | 7 +- bridge_integration_test.go | 87 ++-- go.mod | 14 +- go.sum | 23 +- intercept_anthropic_messages_base.go | 15 + intercept_anthropic_messages_blocking.go | 31 +- intercept_anthropic_messages_streaming.go | 29 +- intercept_openai_chat_base.go | 23 +- intercept_openai_chat_blocking.go | 28 +- intercept_openai_chat_streaming.go | 28 +- interception.go | 48 +- mcp/tool.go | 16 +- metrics_integration_test.go | 25 +- passthrough.go | 13 +- provider.go | 4 +- provider_anthropic.go | 21 +- provider_openai.go | 20 +- recorder.go | 65 ++- trace_integration_test.go | 514 ++++++++++++++++++++++ 20 files changed, 914 insertions(+), 141 deletions(-) create mode 100644 aibtrace/aibtrace.go create mode 100644 trace_integration_test.go diff --git a/aibtrace/aibtrace.go b/aibtrace/aibtrace.go new file mode 100644 index 0000000..6d2cde0 --- /dev/null +++ b/aibtrace/aibtrace.go @@ -0,0 +1,44 @@ +package aibtrace + +import ( + "context" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +type traceInterceptionAttrsContextKey struct{} + +const ( + // trace attribute key constants + InterceptionID = "interception_id" + UserID = "user_id" + Provider = "provider" + Model = "model" + Streaming = "streaming" + IsBedrock = "aws_bedrock" + MCPToolName = "mcp_tool_name" + PassthroughURL = "passthrough_url" + PassthroughMethod = "passthrough_method" +) + +func WithTraceInterceptionAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context { + return context.WithValue(ctx, traceInterceptionAttrsContextKey{}, traceAttrs) +} + +func TraceInterceptionAttributesFromContext(ctx context.Context) []attribute.KeyValue { + attrs, ok := ctx.Value(traceInterceptionAttrsContextKey{}).([]attribute.KeyValue) + if !ok { + return nil + } + + return attrs +} + +func EndSpanErr(span trace.Span, err *error) { + if err != nil && *err != nil { + span.SetStatus(codes.Error, (*err).Error()) + } + span.End() +} diff --git a/bridge.go b/bridge.go index 4f5428d..0f723ea 100644 --- a/bridge.go +++ b/bridge.go @@ -9,6 +9,7 @@ import ( "cdr.dev/slog" "github.com/coder/aibridge/mcp" + "go.opentelemetry.io/otel/trace" "github.com/hashicorp/go-multierror" ) @@ -47,20 +48,20 @@ var _ http.Handler = &RequestBridge{} // A [Recorder] is also required to record prompt, tool, and token use. // // mcpProxy will be closed when the [RequestBridge] is closed. -func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, logger slog.Logger) (*RequestBridge, error) { +func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, tracer trace.Tracer, logger slog.Logger) (*RequestBridge, error) { mux := http.NewServeMux() for _, provider := range providers { // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics)) + mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics, tracer)) } // Any requests which passthrough to this will be reverse-proxied to the upstream. // // We have to whitelist the known-safe routes because an API key with elevated privileges (i.e. admin) might be // configured, so we should just reverse-proxy known-safe routes. - ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics) + ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics, tracer) for _, path := range provider.PassthroughRoutes() { prefix := fmt.Sprintf("/%s", provider.Name()) route := fmt.Sprintf("%s%s", prefix, path) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index f8ef57e..2ea570a 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -20,23 +20,22 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" - "go.uber.org/goleak" - "golang.org/x/tools/txtar" - "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/coder/aibridge" "github.com/coder/aibridge/mcp" "github.com/google/uuid" + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/openai/openai-go/v2" + oaissestream "github.com/openai/openai-go/v2/packages/ssestream" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" - - "github.com/openai/openai-go/v2" - oaissestream "github.com/openai/openai-go/v2/packages/ssestream" - - mcplib "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/tidwall/sjson" + "go.opentelemetry.io/otel" + "go.uber.org/goleak" + "golang.org/x/tools/txtar" ) var ( @@ -65,6 +64,8 @@ var ( oaiMidStreamErr []byte //go:embed fixtures/openai/non_stream_error.txtar oaiNonStreamErr []byte + + defaultTracer = otel.Tracer("github.com/coder/aibridge") ) const ( @@ -90,8 +91,9 @@ func TestAnthropicMessages(t *testing.T) { t.Parallel() cases := []struct { - streaming bool - expectedInputTokens, expectedOutputTokens int + streaming bool + expectedInputTokens int + expectedOutputTokens int }{ { streaming: true, @@ -133,7 +135,8 @@ func TestAnthropicMessages(t *testing.T) { recorderClient := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)} + b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -214,7 +217,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{ aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg), - }, recorderClient, mcp.NewServerProxyManager(nil), nil, logger) + }, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -312,7 +315,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge( ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), bedrockCfg)}, - recorderClient, mcp.NewServerProxyManager(nil), nil, logger) + recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) require.NoError(t, err) mockBridgeSrv := httptest.NewUnstartedServer(b) @@ -399,7 +402,8 @@ func TestOpenAIChatCompletions(t *testing.T) { recorderClient := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))} + b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -466,7 +470,8 @@ func TestSimple(t *testing.T) { fixture: antSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger) + provider := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -504,7 +509,8 @@ func TestSimple(t *testing.T) { fixture: oaiSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -618,17 +624,8 @@ func TestSimple(t *testing.T) { } func setJSON(in []byte, key string, val bool) ([]byte, error) { - var body map[string]any - err := json.Unmarshal(in, &body) - if err != nil { - return nil, err - } - body[key] = val - out, err := json.Marshal(body) - if err != nil { - return nil, err - } - return out, nil + out, err := sjson.Set(string(in), key, val) + return []byte(out), err } func TestFallthrough(t *testing.T) { @@ -645,7 +642,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) require.NoError(t, err) return provider, bridge }, @@ -656,7 +653,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey)) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) require.NoError(t, err) return provider, bridge }, @@ -762,7 +759,8 @@ func TestAnthropicInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) } // Build the requirements & make the assertions which are common to all providers. @@ -843,7 +841,8 @@ func TestOpenAIInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) } // Build the requirements & make the assertions which are common to all providers. @@ -1029,7 +1028,8 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) }, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1046,7 +1046,8 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) }, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1134,7 +1135,8 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) }, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1152,7 +1154,8 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) }, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1238,7 +1241,8 @@ func TestStableRequestEncoding(t *testing.T) { fixture: antSimple, createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) }, }, { @@ -1246,7 +1250,8 @@ func TestStableRequestEncoding(t *testing.T) { fixture: oaiSimple, createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) }, }, } @@ -1352,7 +1357,8 @@ func TestEnvironmentDoNotLeak(t *testing.T) { fixture: antSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) }, createRequest: createAnthropicMessagesReq, envVars: map[string]string{ @@ -1365,7 +1371,8 @@ func TestEnvironmentDoNotLeak(t *testing.T) { fixture: oaiSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) }, createRequest: createOpenAIChatCompletionsReq, envVars: map[string]string{ diff --git a/go.mod b/go.mod index 47fd45d..cfe3a4a 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,13 @@ require ( github.com/openai/openai-go/v2 v2.7.0 ) +require ( + github.com/google/go-cmp v0.7.0 + go.opentelemetry.io/otel v1.38.0 + go.opentelemetry.io/otel/sdk v1.38.0 + go.opentelemetry.io/otel/trace v1.38.0 +) + require ( github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect @@ -46,6 +53,8 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charmbracelet/lipgloss v0.7.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect @@ -61,14 +70,13 @@ require ( github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect github.com/rivo/uniseg v0.4.4 // indirect - github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/spf13/cast v1.7.1 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect - go.opentelemetry.io/otel v1.33.0 // indirect - go.opentelemetry.io/otel/trace v1.33.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel/metric v1.38.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/term v0.34.0 // indirect diff --git a/go.sum b/go.sum index d0b79c8..385345d 100644 --- a/go.sum +++ b/go.sum @@ -53,8 +53,9 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= -github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= -github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -130,14 +131,16 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/otel v1.33.0 h1:/FerN9bax5LoK51X/sI0SVYrjSE0/yUL7DpxW4K3FWw= -go.opentelemetry.io/otel v1.33.0/go.mod h1:SUUkR6csvUQl+yjReHu5uM3EtVV7MBm5FHKRlNx4I8I= -go.opentelemetry.io/otel/metric v1.33.0 h1:r+JOocAyeRVXD8lZpjdQjzMadVZp2M4WmQ+5WtEnklQ= -go.opentelemetry.io/otel/metric v1.33.0/go.mod h1:L9+Fyctbp6HFTddIxClbQkjtubW6O9QS3Ann/M82u6M= -go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiMWgE= -go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4= -go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s= -go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= +go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= +go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= +go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= +go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 5049e54..0ee12ea 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -14,8 +14,11 @@ import ( "github.com/anthropics/anthropic-sdk-go/option" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "github.com/google/uuid" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) @@ -27,6 +30,7 @@ type AnthropicMessagesInterceptionBase struct { cfg AnthropicConfig bedrockCfg *AWSBedrockConfig + tracer trace.Tracer logger slog.Logger recorder Recorder @@ -59,6 +63,17 @@ func (i *AnthropicMessagesInterceptionBase) Model() string { return string(i.req.Model) } +func (s *AnthropicMessagesInterceptionBase) baseTraceAttributes(ctx context.Context, streaming bool) []attribute.KeyValue { + return []attribute.KeyValue{ + attribute.String(aibtrace.Provider, ProviderAnthropic), + attribute.String(aibtrace.InterceptionID, s.id.String()), + attribute.String(aibtrace.Model, s.Model()), + attribute.String(aibtrace.UserID, actorFromContext(ctx).id), + attribute.Bool(aibtrace.Streaming, streaming), + attribute.Bool(aibtrace.IsBedrock, s.bedrockCfg != nil), + } +} + func (i *AnthropicMessagesInterceptionBase) injectTools() { if i.req == nil || i.mcpProxy == nil { return diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index a1f71e6..0773a13 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -1,6 +1,7 @@ package aibridge import ( + "context" "encoding/json" "fmt" "net/http" @@ -10,7 +11,10 @@ import ( "github.com/anthropics/anthropic-sdk-go/option" "github.com/google/uuid" mcplib "github.com/mark3labs/mcp-go/mcp" // TODO: abstract this away so callers need no knowledge of underlying lib. + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "cdr.dev/slog" @@ -22,29 +26,35 @@ type AnthropicMessagesBlockingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicMessagesBlockingInterception { +func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig, tracer trace.Tracer) *AnthropicMessagesBlockingInterception { return &AnthropicMessagesBlockingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ id: id, req: req, cfg: cfg, bedrockCfg: bedrockCfg, + tracer: tracer, }} } -func (s *AnthropicMessagesBlockingInterception) Setup(logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) { - s.AnthropicMessagesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy) +func (i *AnthropicMessagesBlockingInterception) Setup(logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) { + i.AnthropicMessagesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy) +} + +func (i *AnthropicMessagesBlockingInterception) TraceAttributes(ctx context.Context) []attribute.KeyValue { + return i.AnthropicMessagesInterceptionBase.baseTraceAttributes(ctx, false) } func (s *AnthropicMessagesBlockingInterception) Streaming() bool { return false } -func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { +func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { return fmt.Errorf("developer error: req is nil") } - ctx := r.Context() + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(r.Context())...)) + defer aibtrace.EndSpanErr(span, &outErr) i.injectTools() @@ -77,7 +87,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr var cumulativeUsage anthropic.Usage for { - resp, err = svc.New(ctx, messages) + resp, err = i.traceNewMessage(ctx, svc, messages) // traces client.Messages.New(ctx, msgParams) call if err != nil { if isConnError(err) { // Can't write a response, just error out. @@ -166,7 +176,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr continue } - res, err := tool.Call(ctx, tc.Input) + res, err := tool.Call(ctx, i.tracer, tc.Input) _ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{ InterceptionID: i.ID().String(), @@ -285,3 +295,10 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr return nil } + +func (i *AnthropicMessagesBlockingInterception) traceNewMessage(ctx context.Context, svc anthropic.MessageService, msgParams anthropic.MessageNewParams) (_ *anthropic.Message, outErr error) { + ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + + return svc.New(ctx, msgParams) +} diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index ef8aabd..139f7c2 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -11,10 +11,14 @@ import ( "time" "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/anthropics/anthropic-sdk-go/shared/constant" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "github.com/google/uuid" mcplib "github.com/mark3labs/mcp-go/mcp" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) @@ -25,12 +29,13 @@ type AnthropicMessagesStreamingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicMessagesStreamingInterception { +func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig, tracer trace.Tracer) *AnthropicMessagesStreamingInterception { return &AnthropicMessagesStreamingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ id: id, req: req, cfg: cfg, bedrockCfg: bedrockCfg, + tracer: tracer, }} } @@ -42,6 +47,10 @@ func (s *AnthropicMessagesStreamingInterception) Streaming() bool { return true } +func (s *AnthropicMessagesStreamingInterception) TraceAttributes(ctx context.Context) []attribute.KeyValue { + return s.AnthropicMessagesInterceptionBase.baseTraceAttributes(ctx, true) +} + // ProcessRequest handles a request to /v1/messages. // This API has a state-machine behind it, which is described in https://docs.claude.com/en/docs/build-with-claude/streaming#event-types. // @@ -61,13 +70,16 @@ func (s *AnthropicMessagesStreamingInterception) Streaming() bool { // b) if the tool is injected, it will be invoked by the [mcp.ServerProxier] in the remote MCP server, and its // results relayed to the SERVER. The response from the server will be handled synchronously, and this loop // can continue until all injected tool invocations are completed and the response is relayed to the client. -func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { +func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { return fmt.Errorf("developer error: req is nil") } + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(r.Context())...)) + defer aibtrace.EndSpanErr(span, &outErr) + // Allow us to interrupt watch via cancel. - ctx, cancel := context.WithCancel(r.Context()) + ctx, cancel := context.WithCancel(ctx) defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. @@ -129,7 +141,7 @@ newStream: pendingToolCalls := make(map[string]string) - for stream.Next() { + for i.traceStreamNext(ctx, stream) { // traces stream.Next() call event := stream.Current() if err := message.Accumulate(event); err != nil { logger.Warn(ctx, "failed to accumulate streaming events", slog.Error(err), slog.F("event", event), slog.F("msg", message.RawJSON())) @@ -269,7 +281,7 @@ newStream: continue } - res, err := tool.Call(streamCtx, input) + res, err := tool.Call(streamCtx, i.tracer, input) _ = i.recorder.RecordToolUsage(streamCtx, &ToolUsageRecord{ InterceptionID: i.ID().String(), @@ -508,3 +520,10 @@ func (s *AnthropicMessagesStreamingInterception) encodeForStream(payload []byte, buf.WriteString("\n\n") return buf.Bytes() } + +func (s *AnthropicMessagesStreamingInterception) traceStreamNext(ctx context.Context, stream *ssestream.Stream[anthropic.MessageStreamEventUnion]) bool { + _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer span.End() + + return stream.Next() +} diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 20db323..f81b245 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -6,21 +6,26 @@ import ( "net/http" "strings" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" "github.com/openai/openai-go/v2/shared" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) type OpenAIChatInterceptionBase struct { - id uuid.UUID - req *ChatCompletionNewParamsWrapper + id uuid.UUID + req *ChatCompletionNewParamsWrapper + baseURL string + key string - baseURL, key string - logger slog.Logger + tracer trace.Tracer + logger slog.Logger recorder Recorder mcpProxy mcp.ServerProxier @@ -42,6 +47,16 @@ func (i *OpenAIChatInterceptionBase) Setup(logger slog.Logger, recorder Recorder i.mcpProxy = mcpProxy } +func (s *OpenAIChatInterceptionBase) baseTraceAttributes(ctx context.Context, streaming bool) []attribute.KeyValue { + return []attribute.KeyValue{ + attribute.String(aibtrace.Provider, ProviderOpenAI), + attribute.String(aibtrace.InterceptionID, s.id.String()), + attribute.String(aibtrace.Model, s.Model()), + attribute.String(aibtrace.UserID, actorFromContext(ctx).id), + attribute.Bool(aibtrace.Streaming, streaming), + } +} + func (i *OpenAIChatInterceptionBase) Model() string { if i.req == nil { return "coder-aibridge-unknown" diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 757c933..f4d1db3 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -2,16 +2,20 @@ package aibridge import ( "bytes" + "context" "encoding/json" "fmt" "net/http" "strings" "time" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) @@ -22,12 +26,13 @@ type OpenAIBlockingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIBlockingChatInterception { +func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string, tracer trace.Tracer) *OpenAIBlockingChatInterception { return &OpenAIBlockingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ id: id, req: req, baseURL: baseURL, key: key, + tracer: tracer, }} } @@ -39,12 +44,18 @@ func (s *OpenAIBlockingChatInterception) Streaming() bool { return false } -func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { +func (s *OpenAIBlockingChatInterception) TraceAttributes(ctx context.Context) []attribute.KeyValue { + return s.OpenAIChatInterceptionBase.baseTraceAttributes(ctx, false) +} + +func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { return fmt.Errorf("developer error: req is nil") } - ctx := r.Context() + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(r.Context())...)) + defer aibtrace.EndSpanErr(span, &outErr) + svc := i.newCompletionsService(i.baseURL, i.key) logger := i.logger.With(slog.F("model", i.req.Model)) @@ -65,7 +76,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r var opts []option.RequestOption opts = append(opts, option.WithRequestTimeout(time.Second*60)) // TODO: configurable timeout - completion, err = svc.New(ctx, i.req.ChatCompletionNewParams, opts...) + completion, err = i.traceChatCompletionsNew(ctx, svc, opts) // traces svc.New call if err != nil { break } @@ -145,7 +156,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r ) _ = json.NewEncoder(&buf).Encode(tc.Function.Arguments) _ = json.NewDecoder(&buf).Decode(&args) - res, err := tool.Call(ctx, args) + res, err := tool.Call(ctx, i.tracer, args) _ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{ InterceptionID: i.ID().String(), @@ -227,3 +238,10 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r return nil } + +func (i *OpenAIBlockingChatInterception) traceChatCompletionsNew(ctx context.Context, client openai.ChatCompletionService, opts []option.RequestOption) (_ *openai.ChatCompletion, outErr error) { + ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + + return client.New(ctx, i.req.ChatCompletionNewParams, opts...) +} diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index ccabb35..f1ec165 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -10,11 +10,14 @@ import ( "strings" "time" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/packages/ssestream" "github.com/tidwall/sjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) @@ -25,12 +28,13 @@ type OpenAIStreamingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIStreamingChatInterception { +func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string, tracer trace.Tracer) *OpenAIStreamingChatInterception { return &OpenAIStreamingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ id: id, req: req, baseURL: baseURL, key: key, + tracer: tracer, }} } @@ -42,6 +46,10 @@ func (i *OpenAIStreamingChatInterception) Streaming() bool { return true } +func (s *OpenAIStreamingChatInterception) TraceAttributes(ctx context.Context) []attribute.KeyValue { + return s.OpenAIChatInterceptionBase.baseTraceAttributes(ctx, true) +} + // ProcessRequest handles a request to /v1/chat/completions. // See https://platform.openai.com/docs/api-reference/chat-streaming/streaming. // @@ -54,18 +62,21 @@ func (i *OpenAIStreamingChatInterception) Streaming() bool { // b) if the tool is injected, it will be invoked by the [mcp.ServerProxier] in the remote MCP server, and its // results relayed to the SERVER. The response from the server will be handled synchronously, and this loop // can continue until all injected tool invocations are completed and the response is relayed to the client. -func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { +func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { return fmt.Errorf("developer error: req is nil") } + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(r.Context())...)) + defer aibtrace.EndSpanErr(span, &outErr) + // Include token usage. i.req.StreamOptions.IncludeUsage = openai.Bool(true) i.injectTools() // Allow us to interrupt watch via cancel. - ctx, cancel := context.WithCancel(r.Context()) + ctx, cancel := context.WithCancel(ctx) defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. @@ -109,7 +120,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, var toolCall *openai.FinishedChatCompletionToolCall - for stream.Next() { + for i.traceStreamNext(ctx, stream) { // traces stream.Next() call chunk := stream.Current() canRelay := processor.process(chunk) @@ -230,7 +241,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, id := toolCall.ID args := i.unmarshalArgs(toolCall.Arguments) - toolRes, toolErr := tool.Call(streamCtx, args) + toolRes, toolErr := tool.Call(streamCtx, i.tracer, args) _ = i.recorder.RecordToolUsage(streamCtx, &ToolUsageRecord{ InterceptionID: i.ID().String(), @@ -336,6 +347,13 @@ func (i *OpenAIStreamingChatInterception) encodeForStream(payload []byte) []byte return buf.Bytes() } +func (i *OpenAIStreamingChatInterception) traceStreamNext(ctx context.Context, stream *ssestream.Stream[openai.ChatCompletionChunk]) bool { + _, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer span.End() + + return stream.Next() +} + type openAIStreamProcessor struct { ctx context.Context logger slog.Logger diff --git a/interception.go b/interception.go index 8210c41..2f8ba28 100644 --- a/interception.go +++ b/interception.go @@ -1,6 +1,7 @@ package aibridge import ( + "context" "errors" "fmt" "net/http" @@ -8,8 +9,12 @@ import ( "time" "cdr.dev/slog" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" "github.com/google/uuid" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) // Interceptor describes a (potentially) stateful interaction with an AI provider. @@ -25,6 +30,9 @@ type Interceptor interface { ProcessRequest(w http.ResponseWriter, r *http.Request) error // Specifies whether an interceptor handles streaming or not. Streaming() bool + + // TraceAttributes returns tacing attributes for this [Inteceptor] + TraceAttributes(context.Context) []attribute.KeyValue } var UnknownRoute = errors.New("unknown route") @@ -34,11 +42,15 @@ const recordingTimeout = time.Second * 5 // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] recorder. -func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics) http.HandlerFunc { +func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, tracer trace.Tracer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - interceptor, err := p.CreateInterceptor(w, r) + ctx, span := tracer.Start(r.Context(), "Intercept") + defer span.End() + + interceptor, err := p.CreateInterceptor(tracer, w, r.WithContext(ctx)) if err != nil { - logger.Warn(r.Context(), "failed to create interceptor", slog.Error(err), slog.F("path", r.URL.Path)) + span.SetStatus(codes.Error, fmt.Sprintf("failed to create interceptor: %v", err)) + logger.Warn(ctx, "failed to create interceptor", slog.Error(err), slog.F("path", r.URL.Path)) http.Error(w, fmt.Sprintf("failed to create %q interceptor", r.URL.Path), http.StatusInternalServerError) return } @@ -50,13 +62,18 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, }() } - actor := actorFromContext(r.Context()) + actor := actorFromContext(ctx) if actor == nil { - logger.Warn(r.Context(), "no actor found in context") + logger.Warn(ctx, "no actor found in context") http.Error(w, "no actor found", http.StatusBadRequest) return } + traceAttrs := interceptor.TraceAttributes(ctx) + span.SetAttributes(traceAttrs...) + ctx = aibtrace.WithTraceInterceptionAttributesInContext(ctx, traceAttrs) + r = r.WithContext(ctx) + // Record usage in the background to not block request flow. asyncRecorder := NewAsyncRecorder(logger, recorder, recordingTimeout) asyncRecorder.WithMetrics(metrics) @@ -65,14 +82,15 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, asyncRecorder.WithInitiatorID(actor.id) interceptor.Setup(logger, asyncRecorder, mcpProxy) - if err := recorder.RecordInterception(r.Context(), &InterceptionRecord{ + if err := recorder.RecordInterception(ctx, &InterceptionRecord{ ID: interceptor.ID().String(), Metadata: actor.metadata, InitiatorID: actor.id, Provider: p.Name(), Model: interceptor.Model(), }); err != nil { - logger.Warn(r.Context(), "failed to record interception", slog.Error(err)) + span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err)) + logger.Warn(ctx, "failed to record interception", slog.Error(err)) http.Error(w, "failed to record interception", http.StatusInternalServerError) return } @@ -86,27 +104,27 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, slog.F("streaming", interceptor.Streaming()), ) - log.Debug(r.Context(), "interception started") + log.Debug(ctx, "interception started") if metrics != nil { metrics.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Add(1) + defer func() { + metrics.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Sub(1) + }() } if err := interceptor.ProcessRequest(w, r); err != nil { if metrics != nil { metrics.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), InterceptionCountStatusFailed, route, r.Method, actor.id).Add(1) } - log.Warn(r.Context(), "interception failed", slog.Error(err)) + span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) + log.Warn(ctx, "interception failed", slog.Error(err)) } else { if metrics != nil { metrics.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), InterceptionCountStatusCompleted, route, r.Method, actor.id).Add(1) } - log.Debug(r.Context(), "interception ended") - } - asyncRecorder.RecordInterceptionEnded(r.Context(), &InterceptionRecordEnded{ID: interceptor.ID().String()}) - - if metrics != nil { - metrics.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Sub(1) + log.Debug(ctx, "interception ended") } + asyncRecorder.RecordInterceptionEnded(ctx, &InterceptionRecordEnded{ID: interceptor.ID().String()}) // Ensure all recording have completed before completing request. asyncRecorder.Wait() diff --git a/mcp/tool.go b/mcp/tool.go index 2c01535..23dffc4 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -7,7 +7,10 @@ import ( "strings" "cdr.dev/slog" + "github.com/coder/aibridge/aibtrace" "github.com/mark3labs/mcp-go/mcp" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) const ( @@ -34,14 +37,21 @@ type Tool struct { Required []string } -func (t *Tool) Call(ctx context.Context, input any) (*mcp.CallToolResult, error) { +func (t *Tool) Call(ctx context.Context, tracer trace.Tracer, input any) (_ *mcp.CallToolResult, outErr error) { if t == nil { - return nil, errors.New("nil tool!") + return nil, errors.New("nil tool") } if t.Client == nil { - return nil, errors.New("nil client!") + return nil, errors.New("nil client") } + spanAttrs := append( + aibtrace.TraceInterceptionAttributesFromContext(ctx), + attribute.String(aibtrace.MCPToolName, t.Name), + ) + ctx, span := tracer.Start(ctx, "Intercept.RecordInterception.ToolCall", trace.WithAttributes(spanAttrs...)) + defer aibtrace.EndSpanErr(span, &outErr) + return t.Client.CallTool(ctx, mcp.CallToolRequest{ Params: mcp.CallToolParams{ Name: t.Name, diff --git a/metrics_integration_test.go b/metrics_integration_test.go index 3696de2..1b17c6f 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -16,6 +16,7 @@ import ( "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace" "golang.org/x/tools/txtar" ) @@ -48,7 +49,7 @@ func TestMetrics_Interception(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) - srv := newTestSrv(t, ctx, provider, metrics) + srv, _ := newTestSrv(t, ctx, provider, metrics, defaultTracer) req := createAnthropicMessagesReq(t, srv.URL, files[fixtureRequest]) resp, err := http.DefaultClient.Do(req) @@ -89,7 +90,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil) - bridgeSrv := newTestSrv(t, ctx, provider, metrics) + bridgeSrv, _ := newTestSrv(t, ctx, provider, metrics, defaultTracer) // Make request in background. doneCh := make(chan struct{}) @@ -141,7 +142,7 @@ func TestMetrics_PassthroughCount(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey)) - srv := newTestSrv(t, t.Context(), provider, metrics) + srv, _ := newTestSrv(t, t.Context(), provider, metrics, defaultTracer) req, err := http.NewRequestWithContext(t.Context(), "GET", srv.URL+"/openai/v1/models", nil) require.NoError(t, err) @@ -170,7 +171,7 @@ func TestMetrics_PromptCount(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) - srv := newTestSrv(t, ctx, provider, metrics) + srv, _ := newTestSrv(t, ctx, provider, metrics, defaultTracer) req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) resp, err := http.DefaultClient.Do(req) @@ -198,7 +199,7 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) - srv := newTestSrv(t, ctx, provider, metrics) + srv, _ := newTestSrv(t, ctx, provider, metrics, defaultTracer) req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) resp, err := http.DefaultClient.Do(req) @@ -240,7 +241,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { mcpMgr := mcp.NewServerProxyManager(tools) require.NoError(t, mcpMgr.Init(ctx)) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, metrics, logger) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, metrics, defaultTracer, logger) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) @@ -272,13 +273,17 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { require.Equal(t, 1.0, count) } -func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, metrics *aibridge.Metrics) *httptest.Server { +func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, metrics *aibridge.Metrics, tracer trace.Tracer) (*httptest.Server, *mockRecorderClient) { t.Helper() - recorder := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + mockRecorder := &mockRecorderClient{} + clientFn := func() (aibridge.Recorder, error) { + return mockRecorder, nil + } + wrappedRecorder := aibridge.NewRecorder(logger, tracer, clientFn) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcp.NewServerProxyManager(nil), metrics, logger) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil), metrics, tracer, logger) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) @@ -288,5 +293,5 @@ func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, m srv.Start() t.Cleanup(srv.Close) - return srv + return srv, mockRecorder } diff --git a/passthrough.go b/passthrough.go index 6788672..2c9f45b 100644 --- a/passthrough.go +++ b/passthrough.go @@ -8,19 +8,28 @@ import ( "time" "cdr.dev/slog" + "github.com/coder/aibridge/aibtrace" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) // newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically // by a [Provider]. -func newPassthroughRouter(provider Provider, logger slog.Logger, metrics *Metrics) http.HandlerFunc { +func newPassthroughRouter(provider Provider, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if metrics != nil { metrics.PassthroughCount.WithLabelValues(provider.Name(), r.URL.Path, r.Method).Add(1) } + ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes( + attribute.String(aibtrace.PassthroughURL, r.URL.Path), + attribute.String(aibtrace.PassthroughMethod, r.Method), + )) + defer span.End() + upURL, err := url.Parse(provider.BaseURL()) if err != nil { - logger.Warn(r.Context(), "failed to parse provider base URL", slog.Error(err)) + logger.Warn(ctx, "failed to parse provider base URL", slog.Error(err)) http.Error(w, "request error", http.StatusBadGateway) return } diff --git a/provider.go b/provider.go index 3787af6..78bd771 100644 --- a/provider.go +++ b/provider.go @@ -2,6 +2,8 @@ package aibridge import ( "net/http" + + "go.opentelemetry.io/otel/trace" ) // Provider describes an AI provider client's behaviour. @@ -14,7 +16,7 @@ type Provider interface { // CreateInterceptor starts a new [Interceptor] which is responsible for intercepting requests, // communicating with the upstream provider and formulating a response to be sent to the requesting client. - CreateInterceptor(w http.ResponseWriter, r *http.Request) (Interceptor, error) + CreateInterceptor(tracer trace.Tracer, w http.ResponseWriter, r *http.Request) (Interceptor, error) // BridgedRoutes returns a slice of [http.ServeMux]-compatible routes which will have special handling. // See https://pkg.go.dev/net/http#hdr-Patterns-ServeMux. diff --git a/provider_anthropic.go b/provider_anthropic.go index 7e9c99f..009c30f 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -11,7 +11,10 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/shared" "github.com/anthropics/anthropic-sdk-go/shared/constant" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/google/uuid" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) var _ Provider = &AnthropicProvider{} @@ -58,14 +61,16 @@ func (p *AnthropicProvider) PassthroughRoutes() []string { } } -func (p *AnthropicProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request) (Interceptor, error) { +func (p *AnthropicProvider) CreateInterceptor(tracer trace.Tracer, w http.ResponseWriter, r *http.Request) (_ Interceptor, outErr error) { + id := uuid.New() + _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") + defer aibtrace.EndSpanErr(span, &outErr) + payload, err := io.ReadAll(r.Body) if err != nil { return nil, fmt.Errorf("read body: %w", err) } - id := uuid.New() - switch r.URL.Path { case routeMessages: var req MessageNewParamsWrapper @@ -73,13 +78,17 @@ func (p *AnthropicProvider) CreateInterceptor(w http.ResponseWriter, r *http.Req return nil, fmt.Errorf("failed to unmarshal request: %w", err) } + var interceptor Interceptor if req.Stream { - return NewAnthropicMessagesStreamingInterception(id, &req, p.cfg, p.bedrockCfg), nil + interceptor = NewAnthropicMessagesStreamingInterception(id, &req, p.cfg, p.bedrockCfg, tracer) + } else { + interceptor = NewAnthropicMessagesBlockingInterception(id, &req, p.cfg, p.bedrockCfg, tracer) } - - return NewAnthropicMessagesBlockingInterception(id, &req, p.cfg, p.bedrockCfg), nil + span.SetAttributes(interceptor.TraceAttributes(r.Context())...) + return interceptor, nil } + span.SetStatus(codes.Error, "unknown route: "+r.URL.Path) return nil, UnknownRoute } diff --git a/provider_openai.go b/provider_openai.go index 0fc31a6..7cc154a 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -7,7 +7,10 @@ import ( "net/http" "os" + aibtrace "github.com/coder/aibridge/aibtrace" "github.com/google/uuid" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) var _ Provider = &OpenAIProvider{} @@ -58,14 +61,17 @@ func (p *OpenAIProvider) PassthroughRoutes() []string { } } -func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request) (Interceptor, error) { +func (p *OpenAIProvider) CreateInterceptor(tracer trace.Tracer, w http.ResponseWriter, r *http.Request) (_ Interceptor, outErr error) { + id := uuid.New() + + _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") + defer aibtrace.EndSpanErr(span, &outErr) + payload, err := io.ReadAll(r.Body) if err != nil { return nil, fmt.Errorf("read body: %w", err) } - id := uuid.New() - switch r.URL.Path { case routeChatCompletions: var req ChatCompletionNewParamsWrapper @@ -73,13 +79,17 @@ func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Reques return nil, fmt.Errorf("unmarshal request body: %w", err) } + var interceptor Interceptor if req.Stream { - return NewOpenAIStreamingChatInterception(id, &req, p.baseURL, p.key), nil + interceptor = NewOpenAIStreamingChatInterception(id, &req, p.baseURL, p.key, tracer) } else { - return NewOpenAIBlockingChatInterception(id, &req, p.baseURL, p.key), nil + interceptor = NewOpenAIBlockingChatInterception(id, &req, p.baseURL, p.key, tracer) } + span.SetAttributes(interceptor.TraceAttributes(r.Context())...) + return interceptor, nil } + span.SetStatus(codes.Error, "unknown route: "+r.URL.Path) return nil, UnknownRoute } diff --git a/recorder.go b/recorder.go index edd68f7..375ef23 100644 --- a/recorder.go +++ b/recorder.go @@ -7,18 +7,28 @@ import ( "time" "cdr.dev/slog" + + aibtrace "github.com/coder/aibridge/aibtrace" + "go.opentelemetry.io/otel/trace" ) -var _ Recorder = &RecorderWrapper{} +var ( + _ Recorder = &RecorderWrapper{} + _ Recorder = &AsyncRecorder{} +) // RecorderWrapper is a convenience struct which implements RecorderClient and resolves a client before calling each method. // It also sets the start/creation time of each record. type RecorderWrapper struct { logger slog.Logger + tracer trace.Tracer clientFn func() (Recorder, error) } -func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *InterceptionRecord) error { +func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *InterceptionRecord) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterception", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -33,7 +43,10 @@ func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *Intercept return err } -func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error { +func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterceptionEnded", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -48,7 +61,10 @@ func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *Inte return err } -func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) error { +func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordPromptUsage", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -63,7 +79,10 @@ func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsag return err } -func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) error { +func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordTokenUsage", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -78,7 +97,10 @@ func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageR return err } -func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) error { +func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordToolUsage", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + defer aibtrace.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -93,12 +115,14 @@ func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRec return err } -func NewRecorder(logger slog.Logger, clientFn func() (Recorder, error)) *RecorderWrapper { - return &RecorderWrapper{logger: logger, clientFn: clientFn} +func NewRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) *RecorderWrapper { + return &RecorderWrapper{ + logger: logger, + tracer: tracer, + clientFn: clientFn, + } } -var _ Recorder = &AsyncRecorder{} - // AsyncRecorder calls [Recorder] methods asynchronously and logs any errors which may occur. type AsyncRecorder struct { logger slog.Logger @@ -141,7 +165,7 @@ func (a *AsyncRecorder) RecordInterceptionEnded(ctx context.Context, req *Interc a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx, cancel := a.timedContext(ctx) defer cancel() err := a.wrapped.RecordInterceptionEnded(timedCtx, req) @@ -153,11 +177,11 @@ func (a *AsyncRecorder) RecordInterceptionEnded(ctx context.Context, req *Interc return nil // Caller is not interested in error. } -func (a *AsyncRecorder) RecordPromptUsage(_ context.Context, req *PromptUsageRecord) error { +func (a *AsyncRecorder) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) error { a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx, cancel := a.timedContext(ctx) defer cancel() err := a.wrapped.RecordPromptUsage(timedCtx, req) @@ -173,11 +197,11 @@ func (a *AsyncRecorder) RecordPromptUsage(_ context.Context, req *PromptUsageRec return nil // Caller is not interested in error. } -func (a *AsyncRecorder) RecordTokenUsage(_ context.Context, req *TokenUsageRecord) error { +func (a *AsyncRecorder) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) error { a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx, cancel := a.timedContext(ctx) defer cancel() err := a.wrapped.RecordTokenUsage(timedCtx, req) @@ -197,11 +221,11 @@ func (a *AsyncRecorder) RecordTokenUsage(_ context.Context, req *TokenUsageRecor return nil // Caller is not interested in error. } -func (a *AsyncRecorder) RecordToolUsage(_ context.Context, req *ToolUsageRecord) error { +func (a *AsyncRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) error { a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx, cancel := a.timedContext(ctx) defer cancel() err := a.wrapped.RecordToolUsage(timedCtx, req) @@ -228,3 +252,10 @@ func (a *AsyncRecorder) RecordToolUsage(_ context.Context, req *ToolUsageRecord) func (a *AsyncRecorder) Wait() { a.wg.Wait() } + +// returns detrached context with tracing information copied from provided context +func (a *AsyncRecorder) timedContext(ctx context.Context) (context.Context, context.CancelFunc) { + timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx = aibtrace.WithTraceInterceptionAttributesInContext(timedCtx, aibtrace.TraceInterceptionAttributesFromContext(ctx)) + return timedCtx, cancel +} diff --git a/trace_integration_test.go b/trace_integration_test.go new file mode 100644 index 0000000..16b1744 --- /dev/null +++ b/trace_integration_test.go @@ -0,0 +1,514 @@ +package aibridge_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coder/aibridge" + "github.com/coder/aibridge/aibtrace" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "golang.org/x/tools/txtar" +) + +// expect 'count' amount of traces named 'name' with status 'status' +type expectTrace struct { + name string + count int + status codes.Code +} + +func TestTraceAnthropic(t *testing.T) { + expectNonStreaming := []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, + {"Intercept.RecordToolUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + } + + cases := []struct { + name string + streaming bool + bedrock bool + expectTraceCounts []expectTrace + }{ + { + name: "trace_anthr_non_streaming", + expectTraceCounts: expectNonStreaming, + }, + { + name: "trace_bedrock_non_streaming", + bedrock: true, + expectTraceCounts: expectNonStreaming, + }, + { + name: "trace_anthr_streaming", + streaming: true, + expectTraceCounts: []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 2, codes.Unset}, + {"Intercept.RecordToolUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 9, codes.Unset}, + }, + }, + { + name: "trace_bedrock_streaming", + streaming: true, + bedrock: true, + expectTraceCounts: []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + }, + }, + } + + arc := txtar.Parse(antSingleBuiltinTool) + + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureStreamingResponse) + require.Contains(t, files, fixtureNonStreamingResponse) + + fixtureReqBody := files[fixtureRequest] + + for _, tc := range cases { + t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) + require.NoError(t, err) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + + var bedrockCfg *aibridge.AWSBedrockConfig + if tc.bedrock { + bedrockCfg = &aibridge.AWSBedrockConfig{ + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "beddel", // This model should override the request's given one. + SmallFastModel: "modrock", // Unused but needed for validation. + EndpointOverride: mockAPI.URL, + } + } + provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), bedrockCfg) + srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + + req := createAnthropicMessagesReq(t, srv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + srv.Close() + + require.Equal(t, 1, len(recorder.interceptions)) + intcID := recorder.interceptions[0].ID + + model := gjson.Get(string(reqBody), "model").Str + if tc.bedrock { + model = "beddel" + } + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.InterceptionID, intcID), + attribute.String(aibtrace.Provider, aibridge.ProviderAnthropic), + attribute.String(aibtrace.Model, model), + attribute.String(aibtrace.UserID, userID), + attribute.Bool(aibtrace.Streaming, tc.streaming), + attribute.Bool(aibtrace.IsBedrock, tc.bedrock), + } + + verifyCommonTraceAttrs(t, sr, tc.expectTraceCounts, attrs) + }) + } +} + +func TestTraceAnthropicErr(t *testing.T) { + cases := []struct { + name string + streaming bool + expect []expectTrace + }{ + { + name: "trace_anthr_non_streaming_err", + expect: []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Error}, + }, + }, + { + name: "trace_anthr_streaming_err", + streaming: true, + expect: []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 3, codes.Unset}, + }, + }, + } + + for _, tc := range cases { + t.Run(t.Name(), func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + var arc *txtar.Archive + if tc.streaming { + arc = txtar.Parse(antMidStreamErr) + } else { + arc = txtar.Parse(antNonStreamErr) + } + + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + if tc.streaming { + require.Contains(t, files, fixtureStreamingResponse) + } else { + require.Contains(t, files, fixtureNonStreamingResponse) + } + + fixtureReqBody := files[fixtureRequest] + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) + require.NoError(t, err) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + + provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) + srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + + req := createAnthropicMessagesReq(t, srv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + if tc.streaming { + require.Equal(t, http.StatusOK, resp.StatusCode) + } else { + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + } + defer resp.Body.Close() + srv.Close() + + require.Equal(t, 1, len(recorder.interceptions)) + intcID := recorder.interceptions[0].ID + + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.InterceptionID, intcID), + attribute.String(aibtrace.Provider, aibridge.ProviderAnthropic), + attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), + attribute.String(aibtrace.UserID, userID), + attribute.Bool(aibtrace.Streaming, tc.streaming), + attribute.Bool(aibtrace.IsBedrock, false), + } + + verifyCommonTraceAttrs(t, sr, tc.expect, attrs) + }) + } +} + +func TestTraceOpenAI(t *testing.T) { + cases := []struct { + name string + fixture []byte + streaming bool + expect []expectTrace + }{ + { + name: "trace_openai_streaming", + fixture: oaiSimple, + streaming: true, + expect: []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 242, codes.Unset}, + }, + }, + { + name: "trace_openai_non_streaming", + fixture: oaiSimple, + streaming: false, + expect: []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + }, + }, + } + + for _, tc := range cases { + t.Run(t.Name(), func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + arc := txtar.Parse(tc.fixture) + + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureStreamingResponse) + require.Contains(t, files, fixtureNonStreamingResponse) + + fixtureReqBody := files[fixtureRequest] + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) + require.NoError(t, err) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) + srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + + req := createOpenAIChatCompletionsReq(t, srv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + srv.Close() + + require.Equal(t, 1, len(recorder.interceptions)) + intcID := recorder.interceptions[0].ID + + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.InterceptionID, intcID), + attribute.String(aibtrace.Provider, aibridge.ProviderOpenAI), + attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), + attribute.String(aibtrace.UserID, userID), + attribute.Bool(aibtrace.Streaming, tc.streaming), + } + + verifyCommonTraceAttrs(t, sr, tc.expect, attrs) + }) + } +} + +func TestTraceOpenAIErr(t *testing.T) { + cases := []struct { + name string + streaming bool + expect []expectTrace + }{ + { + name: "trace_openai_streaming_err", + streaming: true, + expect: []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 5, codes.Unset}, + }, + }, + { + name: "trace_openai_non_streaming_err", + streaming: false, + expect: []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Error}, + }, + }, + } + + for _, tc := range cases { + t.Run(t.Name(), func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + var arc *txtar.Archive + if tc.streaming { + arc = txtar.Parse(oaiMidStreamErr) + } else { + arc = txtar.Parse(oaiNonStreamErr) + } + + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + if tc.streaming { + require.Contains(t, files, fixtureStreamingResponse) + } else { + require.Contains(t, files, fixtureNonStreamingResponse) + } + + fixtureReqBody := files[fixtureRequest] + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) + require.NoError(t, err) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) + srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + + req := createOpenAIChatCompletionsReq(t, srv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + if tc.streaming { + require.Equal(t, http.StatusOK, resp.StatusCode) + } else { + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + } + defer resp.Body.Close() + srv.Close() + + require.Equal(t, 1, len(recorder.interceptions)) + intcID := recorder.interceptions[0].ID + + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.InterceptionID, intcID), + attribute.String(aibtrace.Provider, aibridge.ProviderOpenAI), + attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), + attribute.String(aibtrace.UserID, userID), + attribute.Bool(aibtrace.Streaming, tc.streaming), + } + + verifyCommonTraceAttrs(t, sr, tc.expect, attrs) + }) + } +} + +func TestTracePassthrough(t *testing.T) { + t.Parallel() + + arc := txtar.Parse(oaiFallthrough) + files := filesMap(arc) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(files[fixtureResponse]) + })) + t.Cleanup(upstream.Close) + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + provider := aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey)) + srv, _ := newTestSrv(t, t.Context(), provider, nil, tracer) + + req, err := http.NewRequestWithContext(t.Context(), "GET", srv.URL+"/openai/v1/models", nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + srv.Close() + + spans := sr.Ended() + require.Equal(t, len(spans), 1) + assert.Equal(t, spans[0].Name(), "Passthrough") + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.PassthroughURL, "/v1/models"), + attribute.String(aibtrace.PassthroughMethod, "GET"), + } + if attrDiff := cmp.Diff(spans[0].Attributes(), attrs, cmpopts.EquateComparable(attribute.KeyValue{}), cmpopts.SortSlices(cmpAttrKeyVal)); attrDiff != "" { + t.Errorf("unexpectet attrs diff: %s", attrDiff) + } +} + +func cmpAttrKeyVal(a attribute.KeyValue, b attribute.KeyValue) bool { + return a.Key < b.Key +} + +func verifyCommonTraceAttrs(t *testing.T, spanRecorder *tracetest.SpanRecorder, expect []expectTrace, attrs []attribute.KeyValue) { + spans := spanRecorder.Ended() + + totalCount := 0 + for _, e := range expect { + totalCount += e.count + } + assert.Equal(t, totalCount, len(spans)) + + for _, e := range expect { + found := 0 + for _, s := range spans { + if s.Name() != e.name || s.Status().Code != e.status { + continue + } + found++ + if attrDiff := cmp.Diff(s.Attributes(), attrs, cmpopts.EquateComparable(attribute.KeyValue{}), cmpopts.SortSlices(cmpAttrKeyVal)); attrDiff != "" { + t.Errorf("unexpectet attrs for span named: %v, diff: %s", e.name, attrDiff) + } + assert.Equalf(t, e.status, s.Status().Code, "unexpected status for trace naned: %v got: %v want: %v", e.name, s.Status().Code, e.status) + } + if found != e.count { + t.Errorf("found unexpected number of spans named: %v with status %v, got: %v want: %v", e.name, e.status, found, e.count) + } + } +} From f6086193ceff07f50c22356ca16122cf84e3fd7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Tue, 25 Nov 2025 15:48:11 +0000 Subject: [PATCH 02/16] EndSpanErr span nil check + typo fix --- aibtrace/aibtrace.go | 4 ++++ interception.go | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/aibtrace/aibtrace.go b/aibtrace/aibtrace.go index 6d2cde0..9e45c0c 100644 --- a/aibtrace/aibtrace.go +++ b/aibtrace/aibtrace.go @@ -37,6 +37,10 @@ func TraceInterceptionAttributesFromContext(ctx context.Context) []attribute.Key } func EndSpanErr(span trace.Span, err *error) { + if span == nil { + return + } + if err != nil && *err != nil { span.SetStatus(codes.Error, (*err).Error()) } diff --git a/interception.go b/interception.go index 2f8ba28..7af954d 100644 --- a/interception.go +++ b/interception.go @@ -30,8 +30,7 @@ type Interceptor interface { ProcessRequest(w http.ResponseWriter, r *http.Request) error // Specifies whether an interceptor handles streaming or not. Streaming() bool - - // TraceAttributes returns tacing attributes for this [Inteceptor] + // TraceAttributes returns tracing attributes for this [Interceptor] TraceAttributes(context.Context) []attribute.KeyValue } From 5fe4137b18009ebaa965f3a82ee940f5d38d415c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Wed, 26 Nov 2025 17:51:16 +0000 Subject: [PATCH 03/16] fix context used by async recorder --- recorder.go | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/recorder.go b/recorder.go index 375ef23..5bbcec0 100644 --- a/recorder.go +++ b/recorder.go @@ -165,7 +165,7 @@ func (a *AsyncRecorder) RecordInterceptionEnded(ctx context.Context, req *Interc a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := a.timedContext(ctx) + timedCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), a.timeout) defer cancel() err := a.wrapped.RecordInterceptionEnded(timedCtx, req) @@ -181,7 +181,7 @@ func (a *AsyncRecorder) RecordPromptUsage(ctx context.Context, req *PromptUsageR a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := a.timedContext(ctx) + timedCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), a.timeout) defer cancel() err := a.wrapped.RecordPromptUsage(timedCtx, req) @@ -201,7 +201,7 @@ func (a *AsyncRecorder) RecordTokenUsage(ctx context.Context, req *TokenUsageRec a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := a.timedContext(ctx) + timedCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), a.timeout) defer cancel() err := a.wrapped.RecordTokenUsage(timedCtx, req) @@ -225,7 +225,7 @@ func (a *AsyncRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRecor a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := a.timedContext(ctx) + timedCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), a.timeout) defer cancel() err := a.wrapped.RecordToolUsage(timedCtx, req) @@ -252,10 +252,3 @@ func (a *AsyncRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRecor func (a *AsyncRecorder) Wait() { a.wg.Wait() } - -// returns detrached context with tracing information copied from provided context -func (a *AsyncRecorder) timedContext(ctx context.Context) (context.Context, context.CancelFunc) { - timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) - timedCtx = aibtrace.WithTraceInterceptionAttributesInContext(timedCtx, aibtrace.TraceInterceptionAttributesFromContext(ctx)) - return timedCtx, cancel -} From 597f3c0e625a1f57d588d26270640438d064872b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Wed, 26 Nov 2025 19:19:12 +0000 Subject: [PATCH 04/16] ServerProxyManager init tracing --- aibtrace/aibtrace.go | 1 + bridge_integration_test.go | 28 ++++++------- mcp/mcp_test.go | 3 +- mcp/server_proxy_manager.go | 21 ++++++++-- metrics_integration_test.go | 4 +- trace_integration_test.go | 84 +++++++++++++++++++++++++++---------- 6 files changed, 97 insertions(+), 44 deletions(-) diff --git a/aibtrace/aibtrace.go b/aibtrace/aibtrace.go index 9e45c0c..8c1828a 100644 --- a/aibtrace/aibtrace.go +++ b/aibtrace/aibtrace.go @@ -18,6 +18,7 @@ const ( Model = "model" Streaming = "streaming" IsBedrock = "aws_bedrock" + MCPProxyName = "mcp_proxy_name" MCPToolName = "mcp_tool_name" PassthroughURL = "passthrough_url" PassthroughMethod = "passthrough_method" diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 2ea570a..da39a57 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -136,7 +136,7 @@ func TestAnthropicMessages(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)} - b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) + b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -217,7 +217,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{ aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg), - }, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) + }, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -315,7 +315,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge( ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), bedrockCfg)}, - recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) + recorderClient, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) require.NoError(t, err) mockBridgeSrv := httptest.NewUnstartedServer(b) @@ -403,7 +403,7 @@ func TestOpenAIChatCompletions(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))} - b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) + b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -471,7 +471,7 @@ func TestSimple(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -510,7 +510,7 @@ func TestSimple(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -642,7 +642,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) require.NoError(t, err) return provider, bridge }, @@ -653,7 +653,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey)) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) require.NoError(t, err) return provider, bridge }, @@ -980,7 +980,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu tools := setupMCPServerProxiesForTest(t) // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(tools) + mcpMgr := mcp.NewServerProxyManager(tools, defaultTracer) require.NoError(t, mcpMgr.Init(ctx)) b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr) require.NoError(t, err) @@ -1096,7 +1096,7 @@ func TestErrorHandling(t *testing.T) { recorderClient := &mockRecorderClient{} - b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil)) + b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer)) require.NoError(t, err) // Invoke request to mocked API via aibridge. @@ -1198,7 +1198,7 @@ func TestErrorHandling(t *testing.T) { recorderClient := &mockRecorderClient{} - b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil)) + b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer)) require.NoError(t, err) // Invoke request to mocked API via aibridge. @@ -1267,7 +1267,7 @@ func TestStableRequestEncoding(t *testing.T) { tools := setupMCPServerProxiesForTest(t) // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(tools) + mcpMgr := mcp.NewServerProxyManager(tools, defaultTracer) require.NoError(t, mcpMgr.Init(ctx)) arc := txtar.Parse(tc.fixture) @@ -1358,7 +1358,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) }, createRequest: createAnthropicMessagesReq, envVars: map[string]string{ @@ -1372,7 +1372,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil), nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) }, createRequest: createOpenAIChatCompletionsReq, envVars: map[string]string{ diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 4d6e2d3..c72bb79 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -13,6 +13,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" + "go.opentelemetry.io/otel" "go.uber.org/goleak" "github.com/coder/aibridge/mcp" @@ -324,7 +325,7 @@ func TestToolInjectionOrder(t *testing.T) { mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{ "coder": proxy, "shmoder": proxy2, - }) + }, otel.GetTracerProvider().Tracer("test")) require.NoError(t, mgr.Init(ctx)) // Then: the tools from both servers should be collectively sorted stably. diff --git a/mcp/server_proxy_manager.go b/mcp/server_proxy_manager.go index 732f1a0..e11cd5e 100644 --- a/mcp/server_proxy_manager.go +++ b/mcp/server_proxy_manager.go @@ -7,8 +7,11 @@ import ( "strings" "sync" + "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/utils" "github.com/mark3labs/mcp-go/mcp" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) var _ ServerProxier = &ServerProxyManager{} @@ -18,14 +21,18 @@ var _ ServerProxier = &ServerProxyManager{} // for the purpose of injection into bridged requests and invocation. type ServerProxyManager struct { proxiers map[string]ServerProxier + tracer trace.Tracer // Protects access to the tools map. toolsMu sync.RWMutex tools map[string]*Tool } -func NewServerProxyManager(proxiers map[string]ServerProxier) *ServerProxyManager { - return &ServerProxyManager{proxiers: proxiers} +func NewServerProxyManager(proxiers map[string]ServerProxier, tracer trace.Tracer) *ServerProxyManager { + return &ServerProxyManager{ + proxiers: proxiers, + tracer: tracer, + } } func (s *ServerProxyManager) addTools(tools []*Tool) { @@ -42,10 +49,16 @@ func (s *ServerProxyManager) addTools(tools []*Tool) { } // Init concurrently initializes all of its [ServerProxier]s. -func (s *ServerProxyManager) Init(ctx context.Context) error { +func (s *ServerProxyManager) Init(ctx context.Context) (outErr error) { + ctx, span := s.tracer.Start(ctx, "ServerProxyManager.Init") + defer aibtrace.EndSpanErr(span, &outErr) + cg := utils.NewConcurrentGroup() - for _, proxy := range s.proxiers { + for name, proxy := range s.proxiers { cg.Go(func() error { + ctx, span := s.tracer.Start(ctx, "ServerProxyManager.Init.Proxy", trace.WithAttributes(attribute.String(aibtrace.MCPProxyName, name))) + defer aibtrace.EndSpanErr(span, &outErr) + return proxy.Init(ctx) }) } diff --git a/metrics_integration_test.go b/metrics_integration_test.go index 1b17c6f..95bdd08 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -238,7 +238,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { // Setup mocked MCP server & tools. tools := setupMCPServerProxiesForTest(t) - mcpMgr := mcp.NewServerProxyManager(tools) + mcpMgr := mcp.NewServerProxyManager(tools, defaultTracer) require.NoError(t, mcpMgr.Init(ctx)) bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, metrics, defaultTracer, logger) @@ -283,7 +283,7 @@ func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, m } wrappedRecorder := aibridge.NewRecorder(logger, tracer, clientFn) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil), metrics, tracer, logger) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil, defaultTracer), metrics, tracer, logger) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) diff --git a/trace_integration_test.go b/trace_integration_test.go index 16b1744..2b0a159 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -10,6 +10,7 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/aibtrace" + "github.com/coder/aibridge/mcp" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/assert" @@ -47,15 +48,18 @@ func TestTraceAnthropic(t *testing.T) { streaming bool bedrock bool expectTraceCounts []expectTrace + expectSpanCount int }{ { name: "trace_anthr_non_streaming", expectTraceCounts: expectNonStreaming, + expectSpanCount: 9, }, { name: "trace_bedrock_non_streaming", bedrock: true, expectTraceCounts: expectNonStreaming, + expectSpanCount: 9, }, { name: "trace_anthr_streaming", @@ -71,6 +75,7 @@ func TestTraceAnthropic(t *testing.T) { {"Intercept.RecordToolUsage", 1, codes.Unset}, {"Intercept.ProcessRequest.Upstream", 9, codes.Unset}, }, + expectSpanCount: 18, }, { name: "trace_bedrock_streaming", @@ -85,6 +90,7 @@ func TestTraceAnthropic(t *testing.T) { {"Intercept.RecordPromptUsage", 1, codes.Unset}, {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, }, + expectSpanCount: 7, }, } @@ -151,16 +157,18 @@ func TestTraceAnthropic(t *testing.T) { attribute.Bool(aibtrace.IsBedrock, tc.bedrock), } - verifyCommonTraceAttrs(t, sr, tc.expectTraceCounts, attrs) + require.Len(t, sr.Ended(), tc.expectSpanCount) + verifyTraces(t, sr, tc.expectTraceCounts, attrs) }) } } func TestTraceAnthropicErr(t *testing.T) { cases := []struct { - name string - streaming bool - expect []expectTrace + name string + streaming bool + expect []expectTrace + expectSpanCount int }{ { name: "trace_anthr_non_streaming_err", @@ -172,6 +180,7 @@ func TestTraceAnthropicErr(t *testing.T) { {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, {"Intercept.ProcessRequest.Upstream", 1, codes.Error}, }, + expectSpanCount: 6, }, { name: "trace_anthr_streaming_err", @@ -186,6 +195,7 @@ func TestTraceAnthropicErr(t *testing.T) { {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, {"Intercept.ProcessRequest.Upstream", 3, codes.Unset}, }, + expectSpanCount: 10, }, } @@ -249,17 +259,19 @@ func TestTraceAnthropicErr(t *testing.T) { attribute.Bool(aibtrace.IsBedrock, false), } - verifyCommonTraceAttrs(t, sr, tc.expect, attrs) + require.Len(t, sr.Ended(), tc.expectSpanCount) + verifyTraces(t, sr, tc.expect, attrs) }) } } func TestTraceOpenAI(t *testing.T) { cases := []struct { - name string - fixture []byte - streaming bool - expect []expectTrace + name string + fixture []byte + streaming bool + expect []expectTrace + expectSpanCount int }{ { name: "trace_openai_streaming", @@ -275,6 +287,7 @@ func TestTraceOpenAI(t *testing.T) { {"Intercept.RecordTokenUsage", 1, codes.Unset}, {"Intercept.ProcessRequest.Upstream", 242, codes.Unset}, }, + expectSpanCount: 249, }, { name: "trace_openai_non_streaming", @@ -290,6 +303,7 @@ func TestTraceOpenAI(t *testing.T) { {"Intercept.RecordTokenUsage", 1, codes.Unset}, {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, }, + expectSpanCount: 8, }, } @@ -339,16 +353,18 @@ func TestTraceOpenAI(t *testing.T) { attribute.Bool(aibtrace.Streaming, tc.streaming), } - verifyCommonTraceAttrs(t, sr, tc.expect, attrs) + require.Len(t, sr.Ended(), tc.expectSpanCount) + verifyTraces(t, sr, tc.expect, attrs) }) } } func TestTraceOpenAIErr(t *testing.T) { cases := []struct { - name string - streaming bool - expect []expectTrace + name string + streaming bool + expect []expectTrace + expectSpanCount int }{ { name: "trace_openai_streaming_err", @@ -362,6 +378,7 @@ func TestTraceOpenAIErr(t *testing.T) { {"Intercept.RecordPromptUsage", 1, codes.Unset}, {"Intercept.ProcessRequest.Upstream", 5, codes.Unset}, }, + expectSpanCount: 11, }, { name: "trace_openai_non_streaming_err", @@ -374,6 +391,7 @@ func TestTraceOpenAIErr(t *testing.T) { {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, {"Intercept.ProcessRequest.Upstream", 1, codes.Error}, }, + expectSpanCount: 6, }, } @@ -435,7 +453,8 @@ func TestTraceOpenAIErr(t *testing.T) { attribute.Bool(aibtrace.Streaming, tc.streaming), } - verifyCommonTraceAttrs(t, sr, tc.expect, attrs) + require.Len(t, sr.Ended(), tc.expectSpanCount) + verifyTraces(t, sr, tc.expect, attrs) }) } } @@ -471,7 +490,8 @@ func TestTracePassthrough(t *testing.T) { srv.Close() spans := sr.Ended() - require.Equal(t, len(spans), 1) + require.Len(t, spans, 1) + assert.Equal(t, spans[0].Name(), "Passthrough") attrs := []attribute.KeyValue{ attribute.String(aibtrace.PassthroughURL, "/v1/models"), @@ -482,19 +502,37 @@ func TestTracePassthrough(t *testing.T) { } } +func TestNewServerProxyManagerTraces(t *testing.T) { + ctx := t.Context() + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + tools := setupMCPServerProxiesForTest(t) + mcpMgr := mcp.NewServerProxyManager(tools, tracer) + err := mcpMgr.Init(ctx) + require.NoError(t, err) + + require.Len(t, sr.Ended(), 2) + verifyTraces(t, sr, []expectTrace{{"ServerProxyManager.Init", 1, codes.Unset}}, []attribute.KeyValue{}) + + for name := range tools { + tc := []expectTrace{{"ServerProxyManager.Init.Proxy", 1, codes.Unset}} + a := []attribute.KeyValue{attribute.String(aibtrace.MCPProxyName, name)} + verifyTraces(t, sr, tc, a) + } +} + func cmpAttrKeyVal(a attribute.KeyValue, b attribute.KeyValue) bool { return a.Key < b.Key } -func verifyCommonTraceAttrs(t *testing.T, spanRecorder *tracetest.SpanRecorder, expect []expectTrace, attrs []attribute.KeyValue) { +// checks counts of traces with given name, status and attributes +func verifyTraces(t *testing.T, spanRecorder *tracetest.SpanRecorder, expect []expectTrace, attrs []attribute.KeyValue) { spans := spanRecorder.Ended() - totalCount := 0 - for _, e := range expect { - totalCount += e.count - } - assert.Equal(t, totalCount, len(spans)) - for _, e := range expect { found := 0 for _, s := range spans { @@ -502,7 +540,7 @@ func verifyCommonTraceAttrs(t *testing.T, spanRecorder *tracetest.SpanRecorder, continue } found++ - if attrDiff := cmp.Diff(s.Attributes(), attrs, cmpopts.EquateComparable(attribute.KeyValue{}), cmpopts.SortSlices(cmpAttrKeyVal)); attrDiff != "" { + if attrDiff := cmp.Diff(s.Attributes(), attrs, cmpopts.EquateEmpty(), cmpopts.EquateComparable(attribute.KeyValue{}), cmpopts.SortSlices(cmpAttrKeyVal)); attrDiff != "" { t.Errorf("unexpectet attrs for span named: %v, diff: %s", e.name, attrDiff) } assert.Equalf(t, e.status, s.Status().Code, "unexpected status for trace naned: %v got: %v want: %v", e.name, s.Status().Code, e.status) From d5b06145ea922d96c3f42e463a843146d02818c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 27 Nov 2025 21:45:52 +0000 Subject: [PATCH 05/16] review 1: streaming upstream fix, tool attrs, request path --- aibtrace/aibtrace.go | 43 ++++++-- bridge_integration_test.go | 55 +++++----- intercept_anthropic_messages_base.go | 7 +- intercept_anthropic_messages_blocking.go | 11 +- intercept_anthropic_messages_streaming.go | 17 +-- intercept_openai_chat_base.go | 7 +- intercept_openai_chat_blocking.go | 15 +-- intercept_openai_chat_streaming.go | 17 +-- interception.go | 7 +- mcp/mcp_test.go | 5 +- mcp/proxy_streamable_http.go | 26 ++++- mcp/server_proxy_manager.go | 6 +- mcp/tool.go | 21 +++- metrics_integration_test.go | 18 +-- provider_anthropic.go | 2 +- provider_openai.go | 2 +- recorder.go | 10 +- trace_integration_test.go | 127 +++++++++++++--------- 18 files changed, 239 insertions(+), 157 deletions(-) diff --git a/aibtrace/aibtrace.go b/aibtrace/aibtrace.go index 8c1828a..d43a61c 100644 --- a/aibtrace/aibtrace.go +++ b/aibtrace/aibtrace.go @@ -9,26 +9,36 @@ import ( ) type traceInterceptionAttrsContextKey struct{} +type traceRequestBridgeAttrsContextKey struct{} const ( // trace attribute key constants - InterceptionID = "interception_id" - UserID = "user_id" - Provider = "provider" - Model = "model" - Streaming = "streaming" - IsBedrock = "aws_bedrock" - MCPProxyName = "mcp_proxy_name" - MCPToolName = "mcp_tool_name" + RequestPath = "request_path" + + InterceptionID = "interception_id" + UserID = "user_id" + Provider = "provider" + Model = "model" + Streaming = "streaming" + IsBedrock = "aws_bedrock" + PassthroughURL = "passthrough_url" PassthroughMethod = "passthrough_method" + + MCPInput = "mcp_input" + MCPProxyName = "mcp_proxy_name" + MCPToolName = "mcp_tool_name" + MCPServerName = "mcp_server_name" + MCPServerURL = "mcp_server_url" + + APIKeyID = "api_key_id" ) -func WithTraceInterceptionAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context { +func WithInterceptionAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context { return context.WithValue(ctx, traceInterceptionAttrsContextKey{}, traceAttrs) } -func TraceInterceptionAttributesFromContext(ctx context.Context) []attribute.KeyValue { +func InterceptionAttributesFromContext(ctx context.Context) []attribute.KeyValue { attrs, ok := ctx.Value(traceInterceptionAttrsContextKey{}).([]attribute.KeyValue) if !ok { return nil @@ -37,6 +47,19 @@ func TraceInterceptionAttributesFromContext(ctx context.Context) []attribute.Key return attrs } +func WithRequestBridgeAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context { + return context.WithValue(ctx, traceRequestBridgeAttrsContextKey{}, traceAttrs) +} + +func RequestBridgeAttributesFromContext(ctx context.Context) []attribute.KeyValue { + attrs, ok := ctx.Value(traceRequestBridgeAttrsContextKey{}).([]attribute.KeyValue) + if !ok { + return nil + } + + return attrs +} + func EndSpanErr(span trace.Span, err *error) { if span == nil { return diff --git a/bridge_integration_test.go b/bridge_integration_test.go index da39a57..aa3c7c6 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -34,6 +34,7 @@ import ( "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" "go.uber.org/goleak" "golang.org/x/tools/txtar" ) @@ -65,7 +66,7 @@ var ( //go:embed fixtures/openai/non_stream_error.txtar oaiNonStreamErr []byte - defaultTracer = otel.Tracer("github.com/coder/aibridge") + testTracer = otel.Tracer("forTesting") ) const ( @@ -136,7 +137,7 @@ func TestAnthropicMessages(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)} - b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) + b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -217,7 +218,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{ aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg), - }, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) + }, recorderClient, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -315,7 +316,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge( ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), bedrockCfg)}, - recorderClient, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) + recorderClient, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) require.NoError(t, err) mockBridgeSrv := httptest.NewUnstartedServer(b) @@ -403,7 +404,7 @@ func TestOpenAIChatCompletions(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))} - b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) + b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -471,7 +472,7 @@ func TestSimple(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -510,7 +511,7 @@ func TestSimple(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -642,7 +643,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) require.NoError(t, err) return provider, bridge }, @@ -653,7 +654,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey)) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) require.NoError(t, err) return provider, bridge }, @@ -724,7 +725,7 @@ func TestFallthrough(t *testing.T) { } // setupMCPServerProxiesForTest creates a mock MCP server, initializes the MCP bridge, and returns the tools -func setupMCPServerProxiesForTest(t *testing.T) map[string]mcp.ServerProxier { +func setupMCPServerProxiesForTest(t *testing.T, tracer trace.Tracer) map[string]mcp.ServerProxier { t.Helper() // Setup Coder MCP integration @@ -732,7 +733,7 @@ func setupMCPServerProxiesForTest(t *testing.T) map[string]mcp.ServerProxier { t.Cleanup(mcpSrv.Close) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - proxy, err := mcp.NewStreamableHTTPServerProxy(logger, "coder", mcpSrv.URL, nil, nil, nil) + proxy, err := mcp.NewStreamableHTTPServerProxy(logger, tracer, "coder", mcpSrv.URL, nil, nil, nil) require.NoError(t, err) // Initialize MCP client, fetch tools, and inject into bridge @@ -760,7 +761,7 @@ func TestAnthropicInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) } // Build the requirements & make the assertions which are common to all providers. @@ -842,7 +843,7 @@ func TestOpenAIInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) } // Build the requirements & make the assertions which are common to all providers. @@ -977,10 +978,10 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu recorderClient := &mockRecorderClient{} // Setup MCP tools. - tools := setupMCPServerProxiesForTest(t) + tools := setupMCPServerProxiesForTest(t, testTracer) // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(tools, defaultTracer) + mcpMgr := mcp.NewServerProxyManager(tools, testTracer) require.NoError(t, mcpMgr.Init(ctx)) b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr) require.NoError(t, err) @@ -1029,7 +1030,7 @@ func TestErrorHandling(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) }, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1047,7 +1048,7 @@ func TestErrorHandling(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) }, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1096,7 +1097,7 @@ func TestErrorHandling(t *testing.T) { recorderClient := &mockRecorderClient{} - b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer)) + b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) require.NoError(t, err) // Invoke request to mocked API via aibridge. @@ -1136,7 +1137,7 @@ func TestErrorHandling(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) }, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1155,7 +1156,7 @@ func TestErrorHandling(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) }, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1198,7 +1199,7 @@ func TestErrorHandling(t *testing.T) { recorderClient := &mockRecorderClient{} - b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer)) + b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) require.NoError(t, err) // Invoke request to mocked API via aibridge. @@ -1242,7 +1243,7 @@ func TestStableRequestEncoding(t *testing.T) { createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) }, }, { @@ -1251,7 +1252,7 @@ func TestStableRequestEncoding(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) }, }, } @@ -1264,10 +1265,10 @@ func TestStableRequestEncoding(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools. - tools := setupMCPServerProxiesForTest(t) + tools := setupMCPServerProxiesForTest(t, testTracer) // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(tools, defaultTracer) + mcpMgr := mcp.NewServerProxyManager(tools, testTracer) require.NoError(t, mcpMgr.Init(ctx)) arc := txtar.Parse(tc.fixture) @@ -1358,7 +1359,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) }, createRequest: createAnthropicMessagesReq, envVars: map[string]string{ @@ -1372,7 +1373,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) }, createRequest: createOpenAIChatCompletionsReq, envVars: map[string]string{ diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 0ee12ea..071be96 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -63,12 +63,13 @@ func (i *AnthropicMessagesInterceptionBase) Model() string { return string(i.req.Model) } -func (s *AnthropicMessagesInterceptionBase) baseTraceAttributes(ctx context.Context, streaming bool) []attribute.KeyValue { +func (s *AnthropicMessagesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { return []attribute.KeyValue{ - attribute.String(aibtrace.Provider, ProviderAnthropic), + attribute.String(aibtrace.RequestPath, r.URL.Path), attribute.String(aibtrace.InterceptionID, s.id.String()), + attribute.String(aibtrace.UserID, actorFromContext(r.Context()).id), + attribute.String(aibtrace.Provider, ProviderAnthropic), attribute.String(aibtrace.Model, s.Model()), - attribute.String(aibtrace.UserID, actorFromContext(ctx).id), attribute.Bool(aibtrace.Streaming, streaming), attribute.Bool(aibtrace.IsBedrock, s.bedrockCfg != nil), } diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index 0773a13..c0207b5 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -40,8 +40,8 @@ func (i *AnthropicMessagesBlockingInterception) Setup(logger slog.Logger, record i.AnthropicMessagesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy) } -func (i *AnthropicMessagesBlockingInterception) TraceAttributes(ctx context.Context) []attribute.KeyValue { - return i.AnthropicMessagesInterceptionBase.baseTraceAttributes(ctx, false) +func (i *AnthropicMessagesBlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { + return i.AnthropicMessagesInterceptionBase.baseTraceAttributes(r, false) } func (s *AnthropicMessagesBlockingInterception) Streaming() bool { @@ -53,7 +53,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr return fmt.Errorf("developer error: req is nil") } - ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(r.Context())...)) + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(r.Context())...)) defer aibtrace.EndSpanErr(span, &outErr) i.injectTools() @@ -87,7 +87,8 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr var cumulativeUsage anthropic.Usage for { - resp, err = i.traceNewMessage(ctx, svc, messages) // traces client.Messages.New(ctx, msgParams) call + // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) + resp, err = i.traceNewMessage(ctx, svc, messages) // traces svc.New(ctx, msgParams) call if err != nil { if isConnError(err) { // Can't write a response, just error out. @@ -297,7 +298,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr } func (i *AnthropicMessagesBlockingInterception) traceNewMessage(ctx context.Context, svc anthropic.MessageService, msgParams anthropic.MessageNewParams) (_ *anthropic.Message, outErr error) { - ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) defer aibtrace.EndSpanErr(span, &outErr) return svc.New(ctx, msgParams) diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index 139f7c2..c8e8769 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -47,8 +47,8 @@ func (s *AnthropicMessagesStreamingInterception) Streaming() bool { return true } -func (s *AnthropicMessagesStreamingInterception) TraceAttributes(ctx context.Context) []attribute.KeyValue { - return s.AnthropicMessagesInterceptionBase.baseTraceAttributes(ctx, true) +func (s *AnthropicMessagesStreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { + return s.AnthropicMessagesInterceptionBase.baseTraceAttributes(r, true) } // ProcessRequest handles a request to /v1/messages. @@ -75,7 +75,7 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW return fmt.Errorf("developer error: req is nil") } - ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(r.Context())...)) + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(r.Context())...)) defer aibtrace.EndSpanErr(span, &outErr) // Allow us to interrupt watch via cancel. @@ -129,19 +129,20 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW isFirst := true newStream: for { + // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) if err := streamCtx.Err(); err != nil { lastErr = fmt.Errorf("stream exit: %w", err) break } - stream := svc.NewStreaming(streamCtx, messages) + stream := i.traceNewStreaming(streamCtx, svc, messages) // traces svc.NewStreaming(streamCtx, messages) var message anthropic.Message var lastToolName string pendingToolCalls := make(map[string]string) - for i.traceStreamNext(ctx, stream) { // traces stream.Next() call + for stream.Next() { event := stream.Current() if err := message.Accumulate(event); err != nil { logger.Warn(ctx, "failed to accumulate streaming events", slog.Error(err), slog.F("event", event), slog.F("msg", message.RawJSON())) @@ -521,9 +522,9 @@ func (s *AnthropicMessagesStreamingInterception) encodeForStream(payload []byte, return buf.Bytes() } -func (s *AnthropicMessagesStreamingInterception) traceStreamNext(ctx context.Context, stream *ssestream.Stream[anthropic.MessageStreamEventUnion]) bool { - _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) +func (s *AnthropicMessagesStreamingInterception) traceNewStreaming(ctx context.Context, svc anthropic.MessageService, messages anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEventUnion] { + _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) defer span.End() - return stream.Next() + return svc.NewStreaming(ctx, messages) } diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index f81b245..8cf65cf 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -47,12 +47,13 @@ func (i *OpenAIChatInterceptionBase) Setup(logger slog.Logger, recorder Recorder i.mcpProxy = mcpProxy } -func (s *OpenAIChatInterceptionBase) baseTraceAttributes(ctx context.Context, streaming bool) []attribute.KeyValue { +func (s *OpenAIChatInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { return []attribute.KeyValue{ - attribute.String(aibtrace.Provider, ProviderOpenAI), + attribute.String(aibtrace.RequestPath, r.URL.Path), attribute.String(aibtrace.InterceptionID, s.id.String()), + attribute.String(aibtrace.UserID, actorFromContext(r.Context()).id), + attribute.String(aibtrace.Provider, ProviderOpenAI), attribute.String(aibtrace.Model, s.Model()), - attribute.String(aibtrace.UserID, actorFromContext(ctx).id), attribute.Bool(aibtrace.Streaming, streaming), } } diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index f4d1db3..ecd6995 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -44,8 +44,8 @@ func (s *OpenAIBlockingChatInterception) Streaming() bool { return false } -func (s *OpenAIBlockingChatInterception) TraceAttributes(ctx context.Context) []attribute.KeyValue { - return s.OpenAIChatInterceptionBase.baseTraceAttributes(ctx, false) +func (s *OpenAIBlockingChatInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { + return s.OpenAIChatInterceptionBase.baseTraceAttributes(r, false) } func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { @@ -53,7 +53,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r return fmt.Errorf("developer error: req is nil") } - ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(r.Context())...)) + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(r.Context())...)) defer aibtrace.EndSpanErr(span, &outErr) svc := i.newCompletionsService(i.baseURL, i.key) @@ -73,10 +73,11 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } for { + // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) var opts []option.RequestOption opts = append(opts, option.WithRequestTimeout(time.Second*60)) // TODO: configurable timeout - completion, err = i.traceChatCompletionsNew(ctx, svc, opts) // traces svc.New call + completion, err = i.traceChatCompletionsNew(ctx, svc, opts) // traces svc.New(ctx, i.req.ChatCompletionNewParams, opts...) call if err != nil { break } @@ -239,9 +240,9 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r return nil } -func (i *OpenAIBlockingChatInterception) traceChatCompletionsNew(ctx context.Context, client openai.ChatCompletionService, opts []option.RequestOption) (_ *openai.ChatCompletion, outErr error) { - ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) +func (i *OpenAIBlockingChatInterception) traceChatCompletionsNew(ctx context.Context, svc openai.ChatCompletionService, opts []option.RequestOption) (_ *openai.ChatCompletion, outErr error) { + ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) defer aibtrace.EndSpanErr(span, &outErr) - return client.New(ctx, i.req.ChatCompletionNewParams, opts...) + return svc.New(ctx, i.req.ChatCompletionNewParams, opts...) } diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index f1ec165..9e08c0e 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -46,8 +46,8 @@ func (i *OpenAIStreamingChatInterception) Streaming() bool { return true } -func (s *OpenAIStreamingChatInterception) TraceAttributes(ctx context.Context) []attribute.KeyValue { - return s.OpenAIChatInterceptionBase.baseTraceAttributes(ctx, true) +func (s *OpenAIStreamingChatInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { + return s.OpenAIChatInterceptionBase.baseTraceAttributes(r, true) } // ProcessRequest handles a request to /v1/chat/completions. @@ -67,7 +67,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, return fmt.Errorf("developer error: req is nil") } - ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(r.Context())...)) + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(r.Context())...)) defer aibtrace.EndSpanErr(span, &outErr) // Include token usage. @@ -115,12 +115,13 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, interceptionErr error ) for { - stream = svc.NewStreaming(streamCtx, i.req.ChatCompletionNewParams) + // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) + stream = i.traceNewStreaming(streamCtx, svc) // traces svc.NewStreaming(streamCtx, i.req.ChatCompletionNewParams) call processor := newStreamProcessor(streamCtx, i.logger.Named("stream-processor"), i.getInjectedToolByName) var toolCall *openai.FinishedChatCompletionToolCall - for i.traceStreamNext(ctx, stream) { // traces stream.Next() call + for stream.Next() { chunk := stream.Current() canRelay := processor.process(chunk) @@ -347,11 +348,11 @@ func (i *OpenAIStreamingChatInterception) encodeForStream(payload []byte) []byte return buf.Bytes() } -func (i *OpenAIStreamingChatInterception) traceStreamNext(ctx context.Context, stream *ssestream.Stream[openai.ChatCompletionChunk]) bool { - _, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) +func (i *OpenAIStreamingChatInterception) traceNewStreaming(ctx context.Context, svc openai.ChatCompletionService) *ssestream.Stream[openai.ChatCompletionChunk] { + _, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) defer span.End() - return stream.Next() + return svc.NewStreaming(ctx, i.req.ChatCompletionNewParams) } type openAIStreamProcessor struct { diff --git a/interception.go b/interception.go index 7af954d..dc1ecfe 100644 --- a/interception.go +++ b/interception.go @@ -1,7 +1,6 @@ package aibridge import ( - "context" "errors" "fmt" "net/http" @@ -31,7 +30,7 @@ type Interceptor interface { // Specifies whether an interceptor handles streaming or not. Streaming() bool // TraceAttributes returns tracing attributes for this [Interceptor] - TraceAttributes(context.Context) []attribute.KeyValue + TraceAttributes(*http.Request) []attribute.KeyValue } var UnknownRoute = errors.New("unknown route") @@ -68,9 +67,9 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, return } - traceAttrs := interceptor.TraceAttributes(ctx) + traceAttrs := interceptor.TraceAttributes(r) span.SetAttributes(traceAttrs...) - ctx = aibtrace.WithTraceInterceptionAttributesInContext(ctx, traceAttrs) + ctx = aibtrace.WithInterceptionAttributesInContext(ctx, traceAttrs) r = r.WithContext(ctx) // Record usage in the background to not block request flow. diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index c72bb79..0ffef1d 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -307,10 +307,11 @@ func TestToolInjectionOrder(t *testing.T) { mcpSrv := httptest.NewServer(createMockMCPSrv(t)) t.Cleanup(mcpSrv.Close) + tracer := otel.Tracer("forTesting") // When: creating two MCP server proxies, both listing the same tools by name but under different server namespaces. - proxy, err := mcp.NewStreamableHTTPServerProxy(logger, "coder", mcpSrv.URL, nil, nil, nil) + proxy, err := mcp.NewStreamableHTTPServerProxy(logger, tracer, "coder", mcpSrv.URL, nil, nil, nil) require.NoError(t, err) - proxy2, err := mcp.NewStreamableHTTPServerProxy(logger, "shmoder", mcpSrv.URL, nil, nil, nil) + proxy2, err := mcp.NewStreamableHTTPServerProxy(logger, tracer, "shmoder", mcpSrv.URL, nil, nil, nil) require.NoError(t, err) // Then: initialize both proxies. diff --git a/mcp/proxy_streamable_http.go b/mcp/proxy_streamable_http.go index 5a5c092..67c2cdf 100644 --- a/mcp/proxy_streamable_http.go +++ b/mcp/proxy_streamable_http.go @@ -8,9 +8,12 @@ import ( "strings" "cdr.dev/slog" + "github.com/coder/aibridge/aibtrace" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "golang.org/x/exp/maps" ) @@ -21,12 +24,13 @@ type StreamableHTTPServerProxy struct { serverURL string client *client.Client logger slog.Logger + tracer trace.Tracer tools map[string]*Tool allowlistPattern, denylistPattern *regexp.Regexp } -func NewStreamableHTTPServerProxy(logger slog.Logger, serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp) (*StreamableHTTPServerProxy, error) { +func NewStreamableHTTPServerProxy(logger slog.Logger, tracer trace.Tracer, serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp) (*StreamableHTTPServerProxy, error) { var opts []transport.StreamableHTTPCOption if headers != nil { opts = append(opts, transport.WithHTTPHeaders(headers)) @@ -42,6 +46,7 @@ func NewStreamableHTTPServerProxy(logger slog.Logger, serverName, serverURL stri serverURL: serverURL, client: mcpClient, logger: logger, + tracer: tracer, allowlistPattern: allowlist, denylistPattern: denylist, }, nil @@ -51,7 +56,10 @@ func (p *StreamableHTTPServerProxy) Name() string { return p.serverName } -func (p *StreamableHTTPServerProxy) Init(ctx context.Context) error { +func (p *StreamableHTTPServerProxy) Init(ctx context.Context) (outErr error) { + ctx, span := p.tracer.Start(ctx, "StreamableHTTPServerProxy.Init", trace.WithAttributes(p.traceAttributes()...)) + defer aibtrace.EndSpanErr(span, &outErr) + if err := p.client.Start(ctx); err != nil { return fmt.Errorf("start client: %w", err) } @@ -122,7 +130,10 @@ func (p *StreamableHTTPServerProxy) CallTool(ctx context.Context, name string, i }) } -func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (map[string]*Tool, error) { +func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (_ map[string]*Tool, outErr error) { + ctx, span := p.tracer.Start(ctx, "StreamableHTTPServerProxy.Init.fetchTools", trace.WithAttributes(p.traceAttributes()...)) + defer aibtrace.EndSpanErr(span, &outErr) + tools, err := p.client.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { return nil, fmt.Errorf("list MCP tools: %w", err) @@ -140,6 +151,7 @@ func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (map[string] Description: tool.Description, Params: tool.InputSchema.Properties, Required: tool.InputSchema.Required, + Logger: p.logger, } } return out, nil @@ -154,3 +166,11 @@ func (p *StreamableHTTPServerProxy) Shutdown(ctx context.Context) error { // it has an internal timeout of 5s, though. return p.client.Close() } + +func (p *StreamableHTTPServerProxy) traceAttributes() []attribute.KeyValue { + return []attribute.KeyValue{ + attribute.String(aibtrace.MCPProxyName, p.Name()), + attribute.String(aibtrace.MCPServerName, p.serverName), + attribute.String(aibtrace.MCPServerURL, p.serverURL), + } +} diff --git a/mcp/server_proxy_manager.go b/mcp/server_proxy_manager.go index e11cd5e..64d9eff 100644 --- a/mcp/server_proxy_manager.go +++ b/mcp/server_proxy_manager.go @@ -10,7 +10,6 @@ import ( "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/utils" "github.com/mark3labs/mcp-go/mcp" - "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" ) @@ -54,11 +53,8 @@ func (s *ServerProxyManager) Init(ctx context.Context) (outErr error) { defer aibtrace.EndSpanErr(span, &outErr) cg := utils.NewConcurrentGroup() - for name, proxy := range s.proxiers { + for _, proxy := range s.proxiers { cg.Go(func() error { - ctx, span := s.tracer.Start(ctx, "ServerProxyManager.Init.Proxy", trace.WithAttributes(attribute.String(aibtrace.MCPProxyName, name))) - defer aibtrace.EndSpanErr(span, &outErr) - return proxy.Init(ctx) }) } diff --git a/mcp/tool.go b/mcp/tool.go index 23dffc4..a4aaa83 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -2,6 +2,7 @@ package mcp import ( "context" + "encoding/json" "errors" "regexp" "strings" @@ -14,6 +15,7 @@ import ( ) const ( + maxSpanInputAttrLen = 100 injectedToolPrefix = "bmcp" // "bridged MCP" injectedToolDelimiter = "_" ) @@ -35,6 +37,7 @@ type Tool struct { Description string Params map[string]any Required []string + Logger slog.Logger } func (t *Tool) Call(ctx context.Context, tracer trace.Tracer, input any) (_ *mcp.CallToolResult, outErr error) { @@ -46,10 +49,22 @@ func (t *Tool) Call(ctx context.Context, tracer trace.Tracer, input any) (_ *mcp } spanAttrs := append( - aibtrace.TraceInterceptionAttributesFromContext(ctx), - attribute.String(aibtrace.MCPToolName, t.Name), + aibtrace.InterceptionAttributesFromContext(ctx), + attribute.String(aibtrace.MCPServerName, t.ServerName), + attribute.String(aibtrace.MCPServerURL, t.ServerURL), ) - ctx, span := tracer.Start(ctx, "Intercept.RecordInterception.ToolCall", trace.WithAttributes(spanAttrs...)) + inputJson, err := json.Marshal(input) + if err != nil { + t.Logger.Warn(ctx, "failed to marshal tool input, will be omitted from span attrs: %v", err) + } else { + strJson := string(inputJson) + if len(strJson) > maxSpanInputAttrLen { + strJson = strJson[:100] + } + spanAttrs = append(spanAttrs, attribute.String(aibtrace.MCPInput, strJson)) + } + + ctx, span := tracer.Start(ctx, "Intercept.ProcessRequest.ToolCall", trace.WithAttributes(spanAttrs...)) defer aibtrace.EndSpanErr(span, &outErr) return t.Client.CallTool(ctx, mcp.CallToolRequest{ diff --git a/metrics_integration_test.go b/metrics_integration_test.go index 95bdd08..cf3201e 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -49,7 +49,7 @@ func TestMetrics_Interception(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) - srv, _ := newTestSrv(t, ctx, provider, metrics, defaultTracer) + srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) req := createAnthropicMessagesReq(t, srv.URL, files[fixtureRequest]) resp, err := http.DefaultClient.Do(req) @@ -90,7 +90,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil) - bridgeSrv, _ := newTestSrv(t, ctx, provider, metrics, defaultTracer) + bridgeSrv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) // Make request in background. doneCh := make(chan struct{}) @@ -142,7 +142,7 @@ func TestMetrics_PassthroughCount(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, t.Context(), provider, metrics, defaultTracer) + srv, _ := newTestSrv(t, t.Context(), provider, metrics, testTracer) req, err := http.NewRequestWithContext(t.Context(), "GET", srv.URL+"/openai/v1/models", nil) require.NoError(t, err) @@ -171,7 +171,7 @@ func TestMetrics_PromptCount(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, provider, metrics, defaultTracer) + srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) resp, err := http.DefaultClient.Do(req) @@ -199,7 +199,7 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, provider, metrics, defaultTracer) + srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) resp, err := http.DefaultClient.Do(req) @@ -237,11 +237,11 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) // Setup mocked MCP server & tools. - tools := setupMCPServerProxiesForTest(t) - mcpMgr := mcp.NewServerProxyManager(tools, defaultTracer) + tools := setupMCPServerProxiesForTest(t, testTracer) + mcpMgr := mcp.NewServerProxyManager(tools, testTracer) require.NoError(t, mcpMgr.Init(ctx)) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, metrics, defaultTracer, logger) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, metrics, testTracer, logger) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) @@ -283,7 +283,7 @@ func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, m } wrappedRecorder := aibridge.NewRecorder(logger, tracer, clientFn) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil, defaultTracer), metrics, tracer, logger) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil, testTracer), metrics, tracer, logger) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) diff --git a/provider_anthropic.go b/provider_anthropic.go index 009c30f..cb7e675 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -84,7 +84,7 @@ func (p *AnthropicProvider) CreateInterceptor(tracer trace.Tracer, w http.Respon } else { interceptor = NewAnthropicMessagesBlockingInterception(id, &req, p.cfg, p.bedrockCfg, tracer) } - span.SetAttributes(interceptor.TraceAttributes(r.Context())...) + span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil } diff --git a/provider_openai.go b/provider_openai.go index 7cc154a..0d71973 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -85,7 +85,7 @@ func (p *OpenAIProvider) CreateInterceptor(tracer trace.Tracer, w http.ResponseW } else { interceptor = NewOpenAIBlockingChatInterception(id, &req, p.baseURL, p.key, tracer) } - span.SetAttributes(interceptor.TraceAttributes(r.Context())...) + span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil } diff --git a/recorder.go b/recorder.go index 5bbcec0..26475a1 100644 --- a/recorder.go +++ b/recorder.go @@ -26,7 +26,7 @@ type RecorderWrapper struct { } func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *InterceptionRecord) (outErr error) { - ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterception", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterception", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) defer aibtrace.EndSpanErr(span, &outErr) client, err := r.clientFn() @@ -44,7 +44,7 @@ func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *Intercept } func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) (outErr error) { - ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterceptionEnded", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterceptionEnded", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) defer aibtrace.EndSpanErr(span, &outErr) client, err := r.clientFn() @@ -62,7 +62,7 @@ func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *Inte } func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) (outErr error) { - ctx, span := r.tracer.Start(ctx, "Intercept.RecordPromptUsage", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + ctx, span := r.tracer.Start(ctx, "Intercept.RecordPromptUsage", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) defer aibtrace.EndSpanErr(span, &outErr) client, err := r.clientFn() @@ -80,7 +80,7 @@ func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsag } func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) (outErr error) { - ctx, span := r.tracer.Start(ctx, "Intercept.RecordTokenUsage", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + ctx, span := r.tracer.Start(ctx, "Intercept.RecordTokenUsage", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) defer aibtrace.EndSpanErr(span, &outErr) client, err := r.clientFn() @@ -98,7 +98,7 @@ func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageR } func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) (outErr error) { - ctx, span := r.tracer.Start(ctx, "Intercept.RecordToolUsage", trace.WithAttributes(aibtrace.TraceInterceptionAttributesFromContext(ctx)...)) + ctx, span := r.tracer.Start(ctx, "Intercept.RecordToolUsage", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) defer aibtrace.EndSpanErr(span, &outErr) client, err := r.clientFn() diff --git a/trace_integration_test.go b/trace_integration_test.go index 2b0a159..f0618f0 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" "github.com/coder/aibridge" "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" @@ -44,27 +46,24 @@ func TestTraceAnthropic(t *testing.T) { } cases := []struct { - name string - streaming bool - bedrock bool - expectTraceCounts []expectTrace - expectSpanCount int + name string + streaming bool + bedrock bool + expect []expectTrace }{ { - name: "trace_anthr_non_streaming", - expectTraceCounts: expectNonStreaming, - expectSpanCount: 9, + name: "trace_anthr_non_streaming", + expect: expectNonStreaming, }, { - name: "trace_bedrock_non_streaming", - bedrock: true, - expectTraceCounts: expectNonStreaming, - expectSpanCount: 9, + name: "trace_bedrock_non_streaming", + bedrock: true, + expect: expectNonStreaming, }, { name: "trace_anthr_streaming", streaming: true, - expectTraceCounts: []expectTrace{ + expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, {"Intercept.RecordInterception", 1, codes.Unset}, @@ -73,15 +72,14 @@ func TestTraceAnthropic(t *testing.T) { {"Intercept.RecordPromptUsage", 1, codes.Unset}, {"Intercept.RecordTokenUsage", 2, codes.Unset}, {"Intercept.RecordToolUsage", 1, codes.Unset}, - {"Intercept.ProcessRequest.Upstream", 9, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, }, - expectSpanCount: 18, }, { name: "trace_bedrock_streaming", streaming: true, bedrock: true, - expectTraceCounts: []expectTrace{ + expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, {"Intercept.RecordInterception", 1, codes.Unset}, @@ -90,7 +88,6 @@ func TestTraceAnthropic(t *testing.T) { {"Intercept.RecordPromptUsage", 1, codes.Unset}, {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, }, - expectSpanCount: 7, }, } @@ -148,7 +145,14 @@ func TestTraceAnthropic(t *testing.T) { if tc.bedrock { model = "beddel" } + + totalCount := 0 + for _, e := range tc.expect { + totalCount += e.count + } + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.RequestPath, req.URL.Path), attribute.String(aibtrace.InterceptionID, intcID), attribute.String(aibtrace.Provider, aibridge.ProviderAnthropic), attribute.String(aibtrace.Model, model), @@ -157,18 +161,17 @@ func TestTraceAnthropic(t *testing.T) { attribute.Bool(aibtrace.IsBedrock, tc.bedrock), } - require.Len(t, sr.Ended(), tc.expectSpanCount) - verifyTraces(t, sr, tc.expectTraceCounts, attrs) + require.Len(t, sr.Ended(), totalCount) + verifyTraces(t, sr, tc.expect, attrs) }) } } func TestTraceAnthropicErr(t *testing.T) { cases := []struct { - name string - streaming bool - expect []expectTrace - expectSpanCount int + name string + streaming bool + expect []expectTrace }{ { name: "trace_anthr_non_streaming_err", @@ -180,7 +183,6 @@ func TestTraceAnthropicErr(t *testing.T) { {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, {"Intercept.ProcessRequest.Upstream", 1, codes.Error}, }, - expectSpanCount: 6, }, { name: "trace_anthr_streaming_err", @@ -193,9 +195,8 @@ func TestTraceAnthropicErr(t *testing.T) { {"Intercept.RecordPromptUsage", 1, codes.Unset}, {"Intercept.RecordTokenUsage", 1, codes.Unset}, {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, - {"Intercept.ProcessRequest.Upstream", 3, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, }, - expectSpanCount: 10, }, } @@ -250,7 +251,13 @@ func TestTraceAnthropicErr(t *testing.T) { require.Equal(t, 1, len(recorder.interceptions)) intcID := recorder.interceptions[0].ID + totalCount := 0 + for _, e := range tc.expect { + totalCount += e.count + } + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.RequestPath, req.URL.Path), attribute.String(aibtrace.InterceptionID, intcID), attribute.String(aibtrace.Provider, aibridge.ProviderAnthropic), attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), @@ -259,7 +266,7 @@ func TestTraceAnthropicErr(t *testing.T) { attribute.Bool(aibtrace.IsBedrock, false), } - require.Len(t, sr.Ended(), tc.expectSpanCount) + require.Len(t, sr.Ended(), totalCount) verifyTraces(t, sr, tc.expect, attrs) }) } @@ -267,11 +274,10 @@ func TestTraceAnthropicErr(t *testing.T) { func TestTraceOpenAI(t *testing.T) { cases := []struct { - name string - fixture []byte - streaming bool - expect []expectTrace - expectSpanCount int + name string + fixture []byte + streaming bool + expect []expectTrace }{ { name: "trace_openai_streaming", @@ -285,9 +291,8 @@ func TestTraceOpenAI(t *testing.T) { {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, {"Intercept.RecordPromptUsage", 1, codes.Unset}, {"Intercept.RecordTokenUsage", 1, codes.Unset}, - {"Intercept.ProcessRequest.Upstream", 242, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, }, - expectSpanCount: 249, }, { name: "trace_openai_non_streaming", @@ -303,7 +308,6 @@ func TestTraceOpenAI(t *testing.T) { {"Intercept.RecordTokenUsage", 1, codes.Unset}, {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, }, - expectSpanCount: 8, }, } @@ -345,15 +349,20 @@ func TestTraceOpenAI(t *testing.T) { require.Equal(t, 1, len(recorder.interceptions)) intcID := recorder.interceptions[0].ID + totalCount := 0 + for _, e := range tc.expect { + totalCount += e.count + } + require.Len(t, sr.Ended(), totalCount) + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.RequestPath, req.URL.Path), attribute.String(aibtrace.InterceptionID, intcID), attribute.String(aibtrace.Provider, aibridge.ProviderOpenAI), attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), attribute.String(aibtrace.UserID, userID), attribute.Bool(aibtrace.Streaming, tc.streaming), } - - require.Len(t, sr.Ended(), tc.expectSpanCount) verifyTraces(t, sr, tc.expect, attrs) }) } @@ -361,10 +370,9 @@ func TestTraceOpenAI(t *testing.T) { func TestTraceOpenAIErr(t *testing.T) { cases := []struct { - name string - streaming bool - expect []expectTrace - expectSpanCount int + name string + streaming bool + expect []expectTrace }{ { name: "trace_openai_streaming_err", @@ -376,9 +384,8 @@ func TestTraceOpenAIErr(t *testing.T) { {"Intercept.ProcessRequest", 1, codes.Error}, {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, {"Intercept.RecordPromptUsage", 1, codes.Unset}, - {"Intercept.ProcessRequest.Upstream", 5, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, }, - expectSpanCount: 11, }, { name: "trace_openai_non_streaming_err", @@ -391,7 +398,6 @@ func TestTraceOpenAIErr(t *testing.T) { {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, {"Intercept.ProcessRequest.Upstream", 1, codes.Error}, }, - expectSpanCount: 6, }, } @@ -445,15 +451,20 @@ func TestTraceOpenAIErr(t *testing.T) { require.Equal(t, 1, len(recorder.interceptions)) intcID := recorder.interceptions[0].ID + totalCount := 0 + for _, e := range tc.expect { + totalCount += e.count + } + require.Len(t, sr.Ended(), totalCount) + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.RequestPath, req.URL.Path), attribute.String(aibtrace.InterceptionID, intcID), attribute.String(aibtrace.Provider, aibridge.ProviderOpenAI), attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), attribute.String(aibtrace.UserID, userID), attribute.Bool(aibtrace.Streaming, tc.streaming), } - - require.Len(t, sr.Ended(), tc.expectSpanCount) verifyTraces(t, sr, tc.expect, attrs) }) } @@ -510,19 +521,29 @@ func TestNewServerProxyManagerTraces(t *testing.T) { tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - tools := setupMCPServerProxiesForTest(t) + serverName := "serverName" + mcpSrv := httptest.NewServer(createMockMCPSrv(t)) + t.Cleanup(mcpSrv.Close) + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + proxy, err := mcp.NewStreamableHTTPServerProxy(logger, tracer, serverName, mcpSrv.URL, nil, nil, nil) + require.NoError(t, err) + tools := map[string]mcp.ServerProxier{"unusedValue": proxy} + mcpMgr := mcp.NewServerProxyManager(tools, tracer) - err := mcpMgr.Init(ctx) + err = mcpMgr.Init(ctx) require.NoError(t, err) - require.Len(t, sr.Ended(), 2) + require.Len(t, sr.Ended(), 3) verifyTraces(t, sr, []expectTrace{{"ServerProxyManager.Init", 1, codes.Unset}}, []attribute.KeyValue{}) - for name := range tools { - tc := []expectTrace{{"ServerProxyManager.Init.Proxy", 1, codes.Unset}} - a := []attribute.KeyValue{attribute.String(aibtrace.MCPProxyName, name)} - verifyTraces(t, sr, tc, a) + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.MCPProxyName, proxy.Name()), + attribute.String(aibtrace.MCPServerURL, mcpSrv.URL), + attribute.String(aibtrace.MCPServerName, serverName), } + verifyTraces(t, sr, []expectTrace{{"StreamableHTTPServerProxy.Init", 1, codes.Unset}}, attrs) + verifyTraces(t, sr, []expectTrace{{"StreamableHTTPServerProxy.Init.fetchTools", 1, codes.Unset}}, attrs) } func cmpAttrKeyVal(a attribute.KeyValue, b attribute.KeyValue) bool { From a863d67e49dfce6152ec40086f66892ecf8b684b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 27 Nov 2025 21:50:05 +0000 Subject: [PATCH 06/16] fmt fix --- aibtrace/aibtrace.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aibtrace/aibtrace.go b/aibtrace/aibtrace.go index d43a61c..e201917 100644 --- a/aibtrace/aibtrace.go +++ b/aibtrace/aibtrace.go @@ -8,8 +8,10 @@ import ( "go.opentelemetry.io/otel/trace" ) -type traceInterceptionAttrsContextKey struct{} -type traceRequestBridgeAttrsContextKey struct{} +type ( + traceInterceptionAttrsContextKey struct{} + traceRequestBridgeAttrsContextKey struct{} +) const ( // trace attribute key constants From 5e81f3c6fc3c151fa2dbe8d3765b47908e03f117 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 28 Nov 2025 10:54:29 +0000 Subject: [PATCH 07/16] re-add tool_name attr to tool.Call + toolCount attr in fetchTools trace --- aibtrace/aibtrace.go | 1 + mcp/proxy_streamable_http.go | 1 + mcp/tool.go | 1 + trace_integration_test.go | 2 ++ 4 files changed, 5 insertions(+) diff --git a/aibtrace/aibtrace.go b/aibtrace/aibtrace.go index e201917..2eb01a4 100644 --- a/aibtrace/aibtrace.go +++ b/aibtrace/aibtrace.go @@ -32,6 +32,7 @@ const ( MCPToolName = "mcp_tool_name" MCPServerName = "mcp_server_name" MCPServerURL = "mcp_server_url" + MCPToolCount = "mcp_tool_count" APIKeyID = "api_key_id" ) diff --git a/mcp/proxy_streamable_http.go b/mcp/proxy_streamable_http.go index 67c2cdf..59894ce 100644 --- a/mcp/proxy_streamable_http.go +++ b/mcp/proxy_streamable_http.go @@ -154,6 +154,7 @@ func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (_ map[strin Logger: p.logger, } } + span.SetAttributes(append(p.traceAttributes(), attribute.Int(aibtrace.MCPToolCount, len(out)))...) return out, nil } diff --git a/mcp/tool.go b/mcp/tool.go index a4aaa83..8fea35b 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -50,6 +50,7 @@ func (t *Tool) Call(ctx context.Context, tracer trace.Tracer, input any) (_ *mcp spanAttrs := append( aibtrace.InterceptionAttributesFromContext(ctx), + attribute.String(aibtrace.MCPToolName, t.Name), attribute.String(aibtrace.MCPServerName, t.ServerName), attribute.String(aibtrace.MCPServerURL, t.ServerURL), ) diff --git a/trace_integration_test.go b/trace_integration_test.go index f0618f0..f43cfa6 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -543,6 +543,8 @@ func TestNewServerProxyManagerTraces(t *testing.T) { attribute.String(aibtrace.MCPServerName, serverName), } verifyTraces(t, sr, []expectTrace{{"StreamableHTTPServerProxy.Init", 1, codes.Unset}}, attrs) + + attrs = append(attrs, attribute.Int(aibtrace.MCPToolCount, len(proxy.ListTools()))) verifyTraces(t, sr, []expectTrace{{"StreamableHTTPServerProxy.Init.fetchTools", 1, codes.Unset}}, attrs) } From 1fd90950c7a0474afa84fbf96e5d4f3589931763 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 28 Nov 2025 11:18:37 +0000 Subject: [PATCH 08/16] test fix --- bridge_integration_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index aa3c7c6..b11d31a 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -1214,6 +1214,7 @@ func TestErrorHandling(t *testing.T) { resp, err := http.DefaultClient.Do(req) t.Cleanup(func() { _ = resp.Body.Close() }) require.NoError(t, err) + bridgeSrv.Close() tc.responseHandlerFn(resp) recorderClient.verifyAllInterceptionsEnded(t) From 7760d3149b0e918806a42ef28d5e2ff8b3dab4c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 28 Nov 2025 15:47:46 +0000 Subject: [PATCH 09/16] add tool call trace test, fixed openAI tool call attrs? --- aibtrace/aibtrace.go | 2 +- bridge_integration_test.go | 19 +-- intercept_anthropic_messages_base.go | 2 +- intercept_openai_chat_base.go | 15 +- intercept_openai_chat_blocking.go | 14 +- intercept_openai_chat_streaming.go | 8 +- mcp/proxy_streamable_http.go | 12 +- trace_integration_test.go | 237 ++++++++++++++++++++++++--- 8 files changed, 238 insertions(+), 71 deletions(-) diff --git a/aibtrace/aibtrace.go b/aibtrace/aibtrace.go index 2eb01a4..2909008 100644 --- a/aibtrace/aibtrace.go +++ b/aibtrace/aibtrace.go @@ -18,7 +18,7 @@ const ( RequestPath = "request_path" InterceptionID = "interception_id" - UserID = "user_id" + InitiatorID = "user_id" Provider = "provider" Model = "model" Streaming = "streaming" diff --git a/bridge_integration_test.go b/bridge_integration_test.go index b11d31a..264d17f 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -442,9 +442,9 @@ func TestOpenAIChatCompletions(t *testing.T) { require.Len(t, recorderClient.toolUsages, 1) assert.Equal(t, "read_file", recorderClient.toolUsages[0].Tool) - require.IsType(t, map[string]any{}, recorderClient.toolUsages[0].Args) + require.IsType(t, "", recorderClient.toolUsages[0].Args) require.Contains(t, recorderClient.toolUsages[0].Args, "path") - assert.Equal(t, "README.md", recorderClient.toolUsages[0].Args.(map[string]any)["path"]) + assert.Equal(t, "README.md", gjson.Get(recorderClient.toolUsages[0].Args.(string), "path").Str) require.Len(t, recorderClient.userPrompts, 1) assert.Equal(t, "how large is the README.md file in my current path", recorderClient.userPrompts[0].Prompt) @@ -765,7 +765,7 @@ func TestAnthropicInjectedTools(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) + recorderClient, _, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) // Ensure expected tool was invoked with expected input. require.Len(t, recorderClient.toolUsages, 1) @@ -847,16 +847,13 @@ func TestOpenAIInjectedTools(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) + recorderClient, _, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) // Ensure expected tool was invoked with expected input. require.Len(t, recorderClient.toolUsages, 1) require.Equal(t, mockToolName, recorderClient.toolUsages[0].Tool) - expected, err := json.Marshal(map[string]any{"owner": "admin"}) - require.NoError(t, err) - actual, err := json.Marshal(recorderClient.toolUsages[0].Args) - require.NoError(t, err) - require.EqualValues(t, expected, actual) + expected := "{\"owner\":\"admin\"}" + require.EqualValues(t, expected, recorderClient.toolUsages[0].Args) var ( content *openai.ChatCompletionChoice @@ -932,7 +929,7 @@ func TestOpenAIInjectedTools(t *testing.T) { // setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests. // Kinda fugly right now, we can refactor this later. -func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error), createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *http.Response) { +func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error), createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, map[string]mcp.ServerProxier, *http.Response) { t.Helper() arc := txtar.Parse(fixture) @@ -1008,7 +1005,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu return mockSrv.callCount.Load() == 2 }, time.Second*10, time.Millisecond*50) - return recorderClient, resp + return recorderClient, tools, resp } func TestErrorHandling(t *testing.T) { diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 071be96..66d00b3 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -67,7 +67,7 @@ func (s *AnthropicMessagesInterceptionBase) baseTraceAttributes(r *http.Request, return []attribute.KeyValue{ attribute.String(aibtrace.RequestPath, r.URL.Path), attribute.String(aibtrace.InterceptionID, s.id.String()), - attribute.String(aibtrace.UserID, actorFromContext(r.Context()).id), + attribute.String(aibtrace.InitiatorID, actorFromContext(r.Context()).id), attribute.String(aibtrace.Provider, ProviderAnthropic), attribute.String(aibtrace.Model, s.Model()), attribute.Bool(aibtrace.Streaming, streaming), diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 8cf65cf..a9a3bf9 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "net/http" - "strings" aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" @@ -51,7 +50,7 @@ func (s *OpenAIChatInterceptionBase) baseTraceAttributes(r *http.Request, stream return []attribute.KeyValue{ attribute.String(aibtrace.RequestPath, r.URL.Path), attribute.String(aibtrace.InterceptionID, s.id.String()), - attribute.String(aibtrace.UserID, actorFromContext(r.Context()).id), + attribute.String(aibtrace.InitiatorID, actorFromContext(r.Context()).id), attribute.String(aibtrace.Provider, ProviderOpenAI), attribute.String(aibtrace.Model, s.Model()), attribute.Bool(aibtrace.Streaming, streaming), @@ -105,18 +104,6 @@ func (i *OpenAIChatInterceptionBase) injectTools() { } } -func (i *OpenAIChatInterceptionBase) unmarshalArgs(in string) (args ToolArgs) { - if len(strings.TrimSpace(in)) == 0 { - return args // An empty string will fail JSON unmarshaling. - } - - if err := json.Unmarshal([]byte(in), &args); err != nil { - i.logger.Warn(context.Background(), "failed to unmarshal tool args", slog.Error(err)) - } - - return args -} - // writeUpstreamError marshals and writes a given error. func (i *OpenAIChatInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *OpenAIErrorResponse) { if oaiErr == nil { diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index ecd6995..98975f6 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -1,7 +1,6 @@ package aibridge import ( - "bytes" "context" "encoding/json" "fmt" @@ -120,7 +119,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r InterceptionID: i.ID().String(), MsgID: completion.ID, Tool: toolCall.Function.Name, - Args: i.unmarshalArgs(toolCall.Function.Arguments), + Args: toolCall.Function.Arguments, Injected: false, }) } @@ -151,20 +150,13 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r appendedPrevMsg = true } - var ( - args map[string]string - buf bytes.Buffer - ) - _ = json.NewEncoder(&buf).Encode(tc.Function.Arguments) - _ = json.NewDecoder(&buf).Decode(&args) - res, err := tool.Call(ctx, i.tracer, args) - + res, err := tool.Call(ctx, i.tracer, tc.Function.Arguments) _ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: completion.ID, ServerURL: &tool.ServerURL, Tool: tool.Name, - Args: i.unmarshalArgs(tc.Function.Arguments), + Args: tc.Function.Arguments, Injected: true, InvocationError: err, }) diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 9e08c0e..3b2132a 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -154,7 +154,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, InterceptionID: i.ID().String(), MsgID: processor.getMsgID(), Tool: toolCall.Name, - Args: i.unmarshalArgs(toolCall.Arguments), + Args: toolCall.Arguments, Injected: false, }) toolCall = nil @@ -241,15 +241,13 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, i.req.Messages = append(i.req.Messages, processor.getLastCompletion().ToParam()) id := toolCall.ID - args := i.unmarshalArgs(toolCall.Arguments) - toolRes, toolErr := tool.Call(streamCtx, i.tracer, args) - + toolRes, toolErr := tool.Call(streamCtx, i.tracer, toolCall.Arguments) _ = i.recorder.RecordToolUsage(streamCtx, &ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: processor.getMsgID(), ServerURL: &tool.ServerURL, Tool: tool.Name, - Args: args, + Args: toolCall.Arguments, Injected: true, InvocationError: toolErr, }) diff --git a/mcp/proxy_streamable_http.go b/mcp/proxy_streamable_http.go index 59894ce..04b5f38 100644 --- a/mcp/proxy_streamable_http.go +++ b/mcp/proxy_streamable_http.go @@ -20,14 +20,16 @@ import ( var _ ServerProxier = &StreamableHTTPServerProxy{} type StreamableHTTPServerProxy struct { + client *client.Client + logger slog.Logger + tracer trace.Tracer + + allowlistPattern *regexp.Regexp + denylistPattern *regexp.Regexp + serverName string serverURL string - client *client.Client - logger slog.Logger - tracer trace.Tracer tools map[string]*Tool - - allowlistPattern, denylistPattern *regexp.Regexp } func NewStreamableHTTPServerProxy(logger slog.Logger, tracer trace.Tracer, serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp) (*StreamableHTTPServerProxy, error) { diff --git a/trace_integration_test.go b/trace_integration_test.go index f43cfa6..c6ec906 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -118,14 +118,7 @@ func TestTraceAnthropic(t *testing.T) { var bedrockCfg *aibridge.AWSBedrockConfig if tc.bedrock { - bedrockCfg = &aibridge.AWSBedrockConfig{ - Region: "us-west-2", - AccessKey: "test-access-key", - AccessKeySecret: "test-secret-key", - Model: "beddel", // This model should override the request's given one. - SmallFastModel: "modrock", // Unused but needed for validation. - EndpointOverride: mockAPI.URL, - } + bedrockCfg = testBedrockCfg(mockAPI.URL) } provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), bedrockCfg) srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) @@ -156,7 +149,7 @@ func TestTraceAnthropic(t *testing.T) { attribute.String(aibtrace.InterceptionID, intcID), attribute.String(aibtrace.Provider, aibridge.ProviderAnthropic), attribute.String(aibtrace.Model, model), - attribute.String(aibtrace.UserID, userID), + attribute.String(aibtrace.InitiatorID, userID), attribute.Bool(aibtrace.Streaming, tc.streaming), attribute.Bool(aibtrace.IsBedrock, tc.bedrock), } @@ -168,32 +161,55 @@ func TestTraceAnthropic(t *testing.T) { } func TestTraceAnthropicErr(t *testing.T) { + expectNonStream := []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Error}, + } + cases := []struct { name string streaming bool + bedrock bool expect []expectTrace }{ { - name: "trace_anthr_non_streaming_err", + name: "anthr_non_streaming_err", + expect: expectNonStream, + }, + { + name: "anthr_streaming_err", + streaming: true, expect: []expectTrace{ {"Intercept", 1, codes.Error}, {"Intercept.CreateInterceptor", 1, codes.Unset}, {"Intercept.RecordInterception", 1, codes.Unset}, {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, - {"Intercept.ProcessRequest.Upstream", 1, codes.Error}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, }, }, { - name: "trace_anthr_streaming_err", + name: "bedrock_non_streaming_err", + bedrock: true, + expect: expectNonStream, + }, + { + name: "bedrock_streaming_err", streaming: true, + bedrock: true, expect: []expectTrace{ - {"Intercept", 1, codes.Error}, + // RecordTokenUsage missing? + {"Intercept", 1, codes.Unset}, // TODO check why this is unset not Error {"Intercept.CreateInterceptor", 1, codes.Unset}, {"Intercept.RecordInterception", 1, codes.Unset}, - {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.ProcessRequest", 1, codes.Unset}, // TODO check why this is unset not Error {"Intercept.RecordPromptUsage", 1, codes.Unset}, - {"Intercept.RecordTokenUsage", 1, codes.Unset}, {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, }, @@ -201,7 +217,7 @@ func TestTraceAnthropicErr(t *testing.T) { } for _, tc := range cases { - t.Run(t.Name(), func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) @@ -233,7 +249,11 @@ func TestTraceAnthropicErr(t *testing.T) { mockAPI := newMockServer(ctx, t, files, nil) t.Cleanup(mockAPI.Close) - provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) + var bedrockCfg *aibridge.AWSBedrockConfig + if tc.bedrock { + bedrockCfg = testBedrockCfg(mockAPI.URL) + } + provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), bedrockCfg) srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) req := createAnthropicMessagesReq(t, srv.URL, reqBody) @@ -255,23 +275,126 @@ func TestTraceAnthropicErr(t *testing.T) { for _, e := range tc.expect { totalCount += e.count } + for _, s := range sr.Ended() { + t.Logf("SPAN: %v", s.Name()) + } + require.Len(t, sr.Ended(), totalCount) + + model := gjson.Get(string(reqBody), "model").Str + if tc.bedrock { + model = "beddel" + } attrs := []attribute.KeyValue{ attribute.String(aibtrace.RequestPath, req.URL.Path), attribute.String(aibtrace.InterceptionID, intcID), attribute.String(aibtrace.Provider, aibridge.ProviderAnthropic), - attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), - attribute.String(aibtrace.UserID, userID), + attribute.String(aibtrace.Model, model), + attribute.String(aibtrace.InitiatorID, userID), attribute.Bool(aibtrace.Streaming, tc.streaming), - attribute.Bool(aibtrace.IsBedrock, false), + attribute.Bool(aibtrace.IsBedrock, tc.bedrock), } - require.Len(t, sr.Ended(), totalCount) verifyTraces(t, sr, tc.expect, attrs) }) } } +func TestAnthropicInjectedToolsTrace(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + streaming bool + bedrock bool + }{ + { + name: "anthr_blocking", + streaming: false, + bedrock: false, + }, + { + name: "anthr_streaming", + streaming: true, + bedrock: false, + }, + { + name: "bedrock_blocking", + streaming: false, + bedrock: true, + }, + // TODO check why it fails + // { + // name: "bedrock_streaming", + // streaming: true, + // bedrock: true, + // }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + var bedrockCfg *aibridge.AWSBedrockConfig + if tc.bedrock { + bedrockCfg = testBedrockCfg(addr) + } + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), bedrockCfg)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, tracer, logger) + } + + var reqBody string + var reqPath string + reqFunc := func(t *testing.T, baseURL string, input []byte) *http.Request { + reqBody = string(input) + r := createAnthropicMessagesReq(t, baseURL, input) + reqPath = r.URL.Path + return r + } + + // Build the requirements & make the assertions which are common to all providers. + recorderClient, proxies, resp := setupInjectedToolTest(t, antSingleInjectedTool, tc.streaming, configureFn, reqFunc) + + defer resp.Body.Close() + + require.Len(t, recorderClient.interceptions, 1) + intcID := recorderClient.interceptions[0].ID + + model := gjson.Get(string(reqBody), "model").Str + if tc.bedrock { + model = "beddel" + } + + for _, proxy := range proxies { + require.NotEmpty(t, proxy.ListTools()) + tool := proxy.ListTools()[0] + + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.RequestPath, reqPath), + attribute.String(aibtrace.InterceptionID, intcID), + attribute.String(aibtrace.Provider, aibridge.ProviderAnthropic), + attribute.String(aibtrace.Model, model), + attribute.String(aibtrace.InitiatorID, userID), + attribute.String(aibtrace.MCPInput, "{\"owner\":\"admin\"}"), + attribute.String(aibtrace.MCPToolName, "coder_list_workspaces"), + attribute.String(aibtrace.MCPServerName, tool.ServerName), + attribute.String(aibtrace.MCPServerURL, tool.ServerURL), + attribute.Bool(aibtrace.Streaming, tc.streaming), + attribute.Bool(aibtrace.IsBedrock, tc.bedrock), + } + verifyTraces(t, sr, []expectTrace{{"Intercept.ProcessRequest.ToolCall", 1, codes.Unset}}, attrs) + } + }) + } +} + func TestTraceOpenAI(t *testing.T) { cases := []struct { name string @@ -360,7 +483,7 @@ func TestTraceOpenAI(t *testing.T) { attribute.String(aibtrace.InterceptionID, intcID), attribute.String(aibtrace.Provider, aibridge.ProviderOpenAI), attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), - attribute.String(aibtrace.UserID, userID), + attribute.String(aibtrace.InitiatorID, userID), attribute.Bool(aibtrace.Streaming, tc.streaming), } verifyTraces(t, sr, tc.expect, attrs) @@ -462,7 +585,7 @@ func TestTraceOpenAIErr(t *testing.T) { attribute.String(aibtrace.InterceptionID, intcID), attribute.String(aibtrace.Provider, aibridge.ProviderOpenAI), attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), - attribute.String(aibtrace.UserID, userID), + attribute.String(aibtrace.InitiatorID, userID), attribute.Bool(aibtrace.Streaming, tc.streaming), } verifyTraces(t, sr, tc.expect, attrs) @@ -470,6 +593,63 @@ func TestTraceOpenAIErr(t *testing.T) { } } +func TestOpenAIInjectedToolsTrace(t *testing.T) { + t.Parallel() + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, tracer, logger) + } + + var reqBody string + var reqPath string + reqFunc := func(t *testing.T, baseURL string, input []byte) *http.Request { + reqBody = string(input) + r := createOpenAIChatCompletionsReq(t, baseURL, input) + reqPath = r.URL.Path + return r + } + + // Build the requirements & make the assertions which are common to all providers. + recorderClient, proxies, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, reqFunc) + + defer resp.Body.Close() + + require.Len(t, recorderClient.interceptions, 1) + intcID := recorderClient.interceptions[0].ID + + for _, proxy := range proxies { + require.NotEmpty(t, proxy.ListTools()) + tool := proxy.ListTools()[0] + + attrs := []attribute.KeyValue{ + attribute.String(aibtrace.RequestPath, reqPath), + attribute.String(aibtrace.InterceptionID, intcID), + attribute.String(aibtrace.Provider, aibridge.ProviderOpenAI), + attribute.String(aibtrace.Model, gjson.Get(reqBody, "model").Str), + attribute.String(aibtrace.InitiatorID, userID), + attribute.String(aibtrace.MCPInput, "\"{\\\"owner\\\":\\\"admin\\\"}\""), + attribute.String(aibtrace.MCPToolName, "coder_list_workspaces"), + attribute.String(aibtrace.MCPServerName, tool.ServerName), + attribute.String(aibtrace.MCPServerURL, tool.ServerURL), + attribute.Bool(aibtrace.Streaming, streaming), + } + verifyTraces(t, sr, []expectTrace{{"Intercept.ProcessRequest.ToolCall", 1, codes.Unset}}, attrs) + } + }) + } +} + func TestTracePassthrough(t *testing.T) { t.Parallel() @@ -573,3 +753,14 @@ func verifyTraces(t *testing.T, spanRecorder *tracetest.SpanRecorder, expect []e } } } + +func testBedrockCfg(url string) *aibridge.AWSBedrockConfig { + return &aibridge.AWSBedrockConfig{ + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "beddel", // This model should override the request's given one. + SmallFastModel: "modrock", // Unused but needed for validation. + EndpointOverride: url, + } +} From 1ac08c13fe8f728e6c82c330d0857d77e2c3aee9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Wed, 3 Dec 2025 11:40:39 +0000 Subject: [PATCH 10/16] rename aibridge -> tracing --- intercept_anthropic_messages_base.go | 16 +-- intercept_anthropic_messages_blocking.go | 10 +- intercept_anthropic_messages_streaming.go | 8 +- intercept_openai_chat_base.go | 14 +-- intercept_openai_chat_blocking.go | 10 +- intercept_openai_chat_streaming.go | 8 +- interception.go | 4 +- mcp/proxy_streamable_http.go | 14 +-- mcp/server_proxy_manager.go | 4 +- mcp/tool.go | 14 +-- passthrough.go | 6 +- provider_anthropic.go | 4 +- provider_openai.go | 4 +- recorder.go | 22 ++--- trace_integration_test.go | 108 ++++++++++----------- aibtrace/aibtrace.go => tracing/tracing.go | 2 +- 16 files changed, 124 insertions(+), 124 deletions(-) rename aibtrace/aibtrace.go => tracing/tracing.go (99%) diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 66d00b3..9fdb4c7 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -14,8 +14,8 @@ import ( "github.com/anthropics/anthropic-sdk-go/option" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" - aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -65,13 +65,13 @@ func (i *AnthropicMessagesInterceptionBase) Model() string { func (s *AnthropicMessagesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { return []attribute.KeyValue{ - attribute.String(aibtrace.RequestPath, r.URL.Path), - attribute.String(aibtrace.InterceptionID, s.id.String()), - attribute.String(aibtrace.InitiatorID, actorFromContext(r.Context()).id), - attribute.String(aibtrace.Provider, ProviderAnthropic), - attribute.String(aibtrace.Model, s.Model()), - attribute.Bool(aibtrace.Streaming, streaming), - attribute.Bool(aibtrace.IsBedrock, s.bedrockCfg != nil), + attribute.String(tracing.RequestPath, r.URL.Path), + attribute.String(tracing.InterceptionID, s.id.String()), + attribute.String(tracing.InitiatorID, actorFromContext(r.Context()).id), + attribute.String(tracing.Provider, ProviderAnthropic), + attribute.String(tracing.Model, s.Model()), + attribute.Bool(tracing.Streaming, streaming), + attribute.Bool(tracing.IsBedrock, s.bedrockCfg != nil), } } diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index c0207b5..d3d28fd 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -14,8 +14,8 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "cdr.dev/slog" ) @@ -53,8 +53,8 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr return fmt.Errorf("developer error: req is nil") } - ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(r.Context())...)) - defer aibtrace.EndSpanErr(span, &outErr) + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) + defer tracing.EndSpanErr(span, &outErr) i.injectTools() @@ -298,8 +298,8 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr } func (i *AnthropicMessagesBlockingInterception) traceNewMessage(ctx context.Context, svc anthropic.MessageService, msgParams anthropic.MessageNewParams) (_ *anthropic.Message, outErr error) { - ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) - defer aibtrace.EndSpanErr(span, &outErr) + ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) return svc.New(ctx, msgParams) } diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index c8e8769..76620fe 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -13,8 +13,8 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/anthropics/anthropic-sdk-go/shared/constant" - aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" mcplib "github.com/mark3labs/mcp-go/mcp" "go.opentelemetry.io/otel/attribute" @@ -75,8 +75,8 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW return fmt.Errorf("developer error: req is nil") } - ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(r.Context())...)) - defer aibtrace.EndSpanErr(span, &outErr) + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) + defer tracing.EndSpanErr(span, &outErr) // Allow us to interrupt watch via cancel. ctx, cancel := context.WithCancel(ctx) @@ -523,7 +523,7 @@ func (s *AnthropicMessagesStreamingInterception) encodeForStream(payload []byte, } func (s *AnthropicMessagesStreamingInterception) traceNewStreaming(ctx context.Context, svc anthropic.MessageService, messages anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEventUnion] { - _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) + _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer span.End() return svc.NewStreaming(ctx, messages) diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index a9a3bf9..1fd7fd9 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -5,8 +5,8 @@ import ( "encoding/json" "net/http" - aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" @@ -48,12 +48,12 @@ func (i *OpenAIChatInterceptionBase) Setup(logger slog.Logger, recorder Recorder func (s *OpenAIChatInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { return []attribute.KeyValue{ - attribute.String(aibtrace.RequestPath, r.URL.Path), - attribute.String(aibtrace.InterceptionID, s.id.String()), - attribute.String(aibtrace.InitiatorID, actorFromContext(r.Context()).id), - attribute.String(aibtrace.Provider, ProviderOpenAI), - attribute.String(aibtrace.Model, s.Model()), - attribute.Bool(aibtrace.Streaming, streaming), + attribute.String(tracing.RequestPath, r.URL.Path), + attribute.String(tracing.InterceptionID, s.id.String()), + attribute.String(tracing.InitiatorID, actorFromContext(r.Context()).id), + attribute.String(tracing.Provider, ProviderOpenAI), + attribute.String(tracing.Model, s.Model()), + attribute.Bool(tracing.Streaming, streaming), } } diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 98975f6..fcb366f 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -8,8 +8,8 @@ import ( "strings" "time" - aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" @@ -52,8 +52,8 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r return fmt.Errorf("developer error: req is nil") } - ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(r.Context())...)) - defer aibtrace.EndSpanErr(span, &outErr) + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) + defer tracing.EndSpanErr(span, &outErr) svc := i.newCompletionsService(i.baseURL, i.key) logger := i.logger.With(slog.F("model", i.req.Model)) @@ -233,8 +233,8 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } func (i *OpenAIBlockingChatInterception) traceChatCompletionsNew(ctx context.Context, svc openai.ChatCompletionService, opts []option.RequestOption) (_ *openai.ChatCompletion, outErr error) { - ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) - defer aibtrace.EndSpanErr(span, &outErr) + ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) return svc.New(ctx, i.req.ChatCompletionNewParams, opts...) } diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 3b2132a..a957d4c 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -10,8 +10,8 @@ import ( "strings" "time" - aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/packages/ssestream" @@ -67,8 +67,8 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, return fmt.Errorf("developer error: req is nil") } - ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(r.Context())...)) - defer aibtrace.EndSpanErr(span, &outErr) + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) + defer tracing.EndSpanErr(span, &outErr) // Include token usage. i.req.StreamOptions.IncludeUsage = openai.Bool(true) @@ -347,7 +347,7 @@ func (i *OpenAIStreamingChatInterception) encodeForStream(payload []byte) []byte } func (i *OpenAIStreamingChatInterception) traceNewStreaming(ctx context.Context, svc openai.ChatCompletionService) *ssestream.Stream[openai.ChatCompletionChunk] { - _, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) + _, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer span.End() return svc.NewStreaming(ctx, i.req.ChatCompletionNewParams) diff --git a/interception.go b/interception.go index dc1ecfe..5f0948f 100644 --- a/interception.go +++ b/interception.go @@ -8,8 +8,8 @@ import ( "time" "cdr.dev/slog" - aibtrace "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" @@ -69,7 +69,7 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, traceAttrs := interceptor.TraceAttributes(r) span.SetAttributes(traceAttrs...) - ctx = aibtrace.WithInterceptionAttributesInContext(ctx, traceAttrs) + ctx = tracing.WithInterceptionAttributesInContext(ctx, traceAttrs) r = r.WithContext(ctx) // Record usage in the background to not block request flow. diff --git a/mcp/proxy_streamable_http.go b/mcp/proxy_streamable_http.go index 04b5f38..4a7eaca 100644 --- a/mcp/proxy_streamable_http.go +++ b/mcp/proxy_streamable_http.go @@ -8,7 +8,7 @@ import ( "strings" "cdr.dev/slog" - "github.com/coder/aibridge/aibtrace" + "github.com/coder/aibridge/tracing" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" @@ -60,7 +60,7 @@ func (p *StreamableHTTPServerProxy) Name() string { func (p *StreamableHTTPServerProxy) Init(ctx context.Context) (outErr error) { ctx, span := p.tracer.Start(ctx, "StreamableHTTPServerProxy.Init", trace.WithAttributes(p.traceAttributes()...)) - defer aibtrace.EndSpanErr(span, &outErr) + defer tracing.EndSpanErr(span, &outErr) if err := p.client.Start(ctx); err != nil { return fmt.Errorf("start client: %w", err) @@ -134,7 +134,7 @@ func (p *StreamableHTTPServerProxy) CallTool(ctx context.Context, name string, i func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (_ map[string]*Tool, outErr error) { ctx, span := p.tracer.Start(ctx, "StreamableHTTPServerProxy.Init.fetchTools", trace.WithAttributes(p.traceAttributes()...)) - defer aibtrace.EndSpanErr(span, &outErr) + defer tracing.EndSpanErr(span, &outErr) tools, err := p.client.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { @@ -156,7 +156,7 @@ func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (_ map[strin Logger: p.logger, } } - span.SetAttributes(append(p.traceAttributes(), attribute.Int(aibtrace.MCPToolCount, len(out)))...) + span.SetAttributes(append(p.traceAttributes(), attribute.Int(tracing.MCPToolCount, len(out)))...) return out, nil } @@ -172,8 +172,8 @@ func (p *StreamableHTTPServerProxy) Shutdown(ctx context.Context) error { func (p *StreamableHTTPServerProxy) traceAttributes() []attribute.KeyValue { return []attribute.KeyValue{ - attribute.String(aibtrace.MCPProxyName, p.Name()), - attribute.String(aibtrace.MCPServerName, p.serverName), - attribute.String(aibtrace.MCPServerURL, p.serverURL), + attribute.String(tracing.MCPProxyName, p.Name()), + attribute.String(tracing.MCPServerName, p.serverName), + attribute.String(tracing.MCPServerURL, p.serverURL), } } diff --git a/mcp/server_proxy_manager.go b/mcp/server_proxy_manager.go index 64d9eff..01c8790 100644 --- a/mcp/server_proxy_manager.go +++ b/mcp/server_proxy_manager.go @@ -7,7 +7,7 @@ import ( "strings" "sync" - "github.com/coder/aibridge/aibtrace" + "github.com/coder/aibridge/tracing" "github.com/coder/aibridge/utils" "github.com/mark3labs/mcp-go/mcp" "go.opentelemetry.io/otel/trace" @@ -50,7 +50,7 @@ func (s *ServerProxyManager) addTools(tools []*Tool) { // Init concurrently initializes all of its [ServerProxier]s. func (s *ServerProxyManager) Init(ctx context.Context) (outErr error) { ctx, span := s.tracer.Start(ctx, "ServerProxyManager.Init") - defer aibtrace.EndSpanErr(span, &outErr) + defer tracing.EndSpanErr(span, &outErr) cg := utils.NewConcurrentGroup() for _, proxy := range s.proxiers { diff --git a/mcp/tool.go b/mcp/tool.go index 8fea35b..8606adb 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -8,7 +8,7 @@ import ( "strings" "cdr.dev/slog" - "github.com/coder/aibridge/aibtrace" + "github.com/coder/aibridge/tracing" "github.com/mark3labs/mcp-go/mcp" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -49,10 +49,10 @@ func (t *Tool) Call(ctx context.Context, tracer trace.Tracer, input any) (_ *mcp } spanAttrs := append( - aibtrace.InterceptionAttributesFromContext(ctx), - attribute.String(aibtrace.MCPToolName, t.Name), - attribute.String(aibtrace.MCPServerName, t.ServerName), - attribute.String(aibtrace.MCPServerURL, t.ServerURL), + tracing.InterceptionAttributesFromContext(ctx), + attribute.String(tracing.MCPToolName, t.Name), + attribute.String(tracing.MCPServerName, t.ServerName), + attribute.String(tracing.MCPServerURL, t.ServerURL), ) inputJson, err := json.Marshal(input) if err != nil { @@ -62,11 +62,11 @@ func (t *Tool) Call(ctx context.Context, tracer trace.Tracer, input any) (_ *mcp if len(strJson) > maxSpanInputAttrLen { strJson = strJson[:100] } - spanAttrs = append(spanAttrs, attribute.String(aibtrace.MCPInput, strJson)) + spanAttrs = append(spanAttrs, attribute.String(tracing.MCPInput, strJson)) } ctx, span := tracer.Start(ctx, "Intercept.ProcessRequest.ToolCall", trace.WithAttributes(spanAttrs...)) - defer aibtrace.EndSpanErr(span, &outErr) + defer tracing.EndSpanErr(span, &outErr) return t.Client.CallTool(ctx, mcp.CallToolRequest{ Params: mcp.CallToolParams{ diff --git a/passthrough.go b/passthrough.go index 2c9f45b..3ce1631 100644 --- a/passthrough.go +++ b/passthrough.go @@ -8,7 +8,7 @@ import ( "time" "cdr.dev/slog" - "github.com/coder/aibridge/aibtrace" + "github.com/coder/aibridge/tracing" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" ) @@ -22,8 +22,8 @@ func newPassthroughRouter(provider Provider, logger slog.Logger, metrics *Metric } ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes( - attribute.String(aibtrace.PassthroughURL, r.URL.Path), - attribute.String(aibtrace.PassthroughMethod, r.Method), + attribute.String(tracing.PassthroughURL, r.URL.Path), + attribute.String(tracing.PassthroughMethod, r.Method), )) defer span.End() diff --git a/provider_anthropic.go b/provider_anthropic.go index cb7e675..d96e521 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -11,7 +11,7 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/shared" "github.com/anthropics/anthropic-sdk-go/shared/constant" - aibtrace "github.com/coder/aibridge/aibtrace" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" @@ -64,7 +64,7 @@ func (p *AnthropicProvider) PassthroughRoutes() []string { func (p *AnthropicProvider) CreateInterceptor(tracer trace.Tracer, w http.ResponseWriter, r *http.Request) (_ Interceptor, outErr error) { id := uuid.New() _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") - defer aibtrace.EndSpanErr(span, &outErr) + defer tracing.EndSpanErr(span, &outErr) payload, err := io.ReadAll(r.Body) if err != nil { diff --git a/provider_openai.go b/provider_openai.go index 0d71973..8bf05cf 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -7,7 +7,7 @@ import ( "net/http" "os" - aibtrace "github.com/coder/aibridge/aibtrace" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" @@ -65,7 +65,7 @@ func (p *OpenAIProvider) CreateInterceptor(tracer trace.Tracer, w http.ResponseW id := uuid.New() _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") - defer aibtrace.EndSpanErr(span, &outErr) + defer tracing.EndSpanErr(span, &outErr) payload, err := io.ReadAll(r.Body) if err != nil { diff --git a/recorder.go b/recorder.go index 26475a1..3cad7a2 100644 --- a/recorder.go +++ b/recorder.go @@ -8,7 +8,7 @@ import ( "cdr.dev/slog" - aibtrace "github.com/coder/aibridge/aibtrace" + "github.com/coder/aibridge/tracing" "go.opentelemetry.io/otel/trace" ) @@ -26,8 +26,8 @@ type RecorderWrapper struct { } func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *InterceptionRecord) (outErr error) { - ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterception", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) - defer aibtrace.EndSpanErr(span, &outErr) + ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterception", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) client, err := r.clientFn() if err != nil { @@ -44,8 +44,8 @@ func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *Intercept } func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) (outErr error) { - ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterceptionEnded", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) - defer aibtrace.EndSpanErr(span, &outErr) + ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterceptionEnded", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) client, err := r.clientFn() if err != nil { @@ -62,8 +62,8 @@ func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *Inte } func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) (outErr error) { - ctx, span := r.tracer.Start(ctx, "Intercept.RecordPromptUsage", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) - defer aibtrace.EndSpanErr(span, &outErr) + ctx, span := r.tracer.Start(ctx, "Intercept.RecordPromptUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) client, err := r.clientFn() if err != nil { @@ -80,8 +80,8 @@ func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsag } func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) (outErr error) { - ctx, span := r.tracer.Start(ctx, "Intercept.RecordTokenUsage", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) - defer aibtrace.EndSpanErr(span, &outErr) + ctx, span := r.tracer.Start(ctx, "Intercept.RecordTokenUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) client, err := r.clientFn() if err != nil { @@ -98,8 +98,8 @@ func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageR } func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) (outErr error) { - ctx, span := r.tracer.Start(ctx, "Intercept.RecordToolUsage", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...)) - defer aibtrace.EndSpanErr(span, &outErr) + ctx, span := r.tracer.Start(ctx, "Intercept.RecordToolUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) client, err := r.clientFn() if err != nil { diff --git a/trace_integration_test.go b/trace_integration_test.go index c6ec906..40e20f1 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -11,8 +11,8 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/aibridge" - "github.com/coder/aibridge/aibtrace" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/assert" @@ -145,13 +145,13 @@ func TestTraceAnthropic(t *testing.T) { } attrs := []attribute.KeyValue{ - attribute.String(aibtrace.RequestPath, req.URL.Path), - attribute.String(aibtrace.InterceptionID, intcID), - attribute.String(aibtrace.Provider, aibridge.ProviderAnthropic), - attribute.String(aibtrace.Model, model), - attribute.String(aibtrace.InitiatorID, userID), - attribute.Bool(aibtrace.Streaming, tc.streaming), - attribute.Bool(aibtrace.IsBedrock, tc.bedrock), + attribute.String(tracing.RequestPath, req.URL.Path), + attribute.String(tracing.InterceptionID, intcID), + attribute.String(tracing.Provider, aibridge.ProviderAnthropic), + attribute.String(tracing.Model, model), + attribute.String(tracing.InitiatorID, userID), + attribute.Bool(tracing.Streaming, tc.streaming), + attribute.Bool(tracing.IsBedrock, tc.bedrock), } require.Len(t, sr.Ended(), totalCount) @@ -286,13 +286,13 @@ func TestTraceAnthropicErr(t *testing.T) { } attrs := []attribute.KeyValue{ - attribute.String(aibtrace.RequestPath, req.URL.Path), - attribute.String(aibtrace.InterceptionID, intcID), - attribute.String(aibtrace.Provider, aibridge.ProviderAnthropic), - attribute.String(aibtrace.Model, model), - attribute.String(aibtrace.InitiatorID, userID), - attribute.Bool(aibtrace.Streaming, tc.streaming), - attribute.Bool(aibtrace.IsBedrock, tc.bedrock), + attribute.String(tracing.RequestPath, req.URL.Path), + attribute.String(tracing.InterceptionID, intcID), + attribute.String(tracing.Provider, aibridge.ProviderAnthropic), + attribute.String(tracing.Model, model), + attribute.String(tracing.InitiatorID, userID), + attribute.Bool(tracing.Streaming, tc.streaming), + attribute.Bool(tracing.IsBedrock, tc.bedrock), } verifyTraces(t, sr, tc.expect, attrs) @@ -377,17 +377,17 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) { tool := proxy.ListTools()[0] attrs := []attribute.KeyValue{ - attribute.String(aibtrace.RequestPath, reqPath), - attribute.String(aibtrace.InterceptionID, intcID), - attribute.String(aibtrace.Provider, aibridge.ProviderAnthropic), - attribute.String(aibtrace.Model, model), - attribute.String(aibtrace.InitiatorID, userID), - attribute.String(aibtrace.MCPInput, "{\"owner\":\"admin\"}"), - attribute.String(aibtrace.MCPToolName, "coder_list_workspaces"), - attribute.String(aibtrace.MCPServerName, tool.ServerName), - attribute.String(aibtrace.MCPServerURL, tool.ServerURL), - attribute.Bool(aibtrace.Streaming, tc.streaming), - attribute.Bool(aibtrace.IsBedrock, tc.bedrock), + attribute.String(tracing.RequestPath, reqPath), + attribute.String(tracing.InterceptionID, intcID), + attribute.String(tracing.Provider, aibridge.ProviderAnthropic), + attribute.String(tracing.Model, model), + attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.MCPInput, "{\"owner\":\"admin\"}"), + attribute.String(tracing.MCPToolName, "coder_list_workspaces"), + attribute.String(tracing.MCPServerName, tool.ServerName), + attribute.String(tracing.MCPServerURL, tool.ServerURL), + attribute.Bool(tracing.Streaming, tc.streaming), + attribute.Bool(tracing.IsBedrock, tc.bedrock), } verifyTraces(t, sr, []expectTrace{{"Intercept.ProcessRequest.ToolCall", 1, codes.Unset}}, attrs) } @@ -479,12 +479,12 @@ func TestTraceOpenAI(t *testing.T) { require.Len(t, sr.Ended(), totalCount) attrs := []attribute.KeyValue{ - attribute.String(aibtrace.RequestPath, req.URL.Path), - attribute.String(aibtrace.InterceptionID, intcID), - attribute.String(aibtrace.Provider, aibridge.ProviderOpenAI), - attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), - attribute.String(aibtrace.InitiatorID, userID), - attribute.Bool(aibtrace.Streaming, tc.streaming), + attribute.String(tracing.RequestPath, req.URL.Path), + attribute.String(tracing.InterceptionID, intcID), + attribute.String(tracing.Provider, aibridge.ProviderOpenAI), + attribute.String(tracing.Model, gjson.Get(string(reqBody), "model").Str), + attribute.String(tracing.InitiatorID, userID), + attribute.Bool(tracing.Streaming, tc.streaming), } verifyTraces(t, sr, tc.expect, attrs) }) @@ -581,12 +581,12 @@ func TestTraceOpenAIErr(t *testing.T) { require.Len(t, sr.Ended(), totalCount) attrs := []attribute.KeyValue{ - attribute.String(aibtrace.RequestPath, req.URL.Path), - attribute.String(aibtrace.InterceptionID, intcID), - attribute.String(aibtrace.Provider, aibridge.ProviderOpenAI), - attribute.String(aibtrace.Model, gjson.Get(string(reqBody), "model").Str), - attribute.String(aibtrace.InitiatorID, userID), - attribute.Bool(aibtrace.Streaming, tc.streaming), + attribute.String(tracing.RequestPath, req.URL.Path), + attribute.String(tracing.InterceptionID, intcID), + attribute.String(tracing.Provider, aibridge.ProviderOpenAI), + attribute.String(tracing.Model, gjson.Get(string(reqBody), "model").Str), + attribute.String(tracing.InitiatorID, userID), + attribute.Bool(tracing.Streaming, tc.streaming), } verifyTraces(t, sr, tc.expect, attrs) }) @@ -633,16 +633,16 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) { tool := proxy.ListTools()[0] attrs := []attribute.KeyValue{ - attribute.String(aibtrace.RequestPath, reqPath), - attribute.String(aibtrace.InterceptionID, intcID), - attribute.String(aibtrace.Provider, aibridge.ProviderOpenAI), - attribute.String(aibtrace.Model, gjson.Get(reqBody, "model").Str), - attribute.String(aibtrace.InitiatorID, userID), - attribute.String(aibtrace.MCPInput, "\"{\\\"owner\\\":\\\"admin\\\"}\""), - attribute.String(aibtrace.MCPToolName, "coder_list_workspaces"), - attribute.String(aibtrace.MCPServerName, tool.ServerName), - attribute.String(aibtrace.MCPServerURL, tool.ServerURL), - attribute.Bool(aibtrace.Streaming, streaming), + attribute.String(tracing.RequestPath, reqPath), + attribute.String(tracing.InterceptionID, intcID), + attribute.String(tracing.Provider, aibridge.ProviderOpenAI), + attribute.String(tracing.Model, gjson.Get(reqBody, "model").Str), + attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.MCPInput, "\"{\\\"owner\\\":\\\"admin\\\"}\""), + attribute.String(tracing.MCPToolName, "coder_list_workspaces"), + attribute.String(tracing.MCPServerName, tool.ServerName), + attribute.String(tracing.MCPServerURL, tool.ServerURL), + attribute.Bool(tracing.Streaming, streaming), } verifyTraces(t, sr, []expectTrace{{"Intercept.ProcessRequest.ToolCall", 1, codes.Unset}}, attrs) } @@ -685,8 +685,8 @@ func TestTracePassthrough(t *testing.T) { assert.Equal(t, spans[0].Name(), "Passthrough") attrs := []attribute.KeyValue{ - attribute.String(aibtrace.PassthroughURL, "/v1/models"), - attribute.String(aibtrace.PassthroughMethod, "GET"), + attribute.String(tracing.PassthroughURL, "/v1/models"), + attribute.String(tracing.PassthroughMethod, "GET"), } if attrDiff := cmp.Diff(spans[0].Attributes(), attrs, cmpopts.EquateComparable(attribute.KeyValue{}), cmpopts.SortSlices(cmpAttrKeyVal)); attrDiff != "" { t.Errorf("unexpectet attrs diff: %s", attrDiff) @@ -718,13 +718,13 @@ func TestNewServerProxyManagerTraces(t *testing.T) { verifyTraces(t, sr, []expectTrace{{"ServerProxyManager.Init", 1, codes.Unset}}, []attribute.KeyValue{}) attrs := []attribute.KeyValue{ - attribute.String(aibtrace.MCPProxyName, proxy.Name()), - attribute.String(aibtrace.MCPServerURL, mcpSrv.URL), - attribute.String(aibtrace.MCPServerName, serverName), + attribute.String(tracing.MCPProxyName, proxy.Name()), + attribute.String(tracing.MCPServerURL, mcpSrv.URL), + attribute.String(tracing.MCPServerName, serverName), } verifyTraces(t, sr, []expectTrace{{"StreamableHTTPServerProxy.Init", 1, codes.Unset}}, attrs) - attrs = append(attrs, attribute.Int(aibtrace.MCPToolCount, len(proxy.ListTools()))) + attrs = append(attrs, attribute.Int(tracing.MCPToolCount, len(proxy.ListTools()))) verifyTraces(t, sr, []expectTrace{{"StreamableHTTPServerProxy.Init.fetchTools", 1, codes.Unset}}, attrs) } diff --git a/aibtrace/aibtrace.go b/tracing/tracing.go similarity index 99% rename from aibtrace/aibtrace.go rename to tracing/tracing.go index 2909008..0074d44 100644 --- a/aibtrace/aibtrace.go +++ b/tracing/tracing.go @@ -1,4 +1,4 @@ -package aibtrace +package tracing import ( "context" From 454cd8bd70b45ab2707fb0acb9e3c50de5413864 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Wed, 3 Dec 2025 11:44:01 +0000 Subject: [PATCH 11/16] traceNewStreaming rename --- intercept_anthropic_messages_streaming.go | 5 +++-- intercept_openai_chat_streaming.go | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index 76620fe..4b51d23 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -135,7 +135,7 @@ newStream: break } - stream := i.traceNewStreaming(streamCtx, svc, messages) // traces svc.NewStreaming(streamCtx, messages) + stream := i.newStream(streamCtx, svc, messages) var message anthropic.Message var lastToolName string @@ -522,7 +522,8 @@ func (s *AnthropicMessagesStreamingInterception) encodeForStream(payload []byte, return buf.Bytes() } -func (s *AnthropicMessagesStreamingInterception) traceNewStreaming(ctx context.Context, svc anthropic.MessageService, messages anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEventUnion] { +// newStream traces svc.NewStreaming(streamCtx, messages) +func (s *AnthropicMessagesStreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService, messages anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEventUnion] { _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer span.End() diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index a957d4c..d8051bb 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -116,7 +116,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, ) for { // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) - stream = i.traceNewStreaming(streamCtx, svc) // traces svc.NewStreaming(streamCtx, i.req.ChatCompletionNewParams) call + stream = i.newStream(streamCtx, svc) processor := newStreamProcessor(streamCtx, i.logger.Named("stream-processor"), i.getInjectedToolByName) var toolCall *openai.FinishedChatCompletionToolCall @@ -346,7 +346,8 @@ func (i *OpenAIStreamingChatInterception) encodeForStream(payload []byte) []byte return buf.Bytes() } -func (i *OpenAIStreamingChatInterception) traceNewStreaming(ctx context.Context, svc openai.ChatCompletionService) *ssestream.Stream[openai.ChatCompletionChunk] { +// newStream traces svc.NewStreaming(streamCtx, i.req.ChatCompletionNewParams) call +func (i *OpenAIStreamingChatInterception) newStream(ctx context.Context, svc openai.ChatCompletionService) *ssestream.Stream[openai.ChatCompletionChunk] { _, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer span.End() From dc0ad44816b265c784079f7a4d0c77b5d43f7fcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Wed, 3 Dec 2025 20:38:26 +0000 Subject: [PATCH 12/16] fix openAI tool arguments consistency --- bridge_integration_test.go | 6 +++--- intercept_openai_chat_base.go | 13 +++++++++++++ intercept_openai_chat_blocking.go | 7 ++++--- intercept_openai_chat_streaming.go | 7 ++++--- mcp/tool.go | 8 ++++---- trace_integration_test.go | 2 +- 6 files changed, 29 insertions(+), 14 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 264d17f..5f957e8 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -442,9 +442,9 @@ func TestOpenAIChatCompletions(t *testing.T) { require.Len(t, recorderClient.toolUsages, 1) assert.Equal(t, "read_file", recorderClient.toolUsages[0].Tool) - require.IsType(t, "", recorderClient.toolUsages[0].Args) + require.IsType(t, map[string]any{}, recorderClient.toolUsages[0].Args) require.Contains(t, recorderClient.toolUsages[0].Args, "path") - assert.Equal(t, "README.md", gjson.Get(recorderClient.toolUsages[0].Args.(string), "path").Str) + assert.Equal(t, "README.md", recorderClient.toolUsages[0].Args.(map[string]any)["path"]) require.Len(t, recorderClient.userPrompts, 1) assert.Equal(t, "how large is the README.md file in my current path", recorderClient.userPrompts[0].Prompt) @@ -852,7 +852,7 @@ func TestOpenAIInjectedTools(t *testing.T) { // Ensure expected tool was invoked with expected input. require.Len(t, recorderClient.toolUsages, 1) require.Equal(t, mockToolName, recorderClient.toolUsages[0].Tool) - expected := "{\"owner\":\"admin\"}" + expected := map[string]any{"owner": "admin"} require.EqualValues(t, expected, recorderClient.toolUsages[0].Args) var ( diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 1fd7fd9..8fb6c73 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "net/http" + "strings" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/tracing" @@ -104,6 +105,18 @@ func (i *OpenAIChatInterceptionBase) injectTools() { } } +func (i *OpenAIChatInterceptionBase) unmarshalArgs(in string) (args ToolArgs) { + if len(strings.TrimSpace(in)) == 0 { + return args // An empty string will fail JSON unmarshaling. + } + + if err := json.Unmarshal([]byte(in), &args); err != nil { + i.logger.Warn(context.Background(), "failed to unmarshal tool args", slog.Error(err)) + } + + return args +} + // writeUpstreamError marshals and writes a given error. func (i *OpenAIChatInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *OpenAIErrorResponse) { if oaiErr == nil { diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index fcb366f..808b361 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -119,7 +119,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r InterceptionID: i.ID().String(), MsgID: completion.ID, Tool: toolCall.Function.Name, - Args: toolCall.Function.Arguments, + Args: i.unmarshalArgs(toolCall.Function.Arguments), Injected: false, }) } @@ -150,13 +150,14 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r appendedPrevMsg = true } - res, err := tool.Call(ctx, i.tracer, tc.Function.Arguments) + args := i.unmarshalArgs(tc.Function.Arguments) + res, err := tool.Call(ctx, i.tracer, args) _ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: completion.ID, ServerURL: &tool.ServerURL, Tool: tool.Name, - Args: tc.Function.Arguments, + Args: args, Injected: true, InvocationError: err, }) diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index d8051bb..f6391c9 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -154,7 +154,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, InterceptionID: i.ID().String(), MsgID: processor.getMsgID(), Tool: toolCall.Name, - Args: toolCall.Arguments, + Args: i.unmarshalArgs(toolCall.Arguments), Injected: false, }) toolCall = nil @@ -241,13 +241,14 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, i.req.Messages = append(i.req.Messages, processor.getLastCompletion().ToParam()) id := toolCall.ID - toolRes, toolErr := tool.Call(streamCtx, i.tracer, toolCall.Arguments) + args := i.unmarshalArgs(toolCall.Arguments) + toolRes, toolErr := tool.Call(streamCtx, i.tracer, args) _ = i.recorder.RecordToolUsage(streamCtx, &ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: processor.getMsgID(), ServerURL: &tool.ServerURL, Tool: tool.Name, - Args: toolCall.Arguments, + Args: args, Injected: true, InvocationError: toolErr, }) diff --git a/mcp/tool.go b/mcp/tool.go index 8606adb..b856cc4 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -54,6 +54,9 @@ func (t *Tool) Call(ctx context.Context, tracer trace.Tracer, input any) (_ *mcp attribute.String(tracing.MCPServerName, t.ServerName), attribute.String(tracing.MCPServerURL, t.ServerURL), ) + ctx, span := tracer.Start(ctx, "Intercept.ProcessRequest.ToolCall", trace.WithAttributes(spanAttrs...)) + defer tracing.EndSpanErr(span, &outErr) + inputJson, err := json.Marshal(input) if err != nil { t.Logger.Warn(ctx, "failed to marshal tool input, will be omitted from span attrs: %v", err) @@ -62,12 +65,9 @@ func (t *Tool) Call(ctx context.Context, tracer trace.Tracer, input any) (_ *mcp if len(strJson) > maxSpanInputAttrLen { strJson = strJson[:100] } - spanAttrs = append(spanAttrs, attribute.String(tracing.MCPInput, strJson)) + span.SetAttributes(attribute.String(tracing.MCPInput, strJson)) } - ctx, span := tracer.Start(ctx, "Intercept.ProcessRequest.ToolCall", trace.WithAttributes(spanAttrs...)) - defer tracing.EndSpanErr(span, &outErr) - return t.Client.CallTool(ctx, mcp.CallToolRequest{ Params: mcp.CallToolParams{ Name: t.Name, diff --git a/trace_integration_test.go b/trace_integration_test.go index 40e20f1..c23f3ef 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -638,7 +638,7 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) { attribute.String(tracing.Provider, aibridge.ProviderOpenAI), attribute.String(tracing.Model, gjson.Get(reqBody, "model").Str), attribute.String(tracing.InitiatorID, userID), - attribute.String(tracing.MCPInput, "\"{\\\"owner\\\":\\\"admin\\\"}\""), + attribute.String(tracing.MCPInput, "{\"owner\":\"admin\"}"), attribute.String(tracing.MCPToolName, "coder_list_workspaces"), attribute.String(tracing.MCPServerName, tool.ServerName), attribute.String(tracing.MCPServerURL, tool.ServerURL), From 3b3d54058c2d0884683bbecd83b5705e255959bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 4 Dec 2025 12:46:44 +0000 Subject: [PATCH 13/16] fix bedrock streaming tests --- bridge_integration_test.go | 2 +- trace_integration_test.go | 78 ++++++++++++++++---------------------- 2 files changed, 33 insertions(+), 47 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 5f957e8..e2b1cc2 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -1555,7 +1555,7 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, resp var reqMsg msg require.NoError(t, json.Unmarshal(body, &reqMsg)) - if !reqMsg.Stream { + if !reqMsg.Stream && !strings.HasSuffix(r.URL.Path, "invoke-with-response-stream") { resp := files[fixtureNonStreamingResponse] if responseMutatorFn != nil { resp = responseMutatorFn(ms.callCount.Load(), resp) diff --git a/trace_integration_test.go b/trace_integration_test.go index c23f3ef..7f678a8 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -45,6 +45,18 @@ func TestTraceAnthropic(t *testing.T) { {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, } + expectStreaming := []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 2, codes.Unset}, + {"Intercept.RecordToolUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + } + cases := []struct { name string streaming bool @@ -63,31 +75,13 @@ func TestTraceAnthropic(t *testing.T) { { name: "trace_anthr_streaming", streaming: true, - expect: []expectTrace{ - {"Intercept", 1, codes.Unset}, - {"Intercept.CreateInterceptor", 1, codes.Unset}, - {"Intercept.RecordInterception", 1, codes.Unset}, - {"Intercept.ProcessRequest", 1, codes.Unset}, - {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, - {"Intercept.RecordPromptUsage", 1, codes.Unset}, - {"Intercept.RecordTokenUsage", 2, codes.Unset}, - {"Intercept.RecordToolUsage", 1, codes.Unset}, - {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, - }, + expect: expectStreaming, }, { name: "trace_bedrock_streaming", streaming: true, bedrock: true, - expect: []expectTrace{ - {"Intercept", 1, codes.Unset}, - {"Intercept.CreateInterceptor", 1, codes.Unset}, - {"Intercept.RecordInterception", 1, codes.Unset}, - {"Intercept.ProcessRequest", 1, codes.Unset}, - {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, - {"Intercept.RecordPromptUsage", 1, codes.Unset}, - {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, - }, + expect: expectStreaming, }, } @@ -170,6 +164,17 @@ func TestTraceAnthropicErr(t *testing.T) { {"Intercept.ProcessRequest.Upstream", 1, codes.Error}, } + expectStreaming := []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + } + cases := []struct { name string streaming bool @@ -183,16 +188,7 @@ func TestTraceAnthropicErr(t *testing.T) { { name: "anthr_streaming_err", streaming: true, - expect: []expectTrace{ - {"Intercept", 1, codes.Error}, - {"Intercept.CreateInterceptor", 1, codes.Unset}, - {"Intercept.RecordInterception", 1, codes.Unset}, - {"Intercept.ProcessRequest", 1, codes.Error}, - {"Intercept.RecordPromptUsage", 1, codes.Unset}, - {"Intercept.RecordTokenUsage", 1, codes.Unset}, - {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, - {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, - }, + expect: expectStreaming, }, { name: "bedrock_non_streaming_err", @@ -203,16 +199,7 @@ func TestTraceAnthropicErr(t *testing.T) { name: "bedrock_streaming_err", streaming: true, bedrock: true, - expect: []expectTrace{ - // RecordTokenUsage missing? - {"Intercept", 1, codes.Unset}, // TODO check why this is unset not Error - {"Intercept.CreateInterceptor", 1, codes.Unset}, - {"Intercept.RecordInterception", 1, codes.Unset}, - {"Intercept.ProcessRequest", 1, codes.Unset}, // TODO check why this is unset not Error - {"Intercept.RecordPromptUsage", 1, codes.Unset}, - {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, - {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, - }, + expect: expectStreaming, }, } @@ -323,12 +310,11 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) { streaming: false, bedrock: true, }, - // TODO check why it fails - // { - // name: "bedrock_streaming", - // streaming: true, - // bedrock: true, - // }, + { + name: "bedrock_streaming", + streaming: true, + bedrock: true, + }, } for _, tc := range tests { From feded7a495c7bd4f66fb45621060400a679735db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 4 Dec 2025 13:36:12 +0000 Subject: [PATCH 14/16] make argument order consistent --- bridge.go | 4 +-- bridge_integration_test.go | 38 +++++++++++------------ intercept_anthropic_messages_blocking.go | 2 +- intercept_anthropic_messages_streaming.go | 2 +- intercept_openai_chat_base.go | 2 +- intercept_openai_chat_blocking.go | 2 +- intercept_openai_chat_streaming.go | 2 +- interception.go | 4 +-- mcp/mcp_test.go | 4 +-- mcp/proxy_streamable_http.go | 2 +- mcp/tool.go | 2 +- metrics_integration_test.go | 4 +-- provider.go | 2 +- provider_anthropic.go | 2 +- provider_openai.go | 2 +- trace_integration_test.go | 6 ++-- 16 files changed, 40 insertions(+), 40 deletions(-) diff --git a/bridge.go b/bridge.go index 0f723ea..9f2c424 100644 --- a/bridge.go +++ b/bridge.go @@ -48,13 +48,13 @@ var _ http.Handler = &RequestBridge{} // A [Recorder] is also required to record prompt, tool, and token use. // // mcpProxy will be closed when the [RequestBridge] is closed. -func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, tracer trace.Tracer, logger slog.Logger) (*RequestBridge, error) { +func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) (*RequestBridge, error) { mux := http.NewServeMux() for _, provider := range providers { // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics, tracer)) + mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer)) } // Any requests which passthrough to this will be reverse-proxied to the upstream. diff --git a/bridge_integration_test.go b/bridge_integration_test.go index e2b1cc2..40c44dd 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -137,7 +137,7 @@ func TestAnthropicMessages(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)} - b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) + b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -218,7 +218,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{ aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg), - }, recorderClient, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) + }, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -316,7 +316,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge( ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), bedrockCfg)}, - recorderClient, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) + recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) mockBridgeSrv := httptest.NewUnstartedServer(b) @@ -404,7 +404,7 @@ func TestOpenAIChatCompletions(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))} - b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) + b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -472,7 +472,7 @@ func TestSimple(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) + return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -511,7 +511,7 @@ func TestSimple(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -643,7 +643,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) return provider, bridge }, @@ -654,7 +654,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey)) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) return provider, bridge }, @@ -733,7 +733,7 @@ func setupMCPServerProxiesForTest(t *testing.T, tracer trace.Tracer) map[string] t.Cleanup(mcpSrv.Close) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - proxy, err := mcp.NewStreamableHTTPServerProxy(logger, tracer, "coder", mcpSrv.URL, nil, nil, nil) + proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer) require.NoError(t, err) // Initialize MCP client, fetch tools, and inject into bridge @@ -761,7 +761,7 @@ func TestAnthropicInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) } // Build the requirements & make the assertions which are common to all providers. @@ -843,7 +843,7 @@ func TestOpenAIInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) } // Build the requirements & make the assertions which are common to all providers. @@ -1027,7 +1027,7 @@ func TestErrorHandling(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1045,7 +1045,7 @@ func TestErrorHandling(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1134,7 +1134,7 @@ func TestErrorHandling(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1153,7 +1153,7 @@ func TestErrorHandling(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1241,7 +1241,7 @@ func TestStableRequestEncoding(t *testing.T) { createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, }, { @@ -1250,7 +1250,7 @@ func TestStableRequestEncoding(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, }, } @@ -1357,7 +1357,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) }, createRequest: createAnthropicMessagesReq, envVars: map[string]string{ @@ -1371,7 +1371,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) }, createRequest: createOpenAIChatCompletionsReq, envVars: map[string]string{ diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index d3d28fd..a37f724 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -177,7 +177,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr continue } - res, err := tool.Call(ctx, i.tracer, tc.Input) + res, err := tool.Call(ctx, tc.Input, i.tracer) _ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{ InterceptionID: i.ID().String(), diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index 4b51d23..9a90d3b 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -282,7 +282,7 @@ newStream: continue } - res, err := tool.Call(streamCtx, i.tracer, input) + res, err := tool.Call(streamCtx, input, i.tracer) _ = i.recorder.RecordToolUsage(streamCtx, &ToolUsageRecord{ InterceptionID: i.ID().String(), diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 8fb6c73..bb7c31e 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -24,8 +24,8 @@ type OpenAIChatInterceptionBase struct { baseURL string key string - tracer trace.Tracer logger slog.Logger + tracer trace.Tracer recorder Recorder mcpProxy mcp.ServerProxier diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 808b361..2666367 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -151,7 +151,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } args := i.unmarshalArgs(tc.Function.Arguments) - res, err := tool.Call(ctx, i.tracer, args) + res, err := tool.Call(ctx, args, i.tracer) _ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: completion.ID, diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index f6391c9..51cc624 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -242,7 +242,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, id := toolCall.ID args := i.unmarshalArgs(toolCall.Arguments) - toolRes, toolErr := tool.Call(streamCtx, i.tracer, args) + toolRes, toolErr := tool.Call(streamCtx, args, i.tracer) _ = i.recorder.RecordToolUsage(streamCtx, &ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: processor.getMsgID(), diff --git a/interception.go b/interception.go index 5f0948f..46ec7bd 100644 --- a/interception.go +++ b/interception.go @@ -40,12 +40,12 @@ const recordingTimeout = time.Second * 5 // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] recorder. -func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, tracer trace.Tracer) http.HandlerFunc { +func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, span := tracer.Start(r.Context(), "Intercept") defer span.End() - interceptor, err := p.CreateInterceptor(tracer, w, r.WithContext(ctx)) + interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer) if err != nil { span.SetStatus(codes.Error, fmt.Sprintf("failed to create interceptor: %v", err)) logger.Warn(ctx, "failed to create interceptor", slog.Error(err), slog.F("path", r.URL.Path)) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 0ffef1d..fa79456 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -309,9 +309,9 @@ func TestToolInjectionOrder(t *testing.T) { tracer := otel.Tracer("forTesting") // When: creating two MCP server proxies, both listing the same tools by name but under different server namespaces. - proxy, err := mcp.NewStreamableHTTPServerProxy(logger, tracer, "coder", mcpSrv.URL, nil, nil, nil) + proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer) require.NoError(t, err) - proxy2, err := mcp.NewStreamableHTTPServerProxy(logger, tracer, "shmoder", mcpSrv.URL, nil, nil, nil) + proxy2, err := mcp.NewStreamableHTTPServerProxy("shmoder", mcpSrv.URL, nil, nil, nil, logger, tracer) require.NoError(t, err) // Then: initialize both proxies. diff --git a/mcp/proxy_streamable_http.go b/mcp/proxy_streamable_http.go index 4a7eaca..c32da1d 100644 --- a/mcp/proxy_streamable_http.go +++ b/mcp/proxy_streamable_http.go @@ -32,7 +32,7 @@ type StreamableHTTPServerProxy struct { tools map[string]*Tool } -func NewStreamableHTTPServerProxy(logger slog.Logger, tracer trace.Tracer, serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp) (*StreamableHTTPServerProxy, error) { +func NewStreamableHTTPServerProxy(serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp, logger slog.Logger, tracer trace.Tracer) (*StreamableHTTPServerProxy, error) { var opts []transport.StreamableHTTPCOption if headers != nil { opts = append(opts, transport.WithHTTPHeaders(headers)) diff --git a/mcp/tool.go b/mcp/tool.go index b856cc4..d5314a5 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -40,7 +40,7 @@ type Tool struct { Logger slog.Logger } -func (t *Tool) Call(ctx context.Context, tracer trace.Tracer, input any) (_ *mcp.CallToolResult, outErr error) { +func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp.CallToolResult, outErr error) { if t == nil { return nil, errors.New("nil tool") } diff --git a/metrics_integration_test.go b/metrics_integration_test.go index cf3201e..6058cb2 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -241,7 +241,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { mcpMgr := mcp.NewServerProxyManager(tools, testTracer) require.NoError(t, mcpMgr.Init(ctx)) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, metrics, testTracer, logger) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, logger, metrics, testTracer) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) @@ -283,7 +283,7 @@ func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, m } wrappedRecorder := aibridge.NewRecorder(logger, tracer, clientFn) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil, testTracer), metrics, tracer, logger) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil, testTracer), logger, metrics, tracer) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) diff --git a/provider.go b/provider.go index 78bd771..20f8f52 100644 --- a/provider.go +++ b/provider.go @@ -16,7 +16,7 @@ type Provider interface { // CreateInterceptor starts a new [Interceptor] which is responsible for intercepting requests, // communicating with the upstream provider and formulating a response to be sent to the requesting client. - CreateInterceptor(tracer trace.Tracer, w http.ResponseWriter, r *http.Request) (Interceptor, error) + CreateInterceptor(http.ResponseWriter, *http.Request, trace.Tracer) (Interceptor, error) // BridgedRoutes returns a slice of [http.ServeMux]-compatible routes which will have special handling. // See https://pkg.go.dev/net/http#hdr-Patterns-ServeMux. diff --git a/provider_anthropic.go b/provider_anthropic.go index d96e521..fb5d10b 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -61,7 +61,7 @@ func (p *AnthropicProvider) PassthroughRoutes() []string { } } -func (p *AnthropicProvider) CreateInterceptor(tracer trace.Tracer, w http.ResponseWriter, r *http.Request) (_ Interceptor, outErr error) { +func (p *AnthropicProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ Interceptor, outErr error) { id := uuid.New() _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") defer tracing.EndSpanErr(span, &outErr) diff --git a/provider_openai.go b/provider_openai.go index 8bf05cf..68777e7 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -61,7 +61,7 @@ func (p *OpenAIProvider) PassthroughRoutes() []string { } } -func (p *OpenAIProvider) CreateInterceptor(tracer trace.Tracer, w http.ResponseWriter, r *http.Request) (_ Interceptor, outErr error) { +func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ Interceptor, outErr error) { id := uuid.New() _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") diff --git a/trace_integration_test.go b/trace_integration_test.go index 7f678a8..b5c3520 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -333,7 +333,7 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) { bedrockCfg = testBedrockCfg(addr) } providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), bedrockCfg)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, tracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, tracer) } var reqBody string @@ -594,7 +594,7 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, tracer, logger) + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, tracer) } var reqBody string @@ -692,7 +692,7 @@ func TestNewServerProxyManagerTraces(t *testing.T) { t.Cleanup(mcpSrv.Close) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - proxy, err := mcp.NewStreamableHTTPServerProxy(logger, tracer, serverName, mcpSrv.URL, nil, nil, nil) + proxy, err := mcp.NewStreamableHTTPServerProxy(serverName, mcpSrv.URL, nil, nil, nil, logger, tracer) require.NoError(t, err) tools := map[string]mcp.ServerProxier{"unusedValue": proxy} From 6f1bc8c5ad117c502f507a430c88e6e2354bf3ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 4 Dec 2025 16:03:46 +0000 Subject: [PATCH 15/16] review 2 --- bridge_integration_test.go | 8 ++++---- go.mod | 1 - intercept_anthropic_messages_blocking.go | 4 ++-- intercept_openai_chat_blocking.go | 4 ++-- mcp/tool.go | 4 ++-- trace_integration_test.go | 23 +++++++++++------------ tracing/tracing.go | 11 +++++++++++ 7 files changed, 32 insertions(+), 23 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index ee9c4f0..42fdd45 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -974,11 +974,11 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu recorderClient := &mockRecorderClient{} - // Setup MCP tools. - tools := setupMCPServerProxiesForTest(t, testTracer) + // Setup MCP proxies. + proxies := setupMCPServerProxiesForTest(t, testTracer) // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(tools, testTracer) + mcpMgr := mcp.NewServerProxyManager(proxies, testTracer) require.NoError(t, mcpMgr.Init(ctx)) b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr) require.NoError(t, err) @@ -1005,7 +1005,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu return mockSrv.callCount.Load() == 2 }, time.Second*10, time.Millisecond*50) - return recorderClient, tools, resp + return recorderClient, proxies, resp } func TestErrorHandling(t *testing.T) { diff --git a/go.mod b/go.mod index cfe3a4a..f49f41f 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,6 @@ require ( ) require ( - github.com/google/go-cmp v0.7.0 go.opentelemetry.io/otel v1.38.0 go.opentelemetry.io/otel/sdk v1.38.0 go.opentelemetry.io/otel/trace v1.38.0 diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index 0541df9..9de51ad 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -88,7 +88,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr for { // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) - resp, err = i.traceNewMessage(ctx, svc, messages) // traces svc.New(ctx, msgParams) call + resp, err = i.newMessage(ctx, svc, messages) if err != nil { if isConnError(err) { // Can't write a response, just error out. @@ -298,7 +298,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr return nil } -func (i *AnthropicMessagesBlockingInterception) traceNewMessage(ctx context.Context, svc anthropic.MessageService, msgParams anthropic.MessageNewParams) (_ *anthropic.Message, outErr error) { +func (i *AnthropicMessagesBlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService, msgParams anthropic.MessageNewParams) (_ *anthropic.Message, outErr error) { ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 2666367..3b4df34 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -76,7 +76,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r var opts []option.RequestOption opts = append(opts, option.WithRequestTimeout(time.Second*60)) // TODO: configurable timeout - completion, err = i.traceChatCompletionsNew(ctx, svc, opts) // traces svc.New(ctx, i.req.ChatCompletionNewParams, opts...) call + completion, err = i.newChatCompletion(ctx, svc, opts) if err != nil { break } @@ -233,7 +233,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r return nil } -func (i *OpenAIBlockingChatInterception) traceChatCompletionsNew(ctx context.Context, svc openai.ChatCompletionService, opts []option.RequestOption) (_ *openai.ChatCompletion, outErr error) { +func (i *OpenAIBlockingChatInterception) newChatCompletion(ctx context.Context, svc openai.ChatCompletionService, opts []option.RequestOption) (_ *openai.ChatCompletion, outErr error) { ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) diff --git a/mcp/tool.go b/mcp/tool.go index d5314a5..1bbb053 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -15,7 +15,7 @@ import ( ) const ( - maxSpanInputAttrLen = 100 + maxSpanInputAttrLen = 100 // truncates tool.Call span input attribute to first `maxSpanInputAttrLen` letters injectedToolPrefix = "bmcp" // "bridged MCP" injectedToolDelimiter = "_" ) @@ -63,7 +63,7 @@ func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp } else { strJson := string(inputJson) if len(strJson) > maxSpanInputAttrLen { - strJson = strJson[:100] + strJson = strJson[:maxSpanInputAttrLen] } span.SetAttributes(attribute.String(tracing.MCPInput, strJson)) } diff --git a/trace_integration_test.go b/trace_integration_test.go index b5c3520..9067ed8 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "net/http/httptest" + "slices" + "strings" "testing" "time" @@ -13,8 +15,6 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/tracing" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" @@ -670,13 +670,12 @@ func TestTracePassthrough(t *testing.T) { require.Len(t, spans, 1) assert.Equal(t, spans[0].Name(), "Passthrough") - attrs := []attribute.KeyValue{ - attribute.String(tracing.PassthroughURL, "/v1/models"), + want := []attribute.KeyValue{ attribute.String(tracing.PassthroughMethod, "GET"), + attribute.String(tracing.PassthroughURL, "/v1/models"), } - if attrDiff := cmp.Diff(spans[0].Attributes(), attrs, cmpopts.EquateComparable(attribute.KeyValue{}), cmpopts.SortSlices(cmpAttrKeyVal)); attrDiff != "" { - t.Errorf("unexpectet attrs diff: %s", attrDiff) - } + got := slices.SortedFunc(slices.Values(spans[0].Attributes()), cmpAttrKeyVal) + require.Equal(t, want, got) } func TestNewServerProxyManagerTraces(t *testing.T) { @@ -714,8 +713,8 @@ func TestNewServerProxyManagerTraces(t *testing.T) { verifyTraces(t, sr, []expectTrace{{"StreamableHTTPServerProxy.Init.fetchTools", 1, codes.Unset}}, attrs) } -func cmpAttrKeyVal(a attribute.KeyValue, b attribute.KeyValue) bool { - return a.Key < b.Key +func cmpAttrKeyVal(a attribute.KeyValue, b attribute.KeyValue) int { + return strings.Compare(string(a.Key), string(b.Key)) } // checks counts of traces with given name, status and attributes @@ -729,9 +728,9 @@ func verifyTraces(t *testing.T, spanRecorder *tracetest.SpanRecorder, expect []e continue } found++ - if attrDiff := cmp.Diff(s.Attributes(), attrs, cmpopts.EquateEmpty(), cmpopts.EquateComparable(attribute.KeyValue{}), cmpopts.SortSlices(cmpAttrKeyVal)); attrDiff != "" { - t.Errorf("unexpectet attrs for span named: %v, diff: %s", e.name, attrDiff) - } + want := slices.SortedFunc(slices.Values(attrs), cmpAttrKeyVal) + got := slices.SortedFunc(slices.Values(s.Attributes()), cmpAttrKeyVal) + require.Equal(t, want, got) assert.Equalf(t, e.status, s.Status().Code, "unexpected status for trace naned: %v got: %v want: %v", e.name, s.Status().Code, e.status) } if found != e.count { diff --git a/tracing/tracing.go b/tracing/tracing.go index 0074d44..aef819b 100644 --- a/tracing/tracing.go +++ b/tracing/tracing.go @@ -63,6 +63,17 @@ func RequestBridgeAttributesFromContext(ctx context.Context) []attribute.KeyValu return attrs } +// EndSpanErr ends given span and sets Error status if error is not nil +// uses pointer to error because defer evaluates function arguments +// when defer statement is executed not when deferred function is called +// +// example usage: +// +// func Example() (result any, outErr error) { +// _, span := tracer.Start(...) +// defer tracing.EndSpanErr(span, &outErr) +// +// } func EndSpanErr(span trace.Span, err *error) { if span == nil { return From ea859c0b5847e275d7654c56a93f1e4a923e8739 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 5 Dec 2025 10:30:12 +0000 Subject: [PATCH 16/16] go.mod comment --- go.mod | 1 + 1 file changed, 1 insertion(+) diff --git a/go.mod b/go.mod index f49f41f..9a62089 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/openai/openai-go/v2 v2.7.0 ) +// Tracing-related libs. require ( go.opentelemetry.io/otel v1.38.0 go.opentelemetry.io/otel/sdk v1.38.0