Skip to content

Commit 7760d31

Browse files
committed
add tool call trace test, fixed openAI tool call attrs?
1 parent 1fd9095 commit 7760d31

8 files changed

+238
-71
lines changed

aibtrace/aibtrace.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ const (
1818
RequestPath = "request_path"
1919

2020
InterceptionID = "interception_id"
21-
UserID = "user_id"
21+
InitiatorID = "user_id"
2222
Provider = "provider"
2323
Model = "model"
2424
Streaming = "streaming"

bridge_integration_test.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,9 @@ func TestOpenAIChatCompletions(t *testing.T) {
442442

443443
require.Len(t, recorderClient.toolUsages, 1)
444444
assert.Equal(t, "read_file", recorderClient.toolUsages[0].Tool)
445-
require.IsType(t, map[string]any{}, recorderClient.toolUsages[0].Args)
445+
require.IsType(t, "", recorderClient.toolUsages[0].Args)
446446
require.Contains(t, recorderClient.toolUsages[0].Args, "path")
447-
assert.Equal(t, "README.md", recorderClient.toolUsages[0].Args.(map[string]any)["path"])
447+
assert.Equal(t, "README.md", gjson.Get(recorderClient.toolUsages[0].Args.(string), "path").Str)
448448

449449
require.Len(t, recorderClient.userPrompts, 1)
450450
assert.Equal(t, "how large is the README.md file in my current path", recorderClient.userPrompts[0].Prompt)
@@ -765,7 +765,7 @@ func TestAnthropicInjectedTools(t *testing.T) {
765765
}
766766

767767
// Build the requirements & make the assertions which are common to all providers.
768-
recorderClient, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq)
768+
recorderClient, _, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq)
769769

770770
// Ensure expected tool was invoked with expected input.
771771
require.Len(t, recorderClient.toolUsages, 1)
@@ -847,16 +847,13 @@ func TestOpenAIInjectedTools(t *testing.T) {
847847
}
848848

849849
// Build the requirements & make the assertions which are common to all providers.
850-
recorderClient, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq)
850+
recorderClient, _, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq)
851851

852852
// Ensure expected tool was invoked with expected input.
853853
require.Len(t, recorderClient.toolUsages, 1)
854854
require.Equal(t, mockToolName, recorderClient.toolUsages[0].Tool)
855-
expected, err := json.Marshal(map[string]any{"owner": "admin"})
856-
require.NoError(t, err)
857-
actual, err := json.Marshal(recorderClient.toolUsages[0].Args)
858-
require.NoError(t, err)
859-
require.EqualValues(t, expected, actual)
855+
expected := "{\"owner\":\"admin\"}"
856+
require.EqualValues(t, expected, recorderClient.toolUsages[0].Args)
860857

861858
var (
862859
content *openai.ChatCompletionChoice
@@ -932,7 +929,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
932929

933930
// setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests.
934931
// Kinda fugly right now, we can refactor this later.
935-
func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error), createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *http.Response) {
932+
func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error), createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, map[string]mcp.ServerProxier, *http.Response) {
936933
t.Helper()
937934

938935
arc := txtar.Parse(fixture)
@@ -1008,7 +1005,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
10081005
return mockSrv.callCount.Load() == 2
10091006
}, time.Second*10, time.Millisecond*50)
10101007

1011-
return recorderClient, resp
1008+
return recorderClient, tools, resp
10121009
}
10131010

10141011
func TestErrorHandling(t *testing.T) {

intercept_anthropic_messages_base.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func (s *AnthropicMessagesInterceptionBase) baseTraceAttributes(r *http.Request,
6767
return []attribute.KeyValue{
6868
attribute.String(aibtrace.RequestPath, r.URL.Path),
6969
attribute.String(aibtrace.InterceptionID, s.id.String()),
70-
attribute.String(aibtrace.UserID, actorFromContext(r.Context()).id),
70+
attribute.String(aibtrace.InitiatorID, actorFromContext(r.Context()).id),
7171
attribute.String(aibtrace.Provider, ProviderAnthropic),
7272
attribute.String(aibtrace.Model, s.Model()),
7373
attribute.Bool(aibtrace.Streaming, streaming),

intercept_openai_chat_base.go

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"encoding/json"
66
"net/http"
7-
"strings"
87

98
aibtrace "github.com/coder/aibridge/aibtrace"
109
"github.com/coder/aibridge/mcp"
@@ -51,7 +50,7 @@ func (s *OpenAIChatInterceptionBase) baseTraceAttributes(r *http.Request, stream
5150
return []attribute.KeyValue{
5251
attribute.String(aibtrace.RequestPath, r.URL.Path),
5352
attribute.String(aibtrace.InterceptionID, s.id.String()),
54-
attribute.String(aibtrace.UserID, actorFromContext(r.Context()).id),
53+
attribute.String(aibtrace.InitiatorID, actorFromContext(r.Context()).id),
5554
attribute.String(aibtrace.Provider, ProviderOpenAI),
5655
attribute.String(aibtrace.Model, s.Model()),
5756
attribute.Bool(aibtrace.Streaming, streaming),
@@ -105,18 +104,6 @@ func (i *OpenAIChatInterceptionBase) injectTools() {
105104
}
106105
}
107106

108-
func (i *OpenAIChatInterceptionBase) unmarshalArgs(in string) (args ToolArgs) {
109-
if len(strings.TrimSpace(in)) == 0 {
110-
return args // An empty string will fail JSON unmarshaling.
111-
}
112-
113-
if err := json.Unmarshal([]byte(in), &args); err != nil {
114-
i.logger.Warn(context.Background(), "failed to unmarshal tool args", slog.Error(err))
115-
}
116-
117-
return args
118-
}
119-
120107
// writeUpstreamError marshals and writes a given error.
121108
func (i *OpenAIChatInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *OpenAIErrorResponse) {
122109
if oaiErr == nil {

intercept_openai_chat_blocking.go

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package aibridge
22

33
import (
4-
"bytes"
54
"context"
65
"encoding/json"
76
"fmt"
@@ -120,7 +119,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
120119
InterceptionID: i.ID().String(),
121120
MsgID: completion.ID,
122121
Tool: toolCall.Function.Name,
123-
Args: i.unmarshalArgs(toolCall.Function.Arguments),
122+
Args: toolCall.Function.Arguments,
124123
Injected: false,
125124
})
126125
}
@@ -151,20 +150,13 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
151150
appendedPrevMsg = true
152151
}
153152

154-
var (
155-
args map[string]string
156-
buf bytes.Buffer
157-
)
158-
_ = json.NewEncoder(&buf).Encode(tc.Function.Arguments)
159-
_ = json.NewDecoder(&buf).Decode(&args)
160-
res, err := tool.Call(ctx, i.tracer, args)
161-
153+
res, err := tool.Call(ctx, i.tracer, tc.Function.Arguments)
162154
_ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{
163155
InterceptionID: i.ID().String(),
164156
MsgID: completion.ID,
165157
ServerURL: &tool.ServerURL,
166158
Tool: tool.Name,
167-
Args: i.unmarshalArgs(tc.Function.Arguments),
159+
Args: tc.Function.Arguments,
168160
Injected: true,
169161
InvocationError: err,
170162
})

intercept_openai_chat_streaming.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
154154
InterceptionID: i.ID().String(),
155155
MsgID: processor.getMsgID(),
156156
Tool: toolCall.Name,
157-
Args: i.unmarshalArgs(toolCall.Arguments),
157+
Args: toolCall.Arguments,
158158
Injected: false,
159159
})
160160
toolCall = nil
@@ -241,15 +241,13 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
241241
i.req.Messages = append(i.req.Messages, processor.getLastCompletion().ToParam())
242242

243243
id := toolCall.ID
244-
args := i.unmarshalArgs(toolCall.Arguments)
245-
toolRes, toolErr := tool.Call(streamCtx, i.tracer, args)
246-
244+
toolRes, toolErr := tool.Call(streamCtx, i.tracer, toolCall.Arguments)
247245
_ = i.recorder.RecordToolUsage(streamCtx, &ToolUsageRecord{
248246
InterceptionID: i.ID().String(),
249247
MsgID: processor.getMsgID(),
250248
ServerURL: &tool.ServerURL,
251249
Tool: tool.Name,
252-
Args: args,
250+
Args: toolCall.Arguments,
253251
Injected: true,
254252
InvocationError: toolErr,
255253
})

mcp/proxy_streamable_http.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@ import (
2020
var _ ServerProxier = &StreamableHTTPServerProxy{}
2121

2222
type StreamableHTTPServerProxy struct {
23+
client *client.Client
24+
logger slog.Logger
25+
tracer trace.Tracer
26+
27+
allowlistPattern *regexp.Regexp
28+
denylistPattern *regexp.Regexp
29+
2330
serverName string
2431
serverURL string
25-
client *client.Client
26-
logger slog.Logger
27-
tracer trace.Tracer
2832
tools map[string]*Tool
29-
30-
allowlistPattern, denylistPattern *regexp.Regexp
3133
}
3234

3335
func NewStreamableHTTPServerProxy(logger slog.Logger, tracer trace.Tracer, serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp) (*StreamableHTTPServerProxy, error) {

0 commit comments

Comments
 (0)