Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 58 additions & 14 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,11 +727,12 @@ func TestFallthrough(t *testing.T) {
}

// setupMCPServerProxiesForTest creates a mock MCP server, initializes the MCP bridge, and returns the tools
func setupMCPServerProxiesForTest(t *testing.T) map[string]mcp.ServerProxier {
func setupMCPServerProxiesForTest(t *testing.T) (map[string]mcp.ServerProxier, *callAccumulator) {
t.Helper()

// Setup Coder MCP integration
mcpSrv := httptest.NewServer(createMockMCPSrv(t))
srv, acc := createMockMCPSrv(t)
mcpSrv := httptest.NewServer(srv)
t.Cleanup(mcpSrv.Close)

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
Expand All @@ -745,7 +746,7 @@ func setupMCPServerProxiesForTest(t *testing.T) map[string]mcp.ServerProxier {
tools := proxy.ListTools()
require.NotEmpty(t, tools)

return map[string]mcp.ServerProxier{proxy.Name(): proxy}
return map[string]mcp.ServerProxier{proxy.Name(): proxy}, acc
}

type (
Expand All @@ -766,7 +767,7 @@ func TestAnthropicInjectedTools(t *testing.T) {
}

// Build the requirements & make the assertions which are common to all providers.
recorderClient, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq)
recorderClient, mcpCalls, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq)

// Ensure expected tool was invoked with expected input.
require.Len(t, recorderClient.toolUsages, 1)
Expand All @@ -776,6 +777,11 @@ func TestAnthropicInjectedTools(t *testing.T) {
actual, err := json.Marshal(recorderClient.toolUsages[0].Args)
require.NoError(t, err)
require.EqualValues(t, expected, actual)
invocations := mcpCalls.getCallsByTool(mockToolName)
require.Len(t, invocations, 1)
actual, err = json.Marshal(invocations[0])
require.NoError(t, err)
require.EqualValues(t, expected, actual)

var (
content *anthropic.ContentBlockUnion
Expand Down Expand Up @@ -847,7 +853,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
}

// Build the requirements & make the assertions which are common to all providers.
recorderClient, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq)
recorderClient, mcpCalls, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq)

// Ensure expected tool was invoked with expected input.
require.Len(t, recorderClient.toolUsages, 1)
Expand All @@ -857,6 +863,11 @@ func TestOpenAIInjectedTools(t *testing.T) {
actual, err := json.Marshal(recorderClient.toolUsages[0].Args)
require.NoError(t, err)
require.EqualValues(t, expected, actual)
invocations := mcpCalls.getCallsByTool(mockToolName)
require.Len(t, invocations, 1)
actual, err = json.Marshal(invocations[0])
require.NoError(t, err)
require.EqualValues(t, expected, actual)

var (
content *openai.ChatCompletionChoice
Expand Down Expand Up @@ -932,7 +943,7 @@ func TestOpenAIInjectedTools(t *testing.T) {

// setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests.
// Kinda fugly right now, we can refactor this later.
func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error), createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *http.Response) {
func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *callAccumulator, *http.Response) {
t.Helper()

arc := txtar.Parse(fixture)
Expand Down Expand Up @@ -977,11 +988,11 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu

recorderClient := &mockRecorderClient{}

// Setup MCP tools.
tools := setupMCPServerProxiesForTest(t)
// Setup MCP mcpProxiers.
mcpProxiers, acc := setupMCPServerProxiesForTest(t)

// Configure the bridge with injected tools.
mcpMgr := mcp.NewServerProxyManager(tools)
mcpMgr := mcp.NewServerProxyManager(mcpProxiers)
require.NoError(t, mcpMgr.Init(ctx))
b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr)
require.NoError(t, err)
Expand All @@ -1008,7 +1019,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
return mockSrv.callCount.Load() == 2
}, time.Second*10, time.Millisecond*50)

return recorderClient, resp
return recorderClient, acc, resp
}

func TestErrorHandling(t *testing.T) {
Expand Down Expand Up @@ -1259,10 +1270,10 @@ func TestStableRequestEncoding(t *testing.T) {
t.Cleanup(cancel)

// Setup MCP tools.
tools := setupMCPServerProxiesForTest(t)
mcpProxiers, _ := setupMCPServerProxiesForTest(t)

// Configure the bridge with injected tools.
mcpMgr := mcp.NewServerProxyManager(tools)
mcpMgr := mcp.NewServerProxyManager(mcpProxiers)
require.NoError(t, mcpMgr.Init(ctx))

arc := txtar.Parse(tc.fixture)
Expand Down Expand Up @@ -1669,7 +1680,36 @@ func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) {

const mockToolName = "coder_list_workspaces"

func createMockMCPSrv(t *testing.T) http.Handler {
// callAccumulator tracks all tool invocations by name and each instance's arguments.
type callAccumulator struct {
calls map[string][]any
callsMu sync.Mutex
}

func newCallAccumulator() *callAccumulator {
return &callAccumulator{
calls: make(map[string][]any),
}
}

func (a *callAccumulator) addCall(tool string, args any) {
a.callsMu.Lock()
defer a.callsMu.Unlock()

a.calls[tool] = append(a.calls[tool], args)
}

func (a *callAccumulator) getCallsByTool(name string) []any {
a.callsMu.Lock()
defer a.callsMu.Unlock()

// Protect against concurrent access of the slice.
result := make([]any, len(a.calls[name]))
copy(result, a.calls[name])
return result
}

func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) {
t.Helper()

s := server.NewMCPServer(
Expand All @@ -1678,16 +1718,20 @@ func createMockMCPSrv(t *testing.T) http.Handler {
server.WithToolCapabilities(true),
)

// Accumulate tool calls & their arguments.
acc := newCallAccumulator()

for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} {
tool := mcplib.NewTool(name,
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
)
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
acc.addCall(request.Params.Name, request.Params.Arguments)
return mcplib.NewToolResultText("mock"), nil
})
}

return server.NewStreamableHTTPServer(s)
return server.NewStreamableHTTPServer(s), acc
}

func openaiCfg(url, key string) aibridge.OpenAIConfig {
Expand Down
10 changes: 2 additions & 8 deletions intercept_openai_chat_blocking.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package aibridge

import (
"bytes"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -139,20 +138,15 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
appendedPrevMsg = true
}

var (
args map[string]string
buf bytes.Buffer
)
_ = json.NewEncoder(&buf).Encode(tc.Function.Arguments)
_ = json.NewDecoder(&buf).Decode(&args)
args := i.unmarshalArgs(tc.Function.Arguments)
res, err := tool.Call(ctx, args)

_ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: completion.ID,
ServerURL: &tool.ServerURL,
Tool: tool.Name,
Args: i.unmarshalArgs(tc.Function.Arguments),
Args: args,
Injected: true,
InvocationError: err,
})
Expand Down
4 changes: 2 additions & 2 deletions metrics_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) {
provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil)

// Setup mocked MCP server & tools.
tools := setupMCPServerProxiesForTest(t)
mcpMgr := mcp.NewServerProxyManager(tools)
mcpProxiers, _ := setupMCPServerProxiesForTest(t)
mcpMgr := mcp.NewServerProxyManager(mcpProxiers)
require.NoError(t, mcpMgr.Init(ctx))

bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, metrics, logger)
Expand Down