diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index a7d03f87c..aa06134ed 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,23 +1,27 @@ import logging -from typing import Any, Protocol, overload +from typing import Any, Protocol, TypeVar, overload import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl, TypeAdapter +from pydantic import AnyUrl, BaseModel, TypeAdapter from typing_extensions import deprecated import mcp.types as types from mcp.client.experimental import ExperimentalClientFeatures from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared.context import RequestContext +from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder +from mcp.shared.session import BaseSession, MessageMetadata, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") logger = logging.getLogger("client") +# TypeVar for generic result type in send_request (bound to BaseModel like in BaseSession) +_ResultT = TypeVar("_ResultT", bound=BaseModel) + class SamplingFnT(Protocol): async def __call__( @@ -195,6 +199,49 @@ async def initialize(self) -> types.InitializeResult: return result + async def send_request( + self, + request: types.ClientRequest, + result_type: type[_ResultT], + request_read_timeout_seconds: float | None = None, + metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, + *, + _session_recovery_attempted: bool = False, + ) -> _ResultT: + """Send a request with automatic session recovery on expiration. + + Per MCP spec, when the server returns 404 indicating the session has + expired, the client MUST re-initialize the session and retry the request. + + This override adds that automatic recovery behavior to the base + send_request method. + """ + try: + return await super().send_request( + request, + result_type, + request_read_timeout_seconds, + metadata, + progress_callback, + ) + except McpError as e: + # Check if this is a session expired error + if e.error.code == types.SESSION_EXPIRED and not _session_recovery_attempted: + logger.info("Session expired, re-initializing...") + # Re-initialize the session + await self.initialize() + # Retry the original request (with flag to prevent infinite loops) + return await self.send_request( + request, + result_type, + request_read_timeout_seconds, + metadata, + progress_callback, + _session_recovery_attempted=True, + ) + raise + def get_server_capabilities(self) -> types.ServerCapabilities | None: """Return the server capabilities received during initialization. diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 22645d3ba..cf0f75a3e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -28,6 +28,7 @@ ) from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( + SESSION_EXPIRED, ErrorData, InitializeResult, JSONRPCError, @@ -347,13 +348,25 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: logger.debug("Received 202 Accepted") return - if response.status_code == 404: # pragma: no branch + if response.status_code == 404: + # Clear invalid session per MCP spec + self.session_id = None + self.protocol_version = None + if isinstance(message.root, JSONRPCRequest): - await self._send_session_terminated_error( # pragma: no cover - ctx.read_stream_writer, # pragma: no cover - message.root.id, # pragma: no cover - ) # pragma: no cover - return # pragma: no cover + if is_initialization: # pragma: no cover + # For initialization requests, session truly doesn't exist + await self._send_session_terminated_error( + ctx.read_stream_writer, + message.root.id, + ) + else: # pragma: no cover + # For other requests, signal session expired for auto-recovery + await self._send_session_expired_error( + ctx.read_stream_writer, + message.root.id, + ) + return response.raise_for_status() if is_initialization: @@ -521,6 +534,23 @@ async def _send_session_terminated_error( session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) await read_stream_writer.send(session_message) + async def _send_session_expired_error( # pragma: no cover + self, + read_stream_writer: StreamWriter, + request_id: RequestId, + ) -> None: + """Send a session expired error response for auto-recovery.""" + jsonrpc_error = JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData( + code=SESSION_EXPIRED, + message="Session expired, re-initialization required", + ), + ) + session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) + await read_stream_writer.send(session_message) + async def post_writer( self, client: httpx.AsyncClient, diff --git a/src/mcp/types.py b/src/mcp/types.py index 654c00660..19eb98946 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -181,6 +181,8 @@ class JSONRPCResponse(BaseModel): # SDK error codes CONNECTION_CLOSED = -32000 # REQUEST_TIMEOUT = -32001 # the typescript sdk uses this +SESSION_EXPIRED = -32002 +"""Error code indicating the session has expired and needs re-initialization.""" # Standard JSON-RPC error codes PARSE_ERROR = -32700 diff --git a/tests/client/test_session_recovery.py b/tests/client/test_session_recovery.py new file mode 100644 index 000000000..53398ff86 --- /dev/null +++ b/tests/client/test_session_recovery.py @@ -0,0 +1,583 @@ +"""Tests for automatic session recovery on 404/SESSION_EXPIRED errors. + +Per MCP spec, when a client receives HTTP 404 in response to a request containing +an MCP-Session-Id, it MUST start a new session by sending a new InitializeRequest +without a session ID attached. +""" + +from typing import Any + +import anyio +import pytest + +import mcp.types as types +from mcp.client.session import ClientSession +from mcp.shared.exceptions import McpError +from mcp.shared.message import SessionMessage +from mcp.types import ( + CONNECTION_CLOSED, + LATEST_PROTOCOL_VERSION, + SESSION_EXPIRED, + CallToolResult, + ClientRequest, + ErrorData, + Implementation, + InitializeRequest, + InitializeResult, + JSONRPCError, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + ServerCapabilities, + ServerResult, + TextContent, + Tool, +) + + +@pytest.mark.anyio +async def test_session_recovery_on_expired_error(): + """Test that client re-initializes session when receiving SESSION_EXPIRED error.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + + init_count = 0 + tool_call_count = 0 + + async def mock_server(): + nonlocal init_count, tool_call_count + + # First initialization + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + init_count += 1 + + # Send init response + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Receive tool call request + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert jsonrpc_request.root.method == "tools/call" + tool_call_count += 1 + + # Send SESSION_EXPIRED error (simulating 404 from transport) + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCError( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + error=ErrorData( + code=SESSION_EXPIRED, + message="Session expired, re-initialization required", + ), + ) + ) + ) + ) + + # Should receive second initialization request (automatic recovery) + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + init_count += 1 + + # Send second init response + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Receive retried tool call + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert jsonrpc_request.root.method == "tools/call" + tool_call_count += 1 + + # Send successful response this time + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult( + CallToolResult( + content=[TextContent(type="text", text="Success!")], + isError=False, + ) + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # call_tool validates result by calling list_tools + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert jsonrpc_request.root.method == "tools/list" + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=types.ListToolsResult(tools=[Tool(name="test_tool", inputSchema={})]).model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + ) + + server_to_client_send.close() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + await session.initialize() + + # This should trigger SESSION_EXPIRED, then auto-reinit, then retry + result = await session.call_tool("test_tool", {"foo": "bar"}) + + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Success!" + + # Verify: 2 initializations (original + recovery), 2 tool calls (failed + retried) + assert init_count == 2 + assert tool_call_count == 2 + + +@pytest.mark.anyio +async def test_no_infinite_retry_loop_on_repeated_session_expired(): + """Test that client doesn't loop infinitely when session keeps expiring.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + + init_count = 0 + + async def mock_server(): + nonlocal init_count + + # First initialization + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + init_count += 1 + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Receive tool call + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + # Send SESSION_EXPIRED + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCError( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + error=ErrorData( + code=SESSION_EXPIRED, + message="Session expired", + ), + ) + ) + ) + ) + + # Second initialization (automatic recovery) + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + init_count += 1 + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Receive retried tool call + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + # Send SESSION_EXPIRED AGAIN - should NOT trigger another reinit + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCError( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + error=ErrorData( + code=SESSION_EXPIRED, + message="Session expired again", + ), + ) + ) + ) + ) + + server_to_client_send.close() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + await session.initialize() + + # Should raise McpError after retry fails (no infinite loop) + with pytest.raises(McpError) as exc_info: + await session.call_tool("test_tool", {}) + + assert exc_info.value.error.code == SESSION_EXPIRED + + # Only 2 initializations: original + one recovery attempt + assert init_count == 2 + + +@pytest.mark.anyio +async def test_non_session_expired_error_not_retried(): + """Test that other MCP errors don't trigger session recovery.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + + init_count = 0 + + async def mock_server(): + nonlocal init_count + + # Initial initialization + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + init_count += 1 + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Receive tool call + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + # Send a different error (CONNECTION_CLOSED) + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCError( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + error=ErrorData( + code=CONNECTION_CLOSED, + message="Connection closed", + ), + ) + ) + ) + ) + + server_to_client_send.close() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + await session.initialize() + + # Should raise McpError directly without recovery attempt + with pytest.raises(McpError) as exc_info: + await session.call_tool("test_tool", {}) + + assert exc_info.value.error.code == CONNECTION_CLOSED + + # Only 1 initialization - no recovery triggered + assert init_count == 1 + + +@pytest.mark.anyio +async def test_session_recovery_preserves_request_data(): + """Test that the original request data is preserved through recovery.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + + tool_params_received: list[dict[str, Any]] = [] + + async def mock_server(): + # First initialization + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Receive first tool call + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert jsonrpc_request.root.params is not None + tool_params_received.append(jsonrpc_request.root.params) + + # Send SESSION_EXPIRED + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCError( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + error=ErrorData( + code=SESSION_EXPIRED, + message="Session expired", + ), + ) + ) + ) + ) + + # Second initialization (recovery) + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Receive retried tool call - should have same params + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert jsonrpc_request.root.params is not None + tool_params_received.append(jsonrpc_request.root.params) + + # Send success + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult( + CallToolResult( + content=[TextContent(type="text", text="Done")], + isError=False, + ) + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # call_tool validates result by calling list_tools + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert jsonrpc_request.root.method == "tools/list" + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=types.ListToolsResult(tools=[Tool(name="important_tool", inputSchema={})]).model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + ) + + server_to_client_send.close() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + await session.initialize() + + # Call with specific arguments + await session.call_tool("important_tool", {"key": "sensitive_value", "count": 42}) + + # Both tool calls should have identical parameters + assert len(tool_params_received) == 2 + assert tool_params_received[0] == tool_params_received[1] + assert tool_params_received[0]["name"] == "important_tool" + assert tool_params_received[0]["arguments"] == {"key": "sensitive_value", "count": 42}