diff --git a/bridge.go b/bridge.go index 4f5428d..9f2c424 100644 --- a/bridge.go +++ b/bridge.go @@ -9,6 +9,7 @@ import ( "cdr.dev/slog" "github.com/coder/aibridge/mcp" + "go.opentelemetry.io/otel/trace" "github.com/hashicorp/go-multierror" ) @@ -47,20 +48,20 @@ var _ http.Handler = &RequestBridge{} // A [Recorder] is also required to record prompt, tool, and token use. // // mcpProxy will be closed when the [RequestBridge] is closed. -func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, logger slog.Logger) (*RequestBridge, error) { +func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) (*RequestBridge, error) { mux := http.NewServeMux() for _, provider := range providers { // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics)) + mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer)) } // Any requests which passthrough to this will be reverse-proxied to the upstream. // // We have to whitelist the known-safe routes because an API key with elevated privileges (i.e. admin) might be // configured, so we should just reverse-proxy known-safe routes. - ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics) + ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics, tracer) for _, path := range provider.PassthroughRoutes() { prefix := fmt.Sprintf("/%s", provider.Name()) route := fmt.Sprintf("%s%s", prefix, path) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 73f77a5..b4ea460 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -20,23 +20,23 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" - "go.uber.org/goleak" - "golang.org/x/tools/txtar" - "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/coder/aibridge" "github.com/coder/aibridge/mcp" "github.com/google/uuid" + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/openai/openai-go/v2" + oaissestream "github.com/openai/openai-go/v2/packages/ssestream" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" - - "github.com/openai/openai-go/v2" - oaissestream "github.com/openai/openai-go/v2/packages/ssestream" - - mcplib "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/tidwall/sjson" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" + "go.uber.org/goleak" + "golang.org/x/tools/txtar" ) var ( @@ -65,6 +65,8 @@ var ( oaiMidStreamErr []byte //go:embed fixtures/openai/non_stream_error.txtar oaiNonStreamErr []byte + + testTracer = otel.Tracer("forTesting") ) const ( @@ -90,8 +92,9 @@ func TestAnthropicMessages(t *testing.T) { t.Parallel() cases := []struct { - streaming bool - expectedInputTokens, expectedOutputTokens int + streaming bool + expectedInputTokens int + expectedOutputTokens int }{ { streaming: true, @@ -133,7 +136,8 @@ func TestAnthropicMessages(t *testing.T) { recorderClient := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)} + b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -214,7 +218,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{ aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg), - }, recorderClient, mcp.NewServerProxyManager(nil), nil, logger) + }, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -312,7 +316,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge( ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), bedrockCfg)}, - recorderClient, mcp.NewServerProxyManager(nil), nil, logger) + recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) mockBridgeSrv := httptest.NewUnstartedServer(b) @@ -399,7 +403,8 @@ func TestOpenAIChatCompletions(t *testing.T) { recorderClient := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))}, recorderClient, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))} + b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -466,7 +471,8 @@ func TestSimple(t *testing.T) { fixture: antSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger) + provider := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -504,7 +510,8 @@ func TestSimple(t *testing.T) { fixture: oaiSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -618,17 +625,8 @@ func TestSimple(t *testing.T) { } func setJSON(in []byte, key string, val bool) ([]byte, error) { - var body map[string]any - err := json.Unmarshal(in, &body) - if err != nil { - return nil, err - } - body[key] = val - out, err := json.Marshal(body) - if err != nil { - return nil, err - } - return out, nil + out, err := sjson.Set(string(in), key, val) + return []byte(out), err } func TestFallthrough(t *testing.T) { @@ -645,7 +643,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) return provider, bridge }, @@ -656,7 +654,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey)) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil), nil, logger) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) return provider, bridge }, @@ -727,7 +725,7 @@ func TestFallthrough(t *testing.T) { } // setupMCPServerProxiesForTest creates a mock MCP server, initializes the MCP bridge, and returns the tools -func setupMCPServerProxiesForTest(t *testing.T) (map[string]mcp.ServerProxier, *callAccumulator) { +func setupMCPServerProxiesForTest(t *testing.T, tracer trace.Tracer) (map[string]mcp.ServerProxier, *callAccumulator) { t.Helper() // Setup Coder MCP integration @@ -736,7 +734,7 @@ func setupMCPServerProxiesForTest(t *testing.T) (map[string]mcp.ServerProxier, * t.Cleanup(mcpSrv.Close) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - proxy, err := mcp.NewStreamableHTTPServerProxy(logger, "coder", mcpSrv.URL, nil, nil, nil) + proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer) require.NoError(t, err) // Initialize MCP client, fetch tools, and inject into bridge @@ -763,11 +761,12 @@ func TestAnthropicInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) } // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) + recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) // Ensure expected tool was invoked with expected input. require.Len(t, recorderClient.toolUsages, 1) @@ -849,11 +848,12 @@ func TestOpenAIInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) } // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) + recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) // Ensure expected tool was invoked with expected input. require.Len(t, recorderClient.toolUsages, 1) @@ -943,7 +943,7 @@ func TestOpenAIInjectedTools(t *testing.T) { // setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests. // Kinda fugly right now, we can refactor this later. -func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *callAccumulator, *http.Response) { +func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) { t.Helper() arc := txtar.Parse(fixture) @@ -989,10 +989,10 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu recorderClient := &mockRecorderClient{} // Setup MCP mcpProxiers. - mcpProxiers, acc := setupMCPServerProxiesForTest(t) + mcpProxiers, acc := setupMCPServerProxiesForTest(t, testTracer) // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(mcpProxiers) + mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) require.NoError(t, mcpMgr.Init(ctx)) b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr) require.NoError(t, err) @@ -1019,7 +1019,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu return mockSrv.callCount.Load() == 2 }, time.Second*10, time.Millisecond*50) - return recorderClient, acc, resp + return recorderClient, acc, mcpProxiers, resp } func TestErrorHandling(t *testing.T) { @@ -1040,7 +1040,8 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1057,7 +1058,8 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1106,7 +1108,7 @@ func TestErrorHandling(t *testing.T) { recorderClient := &mockRecorderClient{} - b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil)) + b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) require.NoError(t, err) // Invoke request to mocked API via aibridge. @@ -1145,7 +1147,8 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1163,7 +1166,8 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1206,7 +1210,7 @@ func TestErrorHandling(t *testing.T) { recorderClient := &mockRecorderClient{} - b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil)) + b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) require.NoError(t, err) // Invoke request to mocked API via aibridge. @@ -1221,6 +1225,7 @@ func TestErrorHandling(t *testing.T) { resp, err := http.DefaultClient.Do(req) t.Cleanup(func() { _ = resp.Body.Close() }) require.NoError(t, err) + bridgeSrv.Close() tc.responseHandlerFn(resp) recorderClient.verifyAllInterceptionsEnded(t) @@ -1249,7 +1254,8 @@ func TestStableRequestEncoding(t *testing.T) { fixture: antSimple, createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, }, { @@ -1257,7 +1263,8 @@ func TestStableRequestEncoding(t *testing.T) { fixture: oaiSimple, createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, srvProxyMgr, nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, }, } @@ -1270,10 +1277,10 @@ func TestStableRequestEncoding(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools. - mcpProxiers, _ := setupMCPServerProxiesForTest(t) + mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer) // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(mcpProxiers) + mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) require.NoError(t, mcpMgr.Init(ctx)) arc := txtar.Parse(tc.fixture) @@ -1363,7 +1370,8 @@ func TestEnvironmentDoNotLeak(t *testing.T) { fixture: antSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, client, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)} + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) }, createRequest: createAnthropicMessagesReq, envVars: map[string]string{ @@ -1376,7 +1384,8 @@ func TestEnvironmentDoNotLeak(t *testing.T) { fixture: oaiSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, client, mcp.NewServerProxyManager(nil), nil, logger) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) }, createRequest: createOpenAIChatCompletionsReq, envVars: map[string]string{ @@ -1560,7 +1569,7 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, resp var reqMsg msg require.NoError(t, json.Unmarshal(body, &reqMsg)) - if !reqMsg.Stream { + if !reqMsg.Stream && !strings.HasSuffix(r.URL.Path, "invoke-with-response-stream") { resp := files[fixtureNonStreamingResponse] if responseMutatorFn != nil { resp = responseMutatorFn(ms.callCount.Load(), resp) diff --git a/go.mod b/go.mod index 47fd45d..9a62089 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,13 @@ require ( github.com/openai/openai-go/v2 v2.7.0 ) +// Tracing-related libs. +require ( + go.opentelemetry.io/otel v1.38.0 + go.opentelemetry.io/otel/sdk v1.38.0 + go.opentelemetry.io/otel/trace v1.38.0 +) + require ( github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect @@ -46,6 +53,8 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charmbracelet/lipgloss v0.7.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect @@ -61,14 +70,13 @@ require ( github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect github.com/rivo/uniseg v0.4.4 // indirect - github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/spf13/cast v1.7.1 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect - go.opentelemetry.io/otel v1.33.0 // indirect - go.opentelemetry.io/otel/trace v1.33.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel/metric v1.38.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/term v0.34.0 // indirect diff --git a/go.sum b/go.sum index d0b79c8..385345d 100644 --- a/go.sum +++ b/go.sum @@ -53,8 +53,9 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= -github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= -github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -130,14 +131,16 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/otel v1.33.0 h1:/FerN9bax5LoK51X/sI0SVYrjSE0/yUL7DpxW4K3FWw= -go.opentelemetry.io/otel v1.33.0/go.mod h1:SUUkR6csvUQl+yjReHu5uM3EtVV7MBm5FHKRlNx4I8I= -go.opentelemetry.io/otel/metric v1.33.0 h1:r+JOocAyeRVXD8lZpjdQjzMadVZp2M4WmQ+5WtEnklQ= -go.opentelemetry.io/otel/metric v1.33.0/go.mod h1:L9+Fyctbp6HFTddIxClbQkjtubW6O9QS3Ann/M82u6M= -go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiMWgE= -go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4= -go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s= -go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= +go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= +go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= +go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= +go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 5049e54..9fdb4c7 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -15,7 +15,10 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) @@ -27,6 +30,7 @@ type AnthropicMessagesInterceptionBase struct { cfg AnthropicConfig bedrockCfg *AWSBedrockConfig + tracer trace.Tracer logger slog.Logger recorder Recorder @@ -59,6 +63,18 @@ func (i *AnthropicMessagesInterceptionBase) Model() string { return string(i.req.Model) } +func (s *AnthropicMessagesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { + return []attribute.KeyValue{ + attribute.String(tracing.RequestPath, r.URL.Path), + attribute.String(tracing.InterceptionID, s.id.String()), + attribute.String(tracing.InitiatorID, actorFromContext(r.Context()).id), + attribute.String(tracing.Provider, ProviderAnthropic), + attribute.String(tracing.Model, s.Model()), + attribute.Bool(tracing.Streaming, streaming), + attribute.Bool(tracing.IsBedrock, s.bedrockCfg != nil), + } +} + func (i *AnthropicMessagesInterceptionBase) injectTools() { if i.req == nil || i.mcpProxy == nil { return diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index 236bd89..9de51ad 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -1,6 +1,7 @@ package aibridge import ( + "context" "fmt" "net/http" "time" @@ -10,8 +11,11 @@ import ( "github.com/google/uuid" mcplib "github.com/mark3labs/mcp-go/mcp" // TODO: abstract this away so callers need no knowledge of underlying lib. "github.com/tidwall/sjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "cdr.dev/slog" ) @@ -22,29 +26,35 @@ type AnthropicMessagesBlockingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicMessagesBlockingInterception { +func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig, tracer trace.Tracer) *AnthropicMessagesBlockingInterception { return &AnthropicMessagesBlockingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ id: id, req: req, cfg: cfg, bedrockCfg: bedrockCfg, + tracer: tracer, }} } -func (s *AnthropicMessagesBlockingInterception) Setup(logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) { - s.AnthropicMessagesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy) +func (i *AnthropicMessagesBlockingInterception) Setup(logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) { + i.AnthropicMessagesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy) +} + +func (i *AnthropicMessagesBlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { + return i.AnthropicMessagesInterceptionBase.baseTraceAttributes(r, false) } func (s *AnthropicMessagesBlockingInterception) Streaming() bool { return false } -func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { +func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { return fmt.Errorf("developer error: req is nil") } - ctx := r.Context() + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) + defer tracing.EndSpanErr(span, &outErr) i.injectTools() @@ -77,7 +87,8 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr var cumulativeUsage anthropic.Usage for { - resp, err = svc.New(ctx, messages) + // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) + resp, err = i.newMessage(ctx, svc, messages) if err != nil { if isConnError(err) { // Can't write a response, just error out. @@ -166,7 +177,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr continue } - res, err := tool.Call(ctx, tc.Input) + res, err := tool.Call(ctx, tc.Input, i.tracer) _ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{ InterceptionID: i.ID().String(), @@ -286,3 +297,10 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr return nil } + +func (i *AnthropicMessagesBlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService, msgParams anthropic.MessageNewParams) (_ *anthropic.Message, outErr error) { + ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) + + return svc.New(ctx, msgParams) +} diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index 5f2393a..7263bce 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -11,11 +11,15 @@ import ( "time" "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" mcplib "github.com/mark3labs/mcp-go/mcp" "github.com/tidwall/sjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) @@ -26,12 +30,13 @@ type AnthropicMessagesStreamingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicMessagesStreamingInterception { +func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig, tracer trace.Tracer) *AnthropicMessagesStreamingInterception { return &AnthropicMessagesStreamingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ id: id, req: req, cfg: cfg, bedrockCfg: bedrockCfg, + tracer: tracer, }} } @@ -43,6 +48,10 @@ func (s *AnthropicMessagesStreamingInterception) Streaming() bool { return true } +func (s *AnthropicMessagesStreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { + return s.AnthropicMessagesInterceptionBase.baseTraceAttributes(r, true) +} + // ProcessRequest handles a request to /v1/messages. // This API has a state-machine behind it, which is described in https://docs.claude.com/en/docs/build-with-claude/streaming#event-types. // @@ -62,13 +71,16 @@ func (s *AnthropicMessagesStreamingInterception) Streaming() bool { // b) if the tool is injected, it will be invoked by the [mcp.ServerProxier] in the remote MCP server, and its // results relayed to the SERVER. The response from the server will be handled synchronously, and this loop // can continue until all injected tool invocations are completed and the response is relayed to the client. -func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { +func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { return fmt.Errorf("developer error: req is nil") } + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) + defer tracing.EndSpanErr(span, &outErr) + // Allow us to interrupt watch via cancel. - ctx, cancel := context.WithCancel(r.Context()) + ctx, cancel := context.WithCancel(ctx) defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. @@ -118,12 +130,13 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW isFirst := true newStream: for { + // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) if err := streamCtx.Err(); err != nil { lastErr = fmt.Errorf("stream exit: %w", err) break } - stream := svc.NewStreaming(streamCtx, messages) + stream := i.newStream(streamCtx, svc, messages) var message anthropic.Message var lastToolName string @@ -270,7 +283,7 @@ newStream: continue } - res, err := tool.Call(streamCtx, input) + res, err := tool.Call(streamCtx, input, i.tracer) _ = i.recorder.RecordToolUsage(streamCtx, &ToolUsageRecord{ InterceptionID: i.ID().String(), @@ -522,3 +535,11 @@ func (s *AnthropicMessagesStreamingInterception) encodeForStream(payload []byte, buf.WriteString("\n\n") return buf.Bytes() } + +// newStream traces svc.NewStreaming(streamCtx, messages) +func (s *AnthropicMessagesStreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService, messages anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEventUnion] { + _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer span.End() + + return svc.NewStreaming(ctx, messages) +} diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 20db323..bb7c31e 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -7,20 +7,25 @@ import ( "strings" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" "github.com/openai/openai-go/v2/shared" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) type OpenAIChatInterceptionBase struct { - id uuid.UUID - req *ChatCompletionNewParamsWrapper + id uuid.UUID + req *ChatCompletionNewParamsWrapper + baseURL string + key string - baseURL, key string - logger slog.Logger + logger slog.Logger + tracer trace.Tracer recorder Recorder mcpProxy mcp.ServerProxier @@ -42,6 +47,17 @@ func (i *OpenAIChatInterceptionBase) Setup(logger slog.Logger, recorder Recorder i.mcpProxy = mcpProxy } +func (s *OpenAIChatInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { + return []attribute.KeyValue{ + attribute.String(tracing.RequestPath, r.URL.Path), + attribute.String(tracing.InterceptionID, s.id.String()), + attribute.String(tracing.InitiatorID, actorFromContext(r.Context()).id), + attribute.String(tracing.Provider, ProviderOpenAI), + attribute.String(tracing.Model, s.Model()), + attribute.Bool(tracing.Streaming, streaming), + } +} + func (i *OpenAIChatInterceptionBase) Model() string { if i.req == nil { return "coder-aibridge-unknown" diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index d1fa255..3b4df34 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -1,6 +1,7 @@ package aibridge import ( + "context" "encoding/json" "fmt" "net/http" @@ -8,9 +9,12 @@ import ( "time" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) @@ -21,12 +25,13 @@ type OpenAIBlockingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIBlockingChatInterception { +func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string, tracer trace.Tracer) *OpenAIBlockingChatInterception { return &OpenAIBlockingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ id: id, req: req, baseURL: baseURL, key: key, + tracer: tracer, }} } @@ -38,12 +43,18 @@ func (s *OpenAIBlockingChatInterception) Streaming() bool { return false } -func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { +func (s *OpenAIBlockingChatInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { + return s.OpenAIChatInterceptionBase.baseTraceAttributes(r, false) +} + +func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { return fmt.Errorf("developer error: req is nil") } - ctx := r.Context() + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) + defer tracing.EndSpanErr(span, &outErr) + svc := i.newCompletionsService(i.baseURL, i.key) logger := i.logger.With(slog.F("model", i.req.Model)) @@ -61,10 +72,11 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } for { + // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) var opts []option.RequestOption opts = append(opts, option.WithRequestTimeout(time.Second*60)) // TODO: configurable timeout - completion, err = svc.New(ctx, i.req.ChatCompletionNewParams, opts...) + completion, err = i.newChatCompletion(ctx, svc, opts) if err != nil { break } @@ -139,8 +151,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } args := i.unmarshalArgs(tc.Function.Arguments) - res, err := tool.Call(ctx, args) - + res, err := tool.Call(ctx, args, i.tracer) _ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: completion.ID, @@ -221,3 +232,10 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r return nil } + +func (i *OpenAIBlockingChatInterception) newChatCompletion(ctx context.Context, svc openai.ChatCompletionService, opts []option.RequestOption) (_ *openai.ChatCompletion, outErr error) { + ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) + + return svc.New(ctx, i.req.ChatCompletionNewParams, opts...) +} diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index ccabb35..51cc624 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -11,10 +11,13 @@ import ( "time" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/packages/ssestream" "github.com/tidwall/sjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "cdr.dev/slog" ) @@ -25,12 +28,13 @@ type OpenAIStreamingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIStreamingChatInterception { +func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string, tracer trace.Tracer) *OpenAIStreamingChatInterception { return &OpenAIStreamingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ id: id, req: req, baseURL: baseURL, key: key, + tracer: tracer, }} } @@ -42,6 +46,10 @@ func (i *OpenAIStreamingChatInterception) Streaming() bool { return true } +func (s *OpenAIStreamingChatInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { + return s.OpenAIChatInterceptionBase.baseTraceAttributes(r, true) +} + // ProcessRequest handles a request to /v1/chat/completions. // See https://platform.openai.com/docs/api-reference/chat-streaming/streaming. // @@ -54,18 +62,21 @@ func (i *OpenAIStreamingChatInterception) Streaming() bool { // b) if the tool is injected, it will be invoked by the [mcp.ServerProxier] in the remote MCP server, and its // results relayed to the SERVER. The response from the server will be handled synchronously, and this loop // can continue until all injected tool invocations are completed and the response is relayed to the client. -func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error { +func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { return fmt.Errorf("developer error: req is nil") } + ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) + defer tracing.EndSpanErr(span, &outErr) + // Include token usage. i.req.StreamOptions.IncludeUsage = openai.Bool(true) i.injectTools() // Allow us to interrupt watch via cancel. - ctx, cancel := context.WithCancel(r.Context()) + ctx, cancel := context.WithCancel(ctx) defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. @@ -104,7 +115,8 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, interceptionErr error ) for { - stream = svc.NewStreaming(streamCtx, i.req.ChatCompletionNewParams) + // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) + stream = i.newStream(streamCtx, svc) processor := newStreamProcessor(streamCtx, i.logger.Named("stream-processor"), i.getInjectedToolByName) var toolCall *openai.FinishedChatCompletionToolCall @@ -230,8 +242,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, id := toolCall.ID args := i.unmarshalArgs(toolCall.Arguments) - toolRes, toolErr := tool.Call(streamCtx, args) - + toolRes, toolErr := tool.Call(streamCtx, args, i.tracer) _ = i.recorder.RecordToolUsage(streamCtx, &ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: processor.getMsgID(), @@ -336,6 +347,14 @@ func (i *OpenAIStreamingChatInterception) encodeForStream(payload []byte) []byte return buf.Bytes() } +// newStream traces svc.NewStreaming(streamCtx, i.req.ChatCompletionNewParams) call +func (i *OpenAIStreamingChatInterception) newStream(ctx context.Context, svc openai.ChatCompletionService) *ssestream.Stream[openai.ChatCompletionChunk] { + _, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer span.End() + + return svc.NewStreaming(ctx, i.req.ChatCompletionNewParams) +} + type openAIStreamProcessor struct { ctx context.Context logger slog.Logger diff --git a/interception.go b/interception.go index 8210c41..46ec7bd 100644 --- a/interception.go +++ b/interception.go @@ -9,7 +9,11 @@ import ( "cdr.dev/slog" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) // Interceptor describes a (potentially) stateful interaction with an AI provider. @@ -25,6 +29,8 @@ type Interceptor interface { ProcessRequest(w http.ResponseWriter, r *http.Request) error // Specifies whether an interceptor handles streaming or not. Streaming() bool + // TraceAttributes returns tracing attributes for this [Interceptor] + TraceAttributes(*http.Request) []attribute.KeyValue } var UnknownRoute = errors.New("unknown route") @@ -34,11 +40,15 @@ const recordingTimeout = time.Second * 5 // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] recorder. -func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics) http.HandlerFunc { +func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - interceptor, err := p.CreateInterceptor(w, r) + ctx, span := tracer.Start(r.Context(), "Intercept") + defer span.End() + + interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer) if err != nil { - logger.Warn(r.Context(), "failed to create interceptor", slog.Error(err), slog.F("path", r.URL.Path)) + span.SetStatus(codes.Error, fmt.Sprintf("failed to create interceptor: %v", err)) + logger.Warn(ctx, "failed to create interceptor", slog.Error(err), slog.F("path", r.URL.Path)) http.Error(w, fmt.Sprintf("failed to create %q interceptor", r.URL.Path), http.StatusInternalServerError) return } @@ -50,13 +60,18 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, }() } - actor := actorFromContext(r.Context()) + actor := actorFromContext(ctx) if actor == nil { - logger.Warn(r.Context(), "no actor found in context") + logger.Warn(ctx, "no actor found in context") http.Error(w, "no actor found", http.StatusBadRequest) return } + traceAttrs := interceptor.TraceAttributes(r) + span.SetAttributes(traceAttrs...) + ctx = tracing.WithInterceptionAttributesInContext(ctx, traceAttrs) + r = r.WithContext(ctx) + // Record usage in the background to not block request flow. asyncRecorder := NewAsyncRecorder(logger, recorder, recordingTimeout) asyncRecorder.WithMetrics(metrics) @@ -65,14 +80,15 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, asyncRecorder.WithInitiatorID(actor.id) interceptor.Setup(logger, asyncRecorder, mcpProxy) - if err := recorder.RecordInterception(r.Context(), &InterceptionRecord{ + if err := recorder.RecordInterception(ctx, &InterceptionRecord{ ID: interceptor.ID().String(), Metadata: actor.metadata, InitiatorID: actor.id, Provider: p.Name(), Model: interceptor.Model(), }); err != nil { - logger.Warn(r.Context(), "failed to record interception", slog.Error(err)) + span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err)) + logger.Warn(ctx, "failed to record interception", slog.Error(err)) http.Error(w, "failed to record interception", http.StatusInternalServerError) return } @@ -86,27 +102,27 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, slog.F("streaming", interceptor.Streaming()), ) - log.Debug(r.Context(), "interception started") + log.Debug(ctx, "interception started") if metrics != nil { metrics.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Add(1) + defer func() { + metrics.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Sub(1) + }() } if err := interceptor.ProcessRequest(w, r); err != nil { if metrics != nil { metrics.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), InterceptionCountStatusFailed, route, r.Method, actor.id).Add(1) } - log.Warn(r.Context(), "interception failed", slog.Error(err)) + span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) + log.Warn(ctx, "interception failed", slog.Error(err)) } else { if metrics != nil { metrics.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), InterceptionCountStatusCompleted, route, r.Method, actor.id).Add(1) } - log.Debug(r.Context(), "interception ended") - } - asyncRecorder.RecordInterceptionEnded(r.Context(), &InterceptionRecordEnded{ID: interceptor.ID().String()}) - - if metrics != nil { - metrics.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Sub(1) + log.Debug(ctx, "interception ended") } + asyncRecorder.RecordInterceptionEnded(ctx, &InterceptionRecordEnded{ID: interceptor.ID().String()}) // Ensure all recording have completed before completing request. asyncRecorder.Wait() diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 4d6e2d3..fa79456 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -13,6 +13,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" + "go.opentelemetry.io/otel" "go.uber.org/goleak" "github.com/coder/aibridge/mcp" @@ -306,10 +307,11 @@ func TestToolInjectionOrder(t *testing.T) { mcpSrv := httptest.NewServer(createMockMCPSrv(t)) t.Cleanup(mcpSrv.Close) + tracer := otel.Tracer("forTesting") // When: creating two MCP server proxies, both listing the same tools by name but under different server namespaces. - proxy, err := mcp.NewStreamableHTTPServerProxy(logger, "coder", mcpSrv.URL, nil, nil, nil) + proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer) require.NoError(t, err) - proxy2, err := mcp.NewStreamableHTTPServerProxy(logger, "shmoder", mcpSrv.URL, nil, nil, nil) + proxy2, err := mcp.NewStreamableHTTPServerProxy("shmoder", mcpSrv.URL, nil, nil, nil, logger, tracer) require.NoError(t, err) // Then: initialize both proxies. @@ -324,7 +326,7 @@ func TestToolInjectionOrder(t *testing.T) { mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{ "coder": proxy, "shmoder": proxy2, - }) + }, otel.GetTracerProvider().Tracer("test")) require.NoError(t, mgr.Init(ctx)) // Then: the tools from both servers should be collectively sorted stably. diff --git a/mcp/proxy_streamable_http.go b/mcp/proxy_streamable_http.go index 5a5c092..c32da1d 100644 --- a/mcp/proxy_streamable_http.go +++ b/mcp/proxy_streamable_http.go @@ -8,25 +8,31 @@ import ( "strings" "cdr.dev/slog" + "github.com/coder/aibridge/tracing" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "golang.org/x/exp/maps" ) var _ ServerProxier = &StreamableHTTPServerProxy{} type StreamableHTTPServerProxy struct { + client *client.Client + logger slog.Logger + tracer trace.Tracer + + allowlistPattern *regexp.Regexp + denylistPattern *regexp.Regexp + serverName string serverURL string - client *client.Client - logger slog.Logger tools map[string]*Tool - - allowlistPattern, denylistPattern *regexp.Regexp } -func NewStreamableHTTPServerProxy(logger slog.Logger, serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp) (*StreamableHTTPServerProxy, error) { +func NewStreamableHTTPServerProxy(serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp, logger slog.Logger, tracer trace.Tracer) (*StreamableHTTPServerProxy, error) { var opts []transport.StreamableHTTPCOption if headers != nil { opts = append(opts, transport.WithHTTPHeaders(headers)) @@ -42,6 +48,7 @@ func NewStreamableHTTPServerProxy(logger slog.Logger, serverName, serverURL stri serverURL: serverURL, client: mcpClient, logger: logger, + tracer: tracer, allowlistPattern: allowlist, denylistPattern: denylist, }, nil @@ -51,7 +58,10 @@ func (p *StreamableHTTPServerProxy) Name() string { return p.serverName } -func (p *StreamableHTTPServerProxy) Init(ctx context.Context) error { +func (p *StreamableHTTPServerProxy) Init(ctx context.Context) (outErr error) { + ctx, span := p.tracer.Start(ctx, "StreamableHTTPServerProxy.Init", trace.WithAttributes(p.traceAttributes()...)) + defer tracing.EndSpanErr(span, &outErr) + if err := p.client.Start(ctx); err != nil { return fmt.Errorf("start client: %w", err) } @@ -122,7 +132,10 @@ func (p *StreamableHTTPServerProxy) CallTool(ctx context.Context, name string, i }) } -func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (map[string]*Tool, error) { +func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (_ map[string]*Tool, outErr error) { + ctx, span := p.tracer.Start(ctx, "StreamableHTTPServerProxy.Init.fetchTools", trace.WithAttributes(p.traceAttributes()...)) + defer tracing.EndSpanErr(span, &outErr) + tools, err := p.client.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { return nil, fmt.Errorf("list MCP tools: %w", err) @@ -140,8 +153,10 @@ func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (map[string] Description: tool.Description, Params: tool.InputSchema.Properties, Required: tool.InputSchema.Required, + Logger: p.logger, } } + span.SetAttributes(append(p.traceAttributes(), attribute.Int(tracing.MCPToolCount, len(out)))...) return out, nil } @@ -154,3 +169,11 @@ func (p *StreamableHTTPServerProxy) Shutdown(ctx context.Context) error { // it has an internal timeout of 5s, though. return p.client.Close() } + +func (p *StreamableHTTPServerProxy) traceAttributes() []attribute.KeyValue { + return []attribute.KeyValue{ + attribute.String(tracing.MCPProxyName, p.Name()), + attribute.String(tracing.MCPServerName, p.serverName), + attribute.String(tracing.MCPServerURL, p.serverURL), + } +} diff --git a/mcp/server_proxy_manager.go b/mcp/server_proxy_manager.go index 732f1a0..01c8790 100644 --- a/mcp/server_proxy_manager.go +++ b/mcp/server_proxy_manager.go @@ -7,8 +7,10 @@ import ( "strings" "sync" + "github.com/coder/aibridge/tracing" "github.com/coder/aibridge/utils" "github.com/mark3labs/mcp-go/mcp" + "go.opentelemetry.io/otel/trace" ) var _ ServerProxier = &ServerProxyManager{} @@ -18,14 +20,18 @@ var _ ServerProxier = &ServerProxyManager{} // for the purpose of injection into bridged requests and invocation. type ServerProxyManager struct { proxiers map[string]ServerProxier + tracer trace.Tracer // Protects access to the tools map. toolsMu sync.RWMutex tools map[string]*Tool } -func NewServerProxyManager(proxiers map[string]ServerProxier) *ServerProxyManager { - return &ServerProxyManager{proxiers: proxiers} +func NewServerProxyManager(proxiers map[string]ServerProxier, tracer trace.Tracer) *ServerProxyManager { + return &ServerProxyManager{ + proxiers: proxiers, + tracer: tracer, + } } func (s *ServerProxyManager) addTools(tools []*Tool) { @@ -42,7 +48,10 @@ func (s *ServerProxyManager) addTools(tools []*Tool) { } // Init concurrently initializes all of its [ServerProxier]s. -func (s *ServerProxyManager) Init(ctx context.Context) error { +func (s *ServerProxyManager) Init(ctx context.Context) (outErr error) { + ctx, span := s.tracer.Start(ctx, "ServerProxyManager.Init") + defer tracing.EndSpanErr(span, &outErr) + cg := utils.NewConcurrentGroup() for _, proxy := range s.proxiers { cg.Go(func() error { diff --git a/mcp/tool.go b/mcp/tool.go index 2c01535..1bbb053 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -2,15 +2,20 @@ package mcp import ( "context" + "encoding/json" "errors" "regexp" "strings" "cdr.dev/slog" + "github.com/coder/aibridge/tracing" "github.com/mark3labs/mcp-go/mcp" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) const ( + maxSpanInputAttrLen = 100 // truncates tool.Call span input attribute to first `maxSpanInputAttrLen` letters injectedToolPrefix = "bmcp" // "bridged MCP" injectedToolDelimiter = "_" ) @@ -32,14 +37,35 @@ type Tool struct { Description string Params map[string]any Required []string + Logger slog.Logger } -func (t *Tool) Call(ctx context.Context, input any) (*mcp.CallToolResult, error) { +func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp.CallToolResult, outErr error) { if t == nil { - return nil, errors.New("nil tool!") + return nil, errors.New("nil tool") } if t.Client == nil { - return nil, errors.New("nil client!") + return nil, errors.New("nil client") + } + + spanAttrs := append( + tracing.InterceptionAttributesFromContext(ctx), + attribute.String(tracing.MCPToolName, t.Name), + attribute.String(tracing.MCPServerName, t.ServerName), + attribute.String(tracing.MCPServerURL, t.ServerURL), + ) + ctx, span := tracer.Start(ctx, "Intercept.ProcessRequest.ToolCall", trace.WithAttributes(spanAttrs...)) + defer tracing.EndSpanErr(span, &outErr) + + inputJson, err := json.Marshal(input) + if err != nil { + t.Logger.Warn(ctx, "failed to marshal tool input, will be omitted from span attrs: %v", err) + } else { + strJson := string(inputJson) + if len(strJson) > maxSpanInputAttrLen { + strJson = strJson[:maxSpanInputAttrLen] + } + span.SetAttributes(attribute.String(tracing.MCPInput, strJson)) } return t.Client.CallTool(ctx, mcp.CallToolRequest{ diff --git a/metrics_integration_test.go b/metrics_integration_test.go index 8e23ead..f326dec 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -16,6 +16,7 @@ import ( "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace" "golang.org/x/tools/txtar" ) @@ -48,7 +49,7 @@ func TestMetrics_Interception(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) - srv := newTestSrv(t, ctx, provider, metrics) + srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) req := createAnthropicMessagesReq(t, srv.URL, files[fixtureRequest]) resp, err := http.DefaultClient.Do(req) @@ -89,7 +90,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil) - bridgeSrv := newTestSrv(t, ctx, provider, metrics) + bridgeSrv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) // Make request in background. doneCh := make(chan struct{}) @@ -141,7 +142,7 @@ func TestMetrics_PassthroughCount(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey)) - srv := newTestSrv(t, t.Context(), provider, metrics) + srv, _ := newTestSrv(t, t.Context(), provider, metrics, testTracer) req, err := http.NewRequestWithContext(t.Context(), "GET", srv.URL+"/openai/v1/models", nil) require.NoError(t, err) @@ -170,7 +171,7 @@ func TestMetrics_PromptCount(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) - srv := newTestSrv(t, ctx, provider, metrics) + srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) resp, err := http.DefaultClient.Do(req) @@ -198,7 +199,7 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) - srv := newTestSrv(t, ctx, provider, metrics) + srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) resp, err := http.DefaultClient.Do(req) @@ -236,11 +237,11 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) // Setup mocked MCP server & tools. - mcpProxiers, _ := setupMCPServerProxiesForTest(t) - mcpMgr := mcp.NewServerProxyManager(mcpProxiers) + mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer) + mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) require.NoError(t, mcpMgr.Init(ctx)) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, metrics, logger) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, logger, metrics, testTracer) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) @@ -272,13 +273,17 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { require.Equal(t, 1.0, count) } -func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, metrics *aibridge.Metrics) *httptest.Server { +func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, metrics *aibridge.Metrics, tracer trace.Tracer) (*httptest.Server, *mockRecorderClient) { t.Helper() - recorder := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + mockRecorder := &mockRecorderClient{} + clientFn := func() (aibridge.Recorder, error) { + return mockRecorder, nil + } + wrappedRecorder := aibridge.NewRecorder(logger, tracer, clientFn) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcp.NewServerProxyManager(nil), metrics, logger) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil, testTracer), logger, metrics, tracer) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) @@ -288,5 +293,5 @@ func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, m srv.Start() t.Cleanup(srv.Close) - return srv + return srv, mockRecorder } diff --git a/passthrough.go b/passthrough.go index 6788672..3ce1631 100644 --- a/passthrough.go +++ b/passthrough.go @@ -8,19 +8,28 @@ import ( "time" "cdr.dev/slog" + "github.com/coder/aibridge/tracing" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) // newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically // by a [Provider]. -func newPassthroughRouter(provider Provider, logger slog.Logger, metrics *Metrics) http.HandlerFunc { +func newPassthroughRouter(provider Provider, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if metrics != nil { metrics.PassthroughCount.WithLabelValues(provider.Name(), r.URL.Path, r.Method).Add(1) } + ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes( + attribute.String(tracing.PassthroughURL, r.URL.Path), + attribute.String(tracing.PassthroughMethod, r.Method), + )) + defer span.End() + upURL, err := url.Parse(provider.BaseURL()) if err != nil { - logger.Warn(r.Context(), "failed to parse provider base URL", slog.Error(err)) + logger.Warn(ctx, "failed to parse provider base URL", slog.Error(err)) http.Error(w, "request error", http.StatusBadGateway) return } diff --git a/provider.go b/provider.go index 3787af6..20f8f52 100644 --- a/provider.go +++ b/provider.go @@ -2,6 +2,8 @@ package aibridge import ( "net/http" + + "go.opentelemetry.io/otel/trace" ) // Provider describes an AI provider client's behaviour. @@ -14,7 +16,7 @@ type Provider interface { // CreateInterceptor starts a new [Interceptor] which is responsible for intercepting requests, // communicating with the upstream provider and formulating a response to be sent to the requesting client. - CreateInterceptor(w http.ResponseWriter, r *http.Request) (Interceptor, error) + CreateInterceptor(http.ResponseWriter, *http.Request, trace.Tracer) (Interceptor, error) // BridgedRoutes returns a slice of [http.ServeMux]-compatible routes which will have special handling. // See https://pkg.go.dev/net/http#hdr-Patterns-ServeMux. diff --git a/provider_anthropic.go b/provider_anthropic.go index 7e9c99f..fb5d10b 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -11,7 +11,10 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/shared" "github.com/anthropics/anthropic-sdk-go/shared/constant" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) var _ Provider = &AnthropicProvider{} @@ -58,14 +61,16 @@ func (p *AnthropicProvider) PassthroughRoutes() []string { } } -func (p *AnthropicProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request) (Interceptor, error) { +func (p *AnthropicProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ Interceptor, outErr error) { + id := uuid.New() + _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") + defer tracing.EndSpanErr(span, &outErr) + payload, err := io.ReadAll(r.Body) if err != nil { return nil, fmt.Errorf("read body: %w", err) } - id := uuid.New() - switch r.URL.Path { case routeMessages: var req MessageNewParamsWrapper @@ -73,13 +78,17 @@ func (p *AnthropicProvider) CreateInterceptor(w http.ResponseWriter, r *http.Req return nil, fmt.Errorf("failed to unmarshal request: %w", err) } + var interceptor Interceptor if req.Stream { - return NewAnthropicMessagesStreamingInterception(id, &req, p.cfg, p.bedrockCfg), nil + interceptor = NewAnthropicMessagesStreamingInterception(id, &req, p.cfg, p.bedrockCfg, tracer) + } else { + interceptor = NewAnthropicMessagesBlockingInterception(id, &req, p.cfg, p.bedrockCfg, tracer) } - - return NewAnthropicMessagesBlockingInterception(id, &req, p.cfg, p.bedrockCfg), nil + span.SetAttributes(interceptor.TraceAttributes(r)...) + return interceptor, nil } + span.SetStatus(codes.Error, "unknown route: "+r.URL.Path) return nil, UnknownRoute } diff --git a/provider_openai.go b/provider_openai.go index 0fc31a6..68777e7 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -7,7 +7,10 @@ import ( "net/http" "os" + "github.com/coder/aibridge/tracing" "github.com/google/uuid" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) var _ Provider = &OpenAIProvider{} @@ -58,14 +61,17 @@ func (p *OpenAIProvider) PassthroughRoutes() []string { } } -func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request) (Interceptor, error) { +func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ Interceptor, outErr error) { + id := uuid.New() + + _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") + defer tracing.EndSpanErr(span, &outErr) + payload, err := io.ReadAll(r.Body) if err != nil { return nil, fmt.Errorf("read body: %w", err) } - id := uuid.New() - switch r.URL.Path { case routeChatCompletions: var req ChatCompletionNewParamsWrapper @@ -73,13 +79,17 @@ func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Reques return nil, fmt.Errorf("unmarshal request body: %w", err) } + var interceptor Interceptor if req.Stream { - return NewOpenAIStreamingChatInterception(id, &req, p.baseURL, p.key), nil + interceptor = NewOpenAIStreamingChatInterception(id, &req, p.baseURL, p.key, tracer) } else { - return NewOpenAIBlockingChatInterception(id, &req, p.baseURL, p.key), nil + interceptor = NewOpenAIBlockingChatInterception(id, &req, p.baseURL, p.key, tracer) } + span.SetAttributes(interceptor.TraceAttributes(r)...) + return interceptor, nil } + span.SetStatus(codes.Error, "unknown route: "+r.URL.Path) return nil, UnknownRoute } diff --git a/recorder.go b/recorder.go index edd68f7..3cad7a2 100644 --- a/recorder.go +++ b/recorder.go @@ -7,18 +7,28 @@ import ( "time" "cdr.dev/slog" + + "github.com/coder/aibridge/tracing" + "go.opentelemetry.io/otel/trace" ) -var _ Recorder = &RecorderWrapper{} +var ( + _ Recorder = &RecorderWrapper{} + _ Recorder = &AsyncRecorder{} +) // RecorderWrapper is a convenience struct which implements RecorderClient and resolves a client before calling each method. // It also sets the start/creation time of each record. type RecorderWrapper struct { logger slog.Logger + tracer trace.Tracer clientFn func() (Recorder, error) } -func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *InterceptionRecord) error { +func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *InterceptionRecord) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterception", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -33,7 +43,10 @@ func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *Intercept return err } -func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error { +func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterceptionEnded", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -48,7 +61,10 @@ func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *Inte return err } -func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) error { +func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordPromptUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -63,7 +79,10 @@ func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsag return err } -func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) error { +func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordTokenUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -78,7 +97,10 @@ func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageR return err } -func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) error { +func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) (outErr error) { + ctx, span := r.tracer.Start(ctx, "Intercept.RecordToolUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) + client, err := r.clientFn() if err != nil { return fmt.Errorf("acquire client: %w", err) @@ -93,12 +115,14 @@ func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRec return err } -func NewRecorder(logger slog.Logger, clientFn func() (Recorder, error)) *RecorderWrapper { - return &RecorderWrapper{logger: logger, clientFn: clientFn} +func NewRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) *RecorderWrapper { + return &RecorderWrapper{ + logger: logger, + tracer: tracer, + clientFn: clientFn, + } } -var _ Recorder = &AsyncRecorder{} - // AsyncRecorder calls [Recorder] methods asynchronously and logs any errors which may occur. type AsyncRecorder struct { logger slog.Logger @@ -141,7 +165,7 @@ func (a *AsyncRecorder) RecordInterceptionEnded(ctx context.Context, req *Interc a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), a.timeout) defer cancel() err := a.wrapped.RecordInterceptionEnded(timedCtx, req) @@ -153,11 +177,11 @@ func (a *AsyncRecorder) RecordInterceptionEnded(ctx context.Context, req *Interc return nil // Caller is not interested in error. } -func (a *AsyncRecorder) RecordPromptUsage(_ context.Context, req *PromptUsageRecord) error { +func (a *AsyncRecorder) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) error { a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), a.timeout) defer cancel() err := a.wrapped.RecordPromptUsage(timedCtx, req) @@ -173,11 +197,11 @@ func (a *AsyncRecorder) RecordPromptUsage(_ context.Context, req *PromptUsageRec return nil // Caller is not interested in error. } -func (a *AsyncRecorder) RecordTokenUsage(_ context.Context, req *TokenUsageRecord) error { +func (a *AsyncRecorder) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) error { a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), a.timeout) defer cancel() err := a.wrapped.RecordTokenUsage(timedCtx, req) @@ -197,11 +221,11 @@ func (a *AsyncRecorder) RecordTokenUsage(_ context.Context, req *TokenUsageRecor return nil // Caller is not interested in error. } -func (a *AsyncRecorder) RecordToolUsage(_ context.Context, req *ToolUsageRecord) error { +func (a *AsyncRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) error { a.wg.Add(1) go func() { defer a.wg.Done() - timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + timedCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), a.timeout) defer cancel() err := a.wrapped.RecordToolUsage(timedCtx, req) diff --git a/trace_integration_test.go b/trace_integration_test.go new file mode 100644 index 0000000..ee6574d --- /dev/null +++ b/trace_integration_test.go @@ -0,0 +1,752 @@ +package aibridge_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "slices" + "strings" + "testing" + "time" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/aibridge" + "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/tracing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "golang.org/x/tools/txtar" +) + +// expect 'count' amount of traces named 'name' with status 'status' +type expectTrace struct { + name string + count int + status codes.Code +} + +func TestTraceAnthropic(t *testing.T) { + expectNonStreaming := []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, + {"Intercept.RecordToolUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + } + + expectStreaming := []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 2, codes.Unset}, + {"Intercept.RecordToolUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + } + + cases := []struct { + name string + streaming bool + bedrock bool + expect []expectTrace + }{ + { + name: "trace_anthr_non_streaming", + expect: expectNonStreaming, + }, + { + name: "trace_bedrock_non_streaming", + bedrock: true, + expect: expectNonStreaming, + }, + { + name: "trace_anthr_streaming", + streaming: true, + expect: expectStreaming, + }, + { + name: "trace_bedrock_streaming", + streaming: true, + bedrock: true, + expect: expectStreaming, + }, + } + + arc := txtar.Parse(antSingleBuiltinTool) + + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureStreamingResponse) + require.Contains(t, files, fixtureNonStreamingResponse) + + fixtureReqBody := files[fixtureRequest] + + for _, tc := range cases { + t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) + require.NoError(t, err) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + + var bedrockCfg *aibridge.AWSBedrockConfig + if tc.bedrock { + bedrockCfg = testBedrockCfg(mockAPI.URL) + } + provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), bedrockCfg) + srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + + req := createAnthropicMessagesReq(t, srv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + srv.Close() + + require.Equal(t, 1, len(recorder.interceptions)) + intcID := recorder.interceptions[0].ID + + model := gjson.Get(string(reqBody), "model").Str + if tc.bedrock { + model = "beddel" + } + + totalCount := 0 + for _, e := range tc.expect { + totalCount += e.count + } + + attrs := []attribute.KeyValue{ + attribute.String(tracing.RequestPath, req.URL.Path), + attribute.String(tracing.InterceptionID, intcID), + attribute.String(tracing.Provider, aibridge.ProviderAnthropic), + attribute.String(tracing.Model, model), + attribute.String(tracing.InitiatorID, userID), + attribute.Bool(tracing.Streaming, tc.streaming), + attribute.Bool(tracing.IsBedrock, tc.bedrock), + } + + require.Len(t, sr.Ended(), totalCount) + verifyTraces(t, sr, tc.expect, attrs) + }) + } +} + +func TestTraceAnthropicErr(t *testing.T) { + expectNonStream := []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Error}, + } + + expectStreaming := []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + } + + cases := []struct { + name string + streaming bool + bedrock bool + expect []expectTrace + }{ + { + name: "anthr_non_streaming_err", + expect: expectNonStream, + }, + { + name: "anthr_streaming_err", + streaming: true, + expect: expectStreaming, + }, + { + name: "bedrock_non_streaming_err", + bedrock: true, + expect: expectNonStream, + }, + { + name: "bedrock_streaming_err", + streaming: true, + bedrock: true, + expect: expectStreaming, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + var arc *txtar.Archive + if tc.streaming { + arc = txtar.Parse(antMidStreamErr) + } else { + arc = txtar.Parse(antNonStreamErr) + } + + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + if tc.streaming { + require.Contains(t, files, fixtureStreamingResponse) + } else { + require.Contains(t, files, fixtureNonStreamingResponse) + } + + fixtureReqBody := files[fixtureRequest] + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) + require.NoError(t, err) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + + var bedrockCfg *aibridge.AWSBedrockConfig + if tc.bedrock { + bedrockCfg = testBedrockCfg(mockAPI.URL) + } + provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), bedrockCfg) + srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + + req := createAnthropicMessagesReq(t, srv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + if tc.streaming { + require.Equal(t, http.StatusOK, resp.StatusCode) + } else { + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + } + defer resp.Body.Close() + srv.Close() + + require.Equal(t, 1, len(recorder.interceptions)) + intcID := recorder.interceptions[0].ID + + totalCount := 0 + for _, e := range tc.expect { + totalCount += e.count + } + for _, s := range sr.Ended() { + t.Logf("SPAN: %v", s.Name()) + } + require.Len(t, sr.Ended(), totalCount) + + model := gjson.Get(string(reqBody), "model").Str + if tc.bedrock { + model = "beddel" + } + + attrs := []attribute.KeyValue{ + attribute.String(tracing.RequestPath, req.URL.Path), + attribute.String(tracing.InterceptionID, intcID), + attribute.String(tracing.Provider, aibridge.ProviderAnthropic), + attribute.String(tracing.Model, model), + attribute.String(tracing.InitiatorID, userID), + attribute.Bool(tracing.Streaming, tc.streaming), + attribute.Bool(tracing.IsBedrock, tc.bedrock), + } + + verifyTraces(t, sr, tc.expect, attrs) + }) + } +} + +func TestAnthropicInjectedToolsTrace(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + streaming bool + bedrock bool + }{ + { + name: "anthr_blocking", + streaming: false, + bedrock: false, + }, + { + name: "anthr_streaming", + streaming: true, + bedrock: false, + }, + { + name: "bedrock_blocking", + streaming: false, + bedrock: true, + }, + { + name: "bedrock_streaming", + streaming: true, + bedrock: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + var bedrockCfg *aibridge.AWSBedrockConfig + if tc.bedrock { + bedrockCfg = testBedrockCfg(addr) + } + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), bedrockCfg)} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, tracer) + } + + var reqBody string + var reqPath string + reqFunc := func(t *testing.T, baseURL string, input []byte) *http.Request { + reqBody = string(input) + r := createAnthropicMessagesReq(t, baseURL, input) + reqPath = r.URL.Path + return r + } + + // Build the requirements & make the assertions which are common to all providers. + recorderClient, _, proxies, resp := setupInjectedToolTest(t, antSingleInjectedTool, tc.streaming, configureFn, reqFunc) + + defer resp.Body.Close() + + require.Len(t, recorderClient.interceptions, 1) + intcID := recorderClient.interceptions[0].ID + + model := gjson.Get(string(reqBody), "model").Str + if tc.bedrock { + model = "beddel" + } + + for _, proxy := range proxies { + require.NotEmpty(t, proxy.ListTools()) + tool := proxy.ListTools()[0] + + attrs := []attribute.KeyValue{ + attribute.String(tracing.RequestPath, reqPath), + attribute.String(tracing.InterceptionID, intcID), + attribute.String(tracing.Provider, aibridge.ProviderAnthropic), + attribute.String(tracing.Model, model), + attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.MCPInput, "{\"owner\":\"admin\"}"), + attribute.String(tracing.MCPToolName, "coder_list_workspaces"), + attribute.String(tracing.MCPServerName, tool.ServerName), + attribute.String(tracing.MCPServerURL, tool.ServerURL), + attribute.Bool(tracing.Streaming, tc.streaming), + attribute.Bool(tracing.IsBedrock, tc.bedrock), + } + verifyTraces(t, sr, []expectTrace{{"Intercept.ProcessRequest.ToolCall", 1, codes.Unset}}, attrs) + } + }) + } +} + +func TestTraceOpenAI(t *testing.T) { + cases := []struct { + name string + fixture []byte + streaming bool + expect []expectTrace + }{ + { + name: "trace_openai_streaming", + fixture: oaiSimple, + streaming: true, + expect: []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + }, + }, + { + name: "trace_openai_non_streaming", + fixture: oaiSimple, + streaming: false, + expect: []expectTrace{ + {"Intercept", 1, codes.Unset}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Unset}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.RecordTokenUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + }, + }, + } + + for _, tc := range cases { + t.Run(t.Name(), func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + arc := txtar.Parse(tc.fixture) + + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureStreamingResponse) + require.Contains(t, files, fixtureNonStreamingResponse) + + fixtureReqBody := files[fixtureRequest] + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) + require.NoError(t, err) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) + srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + + req := createOpenAIChatCompletionsReq(t, srv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + srv.Close() + + require.Equal(t, 1, len(recorder.interceptions)) + intcID := recorder.interceptions[0].ID + + totalCount := 0 + for _, e := range tc.expect { + totalCount += e.count + } + require.Len(t, sr.Ended(), totalCount) + + attrs := []attribute.KeyValue{ + attribute.String(tracing.RequestPath, req.URL.Path), + attribute.String(tracing.InterceptionID, intcID), + attribute.String(tracing.Provider, aibridge.ProviderOpenAI), + attribute.String(tracing.Model, gjson.Get(string(reqBody), "model").Str), + attribute.String(tracing.InitiatorID, userID), + attribute.Bool(tracing.Streaming, tc.streaming), + } + verifyTraces(t, sr, tc.expect, attrs) + }) + } +} + +func TestTraceOpenAIErr(t *testing.T) { + cases := []struct { + name string + streaming bool + expect []expectTrace + }{ + { + name: "trace_openai_streaming_err", + streaming: true, + expect: []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.RecordPromptUsage", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Unset}, + }, + }, + { + name: "trace_openai_non_streaming_err", + streaming: false, + expect: []expectTrace{ + {"Intercept", 1, codes.Error}, + {"Intercept.CreateInterceptor", 1, codes.Unset}, + {"Intercept.RecordInterception", 1, codes.Unset}, + {"Intercept.ProcessRequest", 1, codes.Error}, + {"Intercept.RecordInterceptionEnded", 1, codes.Unset}, + {"Intercept.ProcessRequest.Upstream", 1, codes.Error}, + }, + }, + } + + for _, tc := range cases { + t.Run(t.Name(), func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + var arc *txtar.Archive + if tc.streaming { + arc = txtar.Parse(oaiMidStreamErr) + } else { + arc = txtar.Parse(oaiNonStreamErr) + } + + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + if tc.streaming { + require.Contains(t, files, fixtureStreamingResponse) + } else { + require.Contains(t, files, fixtureNonStreamingResponse) + } + + fixtureReqBody := files[fixtureRequest] + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) + require.NoError(t, err) + + mockAPI := newMockServer(ctx, t, files, nil) + t.Cleanup(mockAPI.Close) + provider := aibridge.NewOpenAIProvider(openaiCfg(mockAPI.URL, apiKey)) + srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + + req := createOpenAIChatCompletionsReq(t, srv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + if tc.streaming { + require.Equal(t, http.StatusOK, resp.StatusCode) + } else { + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + } + defer resp.Body.Close() + srv.Close() + + require.Equal(t, 1, len(recorder.interceptions)) + intcID := recorder.interceptions[0].ID + + totalCount := 0 + for _, e := range tc.expect { + totalCount += e.count + } + require.Len(t, sr.Ended(), totalCount) + + attrs := []attribute.KeyValue{ + attribute.String(tracing.RequestPath, req.URL.Path), + attribute.String(tracing.InterceptionID, intcID), + attribute.String(tracing.Provider, aibridge.ProviderOpenAI), + attribute.String(tracing.Model, gjson.Get(string(reqBody), "model").Str), + attribute.String(tracing.InitiatorID, userID), + attribute.Bool(tracing.Streaming, tc.streaming), + } + verifyTraces(t, sr, tc.expect, attrs) + }) + } +} + +func TestOpenAIInjectedToolsTrace(t *testing.T) { + t.Parallel() + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))} + return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, tracer) + } + + var reqBody string + var reqPath string + reqFunc := func(t *testing.T, baseURL string, input []byte) *http.Request { + reqBody = string(input) + r := createOpenAIChatCompletionsReq(t, baseURL, input) + reqPath = r.URL.Path + return r + } + + // Build the requirements & make the assertions which are common to all providers. + recorderClient, _, proxies, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, reqFunc) + + defer resp.Body.Close() + + require.Len(t, recorderClient.interceptions, 1) + intcID := recorderClient.interceptions[0].ID + + for _, proxy := range proxies { + require.NotEmpty(t, proxy.ListTools()) + tool := proxy.ListTools()[0] + + attrs := []attribute.KeyValue{ + attribute.String(tracing.RequestPath, reqPath), + attribute.String(tracing.InterceptionID, intcID), + attribute.String(tracing.Provider, aibridge.ProviderOpenAI), + attribute.String(tracing.Model, gjson.Get(reqBody, "model").Str), + attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.MCPInput, "{\"owner\":\"admin\"}"), + attribute.String(tracing.MCPToolName, "coder_list_workspaces"), + attribute.String(tracing.MCPServerName, tool.ServerName), + attribute.String(tracing.MCPServerURL, tool.ServerURL), + attribute.Bool(tracing.Streaming, streaming), + } + verifyTraces(t, sr, []expectTrace{{"Intercept.ProcessRequest.ToolCall", 1, codes.Unset}}, attrs) + } + }) + } +} + +func TestTracePassthrough(t *testing.T) { + t.Parallel() + + arc := txtar.Parse(oaiFallthrough) + files := filesMap(arc) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(files[fixtureResponse]) + })) + t.Cleanup(upstream.Close) + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + provider := aibridge.NewOpenAIProvider(openaiCfg(upstream.URL, apiKey)) + srv, _ := newTestSrv(t, t.Context(), provider, nil, tracer) + + req, err := http.NewRequestWithContext(t.Context(), "GET", srv.URL+"/openai/v1/models", nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + srv.Close() + + spans := sr.Ended() + require.Len(t, spans, 1) + + assert.Equal(t, spans[0].Name(), "Passthrough") + want := []attribute.KeyValue{ + attribute.String(tracing.PassthroughMethod, "GET"), + attribute.String(tracing.PassthroughURL, "/v1/models"), + } + got := slices.SortedFunc(slices.Values(spans[0].Attributes()), cmpAttrKeyVal) + require.Equal(t, want, got) +} + +func TestNewServerProxyManagerTraces(t *testing.T) { + ctx := t.Context() + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + tracer := tp.Tracer(t.Name()) + defer func() { _ = tp.Shutdown(t.Context()) }() + + serverName := "serverName" + srv, _ := createMockMCPSrv(t) + mcpSrv := httptest.NewServer(srv) + t.Cleanup(mcpSrv.Close) + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + proxy, err := mcp.NewStreamableHTTPServerProxy(serverName, mcpSrv.URL, nil, nil, nil, logger, tracer) + require.NoError(t, err) + tools := map[string]mcp.ServerProxier{"unusedValue": proxy} + + mcpMgr := mcp.NewServerProxyManager(tools, tracer) + err = mcpMgr.Init(ctx) + require.NoError(t, err) + + require.Len(t, sr.Ended(), 3) + verifyTraces(t, sr, []expectTrace{{"ServerProxyManager.Init", 1, codes.Unset}}, []attribute.KeyValue{}) + + attrs := []attribute.KeyValue{ + attribute.String(tracing.MCPProxyName, proxy.Name()), + attribute.String(tracing.MCPServerURL, mcpSrv.URL), + attribute.String(tracing.MCPServerName, serverName), + } + verifyTraces(t, sr, []expectTrace{{"StreamableHTTPServerProxy.Init", 1, codes.Unset}}, attrs) + + attrs = append(attrs, attribute.Int(tracing.MCPToolCount, len(proxy.ListTools()))) + verifyTraces(t, sr, []expectTrace{{"StreamableHTTPServerProxy.Init.fetchTools", 1, codes.Unset}}, attrs) +} + +func cmpAttrKeyVal(a attribute.KeyValue, b attribute.KeyValue) int { + return strings.Compare(string(a.Key), string(b.Key)) +} + +// checks counts of traces with given name, status and attributes +func verifyTraces(t *testing.T, spanRecorder *tracetest.SpanRecorder, expect []expectTrace, attrs []attribute.KeyValue) { + spans := spanRecorder.Ended() + + for _, e := range expect { + found := 0 + for _, s := range spans { + if s.Name() != e.name || s.Status().Code != e.status { + continue + } + found++ + want := slices.SortedFunc(slices.Values(attrs), cmpAttrKeyVal) + got := slices.SortedFunc(slices.Values(s.Attributes()), cmpAttrKeyVal) + require.Equal(t, want, got) + assert.Equalf(t, e.status, s.Status().Code, "unexpected status for trace naned: %v got: %v want: %v", e.name, s.Status().Code, e.status) + } + if found != e.count { + t.Errorf("found unexpected number of spans named: %v with status %v, got: %v want: %v", e.name, e.status, found, e.count) + } + } +} + +func testBedrockCfg(url string) *aibridge.AWSBedrockConfig { + return &aibridge.AWSBedrockConfig{ + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "beddel", // This model should override the request's given one. + SmallFastModel: "modrock", // Unused but needed for validation. + EndpointOverride: url, + } +} diff --git a/tracing/tracing.go b/tracing/tracing.go new file mode 100644 index 0000000..aef819b --- /dev/null +++ b/tracing/tracing.go @@ -0,0 +1,86 @@ +package tracing + +import ( + "context" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +type ( + traceInterceptionAttrsContextKey struct{} + traceRequestBridgeAttrsContextKey struct{} +) + +const ( + // trace attribute key constants + RequestPath = "request_path" + + InterceptionID = "interception_id" + InitiatorID = "user_id" + Provider = "provider" + Model = "model" + Streaming = "streaming" + IsBedrock = "aws_bedrock" + + PassthroughURL = "passthrough_url" + PassthroughMethod = "passthrough_method" + + MCPInput = "mcp_input" + MCPProxyName = "mcp_proxy_name" + MCPToolName = "mcp_tool_name" + MCPServerName = "mcp_server_name" + MCPServerURL = "mcp_server_url" + MCPToolCount = "mcp_tool_count" + + APIKeyID = "api_key_id" +) + +func WithInterceptionAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context { + return context.WithValue(ctx, traceInterceptionAttrsContextKey{}, traceAttrs) +} + +func InterceptionAttributesFromContext(ctx context.Context) []attribute.KeyValue { + attrs, ok := ctx.Value(traceInterceptionAttrsContextKey{}).([]attribute.KeyValue) + if !ok { + return nil + } + + return attrs +} + +func WithRequestBridgeAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context { + return context.WithValue(ctx, traceRequestBridgeAttrsContextKey{}, traceAttrs) +} + +func RequestBridgeAttributesFromContext(ctx context.Context) []attribute.KeyValue { + attrs, ok := ctx.Value(traceRequestBridgeAttrsContextKey{}).([]attribute.KeyValue) + if !ok { + return nil + } + + return attrs +} + +// EndSpanErr ends given span and sets Error status if error is not nil +// uses pointer to error because defer evaluates function arguments +// when defer statement is executed not when deferred function is called +// +// example usage: +// +// func Example() (result any, outErr error) { +// _, span := tracer.Start(...) +// defer tracing.EndSpanErr(span, &outErr) +// +// } +func EndSpanErr(span trace.Span, err *error) { + if span == nil { + return + } + + if err != nil && *err != nil { + span.SetStatus(codes.Error, (*err).Error()) + } + span.End() +}