Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
handle_registration_response,
handle_token_response_scopes,
is_valid_client_metadata_url,
merge_scopes,
should_use_client_metadata_url,
)
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
Expand Down Expand Up @@ -570,12 +571,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
else:
logger.debug(f"OAuth metadata discovery failed: {url}")

# Step 3: Apply scope selection strategy
self.context.client_metadata.scope = get_client_metadata_scopes(
# Step 3: Apply scope selection strategy, merging with existing scopes
new_scope = get_client_metadata_scopes(
extract_scope_from_www_auth(response),
self.context.protected_resource_metadata,
self.context.oauth_metadata,
)
self.context.client_metadata.scope = merge_scopes(self.context.client_metadata.scope, new_scope)

# Step 4: Register client or use URL-based client ID (CIMD)
if not self.context.client_info:
Expand Down Expand Up @@ -619,10 +621,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Step 2: Check if we need to step-up authorization
if error == "insufficient_scope": # pragma: no branch
try:
# Step 2a: Update the required scopes
self.context.client_metadata.scope = get_client_metadata_scopes(
# Step 2a: Update the required scopes, merging with existing
new_scope = get_client_metadata_scopes(
extract_scope_from_www_auth(response), self.context.protected_resource_metadata
)
self.context.client_metadata.scope = merge_scopes(self.context.client_metadata.scope, new_scope)

# Step 2b: Perform (re-)authorization and token exchange
token_response = yield await self._perform_authorization()
Expand Down
17 changes: 17 additions & 0 deletions src/mcp/client/auth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,23 @@ def get_client_metadata_scopes(
return None


def merge_scopes(existing: str | None, incoming: str | None) -> str | None:
"""Merge OAuth scopes by computing the union of space-delimited scope strings.

Per RFC 6749 §3.3, scopes are space-delimited, case-sensitive strings.
This prevents the infinite re-authorization loop that occurs when a server
uses per-operation scopes and the client overwrites previously-granted scopes.
"""
if not incoming:
return existing
if not existing:
return incoming

existing_set = set(existing.split())
existing_set.update(incoming.split())
return " ".join(sorted(existing_set))


def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
"""Generate an ordered list of URLs for authorization server metadata discovery.

Expand Down
62 changes: 55 additions & 7 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,7 @@ async def test_403_insufficient_scope_updates_scope_from_header(
mock_storage: MockTokenStorage,
valid_tokens: OAuthToken,
):
"""Test that 403 response correctly updates scope from WWW-Authenticate header."""
"""Test that 403 response correctly accumulates scope from WWW-Authenticate header."""
# Pre-store valid tokens and client info
client_info = OAuthClientInformationFull(
client_id="test_client_id",
Expand All @@ -1350,10 +1350,10 @@ async def test_403_insufficient_scope_updates_scope_from_header(
async def capture_redirect(url: str) -> None:
nonlocal redirect_captured, captured_state
redirect_captured = True
# Verify the new scope is included in authorization URL
assert "scope=admin%3Awrite+admin%3Adelete" in url or "scope=admin:write+admin:delete" in url.replace(
"%3A", ":"
).replace("+", " ")
# Verify the accumulated scopes are included (original + new)
decoded = url.replace("%3A", ":").replace("+", " ")
for s in ["admin:write", "admin:delete", "read", "write"]:
assert s in decoded, f"Expected scope '{s}' in URL"
# Extract state from redirect URL
parsed = urlparse(url)
params = parse_qs(parsed.query)
Expand Down Expand Up @@ -1383,8 +1383,9 @@ async def mock_callback() -> tuple[str, str | None]:
# Trigger step-up - should get token exchange request
token_exchange_request = await auth_flow.asend(response_403)

# Verify scope was updated
assert oauth_provider.context.client_metadata.scope == "admin:write admin:delete"
# Verify scope was accumulated (original "read write" + new "admin:write admin:delete")
accumulated = set(oauth_provider.context.client_metadata.scope.split())
assert accumulated == {"admin:delete", "admin:write", "read", "write"}
assert redirect_captured

# Complete the flow with successful token response
Expand Down Expand Up @@ -2264,3 +2265,50 @@ async def callback_handler() -> tuple[str, str | None]:
await auth_flow.asend(final_response)
except StopAsyncIteration:
pass


class TestMergeScopes:
"""Tests for merge_scopes utility function."""

def test_merge_none_existing_returns_incoming(self):
from mcp.client.auth.utils import merge_scopes

assert merge_scopes(None, "mcp:tools:read") == "mcp:tools:read"

def test_merge_none_incoming_returns_existing(self):
from mcp.client.auth.utils import merge_scopes

assert merge_scopes("init", None) == "init"

def test_merge_both_none_returns_none(self):
from mcp.client.auth.utils import merge_scopes

assert merge_scopes(None, None) is None

def test_merge_disjoint_scopes(self):
from mcp.client.auth.utils import merge_scopes

result = merge_scopes("init", "mcp:tools:read")
assert result is not None
scopes = set(result.split())
assert scopes == {"init", "mcp:tools:read"}

def test_merge_overlapping_scopes_deduplicates(self):
from mcp.client.auth.utils import merge_scopes

result = merge_scopes("init mcp:tools:read", "mcp:tools:read mcp:tools:write")
assert result is not None
scopes = set(result.split())
assert scopes == {"init", "mcp:tools:read", "mcp:tools:write"}

def test_merge_identical_scopes(self):
from mcp.client.auth.utils import merge_scopes

result = merge_scopes("init", "init")
assert result == "init"

def test_merge_empty_strings(self):
from mcp.client.auth.utils import merge_scopes

assert merge_scopes("init", "") == "init"
assert merge_scopes("", "init") == "init"