Skip to content

Commit c5b90a8

Browse files
committed
fix data race in test
1 parent 5fc3c83 commit c5b90a8

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

client/sse_test.go

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ package client
22

33
import (
44
"context"
5-
"fmt"
65
"github.com/stretchr/testify/assert"
76
"net/http"
7+
"sync"
88
"testing"
99
"time"
1010

@@ -460,10 +460,13 @@ func TestSSEMCPClient(t *testing.T) {
460460
t.Fatalf("Failed to create client: %v", err)
461461
}
462462

463+
mu := sync.Mutex{}
463464
notificationNum := 0
464465
var messageNotification *mcp.JSONRPCNotification
465466
progressNotifications := make([]*mcp.JSONRPCNotification, 0)
466467
client.OnNotification(func(notification mcp.JSONRPCNotification) {
468+
mu.Lock()
469+
defer mu.Unlock()
467470
if notification.Method == string(mcp.MethodNotificationMessage) {
468471
messageNotification = &notification
469472
} else if notification.Method == string(mcp.MethodNotificationProgress) {
@@ -514,8 +517,10 @@ func TestSSEMCPClient(t *testing.T) {
514517
t.Errorf("Expected 1 content item, got %d", len(result.Content))
515518
}
516519

517-
time.Sleep(time.Millisecond * 200)
520+
time.Sleep(time.Millisecond * 500)
518521

522+
mu.Lock()
523+
defer mu.Unlock()
519524
assert.Equal(t, notificationNum, 3)
520525
assert.NotNil(t, messageNotification)
521526
assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage))
@@ -537,6 +542,7 @@ func TestSSEMCPClient(t *testing.T) {
537542
assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["progress"])
538543
assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"])
539544
assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"])
545+
540546
})
541547

542548
t.Run("Ensure the server does not send notifications", func(t *testing.T) {
@@ -545,8 +551,11 @@ func TestSSEMCPClient(t *testing.T) {
545551
t.Fatalf("Failed to create client: %v", err)
546552
}
547553

554+
mu := sync.Mutex{}
548555
notifications := make([]*mcp.JSONRPCNotification, 0)
549556
client.OnNotification(func(notification mcp.JSONRPCNotification) {
557+
mu.Lock()
558+
defer mu.Unlock()
550559
notifications = append(notifications, &notification)
551560
})
552561
defer client.Close()
@@ -583,8 +592,10 @@ func TestSSEMCPClient(t *testing.T) {
583592
request.Params.Name = "test-tool-for-sending-notification"
584593

585594
_, _ = client.CallTool(ctx, request)
586-
time.Sleep(time.Millisecond * 200)
595+
time.Sleep(time.Millisecond * 500)
587596

597+
mu.Lock()
598+
defer mu.Unlock()
588599
assert.Len(t, notifications, 0)
589600
})
590601

@@ -594,11 +605,13 @@ func TestSSEMCPClient(t *testing.T) {
594605
t.Fatalf("Failed to create client: %v", err)
595606
}
596607

608+
mu := sync.Mutex{}
597609
var messageNotification *mcp.JSONRPCNotification
598610
progressNotifications := make([]*mcp.JSONRPCNotification, 0)
599611
notificationNum := 0
600612
client.OnNotification(func(notification mcp.JSONRPCNotification) {
601-
println(notification.Method)
613+
mu.Lock()
614+
defer mu.Unlock()
602615
if notification.Method == string(mcp.MethodNotificationMessage) {
603616
messageNotification = &notification
604617
} else if notification.Method == string(mcp.MethodNotificationProgress) {
@@ -650,11 +663,10 @@ func TestSSEMCPClient(t *testing.T) {
650663
assert.Equal(t, result.Messages[0].Role, mcp.RoleAssistant)
651664
assert.Equal(t, result.Messages[0].Content.(mcp.TextContent).Type, "text")
652665
assert.Equal(t, result.Messages[0].Content.(mcp.TextContent).Text, "prompt value")
666+
time.Sleep(time.Millisecond * 500)
653667

654-
println(fmt.Sprintf("%v", result))
655-
656-
time.Sleep(time.Millisecond * 200)
657-
668+
mu.Lock()
669+
defer mu.Unlock()
658670
assert.Equal(t, notificationNum, 3)
659671
assert.NotNil(t, messageNotification)
660672
assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage))
@@ -683,11 +695,13 @@ func TestSSEMCPClient(t *testing.T) {
683695
t.Fatalf("Failed to create client: %v", err)
684696
}
685697

698+
mu := sync.Mutex{}
686699
var messageNotification *mcp.JSONRPCNotification
687700
progressNotifications := make([]*mcp.JSONRPCNotification, 0)
688701
notificationNum := 0
689702
client.OnNotification(func(notification mcp.JSONRPCNotification) {
690-
println(notification.Method)
703+
mu.Lock()
704+
defer mu.Unlock()
691705
if notification.Method == string(mcp.MethodNotificationMessage) {
692706
messageNotification = &notification
693707
} else if notification.Method == string(mcp.MethodNotificationProgress) {
@@ -741,8 +755,10 @@ func TestSSEMCPClient(t *testing.T) {
741755
assert.Equal(t, result.Contents[0].(mcp.TextResourceContents).MIMEType, "text/plain")
742756
assert.Equal(t, result.Contents[0].(mcp.TextResourceContents).Text, "test content")
743757

744-
time.Sleep(time.Millisecond * 200)
758+
time.Sleep(time.Millisecond * 500)
745759

760+
mu.Lock()
761+
defer mu.Unlock()
746762
assert.Equal(t, notificationNum, 3)
747763
assert.NotNil(t, messageNotification)
748764
assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage))

server/sse.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,8 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
519519
var message string
520520
if eventData, err := json.Marshal(response); err != nil {
521521
// If there is an error marshalling the response, send a generic error response
522-
log.Printf("failed to marshal response: %v", err)
522+
marshal, _ := json.Marshal(response)
523+
log.Printf("failed to marshal response: %v, response %s", err, string(marshal))
523524
message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n"
524525
} else {
525526
message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData)

0 commit comments

Comments
 (0)