Skip to content

Commit 6cd58d8

Browse files
semistrictJBUinfo
authored andcommitted
Revert sampling formatting changes
1 parent e0d70dd commit 6cd58d8

File tree

7 files changed

+58
-58
lines changed

7 files changed

+58
-58
lines changed

client/transport/streamable_http_sampling_test.go

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,27 @@ import (
1616

1717
// TestStreamableHTTP_SamplingFlow tests the complete sampling flow with HTTP transport
1818
func TestStreamableHTTP_SamplingFlow(t *testing.T) {
19-
// Create simple test server
19+
// Create simple test server
2020
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2121
// Just respond OK to any requests
2222
w.WriteHeader(http.StatusOK)
2323
}))
2424
defer server.Close()
25-
25+
2626
// Create HTTP client transport
2727
client, err := NewStreamableHTTP(server.URL)
2828
if err != nil {
2929
t.Fatalf("Failed to create client: %v", err)
3030
}
3131
defer client.Close()
32-
32+
3333
// Set up sampling request handler
3434
var handledRequest *JSONRPCRequest
3535
handlerCalled := make(chan struct{})
3636
client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) {
3737
handledRequest = &request
3838
close(handlerCalled)
39-
39+
4040
// Simulate sampling handler response
4141
result := map[string]any{
4242
"role": "assistant",
@@ -47,25 +47,25 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) {
4747
"model": "test-model",
4848
"stopReason": "stop_sequence",
4949
}
50-
50+
5151
resultBytes, _ := json.Marshal(result)
52-
52+
5353
return &JSONRPCResponse{
5454
JSONRPC: "2.0",
5555
ID: request.ID,
5656
Result: resultBytes,
5757
}, nil
5858
})
59-
59+
6060
// Start the client
6161
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
6262
defer cancel()
63-
63+
6464
err = client.Start(ctx)
6565
if err != nil {
6666
t.Fatalf("Failed to start client: %v", err)
6767
}
68-
68+
6969
// Test direct request handling (simulating a sampling request)
7070
samplingRequest := JSONRPCRequest{
7171
JSONRPC: "2.0",
@@ -83,23 +83,23 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) {
8383
},
8484
},
8585
}
86-
86+
8787
// Directly test request handling
8888
client.handleIncomingRequest(ctx, samplingRequest)
89-
89+
9090
// Wait for handler to be called
9191
select {
9292
case <-handlerCalled:
9393
// Handler was called
9494
case <-time.After(1 * time.Second):
9595
t.Fatal("Handler was not called within timeout")
9696
}
97-
97+
9898
// Verify the request was handled
9999
if handledRequest == nil {
100100
t.Fatal("Sampling request was not handled")
101101
}
102-
102+
103103
if handledRequest.Method != string(mcp.MethodSamplingCreateMessage) {
104104
t.Errorf("Expected method %s, got %s", mcp.MethodSamplingCreateMessage, handledRequest.Method)
105105
}
@@ -109,7 +109,7 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) {
109109
func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) {
110110
var errorHandled sync.WaitGroup
111111
errorHandled.Add(1)
112-
112+
113113
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
114114
if r.Method == http.MethodPost {
115115
var body map[string]any
@@ -118,7 +118,7 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) {
118118
w.WriteHeader(http.StatusOK)
119119
return
120120
}
121-
121+
122122
// Check if this is an error response
123123
if errorField, ok := body["error"]; ok {
124124
errorMap := errorField.(map[string]any)
@@ -132,36 +132,36 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) {
132132
w.WriteHeader(http.StatusOK)
133133
}))
134134
defer server.Close()
135-
135+
136136
client, err := NewStreamableHTTP(server.URL)
137137
if err != nil {
138138
t.Fatalf("Failed to create client: %v", err)
139139
}
140140
defer client.Close()
141-
141+
142142
// Set up request handler that returns an error
143143
client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) {
144144
return nil, fmt.Errorf("sampling failed")
145145
})
146-
146+
147147
// Start the client
148148
ctx := context.Background()
149149
err = client.Start(ctx)
150150
if err != nil {
151151
t.Fatalf("Failed to start client: %v", err)
152152
}
153-
153+
154154
// Simulate incoming sampling request
155155
samplingRequest := JSONRPCRequest{
156156
JSONRPC: "2.0",
157157
ID: mcp.NewRequestId(1),
158158
Method: string(mcp.MethodSamplingCreateMessage),
159159
Params: map[string]any{},
160160
}
161-
161+
162162
// This should trigger error handling
163163
client.handleIncomingRequest(ctx, samplingRequest)
164-
164+
165165
// Wait for error to be handled
166166
errorHandled.Wait()
167167
}
@@ -170,7 +170,7 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) {
170170
func TestStreamableHTTP_NoSamplingHandler(t *testing.T) {
171171
var errorReceived bool
172172
errorReceivedChan := make(chan struct{})
173-
173+
174174
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
175175
if r.Method == http.MethodPost {
176176
var body map[string]any
@@ -179,12 +179,12 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) {
179179
w.WriteHeader(http.StatusOK)
180180
return
181181
}
182-
182+
183183
// Check if this is an error response with method not found
184184
if errorField, ok := body["error"]; ok {
185185
errorMap := errorField.(map[string]any)
186186
if code, ok := errorMap["code"].(float64); ok && code == -32601 {
187-
if message, ok := errorMap["message"].(string); ok &&
187+
if message, ok := errorMap["message"].(string); ok &&
188188
strings.Contains(message, "no handler configured") {
189189
errorReceived = true
190190
close(errorReceivedChan)
@@ -195,40 +195,40 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) {
195195
w.WriteHeader(http.StatusOK)
196196
}))
197197
defer server.Close()
198-
198+
199199
client, err := NewStreamableHTTP(server.URL)
200200
if err != nil {
201201
t.Fatalf("Failed to create client: %v", err)
202202
}
203203
defer client.Close()
204-
204+
205205
// Don't set any request handler
206-
206+
207207
ctx := context.Background()
208208
err = client.Start(ctx)
209209
if err != nil {
210210
t.Fatalf("Failed to start client: %v", err)
211211
}
212-
212+
213213
// Simulate incoming sampling request
214214
samplingRequest := JSONRPCRequest{
215215
JSONRPC: "2.0",
216216
ID: mcp.NewRequestId(1),
217217
Method: string(mcp.MethodSamplingCreateMessage),
218218
Params: map[string]any{},
219219
}
220-
220+
221221
// This should trigger "method not found" error
222222
client.handleIncomingRequest(ctx, samplingRequest)
223-
223+
224224
// Wait for error to be received
225225
select {
226226
case <-errorReceivedChan:
227227
// Error was received
228228
case <-time.After(1 * time.Second):
229229
t.Fatal("Method not found error was not received within timeout")
230230
}
231-
231+
232232
if !errorReceived {
233233
t.Error("Expected method not found error, but didn't receive it")
234234
}
@@ -241,13 +241,13 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) {
241241
t.Fatalf("Failed to create client: %v", err)
242242
}
243243
defer client.Close()
244-
244+
245245
// Verify it implements BidirectionalInterface
246246
_, ok := any(client).(BidirectionalInterface)
247247
if !ok {
248248
t.Error("StreamableHTTP should implement BidirectionalInterface")
249249
}
250-
250+
251251
// Test SetRequestHandler
252252
handlerSet := false
253253
handlerSetChan := make(chan struct{})
@@ -256,23 +256,23 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) {
256256
close(handlerSetChan)
257257
return nil, nil
258258
})
259-
259+
260260
// Verify handler was set by triggering it
261261
ctx := context.Background()
262262
client.handleIncomingRequest(ctx, JSONRPCRequest{
263263
JSONRPC: "2.0",
264264
ID: mcp.NewRequestId(1),
265265
Method: "test",
266266
})
267-
267+
268268
// Wait for handler to be called
269269
select {
270270
case <-handlerSetChan:
271271
// Handler was called
272272
case <-time.After(1 * time.Second):
273273
t.Fatal("Handler was not called within timeout")
274274
}
275-
275+
276276
if !handlerSet {
277277
t.Error("Request handler was not properly set or called")
278278
}
@@ -315,16 +315,16 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) {
315315
// Track which requests have been received and their completion order
316316
var requestOrder []int
317317
var orderMutex sync.Mutex
318-
318+
319319
// Set up request handler that simulates different processing times
320320
client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) {
321321
// Extract request ID to determine processing time
322322
requestIDValue := request.ID.Value()
323-
323+
324324
var delay time.Duration
325325
var responseText string
326326
var requestNum int
327-
327+
328328
// First request (ID 1) takes longer, second request (ID 2) completes faster
329329
if requestIDValue == int64(1) {
330330
delay = 100 * time.Millisecond
@@ -341,7 +341,7 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) {
341341

342342
// Simulate processing time
343343
time.Sleep(delay)
344-
344+
345345
// Record completion order
346346
orderMutex.Lock()
347347
requestOrder = append(requestOrder, requestNum)
@@ -428,7 +428,7 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) {
428428
// Verify completion order: request 2 should complete first
429429
orderMutex.Lock()
430430
defer orderMutex.Unlock()
431-
431+
432432
if len(requestOrder) != 2 {
433433
t.Fatalf("Expected 2 completed requests, got %d", len(requestOrder))
434434
}
@@ -493,4 +493,4 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) {
493493
}
494494
}
495495
}
496-
}
496+
}

examples/sampling_client/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,15 @@ func main() {
9595
// Setup graceful shutdown
9696
sigChan := make(chan os.Signal, 1)
9797
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
98-
98+
9999
// Create a context that cancels on signal
100100
ctx, cancel := context.WithCancel(ctx)
101101
go func() {
102102
<-sigChan
103103
log.Println("Received shutdown signal, closing client...")
104104
cancel()
105105
}()
106-
106+
107107
// Move defer after error checking
108108
defer func() {
109109
if err := mcpClient.Close(); err != nil {

examples/sampling_http_client/main.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func main() {
6363
log.Fatalf("Failed to create HTTP transport: %v", err)
6464
}
6565
defer httpTransport.Close()
66-
66+
6767
// Create client with sampling support
6868
mcpClient := client.NewClient(
6969
httpTransport,
@@ -81,7 +81,7 @@ func main() {
8181
initRequest := mcp.InitializeRequest{
8282
Params: mcp.InitializeParams{
8383
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
84-
Capabilities: mcp.ClientCapabilities{
84+
Capabilities: mcp.ClientCapabilities{
8585
// Sampling capability will be automatically added by the client
8686
},
8787
ClientInfo: mcp.Implementation{
@@ -90,7 +90,7 @@ func main() {
9090
},
9191
},
9292
}
93-
93+
9494
_, err = mcpClient.Initialize(ctx, initRequest)
9595
if err != nil {
9696
log.Fatalf("Failed to initialize MCP session: %v", err)
@@ -102,7 +102,7 @@ func main() {
102102

103103
// In a real application, you would keep the client running to handle sampling requests
104104
// For this example, we'll just demonstrate that it's working
105-
105+
106106
// Keep the client running (in a real app, you'd have your main application logic here)
107107
sigChan := make(chan os.Signal, 1)
108108
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
@@ -113,4 +113,4 @@ func main() {
113113
case <-sigChan:
114114
log.Println("Received shutdown signal")
115115
}
116-
}
116+
}

examples/sampling_http_server/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,4 +147,4 @@ func main() {
147147
if err := httpServer.Start(":8080"); err != nil {
148148
log.Fatalf("Server failed to start: %v", err)
149149
}
150-
}
150+
}

server/sampling.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
func (s *MCPServer) EnableSampling() {
1313
s.capabilitiesMu.Lock()
1414
defer s.capabilitiesMu.Unlock()
15-
15+
1616
enabled := true
1717
s.capabilities.sampling = &enabled
1818
}

0 commit comments

Comments
 (0)