@@ -2,6 +2,7 @@ package client
22
33import (
44 "context"
5+ "fmt"
56 "github.com/stretchr/testify/assert"
67 "net/http"
78 "testing"
@@ -70,10 +71,9 @@ func TestSSEMCPClient(t *testing.T) {
7071 "test-tool-for-sending-notification" ,
7172 mcp .WithDescription ("Test tool for sending log notification, and the log level is warn" ),
7273 ), func (ctx context.Context , requestContext server.RequestContext , request mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
73-
74- totalProgreddValue := float64 (100 )
74+ totalProgressValue := float64 (100 )
7575 startFuncMessage := "start func"
76- err := requestContext .SendProgressNotification (ctx , float64 (0 ), & totalProgreddValue , & startFuncMessage )
76+ err := requestContext .SendProgressNotification (ctx , float64 (0 ), & totalProgressValue , & startFuncMessage )
7777 if err != nil {
7878 return nil , err
7979 }
@@ -92,7 +92,7 @@ func TestSSEMCPClient(t *testing.T) {
9292 }
9393
9494 startFuncMessage = "end func"
95- err = requestContext .SendProgressNotification (ctx , float64 (100 ), & totalProgreddValue , & startFuncMessage )
95+ err = requestContext .SendProgressNotification (ctx , float64 (100 ), & totalProgressValue , & startFuncMessage )
9696 if err != nil {
9797 return nil , err
9898 }
@@ -106,6 +106,48 @@ func TestSSEMCPClient(t *testing.T) {
106106 },
107107 }, nil
108108 })
109+ mcpServer .AddPrompt (mcp.Prompt {
110+ Name : "prompt_get_for_server_notification" ,
111+ Description : "Test prompt" ,
112+ }, func (ctx context.Context , requestContext server.RequestContext , req mcp.GetPromptRequest ) (* mcp.GetPromptResult , error ) {
113+ totalProgressValue := float64 (100 )
114+ startFuncMessage := "start get prompt"
115+ err := requestContext .SendProgressNotification (ctx , float64 (0 ), & totalProgressValue , & startFuncMessage )
116+ if err != nil {
117+ return nil , err
118+ }
119+
120+ err = requestContext .SendLoggingNotification (ctx , mcp .LoggingLevelInfo , map [string ]any {
121+ "filtered_log_message" : "will be filtered by log level" ,
122+ })
123+ if err != nil {
124+ return nil , err
125+ }
126+ err = requestContext .SendLoggingNotification (ctx , mcp .LoggingLevelError , map [string ]any {
127+ "log_message" : "log message value" ,
128+ })
129+ if err != nil {
130+ return nil , err
131+ }
132+
133+ startFuncMessage = "end get prompt"
134+ err = requestContext .SendProgressNotification (ctx , float64 (100 ), & totalProgressValue , & startFuncMessage )
135+ if err != nil {
136+ return nil , err
137+ }
138+
139+ return & mcp.GetPromptResult {
140+ Messages : []mcp.PromptMessage {
141+ {
142+ Role : mcp .RoleAssistant ,
143+ Content : mcp.TextContent {
144+ Type : "text" ,
145+ Text : "prompt value" ,
146+ },
147+ },
148+ },
149+ }, nil
150+ })
109151
110152 // Initialize
111153 testServer := server .NewTestServer (mcpServer ,
@@ -380,6 +422,7 @@ func TestSSEMCPClient(t *testing.T) {
380422 t .Fatalf ("Failed to create client: %v" , err )
381423 }
382424
425+ notificationNum := 0
383426 var messageNotification * mcp.JSONRPCNotification
384427 progressNotifications := make ([]* mcp.JSONRPCNotification , 0 )
385428 client .OnNotification (func (notification mcp.JSONRPCNotification ) {
@@ -388,6 +431,7 @@ func TestSSEMCPClient(t *testing.T) {
388431 } else if notification .Method == string (mcp .MethodNotificationProgress ) {
389432 progressNotifications = append (progressNotifications , & notification )
390433 }
434+ notificationNum += 1
391435 })
392436 defer client .Close ()
393437
@@ -434,6 +478,7 @@ func TestSSEMCPClient(t *testing.T) {
434478
435479 time .Sleep (time .Millisecond * 200 )
436480
481+ assert .Equal (t , notificationNum , 3 )
437482 assert .NotNil (t , messageNotification )
438483 assert .Equal (t , messageNotification .Method , string (mcp .MethodNotificationMessage ))
439484 assert .Equal (t , messageNotification .Params .AdditionalFields ["level" ], "error" )
@@ -504,4 +549,93 @@ func TestSSEMCPClient(t *testing.T) {
504549
505550 assert .Len (t , notifications , 0 )
506551 })
552+
553+ t .Run ("GetPrompt for testing log and progress notification" , func (t * testing.T ) {
554+ client , err := NewSSEMCPClient (testServer .URL + "/sse" )
555+ if err != nil {
556+ t .Fatalf ("Failed to create client: %v" , err )
557+ }
558+
559+ var messageNotification * mcp.JSONRPCNotification
560+ progressNotifications := make ([]* mcp.JSONRPCNotification , 0 )
561+ notificationNum := 0
562+ client .OnNotification (func (notification mcp.JSONRPCNotification ) {
563+ println (notification .Method )
564+ if notification .Method == string (mcp .MethodNotificationMessage ) {
565+ messageNotification = & notification
566+ } else if notification .Method == string (mcp .MethodNotificationProgress ) {
567+ progressNotifications = append (progressNotifications , & notification )
568+ }
569+ notificationNum += 1
570+ })
571+ defer client .Close ()
572+
573+ ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
574+ defer cancel ()
575+
576+ if err := client .Start (ctx ); err != nil {
577+ t .Fatalf ("Failed to start client: %v" , err )
578+ }
579+
580+ // Initialize
581+ initRequest := mcp.InitializeRequest {}
582+ initRequest .Params .ProtocolVersion = mcp .LATEST_PROTOCOL_VERSION
583+ initRequest .Params .ClientInfo = mcp.Implementation {
584+ Name : "test-client" ,
585+ Version : "1.0.0" ,
586+ }
587+
588+ _ , err = client .Initialize (ctx , initRequest )
589+ if err != nil {
590+ t .Fatalf ("Failed to initialize: %v" , err )
591+ }
592+
593+ setLevelRequest := mcp.SetLevelRequest {}
594+ setLevelRequest .Params .Level = mcp .LoggingLevelWarning
595+ err = client .SetLevel (ctx , setLevelRequest )
596+ if err != nil {
597+ t .Errorf ("SetLevel failed: %v" , err )
598+ }
599+
600+ request := mcp.GetPromptRequest {}
601+ request .Params .Name = "prompt_get_for_server_notification"
602+ request .Params .Meta = & mcp.Meta {
603+ ProgressToken : "progress_token" ,
604+ }
605+
606+ result , err := client .GetPrompt (ctx , request )
607+ if err != nil {
608+ t .Fatalf ("GetPrompt failed: %v" , err )
609+ }
610+ assert .NotNil (t , result )
611+ assert .Len (t , result .Messages , 1 )
612+ assert .Equal (t , result .Messages [0 ].Role , mcp .RoleAssistant )
613+ assert .Equal (t , result .Messages [0 ].Content .(mcp.TextContent ).Type , "text" )
614+ assert .Equal (t , result .Messages [0 ].Content .(mcp.TextContent ).Text , "prompt value" )
615+
616+ println (fmt .Sprintf ("%v" , result ))
617+
618+ time .Sleep (time .Millisecond * 200 )
619+
620+ assert .Equal (t , notificationNum , 3 )
621+ assert .NotNil (t , messageNotification )
622+ assert .Equal (t , messageNotification .Method , string (mcp .MethodNotificationMessage ))
623+ assert .Equal (t , messageNotification .Params .AdditionalFields ["level" ], "error" )
624+ assert .Equal (t , messageNotification .Params .AdditionalFields ["data" ], map [string ]any {
625+ "log_message" : "log message value" ,
626+ })
627+
628+ assert .Len (t , progressNotifications , 2 )
629+ assert .Equal (t , string (mcp .MethodNotificationProgress ), progressNotifications [0 ].Method )
630+ assert .Equal (t , "start get prompt" , progressNotifications [0 ].Params .AdditionalFields ["message" ])
631+ assert .EqualValues (t , 0 , progressNotifications [0 ].Params .AdditionalFields ["progress" ])
632+ assert .Equal (t , "progress_token" , progressNotifications [0 ].Params .AdditionalFields ["progressToken" ])
633+ assert .EqualValues (t , 100 , progressNotifications [0 ].Params .AdditionalFields ["total" ])
634+
635+ assert .Equal (t , string (mcp .MethodNotificationProgress ), progressNotifications [1 ].Method )
636+ assert .Equal (t , "end get prompt" , progressNotifications [1 ].Params .AdditionalFields ["message" ])
637+ assert .EqualValues (t , 100 , progressNotifications [1 ].Params .AdditionalFields ["progress" ])
638+ assert .Equal (t , "progress_token" , progressNotifications [1 ].Params .AdditionalFields ["progressToken" ])
639+ assert .EqualValues (t , 100 , progressNotifications [1 ].Params .AdditionalFields ["total" ])
640+ })
507641}
0 commit comments