Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 47 additions & 30 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from starlette.requests import Request
from starlette.requests import ClientDisconnect, Request
from starlette.responses import Response
from starlette.types import Receive, Scope, Send

Expand Down Expand Up @@ -429,6 +429,37 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se
return False
return True

def _validate_init_session(self, request: Request) -> bool:
"""Check if an initialization request has a valid session ID."""
if not self.mcp_session_id:
return True
request_session_id = self._get_session_id(request)
if request_session_id and request_session_id != self.mcp_session_id:
return False
return True

async def _parse_jsonrpc_body(
self, body: bytes, scope: Scope, receive: Receive, send: Send
) -> JSONRPCMessage | None:
"""Parse request body into a JSON-RPC message, sending error responses on failure."""
try:
raw_message = pydantic_core.from_json(body)
except ValueError as e:
response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR)
await response(scope, receive, send)
return None

try:
return jsonrpc_message_adapter.validate_python(raw_message, by_name=False)
except ValidationError as e: # pragma: no cover
response = self._create_error_response(
f"Validation error: {str(e)}",
HTTPStatus.BAD_REQUEST,
INVALID_PARAMS,
)
await response(scope, receive, send)
return None

async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None:
"""Handle POST requests containing JSON-RPC messages."""
writer = self._read_stream_writer
Expand All @@ -451,41 +482,21 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
# Parse the body - only read it once
body = await request.body()

try:
raw_message = pydantic_core.from_json(body)
except ValueError as e:
response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR)
await response(scope, receive, send)
return

try:
message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False)
except ValidationError as e: # pragma: no cover
response = self._create_error_response(
f"Validation error: {str(e)}",
HTTPStatus.BAD_REQUEST,
INVALID_PARAMS,
)
await response(scope, receive, send)
message = await self._parse_jsonrpc_body(body, scope, receive, send)
if message is None:
return

# Check if this is an initialization request
is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize"

if is_initialization_request:
# Check if the server already has an established session
if self.mcp_session_id:
# Check if request has a session ID
request_session_id = self._get_session_id(request)

# If request has a session ID but doesn't match, return 404
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
response = self._create_error_response(
"Not Found: Invalid or expired session ID",
HTTPStatus.NOT_FOUND,
)
await response(scope, receive, send)
return
if not self._validate_init_session(request): # pragma: no cover
response = self._create_error_response(
"Not Found: Invalid or expired session ID",
HTTPStatus.NOT_FOUND,
)
await response(scope, receive, send)
return
elif not await self._validate_request_headers(request, send): # pragma: no cover
return

Expand Down Expand Up @@ -626,6 +637,12 @@ async def sse_writer(): # pragma: lax no cover
finally:
await sse_stream_reader.aclose()

except ClientDisconnect: # pragma: no cover
logger.info("Client disconnected during POST request")
if writer:
await writer.send(Exception("Client disconnected"))
return

except Exception as err: # pragma: no cover
logger.exception("Error handling POST request")
response = self._create_error_response(
Expand Down
Loading