Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/mcp/server/mcpserver/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic
from urllib.parse import urlparse
from urllib.request import url2pathname

from pydantic import AnyUrl, BaseModel

Expand Down Expand Up @@ -117,6 +120,42 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent
assert self._mcp_server is not None, "Context is not available outside of a request"
return await self._mcp_server.read_resource(uri, self)

async def assert_within_roots(self, path: str | Path) -> None:
"""Assert that a filesystem path is within the client's declared roots.

Provides server-side enforcement of the filesystem boundaries declared by
the client via the Roots capability. Call this at the start of any tool
that accepts a user-provided path, to prevent the tool from accessing
files outside the client's declared scope.

Args:
path: The filesystem path to validate. Accepts a string or Path.
Relative paths and symlinks are resolved before comparison.

Raises:
PermissionError: If the path is outside every declared root, or if
the client has declared no roots.

Example:
```python
@server.tool()
async def read_file(path: str, ctx: Context) -> str:
await ctx.assert_within_roots(path)
with open(path) as f:
return f.read()
```
"""
target = Path(path).resolve()

result = await self.request_context.session.list_roots()

for root in result.roots:
root_path = Path(url2pathname(urlparse(str(root.uri)).path)).resolve()
if target.is_relative_to(root_path):
return

raise PermissionError(f"Path {target} is not within any declared root")

async def elicit(
self,
message: str,
Expand Down
141 changes: 141 additions & 0 deletions tests/server/mcpserver/test_roots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from pathlib import Path

import pytest
from pydantic import FileUrl

from mcp import Client
from mcp.client.session import ClientSession
from mcp.server.mcpserver import Context, MCPServer
from mcp.shared._context import RequestContext
from mcp.types import ListRootsResult, Root, TextContent


def _make_callback(roots: list[Root]):
async def list_roots_callback(
context: RequestContext[ClientSession],
) -> ListRootsResult:
return ListRootsResult(roots=roots)

return list_roots_callback


@pytest.mark.anyio
async def test_path_within_root_passes(tmp_path: Path):
"""A path inside a declared root should not raise."""
inside = tmp_path / "file.txt"
inside.touch()

server = MCPServer("test")

@server.tool("check")
async def check(context: Context, path: str) -> bool:
await context.assert_within_roots(path)
return True

callback = _make_callback([Root(uri=FileUrl(f"file://{tmp_path}"))])

async with Client(server, list_roots_callback=callback) as client:
result = await client.call_tool("check", {"path": str(inside)})
assert result.is_error is False


@pytest.mark.anyio
async def test_path_outside_roots_raises(tmp_path: Path):
"""A path outside every declared root should raise PermissionError."""
root_dir = tmp_path / "allowed"
root_dir.mkdir()
outside = tmp_path / "elsewhere.txt"
outside.touch()

server = MCPServer("test")

@server.tool("check")
async def check(context: Context, path: str) -> bool:
await context.assert_within_roots(path)
return True # pragma: no cover

callback = _make_callback([Root(uri=FileUrl(f"file://{root_dir}"))])

async with Client(server, list_roots_callback=callback) as client:
result = await client.call_tool("check", {"path": str(outside)})
assert result.is_error is True
assert isinstance(result.content[0], TextContent)
assert "not within any declared root" in result.content[0].text


@pytest.mark.anyio
async def test_no_roots_declared_raises(tmp_path: Path):
"""An empty roots list should always raise."""
target = tmp_path / "file.txt"
target.touch()

server = MCPServer("test")

@server.tool("check")
async def check(context: Context, path: str) -> bool:
await context.assert_within_roots(path)
return True # pragma: no cover

callback = _make_callback([])

async with Client(server, list_roots_callback=callback) as client:
result = await client.call_tool("check", {"path": str(target)})
assert result.is_error is True
assert isinstance(result.content[0], TextContent)
assert "not within any declared root" in result.content[0].text


@pytest.mark.anyio
async def test_symlink_escaping_root_raises(tmp_path: Path):
"""A symlink inside a root that points outside should raise (resolve follows links)."""
root_dir = tmp_path / "allowed"
root_dir.mkdir()
outside_dir = tmp_path / "forbidden"
outside_dir.mkdir()
outside_target = outside_dir / "secret.txt"
outside_target.touch()

link = root_dir / "escape"
link.symlink_to(outside_target)

server = MCPServer("test")

@server.tool("check")
async def check(context: Context, path: str) -> bool:
await context.assert_within_roots(path)
return True # pragma: no cover

callback = _make_callback([Root(uri=FileUrl(f"file://{root_dir}"))])

async with Client(server, list_roots_callback=callback) as client:
result = await client.call_tool("check", {"path": str(link)})
assert result.is_error is True


@pytest.mark.anyio
async def test_multiple_roots_any_match_passes(tmp_path: Path):
"""A path inside any one of several declared roots should pass."""
root_a = tmp_path / "a"
root_a.mkdir()
root_b = tmp_path / "b"
root_b.mkdir()
target = root_b / "file.txt"
target.touch()

server = MCPServer("test")

@server.tool("check")
async def check(context: Context, path: str) -> bool:
await context.assert_within_roots(path)
return True

callback = _make_callback(
[
Root(uri=FileUrl(f"file://{root_a}")),
Root(uri=FileUrl(f"file://{root_b}")),
]
)

async with Client(server, list_roots_callback=callback) as client:
result = await client.call_tool("check", {"path": str(target)})
assert result.is_error is False
Loading