Skip to content

Commit 86db7bb

Browse files
committed
add test
1 parent 2704ca5 commit 86db7bb

File tree

16 files changed

+200
-65
lines changed

16 files changed

+200
-65
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -542,17 +542,17 @@ The RequestContext object provides capabilities to interact with the client, suc
542542
mcpServer.AddTool(mcp.NewTool(
543543
"test-RequestContent",
544544
mcp.WithDescription("test RequestContent"),
545-
), func(ctx context.Context, reqContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
546-
// you could invoke `reqContext.IsLoggingNotificationSupported()` first the check if server supports logging notification
545+
), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
546+
// you could invoke `requestContext.IsLoggingNotificationSupported()` first the check if server supports logging notification
547547
// ff server does not support logging notification, this method will do nothing.
548-
_ = reqContext.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{
548+
_ = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{
549549
"testLog": "test send log notification",
550550
})
551551

552552
// server should send progress notification if request metadata includes a progressToken
553553
total := float64(100)
554554
progressMessage := "human readable progress information"
555-
_ = reqContext.SendProgressNotification(ctx, float64(50), &total, &progressMessage)
555+
_ = requestContext.SendProgressNotification(ctx, float64(50), &total, &progressMessage)
556556

557557
return &mcp.CallToolResult{
558558
Content: []mcp.Content{

client/inprocess_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func TestInProcessMCPClient(t *testing.T) {
2727
mcp.WithDestructiveHintAnnotation(false),
2828
mcp.WithIdempotentHintAnnotation(true),
2929
mcp.WithOpenWorldHintAnnotation(false),
30-
), func(ctx context.Context, reqContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
30+
), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
3131
return &mcp.CallToolResult{
3232
Content: []mcp.Content{
3333
mcp.TextContent{
@@ -48,7 +48,7 @@ func TestInProcessMCPClient(t *testing.T) {
4848
URI: "resource://testresource",
4949
Name: "My Resource",
5050
},
51-
func(ctx context.Context, reqContext server.RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
51+
func(ctx context.Context, requestContext server.RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
5252
return []mcp.ResourceContents{
5353
mcp.TextResourceContents{
5454
URI: "resource://testresource",
@@ -70,7 +70,7 @@ func TestInProcessMCPClient(t *testing.T) {
7070
},
7171
},
7272
},
73-
func(ctx context.Context, reqContext server.RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
73+
func(ctx context.Context, requestContext server.RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
7474
return &mcp.GetPromptResult{
7575
Messages: []mcp.PromptMessage{
7676
{

client/sse_test.go

Lines changed: 138 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package client
22

33
import (
44
"context"
5+
"fmt"
56
"github.com/stretchr/testify/assert"
67
"net/http"
78
"testing"
@@ -70,10 +71,9 @@ func TestSSEMCPClient(t *testing.T) {
7071
"test-tool-for-sending-notification",
7172
mcp.WithDescription("Test tool for sending log notification, and the log level is warn"),
7273
), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
73-
74-
totalProgreddValue := float64(100)
74+
totalProgressValue := float64(100)
7575
startFuncMessage := "start func"
76-
err := requestContext.SendProgressNotification(ctx, float64(0), &totalProgreddValue, &startFuncMessage)
76+
err := requestContext.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage)
7777
if err != nil {
7878
return nil, err
7979
}
@@ -92,7 +92,7 @@ func TestSSEMCPClient(t *testing.T) {
9292
}
9393

9494
startFuncMessage = "end func"
95-
err = requestContext.SendProgressNotification(ctx, float64(100), &totalProgreddValue, &startFuncMessage)
95+
err = requestContext.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage)
9696
if err != nil {
9797
return nil, err
9898
}
@@ -106,6 +106,48 @@ func TestSSEMCPClient(t *testing.T) {
106106
},
107107
}, nil
108108
})
109+
mcpServer.AddPrompt(mcp.Prompt{
110+
Name: "prompt_get_for_server_notification",
111+
Description: "Test prompt",
112+
}, func(ctx context.Context, requestContext server.RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
113+
totalProgressValue := float64(100)
114+
startFuncMessage := "start get prompt"
115+
err := requestContext.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage)
116+
if err != nil {
117+
return nil, err
118+
}
119+
120+
err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{
121+
"filtered_log_message": "will be filtered by log level",
122+
})
123+
if err != nil {
124+
return nil, err
125+
}
126+
err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelError, map[string]any{
127+
"log_message": "log message value",
128+
})
129+
if err != nil {
130+
return nil, err
131+
}
132+
133+
startFuncMessage = "end get prompt"
134+
err = requestContext.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage)
135+
if err != nil {
136+
return nil, err
137+
}
138+
139+
return &mcp.GetPromptResult{
140+
Messages: []mcp.PromptMessage{
141+
{
142+
Role: mcp.RoleAssistant,
143+
Content: mcp.TextContent{
144+
Type: "text",
145+
Text: "prompt value",
146+
},
147+
},
148+
},
149+
}, nil
150+
})
109151

110152
// Initialize
111153
testServer := server.NewTestServer(mcpServer,
@@ -380,6 +422,7 @@ func TestSSEMCPClient(t *testing.T) {
380422
t.Fatalf("Failed to create client: %v", err)
381423
}
382424

425+
notificationNum := 0
383426
var messageNotification *mcp.JSONRPCNotification
384427
progressNotifications := make([]*mcp.JSONRPCNotification, 0)
385428
client.OnNotification(func(notification mcp.JSONRPCNotification) {
@@ -388,6 +431,7 @@ func TestSSEMCPClient(t *testing.T) {
388431
} else if notification.Method == string(mcp.MethodNotificationProgress) {
389432
progressNotifications = append(progressNotifications, &notification)
390433
}
434+
notificationNum += 1
391435
})
392436
defer client.Close()
393437

@@ -434,6 +478,7 @@ func TestSSEMCPClient(t *testing.T) {
434478

435479
time.Sleep(time.Millisecond * 200)
436480

481+
assert.Equal(t, notificationNum, 3)
437482
assert.NotNil(t, messageNotification)
438483
assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage))
439484
assert.Equal(t, messageNotification.Params.AdditionalFields["level"], "error")
@@ -504,4 +549,93 @@ func TestSSEMCPClient(t *testing.T) {
504549

505550
assert.Len(t, notifications, 0)
506551
})
552+
553+
t.Run("GetPrompt for testing log and progress notification", func(t *testing.T) {
554+
client, err := NewSSEMCPClient(testServer.URL + "/sse")
555+
if err != nil {
556+
t.Fatalf("Failed to create client: %v", err)
557+
}
558+
559+
var messageNotification *mcp.JSONRPCNotification
560+
progressNotifications := make([]*mcp.JSONRPCNotification, 0)
561+
notificationNum := 0
562+
client.OnNotification(func(notification mcp.JSONRPCNotification) {
563+
println(notification.Method)
564+
if notification.Method == string(mcp.MethodNotificationMessage) {
565+
messageNotification = &notification
566+
} else if notification.Method == string(mcp.MethodNotificationProgress) {
567+
progressNotifications = append(progressNotifications, &notification)
568+
}
569+
notificationNum += 1
570+
})
571+
defer client.Close()
572+
573+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
574+
defer cancel()
575+
576+
if err := client.Start(ctx); err != nil {
577+
t.Fatalf("Failed to start client: %v", err)
578+
}
579+
580+
// Initialize
581+
initRequest := mcp.InitializeRequest{}
582+
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
583+
initRequest.Params.ClientInfo = mcp.Implementation{
584+
Name: "test-client",
585+
Version: "1.0.0",
586+
}
587+
588+
_, err = client.Initialize(ctx, initRequest)
589+
if err != nil {
590+
t.Fatalf("Failed to initialize: %v", err)
591+
}
592+
593+
setLevelRequest := mcp.SetLevelRequest{}
594+
setLevelRequest.Params.Level = mcp.LoggingLevelWarning
595+
err = client.SetLevel(ctx, setLevelRequest)
596+
if err != nil {
597+
t.Errorf("SetLevel failed: %v", err)
598+
}
599+
600+
request := mcp.GetPromptRequest{}
601+
request.Params.Name = "prompt_get_for_server_notification"
602+
request.Params.Meta = &mcp.Meta{
603+
ProgressToken: "progress_token",
604+
}
605+
606+
result, err := client.GetPrompt(ctx, request)
607+
if err != nil {
608+
t.Fatalf("GetPrompt failed: %v", err)
609+
}
610+
assert.NotNil(t, result)
611+
assert.Len(t, result.Messages, 1)
612+
assert.Equal(t, result.Messages[0].Role, mcp.RoleAssistant)
613+
assert.Equal(t, result.Messages[0].Content.(mcp.TextContent).Type, "text")
614+
assert.Equal(t, result.Messages[0].Content.(mcp.TextContent).Text, "prompt value")
615+
616+
println(fmt.Sprintf("%v", result))
617+
618+
time.Sleep(time.Millisecond * 200)
619+
620+
assert.Equal(t, notificationNum, 3)
621+
assert.NotNil(t, messageNotification)
622+
assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage))
623+
assert.Equal(t, messageNotification.Params.AdditionalFields["level"], "error")
624+
assert.Equal(t, messageNotification.Params.AdditionalFields["data"], map[string]any{
625+
"log_message": "log message value",
626+
})
627+
628+
assert.Len(t, progressNotifications, 2)
629+
assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[0].Method)
630+
assert.Equal(t, "start get prompt", progressNotifications[0].Params.AdditionalFields["message"])
631+
assert.EqualValues(t, 0, progressNotifications[0].Params.AdditionalFields["progress"])
632+
assert.Equal(t, "progress_token", progressNotifications[0].Params.AdditionalFields["progressToken"])
633+
assert.EqualValues(t, 100, progressNotifications[0].Params.AdditionalFields["total"])
634+
635+
assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[1].Method)
636+
assert.Equal(t, "end get prompt", progressNotifications[1].Params.AdditionalFields["message"])
637+
assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["progress"])
638+
assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"])
639+
assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"])
640+
})
507641
}

examples/custom_context/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func makeRequest(ctx context.Context, message, token string) (*response, error)
7979
// using the token from the context.
8080
func handleMakeAuthenticatedRequestTool(
8181
ctx context.Context,
82-
reqContext server.RequestContext,
82+
requestContext server.RequestContext,
8383
request mcp.CallToolRequest,
8484
) (*mcp.CallToolResult, error) {
8585
message, ok := request.GetArguments()["message"].(string)

examples/dynamic_path/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func main() {
1919
mcpServer := server.NewMCPServer("dynamic-path-example", "1.0.0")
2020

2121
// Add a trivial tool for demonstration
22-
mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, reqContext server.RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
22+
mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, requestContext server.RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
2323
return mcp.NewToolResultText(fmt.Sprintf("Echo: %v", req.GetArguments()["message"])), nil
2424
})
2525

examples/everything/main.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ func generateResources() []mcp.Resource {
191191

192192
func handleReadResource(
193193
ctx context.Context,
194-
reqContext server.RequestContext,
194+
requestContext server.RequestContext,
195195
request mcp.ReadResourceRequest,
196196
) ([]mcp.ResourceContents, error) {
197197
return []mcp.ResourceContents{
@@ -205,7 +205,7 @@ func handleReadResource(
205205

206206
func handleResourceTemplate(
207207
ctx context.Context,
208-
reqContext server.RequestContext,
208+
requestContext server.RequestContext,
209209
request mcp.ReadResourceRequest,
210210
) ([]mcp.ResourceContents, error) {
211211
return []mcp.ResourceContents{
@@ -219,7 +219,7 @@ func handleResourceTemplate(
219219

220220
func handleGeneratedResource(
221221
ctx context.Context,
222-
reqContext server.RequestContext,
222+
requestContext server.RequestContext,
223223
request mcp.ReadResourceRequest,
224224
) ([]mcp.ResourceContents, error) {
225225
uri := request.Params.URI
@@ -257,7 +257,7 @@ func handleGeneratedResource(
257257

258258
func handleSimplePrompt(
259259
ctx context.Context,
260-
reqContext server.RequestContext,
260+
requestContext server.RequestContext,
261261
request mcp.GetPromptRequest,
262262
) (*mcp.GetPromptResult, error) {
263263
return &mcp.GetPromptResult{
@@ -276,7 +276,7 @@ func handleSimplePrompt(
276276

277277
func handleComplexPrompt(
278278
ctx context.Context,
279-
reqContext server.RequestContext,
279+
requestContext server.RequestContext,
280280
request mcp.GetPromptRequest,
281281
) (*mcp.GetPromptResult, error) {
282282
arguments := request.Params.Arguments
@@ -315,7 +315,7 @@ func handleComplexPrompt(
315315

316316
func handleEchoTool(
317317
ctx context.Context,
318-
reqContext server.RequestContext,
318+
requestContext server.RequestContext,
319319
request mcp.CallToolRequest,
320320
) (*mcp.CallToolResult, error) {
321321
arguments := request.GetArguments()
@@ -335,7 +335,7 @@ func handleEchoTool(
335335

336336
func handleAddTool(
337337
ctx context.Context,
338-
reqContext server.RequestContext,
338+
requestContext server.RequestContext,
339339
request mcp.CallToolRequest,
340340
) (*mcp.CallToolResult, error) {
341341
arguments := request.GetArguments()
@@ -357,7 +357,7 @@ func handleAddTool(
357357

358358
func handleSendNotification(
359359
ctx context.Context,
360-
reqContext server.RequestContext,
360+
requestContext server.RequestContext,
361361
request mcp.CallToolRequest,
362362
) (*mcp.CallToolResult, error) {
363363

@@ -388,7 +388,7 @@ func handleSendNotification(
388388

389389
func handleLongRunningOperationTool(
390390
ctx context.Context,
391-
reqContext server.RequestContext,
391+
requestContext server.RequestContext,
392392
request mcp.CallToolRequest,
393393
) (*mcp.CallToolResult, error) {
394394
arguments := request.GetArguments()
@@ -454,7 +454,7 @@ func handleLongRunningOperationTool(
454454

455455
func handleGetTinyImageTool(
456456
ctx context.Context,
457-
reqContext server.RequestContext,
457+
requestContext server.RequestContext,
458458
request mcp.CallToolRequest,
459459
) (*mcp.CallToolResult, error) {
460460
return &mcp.CallToolResult{

mcp/prompts.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ type GetPromptRequest struct {
2424
Name string `json:"name"`
2525
// Arguments to use for templating the prompt.
2626
Arguments map[string]string `json:"arguments,omitempty"`
27+
Meta *Meta `json:"_meta,omitempty"`
2728
} `json:"params"`
2829
}
2930

mcptest/mcptest_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func TestServer(t *testing.T) {
5050
}
5151
}
5252

53-
func helloWorldHandler(ctx context.Context, reqContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
53+
func helloWorldHandler(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
5454
// Extract name from request arguments
5555
name, ok := request.GetArguments()["name"].(string)
5656
if !ok {

0 commit comments

Comments
 (0)