diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py index e6fb94118..ad85cf7b6 100644 --- a/src/strands/experimental/agent_config.py +++ b/src/strands/experimental/agent_config.py @@ -12,40 +12,23 @@ """ import json +import logging from pathlib import Path from typing import Any import jsonschema from jsonschema import ValidationError -# JSON Schema for agent configuration -AGENT_CONFIG_SCHEMA = { - "$schema": "http://json-schema.org/draft-07/schema#", - "title": "Agent Configuration", - "description": "Configuration schema for creating agents", - "type": "object", - "properties": { - "name": {"description": "Name of the agent", "type": ["string", "null"], "default": None}, - "model": { - "description": "The model ID to use for this agent. If not specified, uses the default model.", - "type": ["string", "null"], - "default": None, - }, - "prompt": { - "description": "The system prompt for the agent. Provides high level context to the agent.", - "type": ["string", "null"], - "default": None, - }, - "tools": { - "description": "List of tools the agent can use. Can be file paths, " - "Python module names, or @tool annotated functions in files.", - "type": "array", - "items": {"type": "string"}, - "default": [], - }, - }, - "additionalProperties": False, -} +from .mcp_config import MCP_SERVER_CONFIG_SCHEMA, load_mcp_clients_from_config + +logger = logging.getLogger(__name__) + +_SCHEMA_PATH = Path(__file__).parent / "agent_config.schema.json" +with open(_SCHEMA_PATH) as _f: + AGENT_CONFIG_SCHEMA: dict[str, Any] = json.load(_f) + +# Resolve the $ref in mcp_servers.additionalProperties to the actual MCP server schema +AGENT_CONFIG_SCHEMA["properties"]["mcp_servers"]["additionalProperties"] = MCP_SERVER_CONFIG_SCHEMA # Pre-compile validator for better performance _VALIDATOR = jsonschema.Draft7Validator(AGENT_CONFIG_SCHEMA) @@ -129,6 +112,15 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A if config_key in config_dict and config_dict[config_key] is not None: agent_kwargs[agent_param] = config_dict[config_key] + # Handle mcp_servers: create MCPClient instances and append to tools + if config_dict.get("mcp_servers"): + mcp_clients = load_mcp_clients_from_config({"mcpServers": config_dict["mcp_servers"]}) + tools_list = agent_kwargs.get("tools", []) + if not isinstance(tools_list, list): + tools_list = list(tools_list) + tools_list.extend(mcp_clients.values()) + agent_kwargs["tools"] = tools_list + # Override with any additional kwargs provided agent_kwargs.update(kwargs) diff --git a/src/strands/experimental/agent_config.schema.json b/src/strands/experimental/agent_config.schema.json new file mode 100644 index 000000000..34a90ffa0 --- /dev/null +++ b/src/strands/experimental/agent_config.schema.json @@ -0,0 +1,35 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Agent Configuration", + "description": "Configuration schema for creating agents.", + "type": "object", + "properties": { + "name": { + "description": "Name of the agent.", + "type": ["string", "null"], + "default": null + }, + "model": { + "description": "The model ID to use for this agent. If not specified, uses the default model.", + "type": ["string", "null"], + "default": null + }, + "prompt": { + "description": "The system prompt for the agent. Provides high level context to the agent.", + "type": ["string", "null"], + "default": null + }, + "tools": { + "description": "List of tools the agent can use. Can be file paths, Python module names, or @tool annotated functions in files.", + "type": "array", + "items": { "type": "string" }, + "default": [] + }, + "mcp_servers": { + "description": "MCP server configurations. Each key is a server name and the value is a server configuration object with transport-specific settings.", + "type": "object", + "additionalProperties": { "$ref": "mcp_server_config.schema.json" } + } + }, + "additionalProperties": false +} diff --git a/src/strands/experimental/mcp_config.py b/src/strands/experimental/mcp_config.py new file mode 100644 index 000000000..ed1742720 --- /dev/null +++ b/src/strands/experimental/mcp_config.py @@ -0,0 +1,218 @@ +"""MCP server configuration parsing and MCPClient factory. + +This module handles parsing MCP server configurations from dictionaries or JSON files +and creating MCPClient instances with the appropriate transport callables. + +Supported transport types: +- stdio: Local subprocess via stdin/stdout (auto-detected when 'command' is present) +- sse: Server-Sent Events over HTTP (auto-detected when 'url' is present without explicit transport) +- streamable-http: Streamable HTTP transport +""" + +import json +import logging +import re +from pathlib import Path +from typing import Any + +import jsonschema +from jsonschema import ValidationError +from mcp import StdioServerParameters +from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamable_http_client + +from ..tools.mcp.mcp_client import MCPClient, ToolFilters + +logger = logging.getLogger(__name__) + +_SCHEMA_PATH = Path(__file__).parent / "mcp_server_config.schema.json" +with open(_SCHEMA_PATH) as _f: + MCP_SERVER_CONFIG_SCHEMA: dict[str, Any] = json.load(_f) + +_SERVER_VALIDATOR = jsonschema.Draft7Validator(MCP_SERVER_CONFIG_SCHEMA) + + +def _parse_tool_filters(config: dict[str, Any] | None) -> ToolFilters | None: + """Parse a tool filter configuration into a ToolFilters instance. + + All filter strings are compiled as regex patterns. Exact-match strings like ``"^echo$"`` + work correctly as regex since they match themselves. + + Args: + config: Tool filter configuration dict with 'allowed' and/or 'rejected' lists, + or None. + + Returns: + A ToolFilters instance, or None if config is None or empty. + + Raises: + ValueError: If a filter string is not a valid regex pattern. + """ + if not config: + return None + + result: ToolFilters = {} + + if "allowed" in config: + allowed: list[re.Pattern[str]] = [] + for pattern_str in config["allowed"]: + try: + allowed.append(re.compile(pattern_str)) + except re.error as e: + raise ValueError(f"invalid regex pattern in tool_filters.allowed: '{pattern_str}': {e}") from e + result["allowed"] = allowed + + if "rejected" in config: + rejected: list[re.Pattern[str]] = [] + for pattern_str in config["rejected"]: + try: + rejected.append(re.compile(pattern_str)) + except re.error as e: + raise ValueError(f"invalid regex pattern in tool_filters.rejected: '{pattern_str}': {e}") from e + result["rejected"] = rejected + + return result if result else None + + +def _create_mcp_client_from_config(server_name: str, config: dict[str, Any]) -> MCPClient: + """Create an MCPClient instance from a server configuration dictionary. + + Transport type is auto-detected based on the presence of 'command' (stdio) or 'url' (sse), + unless explicitly specified via the 'transport' field. + + Args: + server_name: Name of the server (used in error messages). + config: Server configuration dictionary. + + Returns: + A configured MCPClient instance. + + Raises: + ValueError: If the configuration is invalid or missing required fields. + """ + # Validate against schema + try: + _SERVER_VALIDATOR.validate(config) + except ValidationError as e: + error_path = " -> ".join(str(p) for p in e.absolute_path) if e.absolute_path else "root" + raise ValueError(f"server '{server_name}' configuration validation error at {error_path}: {e.message}") from e + + # Determine transport type + transport = config.get("transport") + command = config.get("command") + url = config.get("url") + + if transport is None: + if command: + transport = "stdio" + elif url: + transport = "sse" + else: + raise ValueError( + f"server '{server_name}' must specify either 'command' (for stdio) or 'url' (for sse/http)" + ) + + # Extract common MCPClient parameters + prefix = config.get("prefix") + startup_timeout = config.get("startup_timeout", 30) + tool_filters = _parse_tool_filters(config.get("tool_filters")) + + # Build transport callable based on type + if transport == "stdio": + + def _stdio_transport() -> Any: + params = StdioServerParameters( + command=config["command"], + args=config.get("args", []), + env=config.get("env"), + cwd=config.get("cwd"), + ) + return stdio_client(params) + + transport_callable = _stdio_transport + elif transport == "sse": + if not url: + raise ValueError(f"server '{server_name}': 'url' is required for sse transport") + headers = config.get("headers") + + def _sse_transport() -> Any: + return sse_client(url=url, headers=headers) + + transport_callable = _sse_transport + elif transport == "streamable-http": + if not url: + raise ValueError(f"server '{server_name}': 'url' is required for streamable-http transport") + headers = config.get("headers") + + def _streamable_http_transport() -> Any: + return streamable_http_client(url=url, headers=headers) + + transport_callable = _streamable_http_transport + else: + raise ValueError(f"server '{server_name}': unsupported transport type '{transport}'") + + logger.debug( + "server_name=<%s>, transport=<%s> | creating MCP client from config", + server_name, + transport, + ) + + return MCPClient( + transport_callable, + startup_timeout=startup_timeout, + tool_filters=tool_filters, + prefix=prefix, + ) + + +def load_mcp_clients_from_config(config: str | dict[str, Any]) -> dict[str, MCPClient]: + """Load MCP client instances from a configuration file or dictionary. + + Expects the standard ``mcpServers`` wrapper format used by Claude Desktop, VS Code, etc:: + + { + "mcpServers": { + "server_name": { "command": "...", ... } + } + } + + Args: + config: Either a file path (with optional file:// prefix) to a JSON config file, + or a dictionary with a ``mcpServers`` key mapping server names to configs. + + Returns: + A dictionary mapping server names to MCPClient instances. + + Raises: + FileNotFoundError: If the config file does not exist. + json.JSONDecodeError: If the config file contains invalid JSON. + ValueError: If the config format is invalid or a server config is invalid. + """ + if isinstance(config, str): + file_path = config + if file_path.startswith("file://"): + file_path = file_path[7:] + + config_path = Path(file_path) + if not config_path.exists(): + raise FileNotFoundError(f"MCP configuration file not found: {file_path}") + + with open(config_path) as f: + config_dict: dict[str, Any] = json.load(f) + elif isinstance(config, dict): + config_dict = config + else: + raise ValueError("Config must be a file path string or dictionary") + + if "mcpServers" not in config_dict or not isinstance(config_dict["mcpServers"], dict): + raise ValueError("Config must contain an 'mcpServers' key with a dictionary of server configurations") + + servers = config_dict["mcpServers"] + clients: dict[str, MCPClient] = {} + for server_name, server_config in servers.items(): + clients[server_name] = _create_mcp_client_from_config(server_name, server_config) + + logger.debug("loaded_servers=<%d> | MCP clients created from config", len(clients)) + + return clients diff --git a/src/strands/experimental/mcp_server_config.schema.json b/src/strands/experimental/mcp_server_config.schema.json new file mode 100644 index 000000000..378026b37 --- /dev/null +++ b/src/strands/experimental/mcp_server_config.schema.json @@ -0,0 +1,68 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "MCP Server Configuration", + "description": "Configuration for a single MCP server.", + "type": "object", + "properties": { + "transport": { + "description": "Transport type. Auto-detected from 'command' (stdio) or 'url' (sse) if omitted.", + "type": "string", + "enum": ["stdio", "sse", "streamable-http"] + }, + "command": { + "description": "Command to run for stdio transport.", + "type": "string" + }, + "args": { + "description": "Arguments for the stdio command.", + "type": "array", + "items": { "type": "string" }, + "default": [] + }, + "env": { + "description": "Environment variables for the stdio command.", + "type": "object", + "additionalProperties": { "type": "string" } + }, + "cwd": { + "description": "Working directory for the stdio command.", + "type": "string" + }, + "url": { + "description": "URL for sse or streamable-http transport.", + "type": "string" + }, + "headers": { + "description": "HTTP headers for sse or streamable-http transport.", + "type": "object", + "additionalProperties": { "type": "string" } + }, + "prefix": { + "description": "Prefix to apply to tool names from this server.", + "type": "string" + }, + "startup_timeout": { + "description": "Timeout in seconds for server initialization. Defaults to 30.", + "type": "integer", + "default": 30 + }, + "tool_filters": { + "description": "Filters for controlling which tools are loaded.", + "type": "object", + "properties": { + "allowed": { + "description": "List of regex patterns for tools to include.", + "type": "array", + "items": { "type": "string" } + }, + "rejected": { + "description": "List of regex patterns for tools to exclude.", + "type": "array", + "items": { "type": "string" } + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false +} diff --git a/tests/strands/experimental/test_agent_config.py b/tests/strands/experimental/test_agent_config.py index e6188079b..23db1b33e 100644 --- a/tests/strands/experimental/test_agent_config.py +++ b/tests/strands/experimental/test_agent_config.py @@ -3,6 +3,7 @@ import json import os import tempfile +from unittest.mock import MagicMock, patch import pytest @@ -170,3 +171,75 @@ def test_config_to_agent_with_tool(): config = {"model": "test-model", "tools": ["tests.fixtures.say_tool:say"]} agent = config_to_agent(config) assert "say" in agent.tool_names + + +class TestConfigToAgentWithMcpServers: + """Tests for config_to_agent with mcp_servers field.""" + + @patch("strands.experimental.agent_config.load_mcp_clients_from_config") + def test_mcp_servers_creates_clients(self, mock_load_clients): + """mcp_servers config should create MCPClient instances and add to tools.""" + mock_client = MagicMock() + mock_load_clients.return_value = {"server1": mock_client} + + config = { + "model": "test-model", + "mcp_servers": {"server1": {"command": "echo"}}, + } + config_to_agent(config) + + mock_load_clients.assert_called_once_with({"mcpServers": {"server1": {"command": "echo"}}}) + + @patch("strands.experimental.agent_config.load_mcp_clients_from_config") + def test_mcp_servers_appended_to_existing_tools(self, mock_load_clients): + """MCP clients should be appended to existing tool lists.""" + mock_client = MagicMock() + mock_load_clients.return_value = {"server1": mock_client} + + config = { + "model": "test-model", + "tools": [], + "mcp_servers": {"server1": {"command": "echo"}}, + } + config_to_agent(config) + mock_load_clients.assert_called_once() + + @patch("strands.experimental.agent_config.load_mcp_clients_from_config") + def test_mcp_servers_empty_dict_is_valid(self, mock_load_clients): + """Empty mcp_servers dict should not create any clients.""" + mock_load_clients.return_value = {} + + config = {"model": "test-model", "mcp_servers": {}} + agent = config_to_agent(config) + # Empty dict is falsy, so load_mcp_clients_from_config should not be called + mock_load_clients.assert_not_called() + assert agent.model.config["model_id"] == "test-model" + + def test_mcp_servers_schema_validation_rejects_array(self): + """mcp_servers should only accept an object, not an array.""" + config = {"model": "test-model", "mcp_servers": ["not", "an", "object"]} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + def test_mcp_servers_schema_validation_rejects_string(self): + """mcp_servers should only accept an object, not a string.""" + config = {"model": "test-model", "mcp_servers": "not an object"} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + @patch("strands.experimental.agent_config.load_mcp_clients_from_config") + def test_multiple_mcp_servers(self, mock_load_clients): + """Multiple MCP servers should all be created.""" + mock_client1 = MagicMock() + mock_client2 = MagicMock() + mock_load_clients.return_value = {"server1": mock_client1, "server2": mock_client2} + + config = { + "model": "test-model", + "mcp_servers": { + "server1": {"command": "echo"}, + "server2": {"transport": "sse", "url": "http://localhost:8000/sse"}, + }, + } + config_to_agent(config) + mock_load_clients.assert_called_once() diff --git a/tests/strands/tools/mcp/test_mcp_config.py b/tests/strands/tools/mcp/test_mcp_config.py new file mode 100644 index 000000000..c93616d13 --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_config.py @@ -0,0 +1,439 @@ +"""Tests for MCP config parsing and MCPClient factory.""" + +import json +import os +import re +import tempfile +from unittest.mock import MagicMock, patch + +import jsonschema +import pytest + +from strands.experimental.mcp_config import ( + MCP_SERVER_CONFIG_SCHEMA, + _create_mcp_client_from_config, + _parse_tool_filters, + load_mcp_clients_from_config, +) + + +class TestParseToolFilters: + """Tests for _parse_tool_filters function.""" + + def test_returns_none_for_none_input(self): + assert _parse_tool_filters(None) is None + + def test_returns_none_for_empty_dict(self): + assert _parse_tool_filters({}) is None + + def test_compiles_allowed_patterns(self): + config = {"allowed": ["echo", "search"]} + result = _parse_tool_filters(config) + assert result is not None + assert len(result["allowed"]) == 2 + assert isinstance(result["allowed"][0], re.Pattern) + assert isinstance(result["allowed"][1], re.Pattern) + assert result["allowed"][0].match("echo") + assert result["allowed"][1].match("search") + + def test_compiles_rejected_patterns(self): + config = {"rejected": ["dangerous_tool"]} + result = _parse_tool_filters(config) + assert result is not None + assert len(result["rejected"]) == 1 + assert isinstance(result["rejected"][0], re.Pattern) + assert result["rejected"][0].match("dangerous_tool") + + def test_compiles_regex_patterns(self): + config = {"allowed": ["search_.*", "get_\\w+"]} + result = _parse_tool_filters(config) + assert result is not None + assert len(result["allowed"]) == 2 + assert result["allowed"][0].match("search_docs") + assert result["allowed"][1].match("get_data") + + def test_exact_strings_work_as_regex(self): + """Exact strings like 'echo' match themselves when compiled as regex.""" + config = {"allowed": ["echo"]} + result = _parse_tool_filters(config) + assert result is not None + assert result["allowed"][0].match("echo") + assert result["allowed"][0].match("echo_extra") + + def test_mixed_allowed_and_rejected(self): + config = {"allowed": ["search_.*"], "rejected": ["dangerous_tool"]} + result = _parse_tool_filters(config) + assert result is not None + assert "allowed" in result + assert "rejected" in result + + def test_invalid_regex_raises_error(self): + config = {"allowed": ["[invalid"]} + with pytest.raises(ValueError, match="invalid regex pattern"): + _parse_tool_filters(config) + + def test_invalid_regex_in_rejected_raises_error(self): + config = {"rejected": ["(unclosed"]} + with pytest.raises(ValueError, match="invalid regex pattern"): + _parse_tool_filters(config) + + +class TestCreateMcpClientFromConfig: + """Tests for _create_mcp_client_from_config function.""" + + @patch("strands.experimental.mcp_config.stdio_client") + @patch("strands.experimental.mcp_config.StdioServerParameters") + def test_stdio_transport_from_command(self, mock_params_cls, mock_stdio_client): + """Config with 'command' should create a stdio transport.""" + config = {"command": "uvx", "args": ["some-server@latest"]} + mock_params_cls.return_value = MagicMock() + + client = _create_mcp_client_from_config("test_server", config) + assert client is not None + + @patch("strands.experimental.mcp_config.stdio_client") + @patch("strands.experimental.mcp_config.StdioServerParameters") + def test_stdio_with_env_and_cwd(self, mock_params_cls, mock_stdio_client): + """Stdio config should pass env and cwd to StdioServerParameters.""" + config = { + "command": "node", + "args": ["server.js"], + "env": {"NODE_ENV": "production"}, + "cwd": "/opt/server", + } + mock_params_cls.return_value = MagicMock() + + client = _create_mcp_client_from_config("test_server", config) + + client._transport_callable() + mock_params_cls.assert_called_once_with( + command="node", + args=["server.js"], + env={"NODE_ENV": "production"}, + cwd="/opt/server", + ) + + @patch("strands.experimental.mcp_config.sse_client") + def test_sse_transport_explicit(self, mock_sse_client): + """Config with transport='sse' should create an SSE transport.""" + config = {"transport": "sse", "url": "http://localhost:8000/sse"} + + client = _create_mcp_client_from_config("test_server", config) + assert client is not None + + @patch("strands.experimental.mcp_config.streamable_http_client") + def test_streamable_http_transport_explicit(self, mock_http_client): + """Config with transport='streamable-http' should create a streamable-http transport.""" + config = {"transport": "streamable-http", "url": "http://localhost:8000/mcp"} + + client = _create_mcp_client_from_config("test_server", config) + assert client is not None + + @patch("strands.experimental.mcp_config.sse_client") + def test_url_without_transport_defaults_to_sse(self, mock_sse_client): + """Config with url but no transport should default to sse.""" + config = {"url": "http://localhost:8000/sse"} + + client = _create_mcp_client_from_config("test_server", config) + assert client is not None + + @patch("strands.experimental.mcp_config.sse_client") + def test_sse_with_headers(self, mock_sse_client): + """SSE config should pass headers.""" + config = { + "transport": "sse", + "url": "http://localhost:8000/sse", + "headers": {"Authorization": "Bearer token123"}, + } + + client = _create_mcp_client_from_config("test_server", config) + assert client is not None + + @patch("strands.experimental.mcp_config.streamable_http_client") + def test_streamable_http_with_headers(self, mock_http_client): + """Streamable HTTP config should pass headers.""" + config = { + "transport": "streamable-http", + "url": "http://localhost:8000/mcp", + "headers": {"Authorization": "Bearer token123"}, + } + + client = _create_mcp_client_from_config("test_server", config) + assert client is not None + + def test_prefix_passed_to_mcp_client(self): + """Prefix should be passed to MCPClient constructor.""" + config = {"command": "echo", "prefix": "myprefix"} + + with ( + patch("strands.experimental.mcp_config.stdio_client"), + patch("strands.experimental.mcp_config.StdioServerParameters"), + patch("strands.experimental.mcp_config.MCPClient") as mock_mcp_client, + ): + _create_mcp_client_from_config("test_server", config) + call_kwargs = mock_mcp_client.call_args + assert call_kwargs[1]["prefix"] == "myprefix" + + def test_startup_timeout_passed_to_mcp_client(self): + """startup_timeout should be passed to MCPClient constructor.""" + config = {"command": "echo", "startup_timeout": 60} + + with ( + patch("strands.experimental.mcp_config.stdio_client"), + patch("strands.experimental.mcp_config.StdioServerParameters"), + patch("strands.experimental.mcp_config.MCPClient") as mock_mcp_client, + ): + _create_mcp_client_from_config("test_server", config) + call_kwargs = mock_mcp_client.call_args + assert call_kwargs[1]["startup_timeout"] == 60 + + def test_tool_filters_passed_to_mcp_client(self): + """tool_filters config should be parsed and passed to MCPClient.""" + config = { + "command": "echo", + "tool_filters": {"allowed": ["echo"], "rejected": ["dangerous_.*"]}, + } + + with ( + patch("strands.experimental.mcp_config.stdio_client"), + patch("strands.experimental.mcp_config.StdioServerParameters"), + patch("strands.experimental.mcp_config.MCPClient") as mock_mcp_client, + ): + _create_mcp_client_from_config("test_server", config) + call_kwargs = mock_mcp_client.call_args + tool_filters = call_kwargs[1]["tool_filters"] + assert "allowed" in tool_filters + assert "rejected" in tool_filters + + def test_default_startup_timeout(self): + """Default startup_timeout should be 30 if not specified.""" + config = {"command": "echo"} + + with ( + patch("strands.experimental.mcp_config.stdio_client"), + patch("strands.experimental.mcp_config.StdioServerParameters"), + patch("strands.experimental.mcp_config.MCPClient") as mock_mcp_client, + ): + _create_mcp_client_from_config("test_server", config) + call_kwargs = mock_mcp_client.call_args + assert call_kwargs[1]["startup_timeout"] == 30 + + def test_invalid_transport_raises_error(self): + """Unknown transport type should raise ValueError.""" + config = {"transport": "websocket", "url": "ws://localhost:8000"} + + with pytest.raises(ValueError, match="configuration validation error"): + _create_mcp_client_from_config("test_server", config) + + def test_missing_command_and_url_raises_error(self): + """Config without command or url should raise ValueError.""" + config = {"prefix": "test"} + + with pytest.raises(ValueError, match="must specify either 'command'.*or 'url'"): + _create_mcp_client_from_config("test_server", config) + + def test_missing_url_for_sse_raises_error(self): + """SSE transport without url should raise ValueError.""" + config = {"transport": "sse"} + + with pytest.raises(ValueError, match="'url' is required"): + _create_mcp_client_from_config("test_server", config) + + def test_missing_url_for_streamable_http_raises_error(self): + """Streamable HTTP transport without url should raise ValueError.""" + config = {"transport": "streamable-http"} + + with pytest.raises(ValueError, match="'url' is required"): + _create_mcp_client_from_config("test_server", config) + + def test_invalid_startup_timeout_type_raises_error(self): + """Non-integer startup_timeout should raise ValueError.""" + config = {"command": "echo", "startup_timeout": "30"} + + with pytest.raises(ValueError, match="configuration validation error"): + _create_mcp_client_from_config("test_server", config) + + +class TestLoadMcpClientsFromConfig: + """Tests for load_mcp_clients_from_config function.""" + + def test_load_from_dict(self): + """Loading from a dict with mcpServers wrapper.""" + config = { + "mcpServers": { + "server1": {"command": "echo", "args": []}, + "server2": {"command": "cat", "args": []}, + } + } + + with ( + patch("strands.experimental.mcp_config.stdio_client"), + patch("strands.experimental.mcp_config.StdioServerParameters"), + patch("strands.experimental.mcp_config.MCPClient") as mock_mcp_client, + ): + mock_mcp_client.return_value = MagicMock() + clients = load_mcp_clients_from_config(config) + + assert len(clients) == 2 + assert "server1" in clients + assert "server2" in clients + + def test_load_from_json_file(self): + """Loading from a JSON file path.""" + config_data = { + "mcpServers": { + "my_server": {"command": "echo", "args": ["hello"]}, + } + } + temp_path = "" + + try: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(config_data, f) + temp_path = f.name + + with ( + patch("strands.experimental.mcp_config.stdio_client"), + patch("strands.experimental.mcp_config.StdioServerParameters"), + patch("strands.experimental.mcp_config.MCPClient") as mock_mcp_client, + ): + mock_mcp_client.return_value = MagicMock() + clients = load_mcp_clients_from_config(temp_path) + + assert len(clients) == 1 + assert "my_server" in clients + finally: + if os.path.exists(temp_path): + os.remove(temp_path) + + def test_load_from_file_with_prefix(self): + """Loading from a file:// prefixed path.""" + config_data = {"mcpServers": {"server1": {"command": "echo"}}} + temp_path = "" + + try: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(config_data, f) + temp_path = f.name + + with ( + patch("strands.experimental.mcp_config.stdio_client"), + patch("strands.experimental.mcp_config.StdioServerParameters"), + patch("strands.experimental.mcp_config.MCPClient") as mock_mcp_client, + ): + mock_mcp_client.return_value = MagicMock() + clients = load_mcp_clients_from_config(f"file://{temp_path}") + + assert len(clients) == 1 + finally: + if os.path.exists(temp_path): + os.remove(temp_path) + + def test_file_not_found_raises_error(self): + """Non-existent file should raise FileNotFoundError.""" + with pytest.raises(FileNotFoundError): + load_mcp_clients_from_config("/nonexistent/path/config.json") + + def test_invalid_json_raises_error(self): + """Invalid JSON content should raise JSONDecodeError.""" + temp_path = "" + try: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + f.write("not json") + temp_path = f.name + + with pytest.raises(json.JSONDecodeError): + load_mcp_clients_from_config(temp_path) + finally: + if os.path.exists(temp_path): + os.remove(temp_path) + + def test_invalid_config_type_raises_error(self): + """Non-string, non-dict input should raise ValueError.""" + with pytest.raises(ValueError, match="must be a file path string or dictionary"): + load_mcp_clients_from_config(123) # type: ignore[arg-type] + + def test_empty_mcp_servers_returns_empty(self): + """Empty mcpServers dict should return empty dict.""" + clients = load_mcp_clients_from_config({"mcpServers": {}}) + assert clients == {} + + def test_missing_mcp_servers_key_raises_error(self): + """Dict without mcpServers key should raise ValueError.""" + with pytest.raises(ValueError, match="mcpServers"): + load_mcp_clients_from_config({"server1": {"command": "echo"}}) + + def test_server_creation_error_includes_server_name(self): + """Error creating a server should include the server name in the message.""" + config = {"mcpServers": {"bad_server": {"transport": "invalid"}}} + + with pytest.raises(ValueError, match="bad_server"): + load_mcp_clients_from_config(config) + + +class TestMcpServerConfigSchema: + """Tests for MCP_SERVER_CONFIG_SCHEMA validation.""" + + def _validate(self, config): + jsonschema.Draft7Validator(MCP_SERVER_CONFIG_SCHEMA).validate(config) + + def test_valid_stdio_config(self): + self._validate({"command": "uvx", "args": ["server@latest"], "env": {"KEY": "val"}, "cwd": "/tmp"}) + + def test_valid_sse_config(self): + self._validate({"transport": "sse", "url": "http://localhost:8000/sse", "headers": {"Auth": "Bearer tok"}}) + + def test_valid_streamable_http_config(self): + self._validate({"transport": "streamable-http", "url": "http://localhost:8000/mcp"}) + + def test_valid_config_with_all_common_fields(self): + self._validate({ + "command": "echo", + "prefix": "myprefix", + "startup_timeout": 60, + "tool_filters": {"allowed": ["echo"], "rejected": ["dangerous_.*"]}, + }) + + def test_rejects_unknown_property(self): + with pytest.raises(jsonschema.ValidationError): + self._validate({"command": "echo", "unknown_field": "value"}) + + def test_rejects_invalid_transport_enum(self): + with pytest.raises(jsonschema.ValidationError): + self._validate({"transport": "websocket", "url": "ws://localhost"}) + + def test_rejects_non_string_command(self): + with pytest.raises(jsonschema.ValidationError): + self._validate({"command": 123}) + + def test_rejects_non_array_args(self): + with pytest.raises(jsonschema.ValidationError): + self._validate({"command": "echo", "args": "not-an-array"}) + + def test_rejects_non_string_args_items(self): + with pytest.raises(jsonschema.ValidationError): + self._validate({"command": "echo", "args": [123]}) + + def test_rejects_non_object_env(self): + with pytest.raises(jsonschema.ValidationError): + self._validate({"command": "echo", "env": "not-an-object"}) + + def test_rejects_non_string_env_values(self): + with pytest.raises(jsonschema.ValidationError): + self._validate({"command": "echo", "env": {"KEY": 123}}) + + def test_rejects_non_integer_startup_timeout(self): + with pytest.raises(jsonschema.ValidationError): + self._validate({"command": "echo", "startup_timeout": "30"}) + + def test_rejects_unknown_tool_filter_property(self): + with pytest.raises(jsonschema.ValidationError): + self._validate({"command": "echo", "tool_filters": {"allowed": ["echo"], "extra": True}}) + + def test_rejects_non_string_header_values(self): + with pytest.raises(jsonschema.ValidationError): + self._validate({"transport": "sse", "url": "http://localhost", "headers": {"Key": 123}}) + + def test_schema_is_valid_json_schema(self): + """The schema itself should be a valid JSON Schema draft-07.""" + jsonschema.Draft7Validator.check_schema(MCP_SERVER_CONFIG_SCHEMA) diff --git a/tests_integ/mcp/test_mcp_config.py b/tests_integ/mcp/test_mcp_config.py new file mode 100644 index 000000000..fb5cb9118 --- /dev/null +++ b/tests_integ/mcp/test_mcp_config.py @@ -0,0 +1,104 @@ +"""Integration tests for loading MCP servers from config.""" + +import json +import os +import tempfile + +import pytest + +from strands import Agent +from strands.experimental.mcp_config import load_mcp_clients_from_config + + +def test_load_stdio_server_from_config(): + """Test loading a stdio MCP server from config dict and using it with an agent.""" + config = { + "mcpServers": { + "echo": { + "command": "python", + "args": ["tests_integ/mcp/echo_server.py"], + "prefix": "cfg", + "tool_filters": {"allowed": ["cfg_echo"]}, + } + } + } + + clients = load_mcp_clients_from_config(config) + assert "echo" in clients + + agent = Agent(tools=list(clients.values())) + assert "cfg_echo" in agent.tool_names + + result = agent.tool.cfg_echo(to_echo="Config Test") + assert "Config Test" in str(result) + + agent.cleanup() + + +def test_load_stdio_server_from_json_file(): + """Test loading a stdio MCP server from a JSON config file.""" + config_data = { + "mcpServers": { + "echo": { + "command": "python", + "args": ["tests_integ/mcp/echo_server.py"], + "prefix": "file", + "tool_filters": {"allowed": ["file_echo"]}, + } + } + } + temp_path = "" + + try: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(config_data, f) + temp_path = f.name + + clients = load_mcp_clients_from_config(temp_path) + assert "echo" in clients + + agent = Agent(tools=list(clients.values())) + assert "file_echo" in agent.tool_names + + result = agent.tool.file_echo(to_echo="File Config Test") + assert "File Config Test" in str(result) + + agent.cleanup() + finally: + if os.path.exists(temp_path): + os.remove(temp_path) + + +def test_load_multiple_servers_from_config(): + """Test loading multiple MCP servers from a single config.""" + config = { + "mcpServers": { + "server1": { + "command": "python", + "args": ["tests_integ/mcp/echo_server.py"], + "prefix": "s1", + "tool_filters": {"allowed": ["s1_echo"]}, + }, + "server2": { + "command": "python", + "args": ["tests_integ/mcp/echo_server.py"], + "prefix": "s2", + "tool_filters": {"allowed": ["s2_echo"]}, + }, + } + } + + clients = load_mcp_clients_from_config(config) + assert len(clients) == 2 + + agent = Agent(tools=list(clients.values())) + assert "s1_echo" in agent.tool_names + assert "s2_echo" in agent.tool_names + + result1 = agent.tool.s1_echo(to_echo="From Server 1") + assert "From Server 1" in str(result1) + + result2 = agent.tool.s2_echo(to_echo="From Server 2") + assert "From Server 2" in str(result2) + + agent.cleanup()