diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 1c84c8610..3a8b2ca01 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -41,7 +41,7 @@ async def main(): from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from importlib.metadata import version as importlib_version -from typing import Any, Generic +from typing import Any, Generic, cast import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -52,8 +52,8 @@ async def main(): from typing_extensions import TypeVar from mcp import types -from mcp.server.auth.middleware.auth_context import AuthContextMiddleware -from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, auth_context_var +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings @@ -471,7 +471,15 @@ async def _handle_request( close_sse_stream=close_sse_stream_cb, close_standalone_sse_stream=close_standalone_sse_stream_cb, ) - response = await handler(ctx, req.params) + request_scope = cast(dict[str, object] | None, getattr(request_data, "scope", None)) + request_user = request_scope.get("user") if request_scope is not None else None + auth_context_token = auth_context_var.set( + request_user if isinstance(request_user, AuthenticatedUser) else None + ) + try: + response = await handler(ctx, req.params) + finally: + auth_context_var.reset(auth_context_token) except MCPError as err: response = err.error except anyio.get_cancelled_exc_class(): @@ -558,7 +566,7 @@ def streamable_http_app( required_scopes: list[str] = [] # Set up auth if configured - if auth: # pragma: no cover + if auth: required_scopes = auth.required_scopes or [] # Add auth middleware if token verifier is available @@ -584,7 +592,7 @@ def streamable_http_app( ) # Set up routes with or without auth - if token_verifier: # pragma: no cover + if token_verifier: # Determine resource metadata URL resource_metadata_url = None if auth and auth.resource_server_url: @@ -607,7 +615,7 @@ def streamable_http_app( ) # Add protected resource metadata endpoint if configured as RS - if auth and auth.resource_server_url: # pragma: no cover + if auth and auth.resource_server_url: routes.extend( create_protected_resource_routes( resource_url=auth.resource_server_url, diff --git a/tests/issues/test_2208_stateful_auth_context.py b/tests/issues/test_2208_stateful_auth_context.py new file mode 100644 index 000000000..25068e79c --- /dev/null +++ b/tests/issues/test_2208_stateful_auth_context.py @@ -0,0 +1,98 @@ +"""Regression test for issue #2208. + +In stateful streamable HTTP sessions, get_access_token() must reflect the +Authorization header from the current request, not the one that created the +session's background receive task. +""" + +import time + +import httpx +import pytest +from pydantic import AnyHttpUrl + +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import get_access_token +from mcp.server.auth.provider import AccessToken +from mcp.server.auth.settings import AuthSettings +from mcp.types import CallToolRequestParams, CallToolResult, ListToolsResult, PaginatedRequestParams, TextContent, Tool + + +class EchoTokenVerifier: + """Accept any bearer token and expose it in the authenticated user.""" + + async def verify_token(self, token: str) -> AccessToken | None: + return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600) + + +class MutableBearerAuth(httpx.Auth): + """Update the bearer token between requests without rebuilding the client.""" + + def __init__(self, token: str) -> None: + self.token = token + + def auth_flow(self, request: httpx.Request): + request.headers["Authorization"] = f"Bearer {self.token}" + yield request + + +async def handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + access_token = get_access_token() + token = access_token.token if access_token else "" + return CallToolResult(content=[TextContent(type="text", text=token)]) + + +async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="whoami", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) + + +@pytest.mark.anyio +async def test_get_access_token_uses_current_request_in_stateful_streamable_http_session() -> None: + server = Server( + "auth-test-server", + on_call_tool=handle_whoami, + on_list_tools=handle_list_tools, + ) + app = server.streamable_http_app( + host="testserver", + auth=AuthSettings( + issuer_url=AnyHttpUrl("https://auth.example.com"), + resource_server_url=AnyHttpUrl("https://testserver/mcp"), + ), + token_verifier=EchoTokenVerifier(), + ) + auth = MutableBearerAuth("token-A") + + async with ( + app.router.lifespan_context(app), + httpx.ASGITransport(app) as transport, + httpx.AsyncClient( + transport=transport, + base_url="http://testserver", + auth=auth, + follow_redirects=True, + timeout=httpx.Timeout(30.0, read=30.0), + ) as http_client, + streamable_http_client("http://testserver/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + + first_response = await session.call_tool("whoami", {}) + assert isinstance(first_response.content[0], TextContent) + assert first_response.content[0].text == "token-A" + + auth.token = "token-B" + + second_response = await session.call_tool("whoami", {}) + assert isinstance(second_response.content[0], TextContent) + assert second_response.content[0].text == "token-B" diff --git a/tests/server/test_lowlevel_streamable_http_auth_app.py b/tests/server/test_lowlevel_streamable_http_auth_app.py new file mode 100644 index 000000000..b8e15439e --- /dev/null +++ b/tests/server/test_lowlevel_streamable_http_auth_app.py @@ -0,0 +1,84 @@ +from typing import cast +from unittest.mock import Mock + +import pytest +from pydantic import AnyHttpUrl +from starlette.applications import Starlette +from starlette.middleware.authentication import AuthenticationMiddleware + +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware +from mcp.server.auth.provider import AccessToken +from mcp.server.auth.settings import AuthSettings +from mcp.server.lowlevel.server import Server + + +class DummyTokenVerifier: + async def verify_token(self, token: str) -> AccessToken | None: + return None + + +def route_paths(app: Starlette) -> set[str]: + paths: set[str] = set() + for route in app.routes: + path = getattr(route, "path", None) + if isinstance(path, str): + paths.add(path) + return paths + + +@pytest.mark.anyio +async def test_dummy_token_verifier_returns_none(): + verifier = DummyTokenVerifier() + + assert await verifier.verify_token("token") is None + + +def test_route_paths_ignores_non_string_paths(): + app = Starlette() + routes = cast(list[object], app.router.routes) + routes.append(Mock(path="/ok")) + routes.append(object()) + + assert route_paths(app) == {"/ok"} + + +def test_streamable_http_app_adds_auth_routes_without_token_verifier(): + server = Server("test-server") + + app = server.streamable_http_app( + host="testserver", + auth=AuthSettings( + issuer_url=AnyHttpUrl("https://auth.example.com"), + resource_server_url=AnyHttpUrl("https://testserver/mcp"), + ), + auth_server_provider=Mock(), + ) + + assert { + "/mcp", + "/authorize", + "/token", + "/.well-known/oauth-authorization-server", + "/.well-known/oauth-protected-resource/mcp", + }.issubset(route_paths(app)) + + +def test_streamable_http_app_skips_resource_metadata_route_when_resource_server_url_missing(): + server = Server("test-server") + + app = server.streamable_http_app( + host="testserver", + auth=AuthSettings( + issuer_url=AnyHttpUrl("https://auth.example.com"), + resource_server_url=None, + ), + token_verifier=DummyTokenVerifier(), + ) + + paths = route_paths(app) + middleware_classes = [middleware.cls for middleware in app.user_middleware] + + assert "/mcp" in paths + assert "/.well-known/oauth-protected-resource/mcp" not in paths + assert AuthenticationMiddleware in middleware_classes + assert AuthContextMiddleware in middleware_classes