Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import logging
from typing import Any, Protocol, overload
from typing import Any, Protocol, TypeVar, overload

import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter
from pydantic import AnyUrl, BaseModel, TypeAdapter
from typing_extensions import deprecated

import mcp.types as types
from mcp.client.experimental import ExperimentalClientFeatures
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.session import BaseSession, MessageMetadata, ProgressFnT, RequestResponder
Copy link

Copilot AI Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'BaseSession' is not used.

Suggested change
from mcp.shared.session import BaseSession, MessageMetadata, ProgressFnT, RequestResponder
from mcp.shared.session import MessageMetadata, ProgressFnT, RequestResponder

Copilot uses AI. Check for mistakes.
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS

DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")

logger = logging.getLogger("client")

# TypeVar for generic result type in send_request (bound to BaseModel like in BaseSession)
_ResultT = TypeVar("_ResultT", bound=BaseModel)


class SamplingFnT(Protocol):
async def __call__(
Expand Down Expand Up @@ -195,6 +199,49 @@ async def initialize(self) -> types.InitializeResult:

return result

async def send_request(
self,
request: types.ClientRequest,
result_type: type[_ResultT],
request_read_timeout_seconds: float | None = None,
metadata: MessageMetadata = None,
progress_callback: ProgressFnT | None = None,
*,
_session_recovery_attempted: bool = False,
) -> _ResultT:
"""Send a request with automatic session recovery on expiration.

Per MCP spec, when the server returns 404 indicating the session has
expired, the client MUST re-initialize the session and retry the request.

This override adds that automatic recovery behavior to the base
send_request method.
"""
try:
return await super().send_request(
request,
result_type,
request_read_timeout_seconds,
metadata,
progress_callback,
)
except McpError as e:
# Check if this is a session expired error
if e.error.code == types.SESSION_EXPIRED and not _session_recovery_attempted:
logger.info("Session expired, re-initializing...")
# Re-initialize the session
await self.initialize()
# Retry the original request (with flag to prevent infinite loops)
return await self.send_request(
request,
result_type,
request_read_timeout_seconds,
metadata,
progress_callback,
_session_recovery_attempted=True,
)
raise

def get_server_capabilities(self) -> types.ServerCapabilities | None:
"""Return the server capabilities received during initialization.

Expand Down
42 changes: 36 additions & 6 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
SESSION_EXPIRED,
ErrorData,
InitializeResult,
JSONRPCError,
Expand Down Expand Up @@ -347,13 +348,25 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
logger.debug("Received 202 Accepted")
return

if response.status_code == 404: # pragma: no branch
if response.status_code == 404:
# Clear invalid session per MCP spec
self.session_id = None
self.protocol_version = None

if isinstance(message.root, JSONRPCRequest):
await self._send_session_terminated_error( # pragma: no cover
ctx.read_stream_writer, # pragma: no cover
message.root.id, # pragma: no cover
) # pragma: no cover
return # pragma: no cover
if is_initialization: # pragma: no cover
# For initialization requests, session truly doesn't exist
await self._send_session_terminated_error(
ctx.read_stream_writer,
message.root.id,
)
else: # pragma: no cover
# For other requests, signal session expired for auto-recovery
await self._send_session_expired_error(
ctx.read_stream_writer,
message.root.id,
)
return

response.raise_for_status()
if is_initialization:
Expand Down Expand Up @@ -521,6 +534,23 @@ async def _send_session_terminated_error(
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
await read_stream_writer.send(session_message)

async def _send_session_expired_error( # pragma: no cover
self,
read_stream_writer: StreamWriter,
request_id: RequestId,
) -> None:
"""Send a session expired error response for auto-recovery."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=ErrorData(
code=SESSION_EXPIRED,
message="Session expired, re-initialization required",
),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
await read_stream_writer.send(session_message)

async def post_writer(
self,
client: httpx.AsyncClient,
Expand Down
2 changes: 2 additions & 0 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class JSONRPCResponse(BaseModel):
# SDK error codes
CONNECTION_CLOSED = -32000
# REQUEST_TIMEOUT = -32001 # the typescript sdk uses this
SESSION_EXPIRED = -32002
"""Error code indicating the session has expired and needs re-initialization."""

# Standard JSON-RPC error codes
PARSE_ERROR = -32700
Expand Down
Loading