Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/steady-rivers-reinit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@modelcontextprotocol/client": patch
---

fix: reinitialize expired streamable sessions
91 changes: 53 additions & 38 deletions packages/client/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ export type ClientOptions = ProtocolOptions & {
listChanged?: ListChangedHandlers;
};

type SessionExpiringTransport = Transport & {
onsessionexpired?: () => void | Promise<void>;
};

/**
* An MCP client on top of a pluggable transport.
*
Expand Down Expand Up @@ -410,6 +414,13 @@ export class Client extends Protocol<ClientContext> {
*/
override async connect(transport: Transport, options?: RequestOptions): Promise<void> {
await super.connect(transport);
(transport as SessionExpiringTransport).onsessionexpired = async () => {
this._serverCapabilities = undefined;
this._serverVersion = undefined;
this._negotiatedProtocolVersion = undefined;
await this._initialize(transport, options);
};

// When transport sessionId is already set this means we are trying to reconnect.
// Restore the protocol version negotiated during the original initialize handshake
// so HTTP transports include the required mcp-protocol-version header, but skip re-init.
Expand All @@ -420,50 +431,54 @@ export class Client extends Protocol<ClientContext> {
return;
}
try {
const result = await this._requestWithSchema(
{
method: 'initialize',
params: {
protocolVersion: this._supportedProtocolVersions[0] ?? LATEST_PROTOCOL_VERSION,
capabilities: this._capabilities,
clientInfo: this._clientInfo
}
},
InitializeResultSchema,
options
);
await this._initialize(transport, options);
} catch (error) {
// Disconnect if initialization fails.
void this.close();
throw error;
}
}

if (result === undefined) {
throw new Error(`Server sent invalid initialize result: ${result}`);
}
private async _initialize(transport: Transport, options?: RequestOptions): Promise<void> {
const result = await this._requestWithSchema(
{
method: 'initialize',
params: {
protocolVersion: this._supportedProtocolVersions[0] ?? LATEST_PROTOCOL_VERSION,
capabilities: this._capabilities,
clientInfo: this._clientInfo
}
},
InitializeResultSchema,
options
);

if (!this._supportedProtocolVersions.includes(result.protocolVersion)) {
throw new Error(`Server's protocol version is not supported: ${result.protocolVersion}`);
}
if (result === undefined) {
throw new Error(`Server sent invalid initialize result: ${result}`);
}

this._serverCapabilities = result.capabilities;
this._serverVersion = result.serverInfo;
this._negotiatedProtocolVersion = result.protocolVersion;
// HTTP transports must set the protocol version in each header after initialization.
if (transport.setProtocolVersion) {
transport.setProtocolVersion(result.protocolVersion);
}
if (!this._supportedProtocolVersions.includes(result.protocolVersion)) {
throw new Error(`Server's protocol version is not supported: ${result.protocolVersion}`);
}

this._instructions = result.instructions;
this._serverCapabilities = result.capabilities;
this._serverVersion = result.serverInfo;
this._negotiatedProtocolVersion = result.protocolVersion;
// HTTP transports must set the protocol version in each header after initialization.
if (transport.setProtocolVersion) {
transport.setProtocolVersion(result.protocolVersion);
}

await this.notification({
method: 'notifications/initialized'
});
this._instructions = result.instructions;

// Set up list changed handlers now that we know server capabilities
if (this._pendingListChangedConfig) {
this._setupListChangedHandlers(this._pendingListChangedConfig);
this._pendingListChangedConfig = undefined;
}
} catch (error) {
// Disconnect if initialization fails.
void this.close();
throw error;
await this.notification({
method: 'notifications/initialized'
});

// Set up list changed handlers now that we know server capabilities
if (this._pendingListChangedConfig) {
this._setupListChangedHandlers(this._pendingListChangedConfig);
this._pendingListChangedConfig = undefined;
}
}

Expand Down
19 changes: 15 additions & 4 deletions packages/client/src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ export class StreamableHTTPClientTransport implements Transport {
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
onsessionexpired?: () => void | Promise<void>;

constructor(url: URL, opts?: StreamableHTTPClientTransportOptions) {
this._url = url;
Expand Down Expand Up @@ -521,13 +522,14 @@ export class StreamableHTTPClientTransport implements Transport {
message: JSONRPCMessage | JSONRPCMessage[],
options?: { resumptionToken?: string; onresumptiontoken?: (token: string) => void }
): Promise<void> {
return this._send(message, options, false);
return this._send(message, options, false, false);
}

private async _send(
message: JSONRPCMessage | JSONRPCMessage[],
options: { resumptionToken?: string; onresumptiontoken?: (token: string) => void } | undefined,
isAuthRetry: boolean
isAuthRetry: boolean,
isSessionRetry: boolean
): Promise<void> {
try {
const { resumptionToken, onresumptiontoken } = options || {};
Expand Down Expand Up @@ -579,7 +581,7 @@ export class StreamableHTTPClientTransport implements Transport {
});
await response.text?.().catch(() => {});
// Purposely _not_ awaited, so we don't call onerror twice
return this._send(message, options, true);
return this._send(message, options, true, isSessionRetry);
}
await response.text?.().catch(() => {});
if (isAuthRetry) {
Expand All @@ -593,6 +595,15 @@ export class StreamableHTTPClientTransport implements Transport {

const text = await response.text?.().catch(() => null);

if (response.status === 404 && this._sessionId && !isSessionRetry) {
this._sessionId = undefined;
await this.onsessionexpired?.();

if (this._sessionId) {
return this._send(message, options, isAuthRetry, true);
}
}

if (response.status === 403 && this._oauthProvider) {
const { resourceMetadataUrl, scope, error } = extractWWWAuthenticateParams(response);

Expand Down Expand Up @@ -629,7 +640,7 @@ export class StreamableHTTPClientTransport implements Transport {
throw new UnauthorizedError();
}

return this._send(message, options, isAuthRetry);
return this._send(message, options, isAuthRetry, isSessionRetry);
}
}

Expand Down
93 changes: 92 additions & 1 deletion packages/client/test/client/streamableHttp.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import type { JSONRPCMessage, JSONRPCRequest } from '@modelcontextprotocol/core';
import { OAuthError, OAuthErrorCode, SdkErrorCode, SdkHttpError } from '@modelcontextprotocol/core';
import { LATEST_PROTOCOL_VERSION, OAuthError, OAuthErrorCode, SdkErrorCode, SdkHttpError } from '@modelcontextprotocol/core';
import type { Mock, Mocked } from 'vitest';

import type { OAuthClientProvider } from '../../src/client/auth.js';
import { UnauthorizedError } from '../../src/client/auth.js';
import { Client } from '../../src/client/client.js';
import type { ReconnectionScheduler, StartSSEOptions, StreamableHTTPReconnectionOptions } from '../../src/client/streamableHttp.js';
import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js';

Expand Down Expand Up @@ -249,6 +250,96 @@ describe('StreamableHTTPClientTransport', () => {
expect(errorSpy).toHaveBeenCalled();
});

it('reinitializes and retries once when a persisted session expires', async () => {
const client = new Client({ name: 'test-client', version: '1.0.0' });
const httpTransport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'));
let initializeCount = 0;
let firstPing = true;

(globalThis.fetch as Mock).mockImplementation(async (_url, init) => {
if (init.method === 'GET') {
return {
ok: false,
status: 405,
statusText: 'Method Not Allowed',
headers: new Headers(),
text: async () => ''
};
}

const body = JSON.parse(init.body as string) as JSONRPCRequest;

if (body.method === 'initialize') {
const sessionId = initializeCount++ === 0 ? 'old-session-id' : 'new-session-id';

return {
ok: true,
status: 200,
headers: new Headers({
'content-type': 'application/json',
'mcp-session-id': sessionId
}),
json: async () => ({
jsonrpc: '2.0',
id: body.id,
result: {
protocolVersion: LATEST_PROTOCOL_VERSION,
capabilities: {},
serverInfo: { name: 'test-server', version: '1.0.0' }
}
})
};
}

if (body.method === 'notifications/initialized') {
return {
ok: true,
status: 202,
headers: new Headers(),
text: async () => ''
};
}

if (body.method === 'ping' && firstPing) {
firstPing = false;

return {
ok: false,
status: 404,
statusText: 'Not Found',
headers: new Headers(),
text: async () => 'Session not found'
};
}

return {
ok: true,
status: 200,
headers: new Headers({ 'content-type': 'application/json' }),
json: async () => ({
jsonrpc: '2.0',
id: body.id,
result: {}
})
};
});

try {
await client.connect(httpTransport);
await expect(client.ping()).resolves.toEqual({});

const calls = (globalThis.fetch as Mock).mock.calls;
const postCalls = calls.filter(([, init]) => init.method === 'POST');
expect(postCalls).toHaveLength(6);
expect(postCalls[2]![1].headers.get('mcp-session-id')).toBe('old-session-id');
expect(postCalls[3]![1].headers.get('mcp-session-id')).toBeNull();
expect(postCalls[5]![1].headers.get('mcp-session-id')).toBe('new-session-id');
expect(httpTransport.sessionId).toBe('new-session-id');
} finally {
await client.close().catch(() => {});
}
});

it('should handle non-streaming JSON response', async () => {
const message: JSONRPCMessage = {
jsonrpc: '2.0',
Expand Down
10 changes: 3 additions & 7 deletions test/e2e/requirements.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1641,7 +1641,8 @@ export const REQUIREMENTS: Record<string, Requirement> = {

'client-transport:http:404-surfaces': {
source: 'sdk',
behavior: 'A 404 (session expired) on a request surfaces as an error to the caller.',
behavior:
'A 404 (session expired) on a request surfaces as an error to the caller when the transport has no session recovery hook.',
transports: ['streamableHttp'],
note: 'Session-id continuity testing requires the per-session host (404 is session-not-found).'
},
Expand All @@ -1650,12 +1651,7 @@ export const REQUIREMENTS: Record<string, Requirement> = {
behavior:
'A 404 in response to a request carrying a session ID makes the client start a new session with a fresh InitializeRequest and no session ID attached.',
transports: ['streamableHttp'],
note: 'This exercises the StreamableHTTP client transport directly; the matrix transport arg is ignored, so it runs as a single streamableHttp-labelled cell to avoid duplicate runs.',
knownFailures: [
{
note: 'On a 404 for an existing session the transport throws StreamableHTTPError (streamableHttp.ts:551) and never re-initializes — no session recovery is attempted.'
}
]
note: 'This exercises the StreamableHTTP client transport directly; the matrix transport arg is ignored, so it runs as a single streamableHttp-labelled cell to avoid duplicate runs.'
},
'client-transport:http:accept-header-get': {
source: 'https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#listening-for-messages-from-the-server',
Expand Down
1 change: 1 addition & 0 deletions test/e2e/scenarios/transport-http.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ verifies('client-transport:http:404-surfaces', async (_args: TestArgs) => {

await client.connect(transport);
sessionIdToBreak = transport.sessionId;
transport.onsessionexpired = undefined;

const call = client.ping();
await expect(call).rejects.toThrow();
Expand Down
Loading