Skip to content

Commit 789fa0b

Browse files
committed
fix(mcp): harden notification system against race conditions
- Guard concurrent connect() calls in connection manager with connectingServers Set - Suppress post-disconnect notification handler firing in MCP client - Clean up Redis event listeners in pub/sub dispose() - Add tests for all three hardening fixes (11 new tests)
1 parent 4193007 commit 789fa0b

File tree

7 files changed

+1059
-18
lines changed

7 files changed

+1059
-18
lines changed

apps/sim/lib/mcp/client.test.ts

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/**
2+
* @vitest-environment node
3+
*/
4+
import { loggerMock } from '@sim/testing'
5+
import { beforeEach, describe, expect, it, vi } from 'vitest'
6+
7+
vi.mock('@sim/logger', () => loggerMock)
8+
9+
/**
10+
* Capture the notification handler registered via `client.setNotificationHandler()`.
11+
* This lets us simulate the MCP SDK delivering a `tools/list_changed` notification.
12+
*/
13+
let capturedNotificationHandler: (() => Promise<void>) | null = null
14+
15+
vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({
16+
Client: vi.fn().mockImplementation(() => ({
17+
connect: vi.fn().mockResolvedValue(undefined),
18+
close: vi.fn().mockResolvedValue(undefined),
19+
getServerVersion: vi.fn().mockReturnValue('2025-06-18'),
20+
getServerCapabilities: vi.fn().mockReturnValue({ tools: { listChanged: true } }),
21+
setNotificationHandler: vi
22+
.fn()
23+
.mockImplementation((_schema: unknown, handler: () => Promise<void>) => {
24+
capturedNotificationHandler = handler
25+
}),
26+
listTools: vi.fn().mockResolvedValue({ tools: [] }),
27+
})),
28+
}))
29+
30+
vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({
31+
StreamableHTTPClientTransport: vi.fn().mockImplementation(() => ({
32+
onclose: null,
33+
sessionId: 'test-session',
34+
})),
35+
}))
36+
37+
vi.mock('@modelcontextprotocol/sdk/types.js', () => ({
38+
ToolListChangedNotificationSchema: { method: 'notifications/tools/list_changed' },
39+
}))
40+
41+
vi.mock('@/lib/core/execution-limits', () => ({
42+
getMaxExecutionTimeout: vi.fn().mockReturnValue(30000),
43+
}))
44+
45+
import { McpClient } from './client'
46+
import type { McpServerConfig } from './types'
47+
48+
function createConfig(): McpServerConfig {
49+
return {
50+
id: 'server-1',
51+
name: 'Test Server',
52+
transport: 'streamable-http',
53+
url: 'https://test.example.com/mcp',
54+
}
55+
}
56+
57+
describe('McpClient notification handler', () => {
58+
beforeEach(() => {
59+
capturedNotificationHandler = null
60+
})
61+
62+
it('fires onToolsChanged when a notification arrives while connected', async () => {
63+
const onToolsChanged = vi.fn()
64+
65+
const client = new McpClient({
66+
config: createConfig(),
67+
securityPolicy: { requireConsent: false, auditLevel: 'basic' },
68+
onToolsChanged,
69+
})
70+
71+
await client.connect()
72+
73+
expect(capturedNotificationHandler).not.toBeNull()
74+
75+
await capturedNotificationHandler!()
76+
77+
expect(onToolsChanged).toHaveBeenCalledTimes(1)
78+
expect(onToolsChanged).toHaveBeenCalledWith('server-1')
79+
})
80+
81+
it('suppresses notifications after disconnect', async () => {
82+
const onToolsChanged = vi.fn()
83+
84+
const client = new McpClient({
85+
config: createConfig(),
86+
securityPolicy: { requireConsent: false, auditLevel: 'basic' },
87+
onToolsChanged,
88+
})
89+
90+
await client.connect()
91+
expect(capturedNotificationHandler).not.toBeNull()
92+
93+
await client.disconnect()
94+
95+
// Simulate a late notification arriving after disconnect
96+
await capturedNotificationHandler!()
97+
98+
expect(onToolsChanged).not.toHaveBeenCalled()
99+
})
100+
101+
it('does not register a notification handler when onToolsChanged is not provided', async () => {
102+
const client = new McpClient({
103+
config: createConfig(),
104+
securityPolicy: { requireConsent: false, auditLevel: 'basic' },
105+
})
106+
107+
await client.connect()
108+
109+
expect(capturedNotificationHandler).toBeNull()
110+
})
111+
})

apps/sim/lib/mcp/client.ts

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@
1010

1111
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
1212
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
13-
import type { ListToolsResult, Tool } from '@modelcontextprotocol/sdk/types.js'
13+
import {
14+
type ListToolsResult,
15+
type Tool,
16+
ToolListChangedNotificationSchema,
17+
} from '@modelcontextprotocol/sdk/types.js'
1418
import { createLogger } from '@sim/logger'
1519
import { getMaxExecutionTimeout } from '@/lib/core/execution-limits'
1620
import {
21+
type McpClientOptions,
1722
McpConnectionError,
1823
type McpConnectionStatus,
1924
type McpConsentRequest,
@@ -24,6 +29,7 @@ import {
2429
type McpTool,
2530
type McpToolCall,
2631
type McpToolResult,
32+
type McpToolsChangedCallback,
2733
type McpVersionInfo,
2834
} from '@/lib/mcp/types'
2935

@@ -35,6 +41,7 @@ export class McpClient {
3541
private config: McpServerConfig
3642
private connectionStatus: McpConnectionStatus
3743
private securityPolicy: McpSecurityPolicy
44+
private onToolsChanged?: McpToolsChangedCallback
3845
private isConnected = false
3946

4047
private static readonly SUPPORTED_VERSIONS = [
@@ -44,23 +51,36 @@ export class McpClient {
4451
]
4552

4653
/**
47-
* Creates a new MCP client
48-
*
49-
* No session ID parameter (we disconnect after each operation).
50-
* The SDK handles session management automatically via Mcp-Session-Id header.
54+
* Creates a new MCP client.
5155
*
52-
* @param config - Server configuration
53-
* @param securityPolicy - Optional security policy
56+
* Accepts either the legacy (config, securityPolicy?) signature
57+
* or a single McpClientOptions object with an optional onToolsChanged callback.
5458
*/
55-
constructor(config: McpServerConfig, securityPolicy?: McpSecurityPolicy) {
56-
this.config = config
57-
this.connectionStatus = { connected: false }
58-
this.securityPolicy = securityPolicy ?? {
59-
requireConsent: true,
60-
auditLevel: 'basic',
61-
maxToolExecutionsPerHour: 1000,
59+
constructor(config: McpServerConfig, securityPolicy?: McpSecurityPolicy)
60+
constructor(options: McpClientOptions)
61+
constructor(
62+
configOrOptions: McpServerConfig | McpClientOptions,
63+
securityPolicy?: McpSecurityPolicy
64+
) {
65+
if ('config' in configOrOptions) {
66+
this.config = configOrOptions.config
67+
this.securityPolicy = configOrOptions.securityPolicy ?? {
68+
requireConsent: true,
69+
auditLevel: 'basic',
70+
maxToolExecutionsPerHour: 1000,
71+
}
72+
this.onToolsChanged = configOrOptions.onToolsChanged
73+
} else {
74+
this.config = configOrOptions
75+
this.securityPolicy = securityPolicy ?? {
76+
requireConsent: true,
77+
auditLevel: 'basic',
78+
maxToolExecutionsPerHour: 1000,
79+
}
6280
}
6381

82+
this.connectionStatus = { connected: false }
83+
6484
if (!this.config.url) {
6585
throw new McpError('URL required for Streamable HTTP transport')
6686
}
@@ -79,16 +99,15 @@ export class McpClient {
7999
{
80100
capabilities: {
81101
tools: {},
82-
// Resources and prompts can be added later
83-
// resources: {},
84-
// prompts: {},
85102
},
86103
}
87104
)
88105
}
89106

90107
/**
91-
* Initialize connection to MCP server
108+
* Initialize connection to MCP server.
109+
* If an `onToolsChanged` callback was provided, registers a notification handler
110+
* for `notifications/tools/list_changed` after connecting.
92111
*/
93112
async connect(): Promise<void> {
94113
logger.info(`Connecting to MCP server: ${this.config.name} (${this.config.transport})`)
@@ -100,6 +119,15 @@ export class McpClient {
100119
this.connectionStatus.connected = true
101120
this.connectionStatus.lastConnected = new Date()
102121

122+
if (this.onToolsChanged) {
123+
this.client.setNotificationHandler(ToolListChangedNotificationSchema, async () => {
124+
if (!this.isConnected) return
125+
logger.info(`[${this.config.name}] Received tools/list_changed notification`)
126+
this.onToolsChanged?.(this.config.id)
127+
})
128+
logger.info(`[${this.config.name}] Registered tools/list_changed notification handler`)
129+
}
130+
103131
const serverVersion = this.client.getServerVersion()
104132
logger.info(`Successfully connected to MCP server: ${this.config.name}`, {
105133
protocolVersion: serverVersion,
@@ -241,6 +269,23 @@ export class McpClient {
241269
return !!serverCapabilities?.[capability]
242270
}
243271

272+
/**
273+
* Check if the server declared `capabilities.tools.listChanged: true` during initialization.
274+
*/
275+
hasListChangedCapability(): boolean {
276+
const caps = this.client.getServerCapabilities()
277+
const toolsCap = caps?.tools as Record<string, unknown> | undefined
278+
return !!toolsCap?.listChanged
279+
}
280+
281+
/**
282+
* Register a callback to be invoked when the underlying transport closes.
283+
* Used by the connection manager for reconnection logic.
284+
*/
285+
onClose(callback: () => void): void {
286+
this.transport.onclose = callback
287+
}
288+
244289
/**
245290
* Get server configuration
246291
*/

0 commit comments

Comments
 (0)