Skip to content

Commit b8f8ee8

Browse files
authored
chore: improve ability to use custom tools MCP-295 (#732)
1 parent c3044df commit b8f8ee8

File tree

10 files changed

+222
-31
lines changed

10 files changed

+222
-31
lines changed

package-lock.json

Lines changed: 0 additions & 23 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@
1414
"types": "./dist/cjs/lib.d.ts",
1515
"default": "./dist/cjs/lib.js"
1616
}
17+
},
18+
"./tools": {
19+
"import": {
20+
"types": "./dist/esm/tools/index.d.ts",
21+
"default": "./dist/esm/tools/index.js"
22+
},
23+
"require": {
24+
"types": "./dist/cjs/tools/index.d.ts",
25+
"default": "./dist/cjs/tools/index.js"
26+
}
1727
}
1828
},
1929
"main": "./dist/cjs/lib.js",

src/lib.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ export { Session, type SessionOptions } from "./common/session.js";
33
export { type UserConfig } from "./common/config.js";
44
export { LoggerBase, type LogPayload, type LoggerType, type LogLevel } from "./common/logger.js";
55
export { StreamableHttpRunner } from "./transports/streamableHttp.js";
6+
export { StdioRunner } from "./transports/stdio.js";
7+
export { TransportRunnerBase, type TransportRunnerConfig } from "./transports/base.js";
68
export {
79
ConnectionManager,
810
type AnyConnectionState,
@@ -21,3 +23,4 @@ export { ErrorCodes } from "./common/errors.js";
2123
export { Telemetry } from "./telemetry/telemetry.js";
2224
export { Keychain, registerGlobalSecretToRedact } from "./common/keychain.js";
2325
export type { Secret } from "./common/keychain.js";
26+
export { Elicitation } from "./elicitation.js";

src/server.ts

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import type { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
22
import type { Session } from "./common/session.js";
33
import type { Transport } from "@modelcontextprotocol/sdk/shared/transport.js";
4-
import { AtlasTools } from "./tools/atlas/tools.js";
5-
import { AtlasLocalTools } from "./tools/atlasLocal/tools.js";
6-
import { MongoDbTools } from "./tools/mongodb/tools.js";
74
import { Resources } from "./resources/resources.js";
85
import type { LogLevel } from "./common/logger.js";
96
import { LogId, McpLogger } from "./common/logger.js";
@@ -24,6 +21,7 @@ import { validateConnectionString } from "./helpers/connectionOptions.js";
2421
import { packageInfo } from "./common/packageInfo.js";
2522
import { type ConnectionErrorHandler } from "./common/connectionErrorHandler.js";
2623
import type { Elicitation } from "./elicitation.js";
24+
import { AllTools } from "./tools/index.js";
2725

2826
export interface ServerOptions {
2927
session: Session;
@@ -32,7 +30,28 @@ export interface ServerOptions {
3230
telemetry: Telemetry;
3331
elicitation: Elicitation;
3432
connectionErrorHandler: ConnectionErrorHandler;
35-
toolConstructors?: (new (params: ToolConstructorParams) => ToolBase)[];
33+
/**
34+
* Custom tool constructors to register with the server.
35+
* This will override any default tools. You can use both existing and custom tools by using the `mongodb-mcp-server/tools` export.
36+
*
37+
* ```ts
38+
* import { AllTools, ToolBase } from "mongodb-mcp-server/tools";
39+
* class CustomTool extends ToolBase {
40+
* name = "custom_tool";
41+
* // ...
42+
* }
43+
* const server = new Server({
44+
* session: mySession,
45+
* userConfig: myUserConfig,
46+
* mcpServer: myMcpServer,
47+
* telemetry: myTelemetry,
48+
* elicitation: myElicitation,
49+
* connectionErrorHandler: myConnectionErrorHandler,
50+
* tools: [...AllTools, CustomTool],
51+
* });
52+
* ```
53+
*/
54+
tools?: (new (params: ToolConstructorParams) => ToolBase)[];
3655
}
3756

3857
export class Server {
@@ -61,7 +80,7 @@ export class Server {
6180
telemetry,
6281
connectionErrorHandler,
6382
elicitation,
64-
toolConstructors,
83+
tools,
6584
}: ServerOptions) {
6685
this.startTime = Date.now();
6786
this.session = session;
@@ -70,7 +89,7 @@ export class Server {
7089
this.userConfig = userConfig;
7190
this.elicitation = elicitation;
7291
this.connectionErrorHandler = connectionErrorHandler;
73-
this.toolConstructors = toolConstructors ?? [...AtlasTools, ...MongoDbTools, ...AtlasLocalTools];
92+
this.toolConstructors = tools ?? AllTools;
7493
}
7594

7695
async connect(transport: Transport): Promise<void> {

src/tools/index.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import { AtlasTools } from "./atlas/tools.js";
2+
import { AtlasLocalTools } from "./atlasLocal/tools.js";
3+
import { MongoDbTools } from "./mongodb/tools.js";
4+
5+
const AllTools = [...MongoDbTools, ...AtlasTools, ...AtlasLocalTools];
6+
7+
export { AllTools, MongoDbTools, AtlasTools, AtlasLocalTools };
8+
9+
export {
10+
ToolBase,
11+
type ToolConstructorParams,
12+
type ToolCategory,
13+
type OperationType,
14+
type ToolArgs,
15+
type ToolExecutionContext,
16+
} from "./tool.js";

src/transports/base.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import type { AtlasLocalClientFactoryFn } from "../common/atlasLocal.js";
2020
import { defaultCreateAtlasLocalClient } from "../common/atlasLocal.js";
2121
import type { Client } from "@mongodb-js/atlas-local";
2222
import { VectorSearchEmbeddingsManager } from "../common/search/vectorSearchEmbeddingsManager.js";
23+
import type { ToolBase, ToolConstructorParams } from "../tools/tool.js";
2324

2425
export type TransportRunnerConfig = {
2526
userConfig: UserConfig;
@@ -28,6 +29,7 @@ export type TransportRunnerConfig = {
2829
createAtlasLocalClient?: AtlasLocalClientFactoryFn;
2930
additionalLoggers?: LoggerBase[];
3031
telemetryProperties?: Partial<CommonProperties>;
32+
tools?: (new (params: ToolConstructorParams) => ToolBase)[];
3133
};
3234

3335
export abstract class TransportRunnerBase {
@@ -38,6 +40,7 @@ export abstract class TransportRunnerBase {
3840
private readonly connectionErrorHandler: ConnectionErrorHandler;
3941
private readonly atlasLocalClient: Promise<Client | undefined>;
4042
private readonly telemetryProperties: Partial<CommonProperties>;
43+
private readonly tools?: (new (params: ToolConstructorParams) => ToolBase)[];
4144

4245
protected constructor({
4346
userConfig,
@@ -46,12 +49,14 @@ export abstract class TransportRunnerBase {
4649
createAtlasLocalClient = defaultCreateAtlasLocalClient,
4750
additionalLoggers = [],
4851
telemetryProperties = {},
52+
tools,
4953
}: TransportRunnerConfig) {
5054
this.userConfig = userConfig;
5155
this.createConnectionManager = createConnectionManager;
5256
this.connectionErrorHandler = connectionErrorHandler;
5357
this.atlasLocalClient = createAtlasLocalClient();
5458
this.telemetryProperties = telemetryProperties;
59+
this.tools = tools;
5560
const loggers: LoggerBase[] = [...additionalLoggers];
5661
if (this.userConfig.loggers.includes("stderr")) {
5762
loggers.push(new ConsoleLogger(Keychain.root));
@@ -114,6 +119,7 @@ export abstract class TransportRunnerBase {
114119
userConfig: this.userConfig,
115120
connectionErrorHandler: this.connectionErrorHandler,
116121
elicitation,
122+
tools: this.tools,
117123
});
118124

119125
// We need to create the MCP logger after the server is constructed

tests/integration/build.test.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ const projectRoot = path.resolve(currentDir, "../..");
1111
const esmPath = path.resolve(projectRoot, "dist/esm/lib.js");
1212
const cjsPath = path.resolve(projectRoot, "dist/cjs/lib.js");
1313

14+
const esmToolsPath = path.resolve(projectRoot, "dist/esm/tools/index.js");
15+
const cjsToolsPath = path.resolve(projectRoot, "dist/cjs/tools/index.js");
16+
1417
describe("Build Test", () => {
1518
it("should successfully require CommonJS module", () => {
1619
const require = createRequire(__filename);
@@ -49,7 +52,24 @@ describe("Build Test", () => {
4952
"Session",
5053
"StreamableHttpRunner",
5154
"Telemetry",
55+
"Elicitation",
5256
])
5357
);
5458
});
59+
60+
it("should have matching exports between CommonJS and ESM tools modules", async () => {
61+
// Import CommonJS module
62+
const require = createRequire(__filename);
63+
const cjsModule = require(cjsToolsPath) as Record<string, unknown>;
64+
65+
// Import ESM module
66+
const esmModule = (await import(esmToolsPath)) as Record<string, unknown>;
67+
68+
// Compare exports
69+
const cjsKeys = Object.keys(cjsModule).sort();
70+
const esmKeys = Object.keys(esmModule).sort();
71+
72+
expect(cjsKeys).toEqual(esmKeys);
73+
expect(cjsKeys).toEqual(expect.arrayContaining(["MongoDbTools", "AtlasTools", "AtlasLocalTools", "AllTools"]));
74+
});
5575
});

0 commit comments

Comments
 (0)