From e63eff3ae062d1b3c5559601435a8754d21c33c3 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Fri, 5 Dec 2025 10:08:05 +0200 Subject: [PATCH 1/2] fix: openai blocking requests do not call injected tools correctly Signed-off-by: Danny Kopping --- bridge_integration_test.go | 70 +++++++++++++++++++++++++------ intercept_openai_chat_blocking.go | 10 +---- metrics_integration_test.go | 2 +- 3 files changed, 60 insertions(+), 22 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 7fb7885..cc2d727 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -727,11 +727,12 @@ func TestFallthrough(t *testing.T) { } // setupMCPServerProxiesForTest creates a mock MCP server, initializes the MCP bridge, and returns the tools -func setupMCPServerProxiesForTest(t *testing.T) map[string]mcp.ServerProxier { +func setupMCPServerProxiesForTest(t *testing.T) (map[string]mcp.ServerProxier, *callAccumulator) { t.Helper() // Setup Coder MCP integration - mcpSrv := httptest.NewServer(createMockMCPSrv(t)) + srv, acc := createMockMCPSrv(t) + mcpSrv := httptest.NewServer(srv) t.Cleanup(mcpSrv.Close) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) @@ -745,7 +746,7 @@ func setupMCPServerProxiesForTest(t *testing.T) map[string]mcp.ServerProxier { tools := proxy.ListTools() require.NotEmpty(t, tools) - return map[string]mcp.ServerProxier{proxy.Name(): proxy} + return map[string]mcp.ServerProxier{proxy.Name(): proxy}, acc } type ( @@ -766,7 +767,7 @@ func TestAnthropicInjectedTools(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) + recorderClient, mcpCalls, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) // Ensure expected tool was invoked with expected input. require.Len(t, recorderClient.toolUsages, 1) @@ -776,6 +777,11 @@ func TestAnthropicInjectedTools(t *testing.T) { actual, err := json.Marshal(recorderClient.toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) + invocations := mcpCalls.getCallsByTool(mockToolName) + require.Len(t, invocations, 1) + actual, err = json.Marshal(invocations[0]) + require.NoError(t, err) + require.EqualValues(t, expected, actual) var ( content *anthropic.ContentBlockUnion @@ -847,7 +853,7 @@ func TestOpenAIInjectedTools(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) + recorderClient, mcpCalls, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) // Ensure expected tool was invoked with expected input. require.Len(t, recorderClient.toolUsages, 1) @@ -857,6 +863,11 @@ func TestOpenAIInjectedTools(t *testing.T) { actual, err := json.Marshal(recorderClient.toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) + invocations := mcpCalls.getCallsByTool(mockToolName) + require.Len(t, invocations, 1) + actual, err = json.Marshal(invocations[0]) + require.NoError(t, err) + require.EqualValues(t, expected, actual) var ( content *openai.ChatCompletionChoice @@ -932,7 +943,7 @@ func TestOpenAIInjectedTools(t *testing.T) { // setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests. // Kinda fugly right now, we can refactor this later. -func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error), createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *http.Response) { +func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *callAccumulator, *http.Response) { t.Helper() arc := txtar.Parse(fixture) @@ -977,11 +988,11 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu recorderClient := &mockRecorderClient{} - // Setup MCP tools. - tools := setupMCPServerProxiesForTest(t) + // Setup MCP mcpProxiers. + mcpProxiers, acc := setupMCPServerProxiesForTest(t) // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(tools) + mcpMgr := mcp.NewServerProxyManager(mcpProxiers) require.NoError(t, mcpMgr.Init(ctx)) b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr) require.NoError(t, err) @@ -1008,7 +1019,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu return mockSrv.callCount.Load() == 2 }, time.Second*10, time.Millisecond*50) - return recorderClient, resp + return recorderClient, acc, resp } func TestErrorHandling(t *testing.T) { @@ -1259,7 +1270,7 @@ func TestStableRequestEncoding(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools. - tools := setupMCPServerProxiesForTest(t) + tools, _ := setupMCPServerProxiesForTest(t) // Configure the bridge with injected tools. mcpMgr := mcp.NewServerProxyManager(tools) @@ -1669,7 +1680,36 @@ func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) { const mockToolName = "coder_list_workspaces" -func createMockMCPSrv(t *testing.T) http.Handler { +// callAccumulator tracks all tool invocations by name and each instance's arguments. +type callAccumulator struct { + calls map[string][]any + callsMu sync.Mutex +} + +func newCallAccumulator() *callAccumulator { + return &callAccumulator{ + calls: make(map[string][]any), + } +} + +func (a *callAccumulator) addCall(tool string, args any) { + a.callsMu.Lock() + defer a.callsMu.Unlock() + + a.calls[tool] = append(a.calls[tool], args) +} + +func (a *callAccumulator) getCallsByTool(name string) []any { + a.callsMu.Lock() + defer a.callsMu.Unlock() + + // Protect against concurrent access of the slice. + result := make([]any, len(a.calls[name])) + copy(result, a.calls[name]) + return result +} + +func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) { t.Helper() s := server.NewMCPServer( @@ -1678,16 +1718,20 @@ func createMockMCPSrv(t *testing.T) http.Handler { server.WithToolCapabilities(true), ) + // Accumulate tool calls & their arguments. + acc := newCallAccumulator() + for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} { tool := mcplib.NewTool(name, mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)), ) s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + acc.addCall(request.Params.Name, request.Params.Arguments) return mcplib.NewToolResultText("mock"), nil }) } - return server.NewStreamableHTTPServer(s) + return server.NewStreamableHTTPServer(s), acc } func openaiCfg(url, key string) aibridge.OpenAIConfig { diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 757c933..d1fa255 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -1,7 +1,6 @@ package aibridge import ( - "bytes" "encoding/json" "fmt" "net/http" @@ -139,12 +138,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r appendedPrevMsg = true } - var ( - args map[string]string - buf bytes.Buffer - ) - _ = json.NewEncoder(&buf).Encode(tc.Function.Arguments) - _ = json.NewDecoder(&buf).Decode(&args) + args := i.unmarshalArgs(tc.Function.Arguments) res, err := tool.Call(ctx, args) _ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{ @@ -152,7 +146,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r MsgID: completion.ID, ServerURL: &tool.ServerURL, Tool: tool.Name, - Args: i.unmarshalArgs(tc.Function.Arguments), + Args: args, Injected: true, InvocationError: err, }) diff --git a/metrics_integration_test.go b/metrics_integration_test.go index 3696de2..0ea2343 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -236,7 +236,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) // Setup mocked MCP server & tools. - tools := setupMCPServerProxiesForTest(t) + tools, _ := setupMCPServerProxiesForTest(t) mcpMgr := mcp.NewServerProxyManager(tools) require.NoError(t, mcpMgr.Init(ctx)) From 68cf38ab4d2727709ac01697954840209823bc33 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Fri, 5 Dec 2025 10:12:20 +0200 Subject: [PATCH 2/2] chore: drive-by renaming Signed-off-by: Danny Kopping --- bridge_integration_test.go | 4 ++-- metrics_integration_test.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index cc2d727..73f77a5 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -1270,10 +1270,10 @@ func TestStableRequestEncoding(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools. - tools, _ := setupMCPServerProxiesForTest(t) + mcpProxiers, _ := setupMCPServerProxiesForTest(t) // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(tools) + mcpMgr := mcp.NewServerProxyManager(mcpProxiers) require.NoError(t, mcpMgr.Init(ctx)) arc := txtar.Parse(tc.fixture) diff --git a/metrics_integration_test.go b/metrics_integration_test.go index 0ea2343..8e23ead 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -236,8 +236,8 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) // Setup mocked MCP server & tools. - tools, _ := setupMCPServerProxiesForTest(t) - mcpMgr := mcp.NewServerProxyManager(tools) + mcpProxiers, _ := setupMCPServerProxiesForTest(t) + mcpMgr := mcp.NewServerProxyManager(mcpProxiers) require.NoError(t, mcpMgr.Init(ctx)) bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, metrics, logger)