diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 7f5af5186..5e49c526c 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -34,6 +34,7 @@ handle_registration_response, handle_token_response_scopes, is_valid_client_metadata_url, + merge_scopes, should_use_client_metadata_url, ) from mcp.client.streamable_http import MCP_PROTOCOL_VERSION @@ -570,12 +571,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. else: logger.debug(f"OAuth metadata discovery failed: {url}") - # Step 3: Apply scope selection strategy - self.context.client_metadata.scope = get_client_metadata_scopes( + # Step 3: Apply scope selection strategy, merging with existing scopes + new_scope = get_client_metadata_scopes( extract_scope_from_www_auth(response), self.context.protected_resource_metadata, self.context.oauth_metadata, ) + self.context.client_metadata.scope = merge_scopes(self.context.client_metadata.scope, new_scope) # Step 4: Register client or use URL-based client ID (CIMD) if not self.context.client_info: @@ -619,10 +621,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 2: Check if we need to step-up authorization if error == "insufficient_scope": # pragma: no branch try: - # Step 2a: Update the required scopes - self.context.client_metadata.scope = get_client_metadata_scopes( + # Step 2a: Update the required scopes, merging with existing + new_scope = get_client_metadata_scopes( extract_scope_from_www_auth(response), self.context.protected_resource_metadata ) + self.context.client_metadata.scope = merge_scopes(self.context.client_metadata.scope, new_scope) # Step 2b: Perform (re-)authorization and token exchange token_response = yield await self._perform_authorization() diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index 0ca36b98d..3d7a517ec 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -119,6 +119,23 @@ def get_client_metadata_scopes( return None +def merge_scopes(existing: str | None, incoming: str | None) -> str | None: + """Merge OAuth scopes by computing the union of space-delimited scope strings. + + Per RFC 6749 ยง3.3, scopes are space-delimited, case-sensitive strings. + This prevents the infinite re-authorization loop that occurs when a server + uses per-operation scopes and the client overwrites previously-granted scopes. + """ + if not incoming: + return existing + if not existing: + return incoming + + existing_set = set(existing.split()) + existing_set.update(incoming.split()) + return " ".join(sorted(existing_set)) + + def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]: """Generate an ordered list of URLs for authorization server metadata discovery. diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 5aa985e36..da8c9435b 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1327,7 +1327,7 @@ async def test_403_insufficient_scope_updates_scope_from_header( mock_storage: MockTokenStorage, valid_tokens: OAuthToken, ): - """Test that 403 response correctly updates scope from WWW-Authenticate header.""" + """Test that 403 response correctly accumulates scope from WWW-Authenticate header.""" # Pre-store valid tokens and client info client_info = OAuthClientInformationFull( client_id="test_client_id", @@ -1350,10 +1350,10 @@ async def test_403_insufficient_scope_updates_scope_from_header( async def capture_redirect(url: str) -> None: nonlocal redirect_captured, captured_state redirect_captured = True - # Verify the new scope is included in authorization URL - assert "scope=admin%3Awrite+admin%3Adelete" in url or "scope=admin:write+admin:delete" in url.replace( - "%3A", ":" - ).replace("+", " ") + # Verify the accumulated scopes are included (original + new) + decoded = url.replace("%3A", ":").replace("+", " ") + for s in ["admin:write", "admin:delete", "read", "write"]: + assert s in decoded, f"Expected scope '{s}' in URL" # Extract state from redirect URL parsed = urlparse(url) params = parse_qs(parsed.query) @@ -1383,8 +1383,9 @@ async def mock_callback() -> tuple[str, str | None]: # Trigger step-up - should get token exchange request token_exchange_request = await auth_flow.asend(response_403) - # Verify scope was updated - assert oauth_provider.context.client_metadata.scope == "admin:write admin:delete" + # Verify scope was accumulated (original "read write" + new "admin:write admin:delete") + accumulated = set(oauth_provider.context.client_metadata.scope.split()) + assert accumulated == {"admin:delete", "admin:write", "read", "write"} assert redirect_captured # Complete the flow with successful token response @@ -2264,3 +2265,50 @@ async def callback_handler() -> tuple[str, str | None]: await auth_flow.asend(final_response) except StopAsyncIteration: pass + + +class TestMergeScopes: + """Tests for merge_scopes utility function.""" + + def test_merge_none_existing_returns_incoming(self): + from mcp.client.auth.utils import merge_scopes + + assert merge_scopes(None, "mcp:tools:read") == "mcp:tools:read" + + def test_merge_none_incoming_returns_existing(self): + from mcp.client.auth.utils import merge_scopes + + assert merge_scopes("init", None) == "init" + + def test_merge_both_none_returns_none(self): + from mcp.client.auth.utils import merge_scopes + + assert merge_scopes(None, None) is None + + def test_merge_disjoint_scopes(self): + from mcp.client.auth.utils import merge_scopes + + result = merge_scopes("init", "mcp:tools:read") + assert result is not None + scopes = set(result.split()) + assert scopes == {"init", "mcp:tools:read"} + + def test_merge_overlapping_scopes_deduplicates(self): + from mcp.client.auth.utils import merge_scopes + + result = merge_scopes("init mcp:tools:read", "mcp:tools:read mcp:tools:write") + assert result is not None + scopes = set(result.split()) + assert scopes == {"init", "mcp:tools:read", "mcp:tools:write"} + + def test_merge_identical_scopes(self): + from mcp.client.auth.utils import merge_scopes + + result = merge_scopes("init", "init") + assert result == "init" + + def test_merge_empty_strings(self): + from mcp.client.auth.utils import merge_scopes + + assert merge_scopes("init", "") == "init" + assert merge_scopes("", "init") == "init"