diff --git a/.gitignore b/.gitignore index a1b83bc4f..58cff53ff 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,5 @@ dist/ # IDE .idea/ + +# ahammednibras8 \ No newline at end of file diff --git a/docs/server.md b/docs/server.md index 4d5138e84..080b36d4e 100644 --- a/docs/server.md +++ b/docs/server.md @@ -1,6 +1,8 @@ ## Server overview -This SDK lets you build MCP servers in TypeScript and connect them to different transports. For most use cases you will use `McpServer` from `@modelcontextprotocol/server` and choose one of: +This SDK lets you build MCP servers in TypeScript and connect them to different +transports. For most use cases you will use `McpServer` from +`@modelcontextprotocol/server` and choose one of: - **Streamable HTTP** (recommended for remote servers) - **HTTP + SSE** (deprecated, for backwards compatibility only) @@ -8,11 +10,16 @@ This SDK lets you build MCP servers in TypeScript and connect them to different For a complete, runnable example server, see: -- [`simpleStreamableHttp.ts`](../examples/server/src/simpleStreamableHttp.ts) – feature‑rich Streamable HTTP server -- [`jsonResponseStreamableHttp.ts`](../examples/server/src/jsonResponseStreamableHttp.ts) – Streamable HTTP with JSON response mode -- [`simpleStatelessStreamableHttp.ts`](../examples/server/src/simpleStatelessStreamableHttp.ts) – stateless Streamable HTTP server -- [`simpleSseServer.ts`](../examples/server/src/simpleSseServer.ts) – deprecated HTTP+SSE transport -- [`sseAndStreamableHttpCompatibleServer.ts`](../examples/server/src/sseAndStreamableHttpCompatibleServer.ts) – backwards‑compatible server for old and new clients +- [`simpleStreamableHttp.ts`](../examples/server/src/simpleStreamableHttp.ts) – + feature‑rich Streamable HTTP server +- [`jsonResponseStreamableHttp.ts`](../examples/server/src/jsonResponseStreamableHttp.ts) + – Streamable HTTP with JSON response mode +- [`simpleStatelessStreamableHttp.ts`](../examples/server/src/simpleStatelessStreamableHttp.ts) + – stateless Streamable HTTP server +- [`simpleSseServer.ts`](../examples/server/src/simpleSseServer.ts) – deprecated + HTTP+SSE transport +- [`sseAndStreamableHttpCompatibleServer.ts`](../examples/server/src/sseAndStreamableHttpCompatibleServer.ts) + – backwards‑compatible server for old and new clients ## Transports @@ -27,69 +34,122 @@ Streamable HTTP is the modern, fully featured transport. It supports: Key examples: -- [`simpleStreamableHttp.ts`](../examples/server/src/simpleStreamableHttp.ts) – sessions, logging, tasks, elicitation, auth hooks -- [`jsonResponseStreamableHttp.ts`](../examples/server/src/jsonResponseStreamableHttp.ts) – `enableJsonResponse: true`, no SSE -- [`standaloneSseWithGetStreamableHttp.ts`](../examples/server/src/standaloneSseWithGetStreamableHttp.ts) – notifications with Streamable HTTP GET + SSE +- [`simpleStreamableHttp.ts`](../examples/server/src/simpleStreamableHttp.ts) – + sessions, logging, tasks, elicitation, auth hooks +- [`jsonResponseStreamableHttp.ts`](../examples/server/src/jsonResponseStreamableHttp.ts) + – `enableJsonResponse: true`, no SSE +- [`standaloneSseWithGetStreamableHttp.ts`](../examples/server/src/standaloneSseWithGetStreamableHttp.ts) + – notifications with Streamable HTTP GET + SSE -See the MCP spec for full transport details: `https://modelcontextprotocol.io/specification/2025-11-25/basic/transports` +See the MCP spec for full transport details: +`https://modelcontextprotocol.io/specification/2025-11-25/basic/transports` ### Stateless vs stateful sessions Streamable HTTP can run: - **Stateless** – no session tracking, ideal for simple API‑style servers. -- **Stateful** – sessions have IDs, and you can enable resumability and advanced features. +- **Stateful** – sessions have IDs, and you can enable resumability and advanced + features. Examples: -- Stateless Streamable HTTP: [`simpleStatelessStreamableHttp.ts`](../examples/server/src/simpleStatelessStreamableHttp.ts) -- Stateful with resumability: [`simpleStreamableHttp.ts`](../examples/server/src/simpleStreamableHttp.ts) +- Stateless Streamable HTTP: + [`simpleStatelessStreamableHttp.ts`](../examples/server/src/simpleStatelessStreamableHttp.ts) +- Stateful with resumability: + [`simpleStreamableHttp.ts`](../examples/server/src/simpleStreamableHttp.ts) ### Deprecated HTTP + SSE -The older HTTP+SSE transport (protocol version 2024‑11‑05) is supported only for backwards compatibility. New implementations should prefer Streamable HTTP. +The older HTTP+SSE transport (protocol version 2024‑11‑05) is supported only for +backwards compatibility. New implementations should prefer Streamable HTTP. Examples: -- Legacy SSE server: [`simpleSseServer.ts`](../examples/server/src/simpleSseServer.ts) -- Backwards‑compatible server (Streamable HTTP + SSE): +- Legacy SSE server: + [`simpleSseServer.ts`](../examples/server/src/simpleSseServer.ts) +- Backwards‑compatible server (Streamable HTTP + SSE):\ [`sseAndStreamableHttpCompatibleServer.ts`](../examples/server/src/sseAndStreamableHttpCompatibleServer.ts) ## Running your server For a minimal “getting started” experience: -1. Start from [`simpleStreamableHttp.ts`](../examples/server/src/simpleStreamableHttp.ts). +1. Start from + [`simpleStreamableHttp.ts`](../examples/server/src/simpleStreamableHttp.ts). 2. Remove features you do not need (tasks, advanced logging, OAuth, etc.). 3. Register your own tools, resources and prompts. -For more detailed patterns (stateless vs stateful, JSON response mode, CORS, DNS rebind protection), see the examples above and the MCP spec sections on transports. +For more detailed patterns (stateless vs stateful, JSON response mode, CORS, DNS +rebind protection), see the examples above and the MCP spec sections on +transports. + +## Middleware + +The `McpServer` supports a middleware system similar to Express or Koa, allowing +you to intercept and modify requests, log activity, or enforce authentication +across all tools, prompts, and resources. + +Register middleware using `server.use()`: + +```typescript +const server = new McpServer( + { name: "my-server", version: "1.0.0" }, + { capabilities: { logging: {} } }, +); + +// Logging middleware +server.use(async (context, next) => { + const start = Date.now(); + try { + await next(); + } finally { + const duration = Date.now() - start; + console.error(`[${context.request.method}] took ${duration}ms`); + } +}); + +// Authentication middleware example +server.use(async (context, next) => { + if (context.request.method === "tools/call") { + // Perform auth checks here + // throw new McpError(ErrorCode.InvalidRequest, "Unauthorized"); + } + await next(); +}); +``` + +Middleware executes in the order registered. Calling `next()` passes control to +the next middleware or the actual handler. Note that middleware must be +registered before connecting the server transport. ## DNS rebinding protection -MCP servers running on localhost are vulnerable to DNS rebinding attacks. Use `createMcpExpressApp()` to create an Express app with DNS rebinding protection enabled by default: +MCP servers running on localhost are vulnerable to DNS rebinding attacks. Use +`createMcpExpressApp()` to create an Express app with DNS rebinding protection +enabled by default: ```typescript -import { createMcpExpressApp } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from "@modelcontextprotocol/server"; // Protection auto-enabled (default host is 127.0.0.1) const app = createMcpExpressApp(); // Protection auto-enabled for localhost -const app = createMcpExpressApp({ host: 'localhost' }); +const app = createMcpExpressApp({ host: "localhost" }); // No auto protection when binding to all interfaces, unless you provide allowedHosts -const app = createMcpExpressApp({ host: '0.0.0.0' }); +const app = createMcpExpressApp({ host: "0.0.0.0" }); ``` When binding to `0.0.0.0` / `::`, provide an allow-list of hosts: ```typescript -import { createMcpExpressApp } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from "@modelcontextprotocol/server"; const app = createMcpExpressApp({ - host: '0.0.0.0', - allowedHosts: ['localhost', '127.0.0.1', 'myhost.local'] + host: "0.0.0.0", + allowedHosts: ["localhost", "127.0.0.1", "myhost.local"], }); ``` @@ -97,29 +157,30 @@ const app = createMcpExpressApp({ ### Tools -Tools let MCP clients ask your server to take actions. They are usually the main way that LLMs call into your application. +Tools let MCP clients ask your server to take actions. They are usually the main +way that LLMs call into your application. A typical registration with `registerTool` looks like this: ```typescript server.registerTool( - 'calculate-bmi', + "calculate-bmi", { - title: 'BMI Calculator', - description: 'Calculate Body Mass Index', + title: "BMI Calculator", + description: "Calculate Body Mass Index", inputSchema: { weightKg: z.number(), - heightM: z.number() + heightM: z.number(), }, - outputSchema: { bmi: z.number() } + outputSchema: { bmi: z.number() }, }, async ({ weightKg, heightM }) => { const output = { bmi: weightKg / (heightM * heightM) }; return { - content: [{ type: 'text', text: JSON.stringify(output) }], - structuredContent: output + content: [{ type: "text", text: JSON.stringify(output) }], + structuredContent: output, }; - } + }, ); ``` @@ -130,60 +191,66 @@ This snippet is illustrative only; for runnable servers that expose tools, see: #### ResourceLink outputs -Tools can return `resource_link` content items to reference large resources without embedding them directly, allowing clients to fetch only what they need. +Tools can return `resource_link` content items to reference large resources +without embedding them directly, allowing clients to fetch only what they need. -The README’s `list-files` example shows the pattern conceptually; for concrete usage, see the Streamable HTTP examples in `examples/server/src`. +The README’s `list-files` example shows the pattern conceptually; for concrete +usage, see the Streamable HTTP examples in `examples/server/src`. ### Resources -Resources expose data to clients, but should not perform heavy computation or side‑effects. They are ideal for configuration, documents, or other reference data. +Resources expose data to clients, but should not perform heavy computation or +side‑effects. They are ideal for configuration, documents, or other reference +data. Conceptually, you might register resources like: ```typescript server.registerResource( - 'config', - 'config://app', + "config", + "config://app", { - title: 'Application Config', - description: 'Application configuration data', - mimeType: 'text/plain' + title: "Application Config", + description: "Application configuration data", + mimeType: "text/plain", }, - async uri => ({ - contents: [{ uri: uri.href, text: 'App configuration here' }] - }) + async (uri) => ({ + contents: [{ uri: uri.href, text: "App configuration here" }], + }), ); ``` -Dynamic resources use `ResourceTemplate` and can support completions on path parameters. For full runnable examples of resources: +Dynamic resources use `ResourceTemplate` and can support completions on path +parameters. For full runnable examples of resources: - [`simpleStreamableHttp.ts`](../examples/server/src/simpleStreamableHttp.ts) ### Prompts -Prompts are reusable templates that help humans (or client UIs) talk to models in a consistent way. They are declared on the server and listed through MCP. +Prompts are reusable templates that help humans (or client UIs) talk to models +in a consistent way. They are declared on the server and listed through MCP. A minimal prompt: ```typescript server.registerPrompt( - 'review-code', + "review-code", { - title: 'Code Review', - description: 'Review code for best practices and potential issues', - argsSchema: { code: z.string() } + title: "Code Review", + description: "Review code for best practices and potential issues", + argsSchema: { code: z.string() }, }, ({ code }) => ({ messages: [ { - role: 'user', + role: "user", content: { - type: 'text', - text: `Please review this code:\n\n${code}` - } - } - ] - }) + type: "text", + text: `Please review this code:\n\n${code}`, + }, + }, + ], + }), ); ``` @@ -193,19 +260,26 @@ For prompts integrated into a full server, see: ### Completions -Both prompts and resources can support argument completions. On the client side, you use `client.complete()` with a reference to the prompt or resource and the partially‑typed argument. +Both prompts and resources can support argument completions. On the client side, +you use `client.complete()` with a reference to the prompt or resource and the +partially‑typed argument. -See the MCP spec sections on prompts and resources for complete details, and [`simpleStreamableHttp.ts`](../examples/client/src/simpleStreamableHttp.ts) for client‑side usage patterns. +See the MCP spec sections on prompts and resources for complete details, and +[`simpleStreamableHttp.ts`](../examples/client/src/simpleStreamableHttp.ts) for +client‑side usage patterns. ### Display names and metadata -Tools, resources and prompts support a `title` field for human‑readable names. Older APIs can also attach `annotations.title`. To compute the correct display name on the client, use: +Tools, resources and prompts support a `title` field for human‑readable names. +Older APIs can also attach `annotations.title`. To compute the correct display +name on the client, use: - `getDisplayName` from `@modelcontextprotocol/client` ## Multi‑node deployment patterns -The SDK supports multi‑node deployments using Streamable HTTP. The high‑level patterns and diagrams live with the runnable server examples: +The SDK supports multi‑node deployments using Streamable HTTP. The high‑level +patterns and diagrams live with the runnable server examples: - [`examples/server/README.md`](../examples/server/README.md#multi-node-deployment-patterns) @@ -214,8 +288,9 @@ The SDK supports multi‑node deployments using Streamable HTTP. The high‑leve To handle both modern and legacy clients: - Run a backwards‑compatible server: - - [`sseAndStreamableHttpCompatibleServer.ts`](../examples/server/src/sseAndStreamableHttpCompatibleServer.ts) + - [`sseAndStreamableHttpCompatibleServer.ts`](../examples/server/src/sseAndStreamableHttpCompatibleServer.ts) - Use a client that falls back from Streamable HTTP to SSE: - - [`streamableHttpWithSseFallbackClient.ts`](../examples/client/src/streamableHttpWithSseFallbackClient.ts) + - [`streamableHttpWithSseFallbackClient.ts`](../examples/client/src/streamableHttpWithSseFallbackClient.ts) -For the detailed protocol rules, see the “Backwards compatibility” section of the MCP spec. +For the detailed protocol rules, see the “Backwards compatibility” section of +the MCP spec. diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index 8564212c1..19ff4ceae 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -7,16 +7,22 @@ import type { CompleteRequestPrompt, CompleteRequestResourceTemplate, CompleteResult, + CreateTaskRequestHandlerExtra, CreateTaskResult, + GetPromptRequest, GetPromptResult, Implementation, + ListPromptsRequest, ListPromptsResult, + ListResourcesRequest, ListResourcesResult, + ListToolsRequest, ListToolsResult, LoggingMessageNotification, Prompt, PromptArgument, PromptReference, + ReadResourceRequest, ReadResourceResult, RequestHandlerExtra, Resource, @@ -66,6 +72,35 @@ import { getCompleter, isCompletable } from './completable.js'; import type { ServerOptions } from './server.js'; import { Server } from './server.js'; +/** + * Context passed to MCP middleware functions. + */ +export interface McpMiddlewareContext { + /** + * The incoming JSON-RPC request. + * While technically mutable, middleware should generally treat this as read-only. + * Mutation is permitted only for specific cases like schema normalization or request enrichment. + */ + request: ServerRequest; + + /** + * Additional metadata passed from the transport or SDK. + */ + extra: RequestHandlerExtra; + + /** + * A generic key-value store for cross-middleware communication (e.g., attaching a user object after auth). + */ + state: Record; +} + +/** + * Middleware function for intercepting MCP requests. + * @param context The request context. + * @param next A function that calls the next middleware or the implementation handler. + */ +export type McpMiddleware = (context: McpMiddlewareContext, next: () => Promise) => Promise; + /** * High-level MCP server that provides a simpler API for working with resources, tools, and prompts. * For advanced usage (like sending notifications or setting custom request handlers), use the underlying @@ -83,6 +118,8 @@ export class McpServer { } = {}; private _registeredTools: { [name: string]: RegisteredTool } = {}; private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; + private _middleware: McpMiddleware[] = []; + private _middlewareFrozen = false; private _experimental?: { tasks: ExperimentalMcpServerTasks }; constructor(serverInfo: Implementation, options?: ServerOptions) { @@ -105,12 +142,25 @@ export class McpServer { return this._experimental; } + /** + * Registers a middleware function. + * @param middleware The middleware to register. + */ + public use(middleware: McpMiddleware) { + if (this._middlewareFrozen) { + throw new Error('Cannot register middleware after the server has started or processed requests.'); + } + this._middleware.push(middleware); + return this; + } + /** * Attaches to the given transport, starts it, and starts listening for messages. * * The `server` object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward. */ async connect(transport: Transport): Promise { + this._middlewareFrozen = true; return await this.server.connect(transport); } @@ -139,99 +189,120 @@ export class McpServer { this.server.setRequestHandler( ListToolsRequestSchema, - (): ListToolsResult => ({ - tools: Object.entries(this._registeredTools) - .filter(([, tool]) => tool.enabled) - .map(([name, tool]): Tool => { - const toolDefinition: Tool = { - name, - title: tool.title, - description: tool.description, - inputSchema: (() => { - const obj = normalizeObjectSchema(tool.inputSchema); - return obj - ? (toJsonSchemaCompat(obj, { - strictUnions: true, - pipeStrategy: 'input' - }) as Tool['inputSchema']) - : EMPTY_OBJECT_JSON_SCHEMA; - })(), - annotations: tool.annotations, - execution: tool.execution, - _meta: tool._meta - }; - - if (tool.outputSchema) { - const obj = normalizeObjectSchema(tool.outputSchema); - if (obj) { - toolDefinition.outputSchema = toJsonSchemaCompat(obj, { - strictUnions: true, - pipeStrategy: 'output' - }) as Tool['outputSchema']; - } - } - - return toolDefinition; - }) - }) + (request: ListToolsRequest, extra: RequestHandlerExtra) => + this._executeRequest>( + (): Promise => + Promise.resolve({ + tools: Object.entries(this._registeredTools) + .filter(([, tool]) => tool.enabled) + .map(([name, tool]): Tool => { + const toolDefinition: Tool = { + name, + title: tool.title, + description: tool.description, + inputSchema: (() => { + const obj = normalizeObjectSchema(tool.inputSchema); + return obj + ? (toJsonSchemaCompat(obj, { + strictUnions: true, + pipeStrategy: 'input' + }) as Tool['inputSchema']) + : EMPTY_OBJECT_JSON_SCHEMA; + })(), + annotations: tool.annotations, + execution: tool.execution, + _meta: tool._meta + }; + + if (tool.outputSchema) { + const obj = normalizeObjectSchema(tool.outputSchema); + if (obj) { + toolDefinition.outputSchema = toJsonSchemaCompat(obj, { + strictUnions: true, + pipeStrategy: 'output' + }) as Tool['outputSchema']; + } + } + + return toolDefinition; + }) + }), + request, + extra + ) ); - this.server.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { - try { - const tool = this._registeredTools[request.params.name]; - if (!tool) { - throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} not found`); - } - if (!tool.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} disabled`); - } + this.server.setRequestHandler( + CallToolRequestSchema, + (request: CallToolRequest, extra: RequestHandlerExtra) => + this._executeRequest< + CallToolResult | CreateTaskResult, + CallToolRequest, + RequestHandlerExtra + >( + async ( + request: CallToolRequest, + extra: RequestHandlerExtra + ): Promise => { + try { + const tool = this._registeredTools[request.params.name]; + if (!tool) { + throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} not found`); + } + if (!tool.enabled) { + throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} disabled`); + } - const isTaskRequest = !!request.params.task; - const taskSupport = tool.execution?.taskSupport; - const isTaskHandler = 'createTask' in (tool.handler as AnyToolHandler); + const isTaskRequest = !!request.params.task; + const taskSupport = tool.execution?.taskSupport; + const isTaskHandler = 'createTask' in (tool.handler as AnyToolHandler); - // Validate task hint configuration - if ((taskSupport === 'required' || taskSupport === 'optional') && !isTaskHandler) { - throw new McpError( - ErrorCode.InternalError, - `Tool ${request.params.name} has taskSupport '${taskSupport}' but was not registered with registerToolTask` - ); - } + // Validate task hint configuration + if ((taskSupport === 'required' || taskSupport === 'optional') && !isTaskHandler) { + throw new McpError( + ErrorCode.InternalError, + `Tool ${request.params.name} has taskSupport '${taskSupport}' but was not registered with registerToolTask` + ); + } - // Handle taskSupport 'required' without task augmentation - if (taskSupport === 'required' && !isTaskRequest) { - throw new McpError( - ErrorCode.MethodNotFound, - `Tool ${request.params.name} requires task augmentation (taskSupport: 'required')` - ); - } + // Handle taskSupport 'required' without task augmentation + if (taskSupport === 'required' && !isTaskRequest) { + throw new McpError( + ErrorCode.MethodNotFound, + `Tool ${request.params.name} requires task augmentation (taskSupport: 'required')` + ); + } - // Handle taskSupport 'optional' without task augmentation - automatic polling - if (taskSupport === 'optional' && !isTaskRequest && isTaskHandler) { - return await this.handleAutomaticTaskPolling(tool, request, extra); - } + // Handle taskSupport 'optional' without task augmentation - automatic polling + if (taskSupport === 'optional' && !isTaskRequest && isTaskHandler) { + return await this.handleAutomaticTaskPolling(tool, request, extra); + } - // Normal execution path - const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); - const result = await this.executeToolHandler(tool, args, extra); + // Normal execution path + const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); + const result = await this.executeToolHandler(tool, args, extra); - // Return CreateTaskResult immediately for task requests - if (isTaskRequest) { - return result; - } + // Return CreateTaskResult immediately for task requests + if (isTaskRequest) { + return result; + } - // Validate output schema for non-task requests - await this.validateToolOutput(tool, result, request.params.name); - return result; - } catch (error) { - if (error instanceof McpError) { - if (error.code === ErrorCode.UrlElicitationRequired) { - throw error; // Return the error to the caller without wrapping in CallToolResult - } - } - return this.createToolError(error instanceof Error ? error.message : String(error)); - } - }); + // Validate output schema for non-task requests + await this.validateToolOutput(tool, result, request.params.name); + return result; + } catch (error) { + if (error instanceof McpError) { + if (error.code === ErrorCode.UrlElicitationRequired) { + throw error; + } + } + return this.createToolError(error instanceof Error ? error.message : String(error)); + } + }, + request, + extra + ) + ); this._toolHandlersInitialized = true; } @@ -323,50 +394,56 @@ export class McpServer { /** * Executes a tool handler (either regular or task-based). */ - private async executeToolHandler( + private async executeToolHandler( tool: RegisteredTool, args: unknown, - extra: RequestHandlerExtra + extra: ExtraT ): Promise { const handler = tool.handler as AnyToolHandler; const isTaskHandler = 'createTask' in handler; if (isTaskHandler) { - if (!extra.taskStore) { + const hasTaskStore = 'taskStore' in (extra as object) && (extra as { taskStore?: unknown }).taskStore; + if (!hasTaskStore) { throw new Error('No task store provided.'); } - const taskExtra = { ...extra, taskStore: extra.taskStore }; + const taskExtra = { + ...extra, + taskStore: (extra as { taskStore: unknown }).taskStore + }; if (tool.inputSchema) { const typedHandler = handler as ToolTaskHandler; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve(typedHandler.createTask(args as any, taskExtra)); + return await Promise.resolve( + typedHandler.createTask(args as ShapeOutput, taskExtra as unknown as CreateTaskRequestHandlerExtra) + ); } else { const typedHandler = handler as ToolTaskHandler; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((typedHandler.createTask as any)(taskExtra)); + return await Promise.resolve(typedHandler.createTask(taskExtra as unknown as CreateTaskRequestHandlerExtra)); } } if (tool.inputSchema) { const typedHandler = handler as ToolCallback; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve(typedHandler(args as any, extra)); + return await Promise.resolve( + typedHandler( + args as ShapeOutput, + extra as unknown as RequestHandlerExtra + ) + ); } else { const typedHandler = handler as ToolCallback; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((typedHandler as any)(extra)); + return await Promise.resolve(typedHandler(extra as unknown as RequestHandlerExtra)); } } /** * Handles automatic task polling for tools with taskSupport 'optional'. */ - private async handleAutomaticTaskPolling( - tool: RegisteredTool, - request: RequestT, - extra: RequestHandlerExtra - ): Promise { + private async handleAutomaticTaskPolling< + RequestT extends CallToolRequest, + ExtraT extends RequestHandlerExtra + >(tool: RegisteredTool, request: RequestT, extra: ExtraT): Promise { if (!extra.taskStore) { throw new Error('No task store provided for task-capable tool.'); } @@ -377,9 +454,15 @@ export class McpServer { const taskExtra = { ...extra, taskStore: extra.taskStore }; const createTaskResult: CreateTaskResult = args // undefined only if tool.inputSchema is undefined - ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, taskExtra)) - : // eslint-disable-next-line @typescript-eslint/no-explicit-any - await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(taskExtra)); + ? await Promise.resolve( + (handler as ToolTaskHandler).createTask( + args as ShapeOutput, + taskExtra as unknown as CreateTaskRequestHandlerExtra + ) + ) + : await Promise.resolve( + (handler as ToolTaskHandler).createTask(taskExtra as unknown as CreateTaskRequestHandlerExtra) + ); // Poll until completion const taskId = createTaskResult.task.taskId; @@ -499,33 +582,49 @@ export class McpServer { } }); - this.server.setRequestHandler(ListResourcesRequestSchema, async (request, extra) => { - const resources = Object.entries(this._registeredResources) - .filter(([_, resource]) => resource.enabled) - .map(([uri, resource]) => ({ - uri, - name: resource.name, - ...resource.metadata - })); - - const templateResources: Resource[] = []; - for (const template of Object.values(this._registeredResourceTemplates)) { - if (!template.resourceTemplate.listCallback) { - continue; - } + this.server.setRequestHandler( + ListResourcesRequestSchema, + (request: ListResourcesRequest, extra: RequestHandlerExtra) => + this._executeRequest< + ListResourcesResult, + ListResourcesRequest, + RequestHandlerExtra + >( + async (request: ListResourcesRequest, extra: RequestHandlerExtra) => { + const resources = Object.entries(this._registeredResources) + .filter(([_, resource]) => resource.enabled) + .map(([uri, resource]) => ({ + uri, + name: resource.name, + ...resource.metadata + })); + + const templateResources: Resource[] = []; + for (const template of Object.values(this._registeredResourceTemplates)) { + if (!template.resourceTemplate.listCallback) { + continue; + } - const result = await template.resourceTemplate.listCallback(extra); - for (const resource of result.resources) { - templateResources.push({ - ...template.metadata, - // the defined resource metadata should override the template metadata if present - ...resource - }); - } - } + const result = await template.resourceTemplate.listCallback( + extra as unknown as RequestHandlerExtra + ); + for (const resource of result.resources) { + templateResources.push({ + ...template.metadata, + // the defined resource metadata should override the template metadata if present + ...resource + }); + } + } - return { resources: [...resources, ...templateResources] }; - }); + return { + resources: [...resources, ...templateResources] + }; + }, + request, + extra + ) + ); this.server.setRequestHandler(ListResourceTemplatesRequestSchema, async () => { const resourceTemplates = Object.entries(this._registeredResourceTemplates).map(([name, template]) => ({ @@ -537,28 +636,40 @@ export class McpServer { return { resourceTemplates }; }); - this.server.setRequestHandler(ReadResourceRequestSchema, async (request, extra) => { - const uri = new URL(request.params.uri); - - // First check for exact resource match - const resource = this._registeredResources[uri.toString()]; - if (resource) { - if (!resource.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} disabled`); - } - return resource.readCallback(uri, extra); - } + this.server.setRequestHandler( + ReadResourceRequestSchema, + (request: ReadResourceRequest, extra: RequestHandlerExtra) => + this._executeRequest>( + async (request: ReadResourceRequest, extra: RequestHandlerExtra) => { + const uri = new URL(request.params.uri); + + // First check for exact resource match + const resource = this._registeredResources[uri.toString()]; + if (resource) { + if (!resource.enabled) { + throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} disabled`); + } + return resource.readCallback(uri, extra as unknown as RequestHandlerExtra); + } - // Then check templates - for (const template of Object.values(this._registeredResourceTemplates)) { - const variables = template.resourceTemplate.uriTemplate.match(uri.toString()); - if (variables) { - return template.readCallback(uri, variables, extra); - } - } + // Then check templates + for (const template of Object.values(this._registeredResourceTemplates)) { + const variables = template.resourceTemplate.uriTemplate.match(uri.toString()); + if (variables) { + return template.readCallback( + uri, + variables, + extra as unknown as RequestHandlerExtra + ); + } + } - throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} not found`); - }); + throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} not found`); + }, + request, + extra + ) + ); this._resourceHandlersInitialized = true; } @@ -581,48 +692,69 @@ export class McpServer { this.server.setRequestHandler( ListPromptsRequestSchema, - (): ListPromptsResult => ({ - prompts: Object.entries(this._registeredPrompts) - .filter(([, prompt]) => prompt.enabled) - .map(([name, prompt]): Prompt => { - return { - name, - title: prompt.title, - description: prompt.description, - arguments: prompt.argsSchema ? promptArgumentsFromSchema(prompt.argsSchema) : undefined - }; - }) - }) + (request: ListPromptsRequest, extra: RequestHandlerExtra) => + this._executeRequest>( + (): Promise => + Promise.resolve({ + prompts: Object.entries(this._registeredPrompts) + .filter(([, prompt]) => prompt.enabled) + .map(([name, prompt]): Prompt => { + return { + name, + title: prompt.title, + description: prompt.description, + arguments: prompt.argsSchema ? promptArgumentsFromSchema(prompt.argsSchema) : undefined + }; + }) + }), + request, + extra + ) ); - this.server.setRequestHandler(GetPromptRequestSchema, async (request, extra): Promise => { - const prompt = this._registeredPrompts[request.params.name]; - if (!prompt) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} not found`); - } + this.server.setRequestHandler( + GetPromptRequestSchema, + (request: GetPromptRequest, extra: RequestHandlerExtra) => + this._executeRequest>( + async ( + request: GetPromptRequest, + extra: RequestHandlerExtra + ): Promise => { + const prompt = this._registeredPrompts[request.params.name]; + if (!prompt) { + throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} not found`); + } - if (!prompt.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} disabled`); - } + if (!prompt.enabled) { + throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} disabled`); + } - if (prompt.argsSchema) { - const argsObj = normalizeObjectSchema(prompt.argsSchema) as AnyObjectSchema; - const parseResult = await safeParseAsync(argsObj, request.params.arguments); - if (!parseResult.success) { - const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; - const errorMessage = getParseErrorMessage(error); - throw new McpError(ErrorCode.InvalidParams, `Invalid arguments for prompt ${request.params.name}: ${errorMessage}`); - } + if (prompt.argsSchema) { + const argsObj = normalizeObjectSchema(prompt.argsSchema) as AnyObjectSchema; + const parseResult = await safeParseAsync(argsObj, request.params.arguments); + if (!parseResult.success) { + const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; + const errorMessage = getParseErrorMessage(error); + throw new McpError( + ErrorCode.InvalidParams, + `Invalid arguments for prompt ${request.params.name}: ${errorMessage}` + ); + } - const args = parseResult.data; - const cb = prompt.callback as PromptCallback; - return await Promise.resolve(cb(args, extra)); - } else { - const cb = prompt.callback as PromptCallback; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((cb as any)(extra)); - } - }); + const args = parseResult.data; + const cb = prompt.callback as PromptCallback; + return await Promise.resolve( + cb(args, extra as unknown as RequestHandlerExtra) + ); + } else { + const cb = prompt.callback as PromptCallback; + return await Promise.resolve(cb(extra as unknown as RequestHandlerExtra)); + } + }, + request, + extra + ) + ); this._promptHandlersInitialized = true; } @@ -770,13 +902,25 @@ export class McpServer { update: updates => { if (typeof updates.uri !== 'undefined' && updates.uri !== uri) { delete this._registeredResources[uri]; - if (updates.uri) this._registeredResources[updates.uri] = registeredResource; + if (updates.uri) { + this._registeredResources[updates.uri] = registeredResource; + } + } + if (typeof updates.name !== 'undefined') { + registeredResource.name = updates.name; + } + if (typeof updates.title !== 'undefined') { + registeredResource.title = updates.title; + } + if (typeof updates.metadata !== 'undefined') { + registeredResource.metadata = updates.metadata; + } + if (typeof updates.callback !== 'undefined') { + registeredResource.readCallback = updates.callback; + } + if (typeof updates.enabled !== 'undefined') { + registeredResource.enabled = updates.enabled; } - if (typeof updates.name !== 'undefined') registeredResource.name = updates.name; - if (typeof updates.title !== 'undefined') registeredResource.title = updates.title; - if (typeof updates.metadata !== 'undefined') registeredResource.metadata = updates.metadata; - if (typeof updates.callback !== 'undefined') registeredResource.readCallback = updates.callback; - if (typeof updates.enabled !== 'undefined') registeredResource.enabled = updates.enabled; this.sendResourceListChanged(); } }; @@ -803,13 +947,25 @@ export class McpServer { update: updates => { if (typeof updates.name !== 'undefined' && updates.name !== name) { delete this._registeredResourceTemplates[name]; - if (updates.name) this._registeredResourceTemplates[updates.name] = registeredResourceTemplate; + if (updates.name) { + this._registeredResourceTemplates[updates.name] = registeredResourceTemplate; + } + } + if (typeof updates.title !== 'undefined') { + registeredResourceTemplate.title = updates.title; + } + if (typeof updates.template !== 'undefined') { + registeredResourceTemplate.resourceTemplate = updates.template; + } + if (typeof updates.metadata !== 'undefined') { + registeredResourceTemplate.metadata = updates.metadata; + } + if (typeof updates.callback !== 'undefined') { + registeredResourceTemplate.readCallback = updates.callback; + } + if (typeof updates.enabled !== 'undefined') { + registeredResourceTemplate.enabled = updates.enabled; } - if (typeof updates.title !== 'undefined') registeredResourceTemplate.title = updates.title; - if (typeof updates.template !== 'undefined') registeredResourceTemplate.resourceTemplate = updates.template; - if (typeof updates.metadata !== 'undefined') registeredResourceTemplate.metadata = updates.metadata; - if (typeof updates.callback !== 'undefined') registeredResourceTemplate.readCallback = updates.callback; - if (typeof updates.enabled !== 'undefined') registeredResourceTemplate.enabled = updates.enabled; this.sendResourceListChanged(); } }; @@ -844,13 +1000,25 @@ export class McpServer { update: updates => { if (typeof updates.name !== 'undefined' && updates.name !== name) { delete this._registeredPrompts[name]; - if (updates.name) this._registeredPrompts[updates.name] = registeredPrompt; + if (updates.name) { + this._registeredPrompts[updates.name] = registeredPrompt; + } + } + if (typeof updates.title !== 'undefined') { + registeredPrompt.title = updates.title; + } + if (typeof updates.description !== 'undefined') { + registeredPrompt.description = updates.description; + } + if (typeof updates.argsSchema !== 'undefined') { + registeredPrompt.argsSchema = objectFromShape(updates.argsSchema); + } + if (typeof updates.callback !== 'undefined') { + registeredPrompt.callback = updates.callback; + } + if (typeof updates.enabled !== 'undefined') { + registeredPrompt.enabled = updates.enabled; } - if (typeof updates.title !== 'undefined') registeredPrompt.title = updates.title; - if (typeof updates.description !== 'undefined') registeredPrompt.description = updates.description; - if (typeof updates.argsSchema !== 'undefined') registeredPrompt.argsSchema = objectFromShape(updates.argsSchema); - if (typeof updates.callback !== 'undefined') registeredPrompt.callback = updates.callback; - if (typeof updates.enabled !== 'undefined') registeredPrompt.enabled = updates.enabled; this.sendPromptListChanged(); } }; @@ -903,16 +1071,34 @@ export class McpServer { validateAndWarnToolName(updates.name); } delete this._registeredTools[name]; - if (updates.name) this._registeredTools[updates.name] = registeredTool; + if (updates.name) { + this._registeredTools[updates.name] = registeredTool; + } + } + if (typeof updates.title !== 'undefined') { + registeredTool.title = updates.title; + } + if (typeof updates.description !== 'undefined') { + registeredTool.description = updates.description; + } + if (typeof updates.paramsSchema !== 'undefined') { + registeredTool.inputSchema = objectFromShape(updates.paramsSchema); + } + if (typeof updates.outputSchema !== 'undefined') { + registeredTool.outputSchema = objectFromShape(updates.outputSchema); + } + if (typeof updates.callback !== 'undefined') { + registeredTool.handler = updates.callback; + } + if (typeof updates.annotations !== 'undefined') { + registeredTool.annotations = updates.annotations; + } + if (typeof updates._meta !== 'undefined') { + registeredTool._meta = updates._meta; + } + if (typeof updates.enabled !== 'undefined') { + registeredTool.enabled = updates.enabled; } - if (typeof updates.title !== 'undefined') registeredTool.title = updates.title; - if (typeof updates.description !== 'undefined') registeredTool.description = updates.description; - if (typeof updates.paramsSchema !== 'undefined') registeredTool.inputSchema = objectFromShape(updates.paramsSchema); - if (typeof updates.outputSchema !== 'undefined') registeredTool.outputSchema = objectFromShape(updates.outputSchema); - if (typeof updates.callback !== 'undefined') registeredTool.handler = updates.callback; - if (typeof updates.annotations !== 'undefined') registeredTool.annotations = updates.annotations; - if (typeof updates._meta !== 'undefined') registeredTool._meta = updates._meta; - if (typeof updates.enabled !== 'undefined') registeredTool.enabled = updates.enabled; this.sendToolListChanged(); } }; @@ -1210,6 +1396,67 @@ export class McpServer { this.server.sendPromptListChanged(); } } + + private async _executeRequest( + handler: (request: RequestT, extra: ExtraT) => Promise, + request: RequestT, + extra: ExtraT + ): Promise { + this._middlewareFrozen = true; + const middleware = this._middleware; + + // Optimized path: If there are no middleware, just run the handler + if (middleware.length === 0) { + return handler(request, extra); + } + + let result: ResultT | undefined; + let handlerError: unknown; + + // Wrap the handler as the final middleware + const leafMiddleware: McpMiddleware = async (_context, _next) => { + try { + result = await handler(request, extra); + } catch (e) { + handlerError = e; + } + }; + + const chain = [...middleware, leafMiddleware]; + + // Execute the chain + // Protect against creating a context with incorrect types by casting + const context: McpMiddlewareContext = { + request: request as unknown as ServerRequest, + extra: extra as unknown as RequestHandlerExtra, + state: {} + }; + + const executeChain = async (i: number): Promise => { + if (i >= chain.length) { + return; + } + const fn = chain[i] as McpMiddleware; + + let nextCalled = false; + await fn(context, async () => { + if (nextCalled) { + throw new Error('next() called multiple times in middleware'); + } + nextCalled = true; + await executeChain(i + 1); + }); + }; + + await executeChain(0); + + if (handlerError) { + throw handlerError; + } + + // Return result, asserting it exists (handlers should generally return something) + return result as ResultT; + } } /** diff --git a/packages/server/test/server/mcpServer.test.ts b/packages/server/test/server/mcpServer.test.ts new file mode 100644 index 000000000..b4649a29b --- /dev/null +++ b/packages/server/test/server/mcpServer.test.ts @@ -0,0 +1,518 @@ +import { McpServer } from '../../src/server/mcp.js'; +import { JSONRPCMessage } from '@modelcontextprotocol/core'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +describe('McpServer Middleware', () => { + let server: McpServer; + + beforeEach(() => { + server = new McpServer({ + name: 'test-server', + version: '1.0.0' + }); + }); + + // Helper to simulate a tool call and capture the response + async function simulateCallTool(toolName: string): Promise { + let serverOnMessage: (message: any) => Promise; + let capturedResponse: JSONRPCMessage | undefined; + let resolveSend: () => void; + const sendPromise = new Promise(resolve => { + resolveSend = resolve; + }); + + const transport = { + start: vi.fn(), + send: vi.fn().mockImplementation(async msg => { + capturedResponse = msg as JSONRPCMessage; + resolveSend(); + }), + close: vi.fn(), + set onmessage(handler: any) { + serverOnMessage = handler; + } + }; + + await server.connect(transport); + + const request = { + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { + name: toolName, + arguments: {} + } + }; + + if (!serverOnMessage!) { + throw new Error('Server did not attach onMessage listener'); + } + + // Trigger request + serverOnMessage(request); + + // Wait for response + await sendPromise; + + return capturedResponse!; + } + + it('should execute middleware in registration order (Onion model)', async () => { + const sequence: string[] = []; + + server.use(async (context, next) => { + sequence.push('mw1 start'); + await next(); + sequence.push('mw1 end'); + }); + + server.use(async (context, next) => { + sequence.push('mw2 start'); + await next(); + sequence.push('mw2 end'); + }); + + server.tool('test-tool', {}, async () => { + sequence.push('handler'); + return { content: [{ type: 'text', text: 'result' }] }; + }); + + await simulateCallTool('test-tool'); + + expect(sequence).toEqual(['mw1 start', 'mw2 start', 'handler', 'mw2 end', 'mw1 end']); + }); + + it('should short-circuit if next() is not called', async () => { + const sequence: string[] = []; + + server.use(async (context, next) => { + sequence.push('mw1 start'); + // next() NOT called + sequence.push('mw1 end'); + }); + + server.use(async (context, next) => { + sequence.push('mw2 start'); + await next(); + }); + + server.tool('test-tool', {}, async () => { + sequence.push('handler'); + return { content: [{ type: 'text', text: 'result' }] }; + }); + + await simulateCallTool('test-tool'); + + // mw2 and handler should NOT run + expect(sequence).toEqual(['mw1 start', 'mw1 end']); + }); + + it('should allow middleware to communicate via ctx.state', async () => { + const server = new McpServer({ name: 'test', version: '1.0' }); + server.use(async (ctx, next) => { + ctx.state.value = 1; + await next(); + }); + server.use(async (ctx, next) => { + ctx.state.value = (ctx.state.value as number) + 1; + await next(); + }); + + // Use a tool list request to trigger the chain + server.tool('test-tool', {}, async () => ({ content: [{ type: 'text', text: 'ok' }] })); + + let capturedState: any; + server.use(async (ctx, next) => { + capturedState = ctx.state; + await next(); + }); + + let resolveSend: () => void; + const sendPromise = new Promise(resolve => { + resolveSend = resolve; + }); + + const transport = { + start: vi.fn(), + send: vi.fn().mockImplementation(async () => { + resolveSend(); + }), + close: vi.fn() + }; + await server.connect(transport as any); + // @ts-ignore + const onMsg = (server.server.transport as any).onmessage; + onMsg({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }); + + await sendPromise; + + expect(capturedState).toBeDefined(); + expect(capturedState.value).toBe(2); + }); + + it('should execute middleware for other methods (e.g. tools/list)', async () => { + // For this check, we need to simulate tools/list. + // We can adapt our helper or just copy-paste a simplified version here for variety. + const sequence: string[] = []; + server.use(async (context, next) => { + sequence.push('mw'); + await next(); + }); + + // Register a dummy tool to ensure tools/list handler is set up + server.tool('dummy', {}, async () => ({ content: [] })); + + let serverOnMessage: any; + let resolveSend: any; + const p = new Promise(r => (resolveSend = r)); + const transport = { + start: vi.fn(), + send: vi.fn().mockImplementation(() => resolveSend()), + close: vi.fn(), + set onmessage(h: any) { + serverOnMessage = h; + } + }; + await server.connect(transport); + + serverOnMessage({ + jsonrpc: '2.0', + id: 1, + method: 'tools/list', + params: {} + }); + await p; + + expect(sequence).toEqual(['mw']); + }); + + it('should allow middleware to catch errors from downstream', async () => { + server.use(async (context, next) => { + try { + await next(); + } catch (e) { + // Suppress error + } + }); + + server.tool('error-tool', {}, async () => { + throw new Error('Boom'); + }); + + const response = await simulateCallTool('error-tool'); + + // Since middleware swallowed the error, the handler returns undefined (or whatever executed). + // Actually, if handler throws and middleware catches, `result` in `_executeRequest` will be undefined. + // The server transport might expect a result. + // Typescript core SDK might throw if result is missing maybe? + // Or it sends a success response with "undefined"? + + // Let's check what response we got. If error was swallowed, it shouldn't be an error response. + expect((response as any).error).toBeUndefined(); + }); + + it('should propagate errors if middleware throws', async () => { + server.use(async (context, next) => { + throw new Error('Middleware Error'); + }); + + server.tool('test-tool', {}, async () => ({ content: [] })); + + const response = await simulateCallTool('test-tool'); + + // Standard JSON-RPC error response + expect((response as any).error).toBeDefined(); + expect((response as any).error.message).toContain('Middleware Error'); + }); + + it('should throw an error if next() is called multiple times', async () => { + server.use(async (context, next) => { + await next(); + await next(); // Second call should throw + }); + + server.tool('test-tool', {}, async () => ({ content: [] })); + + const response = await simulateCallTool('test-tool'); + + // Expect an error response due to double-call + expect((response as any).error).toBeDefined(); + expect((response as any).error.message).toContain('next() called multiple times'); + }); + + it('should respect async timing (middleware can await)', async () => { + const sequence: string[] = []; + const delay = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + + server.use(async (context, next) => { + sequence.push('mw1 start'); + await delay(10); // Wait 10ms + sequence.push('mw1 after delay'); + await next(); + sequence.push('mw1 end'); + }); + + server.use(async (context, next) => { + sequence.push('mw2 start'); + await next(); + }); + + server.tool('test-tool', {}, async () => { + sequence.push('handler'); + return { content: [] }; + }); + + await simulateCallTool('test-tool'); + + expect(sequence).toEqual(['mw1 start', 'mw1 after delay', 'mw2 start', 'handler', 'mw1 end']); + }); + + it('should throw an error if use() is called after connect()', async () => { + const transport = { + start: vi.fn(), + send: vi.fn(), + close: vi.fn(), + set onmessage(_handler: any) {} + }; + + await server.connect(transport); + + // Trying to register middleware after connect should throw + expect(() => { + server.use(async (context, next) => { + await next(); + }); + }).toThrow('Cannot register middleware after the server has started'); + }); + + // ============================================================ + // Real World Use Case Integration Tests + // ============================================================ + + describe('Real World Use Cases', () => { + it('Logging: should observe request method and capture response timing', async () => { + const logs: { method: string; durationMs: number }[] = []; + + server.use(async (context, next) => { + const start = Date.now(); + const method = (context.request as any).method || 'unknown'; + + await next(); + + const durationMs = Date.now() - start; + logs.push({ method, durationMs }); + }); + + server.tool('fast-tool', {}, async () => { + return { content: [{ type: 'text', text: 'done' }] }; + }); + + await simulateCallTool('fast-tool'); + + expect(logs).toHaveLength(1); + expect(logs[0]!.method).toBe('tools/call'); + expect(logs[0]!.durationMs).toBeGreaterThanOrEqual(0); + }); + + it('Auth: should short-circuit unauthorized requests', async () => { + const VALID_TOKEN = 'secret-token'; + + server.use(async (context, next) => { + // Simulate checking for an auth token in extra/authInfo + const authInfo = (context.extra as any)?.authInfo; + + // In real usage, authInfo would come from the transport. + // For this test, we simulate by checking a header-like property. + // Since we can't inject authInfo easily, we'll check a custom property. + const token = (context.request as any).params?._authToken; + + if (token !== VALID_TOKEN) { + // Short-circuit: don't call next(), effectively blocking the request + // In a real scenario, you might throw an error or set a response + throw new Error('Unauthorized'); + } + + await next(); + }); + + server.tool('protected-tool', {}, async () => { + return { content: [{ type: 'text', text: 'secret data' }] }; + }); + + // Simulate unauthorized request (no token) + const response = await simulateCallTool('protected-tool'); + + expect((response as any).error).toBeDefined(); + expect((response as any).error.message).toContain('Unauthorized'); + }); + + it('Activity Aggregation: should intercept tools/list and count discoveries', async () => { + let toolListCount = 0; + let toolCallCount = 0; + + server.use(async (context, next) => { + const method = (context.request as any).method; + + if (method === 'tools/list') { + toolListCount++; + } else if (method === 'tools/call') { + toolCallCount++; + } + + await next(); + }); + + server.tool('my-tool', {}, async () => ({ content: [] })); + + // Simulate tools/list + let serverOnMessage: any; + let resolveSend: any; + const p = new Promise(r => (resolveSend = r)); + const transport = { + start: vi.fn(), + send: vi.fn().mockImplementation(() => resolveSend()), + close: vi.fn(), + set onmessage(h: any) { + serverOnMessage = h; + } + }; + await server.connect(transport); + + // First: tools/list + serverOnMessage({ + jsonrpc: '2.0', + id: 1, + method: 'tools/list', + params: {} + }); + await p; + + // Second: tools/call (need new promise) + let resolveSend2: any; + const p2 = new Promise(r => (resolveSend2 = r)); + transport.send.mockImplementation(() => resolveSend2()); + + serverOnMessage({ + jsonrpc: '2.0', + id: 2, + method: 'tools/call', + params: { name: 'my-tool', arguments: {} } + }); + await p2; + + expect(toolListCount).toBe(1); + expect(toolCallCount).toBe(1); + }); + }); + + // ============================================================ + // Failure Mode Verification Tests + // ============================================================ + + describe('Failure Mode Verification', () => { + it('Pre-next: error thrown before next() maps to JSON-RPC error', async () => { + server.use(async (context, next) => { + // Error thrown BEFORE calling next() + throw new Error('Pre-next failure'); + }); + + server.tool('test-tool', {}, async () => ({ content: [] })); + + const response = await simulateCallTool('test-tool'); + + // Should be a proper JSON-RPC error response + expect((response as any).jsonrpc).toBe('2.0'); + expect((response as any).id).toBe(1); + expect((response as any).error).toBeDefined(); + expect((response as any).error.message).toContain('Pre-next failure'); + // Server should not crash - we got a response + }); + + it('Post-next: error thrown after next() maps to JSON-RPC error', async () => { + server.use(async (context, next) => { + await next(); + // Error thrown AFTER calling next() + throw new Error('Post-next failure'); + }); + + server.tool('test-tool', {}, async () => ({ content: [] })); + + const response = await simulateCallTool('test-tool'); + + // Should be a proper JSON-RPC error response + expect((response as any).jsonrpc).toBe('2.0'); + expect((response as any).id).toBe(1); + expect((response as any).error).toBeDefined(); + expect((response as any).error.message).toContain('Post-next failure'); + }); + + it('Handler: error thrown in tool handler returns error result (SDK behavior)', async () => { + // No middleware - test pure handler error + server.tool('failing-tool', {}, async () => { + throw new Error('Handler failure'); + }); + + const response = await simulateCallTool('failing-tool'); + + // MCP SDK converts handler errors to result with isError: true + // (not JSON-RPC error - this is intentional SDK behavior) + expect((response as any).jsonrpc).toBe('2.0'); + expect((response as any).id).toBe(1); + expect((response as any).result).toBeDefined(); + expect((response as any).result.isError).toBe(true); + expect((response as any).result.content[0]!.text).toContain('Handler failure'); + }); + + it('Multiple middleware: error in second middleware propagates correctly', async () => { + const sequence: string[] = []; + + server.use(async (context, next) => { + sequence.push('mw1 start'); + try { + await next(); + } catch (e) { + sequence.push('mw1 caught'); + throw e; // Re-throw to propagate + } + sequence.push('mw1 end'); + }); + + server.use(async (context, next) => { + sequence.push('mw2 start'); + throw new Error('mw2 failure'); + }); + + server.tool('test-tool', {}, async () => ({ content: [] })); + + const response = await simulateCallTool('test-tool'); + + expect((response as any).error).toBeDefined(); + expect((response as any).error.message).toContain('mw2 failure'); + // Verify mw1 caught the error + expect(sequence).toContain('mw1 caught'); + // mw1 end should NOT be in sequence since error was re-thrown + expect(sequence).not.toContain('mw1 end'); + }); + + it('Error contains proper JSON-RPC error code', async () => { + server.use(async (context, next) => { + throw new Error('Generic middleware error'); + }); + + server.tool('test-tool', {}, async () => ({ content: [] })); + + const response = await simulateCallTool('test-tool'); + + expect((response as any).error).toBeDefined(); + // JSON-RPC internal error code is -32603 + expect((response as any).error.code).toBeDefined(); + expect(typeof (response as any).error.code).toBe('number'); + }); + }); +});