Skip to content

Commit 4a76607

Browse files
authored
fix: prevent tools invocation without valid session initialization (#607)
* fix: prevent tools invocation without valid session initialization - Modified InsecureStatefulSessionIdManager to track active sessions using sync.Map - Added session existence validation in addition to format validation - Implemented proper session termination tracking - Added comprehensive regression tests for session validation scenarios - Updated sampling tests to use stateless mode for compatibility This fixes a security issue where tools could be invoked with any well-formatted session ID without going through proper initialization, allowing unauthorized access. * fix: make session termination idempotent to prevent retry failures
1 parent d2b01f6 commit 4a76607

File tree

3 files changed

+326
-9
lines changed

3 files changed

+326
-9
lines changed

server/streamable_http.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,29 +1015,47 @@ func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bo
10151015
return false, nil
10161016
}
10171017

1018-
// InsecureStatefulSessionIdManager generate id with uuid
1019-
// It won't validate the id indeed, so it could be fake.
1018+
// InsecureStatefulSessionIdManager generate id with uuid and tracks active sessions.
1019+
// It validates both format and existence of session IDs.
10201020
// For more secure session id, use a more complex generator, like a JWT.
1021-
type InsecureStatefulSessionIdManager struct{}
1021+
type InsecureStatefulSessionIdManager struct {
1022+
sessions sync.Map
1023+
terminated sync.Map
1024+
}
10221025

10231026
const idPrefix = "mcp-session-"
10241027

10251028
func (s *InsecureStatefulSessionIdManager) Generate() string {
1026-
return idPrefix + uuid.New().String()
1029+
sessionID := idPrefix + uuid.New().String()
1030+
s.sessions.Store(sessionID, true)
1031+
return sessionID
10271032
}
10281033

10291034
func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
1030-
// validate the session id is a valid uuid
10311035
if !strings.HasPrefix(sessionID, idPrefix) {
10321036
return false, fmt.Errorf("invalid session id: %s", sessionID)
10331037
}
10341038
if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil {
10351039
return false, fmt.Errorf("invalid session id: %s", sessionID)
10361040
}
1041+
if _, exists := s.terminated.Load(sessionID); exists {
1042+
return true, nil
1043+
}
1044+
if _, exists := s.sessions.Load(sessionID); !exists {
1045+
return false, fmt.Errorf("session not found: %s", sessionID)
1046+
}
10371047
return false, nil
10381048
}
10391049

10401050
func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
1051+
if _, exists := s.terminated.Load(sessionID); exists {
1052+
return false, nil
1053+
}
1054+
if _, exists := s.sessions.Load(sessionID); !exists {
1055+
return false, nil
1056+
}
1057+
s.terminated.Store(sessionID, true)
1058+
s.sessions.Delete(sessionID)
10411059
return false, nil
10421060
}
10431061

server/streamable_http_sampling_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) {
4545
mcpServer := NewMCPServer("test-server", "1.0.0")
4646
mcpServer.EnableSampling()
4747

48-
httpServer := NewStreamableHTTPServer(mcpServer)
48+
httpServer := NewStreamableHTTPServer(mcpServer, WithStateLess(true))
4949
testServer := httptest.NewServer(httpServer)
5050
defer testServer.Close()
5151

@@ -76,7 +76,7 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) {
7676
},
7777
{
7878
name: "invalid request ID",
79-
sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000",
79+
sessionID: "any-session-id",
8080
body: map[string]any{
8181
"jsonrpc": "2.0",
8282
"id": "invalid-id",
@@ -92,13 +92,13 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) {
9292
},
9393
{
9494
name: "malformed result",
95-
sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000",
95+
sessionID: "any-session-id",
9696
body: map[string]any{
9797
"jsonrpc": "2.0",
9898
"id": 1,
9999
"result": "invalid-result",
100100
},
101-
expectedStatus: http.StatusInternalServerError, // Now correctly returns 500 due to no active session
101+
expectedStatus: http.StatusInternalServerError,
102102
},
103103
}
104104

server/streamable_http_test.go

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,3 +1015,302 @@ func postJSON(url string, bodyObject any) (*http.Response, error) {
10151015
req.Header.Set("Content-Type", "application/json")
10161016
return http.DefaultClient.Do(req)
10171017
}
1018+
1019+
func TestStreamableHTTP_SessionValidation(t *testing.T) {
1020+
mcpServer := NewMCPServer("test-server", "1.0.0")
1021+
mcpServer.AddTool(mcp.NewTool("time",
1022+
mcp.WithDescription("Get the current time")), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1023+
return mcp.NewToolResultText("2024-01-01T00:00:00Z"), nil
1024+
})
1025+
1026+
server := NewTestStreamableHTTPServer(mcpServer)
1027+
defer server.Close()
1028+
1029+
t.Run("Reject tool call with fake session ID", func(t *testing.T) {
1030+
toolCallRequest := map[string]any{
1031+
"jsonrpc": "2.0",
1032+
"id": 1,
1033+
"method": "tools/call",
1034+
"params": map[string]any{
1035+
"name": "time",
1036+
},
1037+
}
1038+
1039+
jsonBody, _ := json.Marshal(toolCallRequest)
1040+
req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody))
1041+
req.Header.Set("Content-Type", "application/json")
1042+
req.Header.Set(HeaderKeySessionID, "mcp-session-ffffffff-ffff-ffff-ffff-ffffffffffff")
1043+
1044+
resp, err := server.Client().Do(req)
1045+
if err != nil {
1046+
t.Fatalf("Failed to send request: %v", err)
1047+
}
1048+
defer resp.Body.Close()
1049+
1050+
if resp.StatusCode != http.StatusBadRequest {
1051+
t.Errorf("Expected status 400, got %d", resp.StatusCode)
1052+
}
1053+
1054+
body, _ := io.ReadAll(resp.Body)
1055+
if !strings.Contains(string(body), "Invalid session ID") {
1056+
t.Errorf("Expected 'Invalid session ID' error, got: %s", string(body))
1057+
}
1058+
})
1059+
1060+
t.Run("Reject tool call with malformed session ID", func(t *testing.T) {
1061+
toolCallRequest := map[string]any{
1062+
"jsonrpc": "2.0",
1063+
"id": 1,
1064+
"method": "tools/call",
1065+
"params": map[string]any{
1066+
"name": "time",
1067+
},
1068+
}
1069+
1070+
jsonBody, _ := json.Marshal(toolCallRequest)
1071+
req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody))
1072+
req.Header.Set("Content-Type", "application/json")
1073+
req.Header.Set(HeaderKeySessionID, "invalid-session-id")
1074+
1075+
resp, err := server.Client().Do(req)
1076+
if err != nil {
1077+
t.Fatalf("Failed to send request: %v", err)
1078+
}
1079+
defer resp.Body.Close()
1080+
1081+
if resp.StatusCode != http.StatusBadRequest {
1082+
t.Errorf("Expected status 400, got %d", resp.StatusCode)
1083+
}
1084+
1085+
body, _ := io.ReadAll(resp.Body)
1086+
if !strings.Contains(string(body), "Invalid session ID") {
1087+
t.Errorf("Expected 'Invalid session ID' error, got: %s", string(body))
1088+
}
1089+
})
1090+
1091+
t.Run("Accept tool call with valid session ID from initialize", func(t *testing.T) {
1092+
jsonBody, _ := json.Marshal(initRequest)
1093+
req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody))
1094+
req.Header.Set("Content-Type", "application/json")
1095+
1096+
resp, err := server.Client().Do(req)
1097+
if err != nil {
1098+
t.Fatalf("Failed to initialize: %v", err)
1099+
}
1100+
defer resp.Body.Close()
1101+
1102+
sessionID := resp.Header.Get(HeaderKeySessionID)
1103+
if sessionID == "" {
1104+
t.Fatal("Expected session ID in response header")
1105+
}
1106+
1107+
toolCallRequest := map[string]any{
1108+
"jsonrpc": "2.0",
1109+
"id": 2,
1110+
"method": "tools/call",
1111+
"params": map[string]any{
1112+
"name": "time",
1113+
},
1114+
}
1115+
1116+
jsonBody, _ = json.Marshal(toolCallRequest)
1117+
req, _ = http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody))
1118+
req.Header.Set("Content-Type", "application/json")
1119+
req.Header.Set(HeaderKeySessionID, sessionID)
1120+
1121+
resp, err = server.Client().Do(req)
1122+
if err != nil {
1123+
t.Fatalf("Failed to call tool: %v", err)
1124+
}
1125+
defer resp.Body.Close()
1126+
1127+
if resp.StatusCode != http.StatusOK {
1128+
body, _ := io.ReadAll(resp.Body)
1129+
t.Errorf("Expected status 200, got %d. Body: %s", resp.StatusCode, string(body))
1130+
}
1131+
})
1132+
1133+
t.Run("Reject tool call with terminated session ID", func(t *testing.T) {
1134+
jsonBody, _ := json.Marshal(initRequest)
1135+
req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody))
1136+
req.Header.Set("Content-Type", "application/json")
1137+
1138+
resp, err := server.Client().Do(req)
1139+
if err != nil {
1140+
t.Fatalf("Failed to initialize: %v", err)
1141+
}
1142+
resp.Body.Close()
1143+
1144+
sessionID := resp.Header.Get(HeaderKeySessionID)
1145+
if sessionID == "" {
1146+
t.Fatal("Expected session ID in response header")
1147+
}
1148+
1149+
req, _ = http.NewRequest(http.MethodDelete, server.URL, nil)
1150+
req.Header.Set(HeaderKeySessionID, sessionID)
1151+
1152+
resp, err = server.Client().Do(req)
1153+
if err != nil {
1154+
t.Fatalf("Failed to terminate session: %v", err)
1155+
}
1156+
resp.Body.Close()
1157+
1158+
if resp.StatusCode != http.StatusOK {
1159+
t.Errorf("Expected status 200 for termination, got %d", resp.StatusCode)
1160+
}
1161+
1162+
toolCallRequest := map[string]any{
1163+
"jsonrpc": "2.0",
1164+
"id": 2,
1165+
"method": "tools/call",
1166+
"params": map[string]any{
1167+
"name": "time",
1168+
},
1169+
}
1170+
1171+
jsonBody, _ = json.Marshal(toolCallRequest)
1172+
req, _ = http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody))
1173+
req.Header.Set("Content-Type", "application/json")
1174+
req.Header.Set(HeaderKeySessionID, sessionID)
1175+
1176+
resp, err = server.Client().Do(req)
1177+
if err != nil {
1178+
t.Fatalf("Failed to send request: %v", err)
1179+
}
1180+
defer resp.Body.Close()
1181+
1182+
if resp.StatusCode != http.StatusNotFound {
1183+
body, _ := io.ReadAll(resp.Body)
1184+
t.Errorf("Expected status 404, got %d. Body: %s", resp.StatusCode, string(body))
1185+
}
1186+
})
1187+
}
1188+
1189+
func TestInsecureStatefulSessionIdManager(t *testing.T) {
1190+
t.Run("Generate creates valid session ID", func(t *testing.T) {
1191+
manager := &InsecureStatefulSessionIdManager{}
1192+
sessionID := manager.Generate()
1193+
1194+
if !strings.HasPrefix(sessionID, idPrefix) {
1195+
t.Errorf("Expected session ID to start with %s, got %s", idPrefix, sessionID)
1196+
}
1197+
1198+
isTerminated, err := manager.Validate(sessionID)
1199+
if err != nil {
1200+
t.Errorf("Expected valid session ID, got error: %v", err)
1201+
}
1202+
if isTerminated {
1203+
t.Error("Expected session to not be terminated")
1204+
}
1205+
})
1206+
1207+
t.Run("Validate rejects non-existent session ID", func(t *testing.T) {
1208+
manager := &InsecureStatefulSessionIdManager{}
1209+
fakeSessionID := "mcp-session-ffffffff-ffff-ffff-ffff-ffffffffffff"
1210+
1211+
isTerminated, err := manager.Validate(fakeSessionID)
1212+
if err == nil {
1213+
t.Error("Expected error for non-existent session ID")
1214+
}
1215+
if isTerminated {
1216+
t.Error("Expected isTerminated to be false for invalid session")
1217+
}
1218+
if !strings.Contains(err.Error(), "session not found") {
1219+
t.Errorf("Expected 'session not found' error, got: %v", err)
1220+
}
1221+
})
1222+
1223+
t.Run("Validate rejects malformed session ID", func(t *testing.T) {
1224+
manager := &InsecureStatefulSessionIdManager{}
1225+
invalidSessionID := "invalid-session-id"
1226+
1227+
_, err := manager.Validate(invalidSessionID)
1228+
if err == nil {
1229+
t.Error("Expected error for malformed session ID")
1230+
}
1231+
if !strings.Contains(err.Error(), "invalid session id") {
1232+
t.Errorf("Expected 'invalid session id' error, got: %v", err)
1233+
}
1234+
})
1235+
1236+
t.Run("Terminate marks session as terminated", func(t *testing.T) {
1237+
manager := &InsecureStatefulSessionIdManager{}
1238+
sessionID := manager.Generate()
1239+
1240+
isNotAllowed, err := manager.Terminate(sessionID)
1241+
if err != nil {
1242+
t.Errorf("Expected no error on termination, got: %v", err)
1243+
}
1244+
if isNotAllowed {
1245+
t.Error("Expected termination to be allowed")
1246+
}
1247+
1248+
isTerminated, err := manager.Validate(sessionID)
1249+
if !isTerminated {
1250+
t.Error("Expected session to be marked as terminated")
1251+
}
1252+
if err != nil {
1253+
t.Errorf("Expected no error for terminated session, got: %v", err)
1254+
}
1255+
})
1256+
1257+
t.Run("Terminate is idempotent for non-existent session ID", func(t *testing.T) {
1258+
manager := &InsecureStatefulSessionIdManager{}
1259+
fakeSessionID := "mcp-session-ffffffff-ffff-ffff-ffff-ffffffffffff"
1260+
1261+
isNotAllowed, err := manager.Terminate(fakeSessionID)
1262+
if err != nil {
1263+
t.Errorf("Expected no error when terminating non-existent session, got: %v", err)
1264+
}
1265+
if isNotAllowed {
1266+
t.Error("Expected isNotAllowed to be false")
1267+
}
1268+
})
1269+
1270+
t.Run("Terminate is idempotent for already-terminated session", func(t *testing.T) {
1271+
manager := &InsecureStatefulSessionIdManager{}
1272+
sessionID := manager.Generate()
1273+
1274+
isNotAllowed, err := manager.Terminate(sessionID)
1275+
if err != nil {
1276+
t.Errorf("Expected no error on first termination, got: %v", err)
1277+
}
1278+
if isNotAllowed {
1279+
t.Error("Expected termination to be allowed")
1280+
}
1281+
1282+
isNotAllowed, err = manager.Terminate(sessionID)
1283+
if err != nil {
1284+
t.Errorf("Expected no error on second termination (idempotent), got: %v", err)
1285+
}
1286+
if isNotAllowed {
1287+
t.Error("Expected termination to be allowed on retry")
1288+
}
1289+
})
1290+
1291+
t.Run("Concurrent generate and validate", func(t *testing.T) {
1292+
manager := &InsecureStatefulSessionIdManager{}
1293+
var wg sync.WaitGroup
1294+
sessionIDs := make([]string, 100)
1295+
1296+
for i := 0; i < 100; i++ {
1297+
wg.Add(1)
1298+
go func(index int) {
1299+
defer wg.Done()
1300+
sessionIDs[index] = manager.Generate()
1301+
}(i)
1302+
}
1303+
1304+
wg.Wait()
1305+
1306+
for _, sessionID := range sessionIDs {
1307+
isTerminated, err := manager.Validate(sessionID)
1308+
if err != nil {
1309+
t.Errorf("Expected valid session ID %s, got error: %v", sessionID, err)
1310+
}
1311+
if isTerminated {
1312+
t.Errorf("Expected session %s to not be terminated", sessionID)
1313+
}
1314+
}
1315+
})
1316+
}

0 commit comments

Comments
 (0)