@@ -1015,3 +1015,302 @@ func postJSON(url string, bodyObject any) (*http.Response, error) {
10151015 req .Header .Set ("Content-Type" , "application/json" )
10161016 return http .DefaultClient .Do (req )
10171017}
1018+
1019+ func TestStreamableHTTP_SessionValidation (t * testing.T ) {
1020+ mcpServer := NewMCPServer ("test-server" , "1.0.0" )
1021+ mcpServer .AddTool (mcp .NewTool ("time" ,
1022+ mcp .WithDescription ("Get the current time" )), func (ctx context.Context , req mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
1023+ return mcp .NewToolResultText ("2024-01-01T00:00:00Z" ), nil
1024+ })
1025+
1026+ server := NewTestStreamableHTTPServer (mcpServer )
1027+ defer server .Close ()
1028+
1029+ t .Run ("Reject tool call with fake session ID" , func (t * testing.T ) {
1030+ toolCallRequest := map [string ]any {
1031+ "jsonrpc" : "2.0" ,
1032+ "id" : 1 ,
1033+ "method" : "tools/call" ,
1034+ "params" : map [string ]any {
1035+ "name" : "time" ,
1036+ },
1037+ }
1038+
1039+ jsonBody , _ := json .Marshal (toolCallRequest )
1040+ req , _ := http .NewRequest (http .MethodPost , server .URL , bytes .NewBuffer (jsonBody ))
1041+ req .Header .Set ("Content-Type" , "application/json" )
1042+ req .Header .Set (HeaderKeySessionID , "mcp-session-ffffffff-ffff-ffff-ffff-ffffffffffff" )
1043+
1044+ resp , err := server .Client ().Do (req )
1045+ if err != nil {
1046+ t .Fatalf ("Failed to send request: %v" , err )
1047+ }
1048+ defer resp .Body .Close ()
1049+
1050+ if resp .StatusCode != http .StatusBadRequest {
1051+ t .Errorf ("Expected status 400, got %d" , resp .StatusCode )
1052+ }
1053+
1054+ body , _ := io .ReadAll (resp .Body )
1055+ if ! strings .Contains (string (body ), "Invalid session ID" ) {
1056+ t .Errorf ("Expected 'Invalid session ID' error, got: %s" , string (body ))
1057+ }
1058+ })
1059+
1060+ t .Run ("Reject tool call with malformed session ID" , func (t * testing.T ) {
1061+ toolCallRequest := map [string ]any {
1062+ "jsonrpc" : "2.0" ,
1063+ "id" : 1 ,
1064+ "method" : "tools/call" ,
1065+ "params" : map [string ]any {
1066+ "name" : "time" ,
1067+ },
1068+ }
1069+
1070+ jsonBody , _ := json .Marshal (toolCallRequest )
1071+ req , _ := http .NewRequest (http .MethodPost , server .URL , bytes .NewBuffer (jsonBody ))
1072+ req .Header .Set ("Content-Type" , "application/json" )
1073+ req .Header .Set (HeaderKeySessionID , "invalid-session-id" )
1074+
1075+ resp , err := server .Client ().Do (req )
1076+ if err != nil {
1077+ t .Fatalf ("Failed to send request: %v" , err )
1078+ }
1079+ defer resp .Body .Close ()
1080+
1081+ if resp .StatusCode != http .StatusBadRequest {
1082+ t .Errorf ("Expected status 400, got %d" , resp .StatusCode )
1083+ }
1084+
1085+ body , _ := io .ReadAll (resp .Body )
1086+ if ! strings .Contains (string (body ), "Invalid session ID" ) {
1087+ t .Errorf ("Expected 'Invalid session ID' error, got: %s" , string (body ))
1088+ }
1089+ })
1090+
1091+ t .Run ("Accept tool call with valid session ID from initialize" , func (t * testing.T ) {
1092+ jsonBody , _ := json .Marshal (initRequest )
1093+ req , _ := http .NewRequest (http .MethodPost , server .URL , bytes .NewBuffer (jsonBody ))
1094+ req .Header .Set ("Content-Type" , "application/json" )
1095+
1096+ resp , err := server .Client ().Do (req )
1097+ if err != nil {
1098+ t .Fatalf ("Failed to initialize: %v" , err )
1099+ }
1100+ defer resp .Body .Close ()
1101+
1102+ sessionID := resp .Header .Get (HeaderKeySessionID )
1103+ if sessionID == "" {
1104+ t .Fatal ("Expected session ID in response header" )
1105+ }
1106+
1107+ toolCallRequest := map [string ]any {
1108+ "jsonrpc" : "2.0" ,
1109+ "id" : 2 ,
1110+ "method" : "tools/call" ,
1111+ "params" : map [string ]any {
1112+ "name" : "time" ,
1113+ },
1114+ }
1115+
1116+ jsonBody , _ = json .Marshal (toolCallRequest )
1117+ req , _ = http .NewRequest (http .MethodPost , server .URL , bytes .NewBuffer (jsonBody ))
1118+ req .Header .Set ("Content-Type" , "application/json" )
1119+ req .Header .Set (HeaderKeySessionID , sessionID )
1120+
1121+ resp , err = server .Client ().Do (req )
1122+ if err != nil {
1123+ t .Fatalf ("Failed to call tool: %v" , err )
1124+ }
1125+ defer resp .Body .Close ()
1126+
1127+ if resp .StatusCode != http .StatusOK {
1128+ body , _ := io .ReadAll (resp .Body )
1129+ t .Errorf ("Expected status 200, got %d. Body: %s" , resp .StatusCode , string (body ))
1130+ }
1131+ })
1132+
1133+ t .Run ("Reject tool call with terminated session ID" , func (t * testing.T ) {
1134+ jsonBody , _ := json .Marshal (initRequest )
1135+ req , _ := http .NewRequest (http .MethodPost , server .URL , bytes .NewBuffer (jsonBody ))
1136+ req .Header .Set ("Content-Type" , "application/json" )
1137+
1138+ resp , err := server .Client ().Do (req )
1139+ if err != nil {
1140+ t .Fatalf ("Failed to initialize: %v" , err )
1141+ }
1142+ resp .Body .Close ()
1143+
1144+ sessionID := resp .Header .Get (HeaderKeySessionID )
1145+ if sessionID == "" {
1146+ t .Fatal ("Expected session ID in response header" )
1147+ }
1148+
1149+ req , _ = http .NewRequest (http .MethodDelete , server .URL , nil )
1150+ req .Header .Set (HeaderKeySessionID , sessionID )
1151+
1152+ resp , err = server .Client ().Do (req )
1153+ if err != nil {
1154+ t .Fatalf ("Failed to terminate session: %v" , err )
1155+ }
1156+ resp .Body .Close ()
1157+
1158+ if resp .StatusCode != http .StatusOK {
1159+ t .Errorf ("Expected status 200 for termination, got %d" , resp .StatusCode )
1160+ }
1161+
1162+ toolCallRequest := map [string ]any {
1163+ "jsonrpc" : "2.0" ,
1164+ "id" : 2 ,
1165+ "method" : "tools/call" ,
1166+ "params" : map [string ]any {
1167+ "name" : "time" ,
1168+ },
1169+ }
1170+
1171+ jsonBody , _ = json .Marshal (toolCallRequest )
1172+ req , _ = http .NewRequest (http .MethodPost , server .URL , bytes .NewBuffer (jsonBody ))
1173+ req .Header .Set ("Content-Type" , "application/json" )
1174+ req .Header .Set (HeaderKeySessionID , sessionID )
1175+
1176+ resp , err = server .Client ().Do (req )
1177+ if err != nil {
1178+ t .Fatalf ("Failed to send request: %v" , err )
1179+ }
1180+ defer resp .Body .Close ()
1181+
1182+ if resp .StatusCode != http .StatusNotFound {
1183+ body , _ := io .ReadAll (resp .Body )
1184+ t .Errorf ("Expected status 404, got %d. Body: %s" , resp .StatusCode , string (body ))
1185+ }
1186+ })
1187+ }
1188+
1189+ func TestInsecureStatefulSessionIdManager (t * testing.T ) {
1190+ t .Run ("Generate creates valid session ID" , func (t * testing.T ) {
1191+ manager := & InsecureStatefulSessionIdManager {}
1192+ sessionID := manager .Generate ()
1193+
1194+ if ! strings .HasPrefix (sessionID , idPrefix ) {
1195+ t .Errorf ("Expected session ID to start with %s, got %s" , idPrefix , sessionID )
1196+ }
1197+
1198+ isTerminated , err := manager .Validate (sessionID )
1199+ if err != nil {
1200+ t .Errorf ("Expected valid session ID, got error: %v" , err )
1201+ }
1202+ if isTerminated {
1203+ t .Error ("Expected session to not be terminated" )
1204+ }
1205+ })
1206+
1207+ t .Run ("Validate rejects non-existent session ID" , func (t * testing.T ) {
1208+ manager := & InsecureStatefulSessionIdManager {}
1209+ fakeSessionID := "mcp-session-ffffffff-ffff-ffff-ffff-ffffffffffff"
1210+
1211+ isTerminated , err := manager .Validate (fakeSessionID )
1212+ if err == nil {
1213+ t .Error ("Expected error for non-existent session ID" )
1214+ }
1215+ if isTerminated {
1216+ t .Error ("Expected isTerminated to be false for invalid session" )
1217+ }
1218+ if ! strings .Contains (err .Error (), "session not found" ) {
1219+ t .Errorf ("Expected 'session not found' error, got: %v" , err )
1220+ }
1221+ })
1222+
1223+ t .Run ("Validate rejects malformed session ID" , func (t * testing.T ) {
1224+ manager := & InsecureStatefulSessionIdManager {}
1225+ invalidSessionID := "invalid-session-id"
1226+
1227+ _ , err := manager .Validate (invalidSessionID )
1228+ if err == nil {
1229+ t .Error ("Expected error for malformed session ID" )
1230+ }
1231+ if ! strings .Contains (err .Error (), "invalid session id" ) {
1232+ t .Errorf ("Expected 'invalid session id' error, got: %v" , err )
1233+ }
1234+ })
1235+
1236+ t .Run ("Terminate marks session as terminated" , func (t * testing.T ) {
1237+ manager := & InsecureStatefulSessionIdManager {}
1238+ sessionID := manager .Generate ()
1239+
1240+ isNotAllowed , err := manager .Terminate (sessionID )
1241+ if err != nil {
1242+ t .Errorf ("Expected no error on termination, got: %v" , err )
1243+ }
1244+ if isNotAllowed {
1245+ t .Error ("Expected termination to be allowed" )
1246+ }
1247+
1248+ isTerminated , err := manager .Validate (sessionID )
1249+ if ! isTerminated {
1250+ t .Error ("Expected session to be marked as terminated" )
1251+ }
1252+ if err != nil {
1253+ t .Errorf ("Expected no error for terminated session, got: %v" , err )
1254+ }
1255+ })
1256+
1257+ t .Run ("Terminate is idempotent for non-existent session ID" , func (t * testing.T ) {
1258+ manager := & InsecureStatefulSessionIdManager {}
1259+ fakeSessionID := "mcp-session-ffffffff-ffff-ffff-ffff-ffffffffffff"
1260+
1261+ isNotAllowed , err := manager .Terminate (fakeSessionID )
1262+ if err != nil {
1263+ t .Errorf ("Expected no error when terminating non-existent session, got: %v" , err )
1264+ }
1265+ if isNotAllowed {
1266+ t .Error ("Expected isNotAllowed to be false" )
1267+ }
1268+ })
1269+
1270+ t .Run ("Terminate is idempotent for already-terminated session" , func (t * testing.T ) {
1271+ manager := & InsecureStatefulSessionIdManager {}
1272+ sessionID := manager .Generate ()
1273+
1274+ isNotAllowed , err := manager .Terminate (sessionID )
1275+ if err != nil {
1276+ t .Errorf ("Expected no error on first termination, got: %v" , err )
1277+ }
1278+ if isNotAllowed {
1279+ t .Error ("Expected termination to be allowed" )
1280+ }
1281+
1282+ isNotAllowed , err = manager .Terminate (sessionID )
1283+ if err != nil {
1284+ t .Errorf ("Expected no error on second termination (idempotent), got: %v" , err )
1285+ }
1286+ if isNotAllowed {
1287+ t .Error ("Expected termination to be allowed on retry" )
1288+ }
1289+ })
1290+
1291+ t .Run ("Concurrent generate and validate" , func (t * testing.T ) {
1292+ manager := & InsecureStatefulSessionIdManager {}
1293+ var wg sync.WaitGroup
1294+ sessionIDs := make ([]string , 100 )
1295+
1296+ for i := 0 ; i < 100 ; i ++ {
1297+ wg .Add (1 )
1298+ go func (index int ) {
1299+ defer wg .Done ()
1300+ sessionIDs [index ] = manager .Generate ()
1301+ }(i )
1302+ }
1303+
1304+ wg .Wait ()
1305+
1306+ for _ , sessionID := range sessionIDs {
1307+ isTerminated , err := manager .Validate (sessionID )
1308+ if err != nil {
1309+ t .Errorf ("Expected valid session ID %s, got error: %v" , sessionID , err )
1310+ }
1311+ if isTerminated {
1312+ t .Errorf ("Expected session %s to not be terminated" , sessionID )
1313+ }
1314+ }
1315+ })
1316+ }
0 commit comments