From d1af78f0d544b85b0bc1c6c75f3952f9d50c49c8 Mon Sep 17 00:00:00 2001 From: Berat Elcelik Date: Mon, 15 Jun 2026 14:09:56 +0200 Subject: [PATCH] fix(client/auth): use stored refresh_token on 401 instead of full re-auth MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A reloaded OAuthClientProvider never restores token_expiry_time, so is_token_valid() reports a stored (possibly expired) access token as valid and the proactive-refresh branch in async_auth_flow is skipped. When the server then returns 401, the flow goes straight to a full authorization-code grant — even though a valid refresh_token and client_info are present in storage. Clients that construct a fresh provider per request (loading tokens from persistent storage) therefore force an interactive re-authorization on every access-token expiry instead of refreshing silently. Attempt a refresh_token grant in the 401 branch, after metadata discovery and before client registration / authorization. Discovery in that branch already yields the token endpoint, so the stored refresh_token is used directly; only if the refresh fails does the flow fall back to full re-authorization. Tests: add reload -> 401 -> refresh coverage to tests/client/test_auth.py for both the success path and the fallback-to-reauth path when refresh fails. Reported-by: Berat Elcelik --- src/mcp/client/auth/oauth2.py | 15 ++++ tests/client/test_auth.py | 162 +++++++++++++++++++++++++++++++++- 2 files changed, 175 insertions(+), 2 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 01bcc8234..d30a3271f 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -583,6 +583,21 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.client_metadata.grant_types, ) + # Step 3.5: Use a stored refresh_token before full re-auth. + # A provider reloaded from storage never restores + # token_expiry_time, so is_token_valid() treats the stale token + # as valid and the proactive-refresh branch above is skipped. + # Without this, every access-token expiry forces an interactive + # re-authorization instead of a silent refresh. + if self.context.can_refresh_token(): + refresh_request = await self._refresh_token() + refresh_response = yield refresh_request + if await self._handle_refresh_response(refresh_response): + self._add_auth_header(request) + yield request + return + # refresh failed -> fall through to full re-authorization + # Step 4: Register client or use URL-based client ID (CIMD) if not self.context.client_info: if should_use_client_metadata_url( diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index ca7a495e6..7095c14e3 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -45,13 +45,13 @@ def __init__(self): self._client_info: OAuthClientInformationFull | None = None async def get_tokens(self) -> OAuthToken | None: - return self._tokens # pragma: no cover + return self._tokens async def set_tokens(self, tokens: OAuthToken) -> None: self._tokens = tokens async def get_client_info(self) -> OAuthClientInformationFull | None: - return self._client_info # pragma: no cover + return self._client_info async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: self._client_info = client_info @@ -2636,3 +2636,161 @@ async def callback_handler() -> tuple[str, str | None]: await auth_flow.asend(final_response) except StopAsyncIteration: pass + + +def _reloaded_client_info() -> OAuthClientInformationFull: + return OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + +@pytest.mark.anyio +async def test_initialize_does_not_restore_token_expiry( + oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken +): + """After _initialize() loads stored tokens, token_expiry_time stays None. + + OAuthToken.expires_in is relative, so a reloaded provider cannot know the real + expiry; is_token_valid() then reports the loaded token as valid and the + proactive-refresh branch in async_auth_flow is skipped. + """ + await mock_storage.set_tokens(valid_tokens) + await mock_storage.set_client_info(_reloaded_client_info()) + + await oauth_provider._initialize() + + assert oauth_provider.context.token_expiry_time is None + assert oauth_provider.context.is_token_valid() is True + assert oauth_provider.context.can_refresh_token() is True + + +@pytest.mark.anyio +async def test_reloaded_provider_uses_refresh_token_on_401( + oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken +): + """A fresh provider that reloaded a stored refresh_token must perform a + refresh_token grant on 401 instead of a full re-authorization.""" + await mock_storage.set_tokens(valid_tokens) + await mock_storage.set_client_info(_reloaded_client_info()) + oauth_provider._perform_authorization_code_grant = mock.AsyncMock(return_value=("unused_code", "unused_verifier")) + + request = httpx.Request("GET", "https://api.example.com/mcp") + auth_flow = oauth_provider.async_auth_flow(request) + + # _initialize() loads the stale token and (wrongly) treats it as valid + first_request = await auth_flow.__anext__() + assert first_request.headers["Authorization"] == "Bearer test_access_token" + + # server rejects the stale token -> discovery -> refresh + resp_401 = httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + }, + request=request, + ) + prm_request = await auth_flow.asend(resp_401) + prm_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', + request=prm_request, + ) + asm_request = await auth_flow.asend(prm_response) + asm_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register"}' + ), + request=asm_request, + ) + + refresh_request = await auth_flow.asend(asm_response) + body = refresh_request.content.decode() + assert refresh_request.method == "POST" + assert str(refresh_request.url) == "https://auth.example.com/token" + assert "grant_type=refresh_token" in body + assert "refresh_token=test_refresh_token" in body + oauth_provider._perform_authorization_code_grant.assert_not_called() + + refresh_response = httpx.Response( + 200, + content=( + b'{"access_token": "new_access_token", "token_type": "Bearer", ' + b'"expires_in": 3600, "refresh_token": "new_rt"}' + ), + request=refresh_request, + ) + final_request = await auth_flow.asend(refresh_response) + assert final_request.headers["Authorization"] == "Bearer new_access_token" + + try: + await auth_flow.asend(httpx.Response(200, request=final_request)) + except StopAsyncIteration: + pass + + +@pytest.mark.anyio +async def test_reloaded_provider_falls_back_to_reauth_when_refresh_fails( + oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken +): + """If the refresh_token grant fails, the 401 flow falls back to full + re-authorization rather than erroring out.""" + await mock_storage.set_tokens(valid_tokens) + await mock_storage.set_client_info(_reloaded_client_info()) + oauth_provider._perform_authorization_code_grant = mock.AsyncMock(return_value=("reauth_code", "reauth_verifier")) + + request = httpx.Request("GET", "https://api.example.com/mcp") + auth_flow = oauth_provider.async_auth_flow(request) + await auth_flow.__anext__() + + resp_401 = httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + }, + request=request, + ) + prm_request = await auth_flow.asend(resp_401) + prm_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', + request=prm_request, + ) + asm_request = await auth_flow.asend(prm_response) + asm_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register"}' + ), + request=asm_request, + ) + + refresh_request = await auth_flow.asend(asm_response) + assert "grant_type=refresh_token" in refresh_request.content.decode() + + # refresh fails -> fall through to full authorization-code grant + refresh_failure = httpx.Response(400, content=b'{"error": "invalid_grant"}', request=refresh_request) + token_request = await auth_flow.asend(refresh_failure) + assert "grant_type=authorization_code" in token_request.content.decode() + oauth_provider._perform_authorization_code_grant.assert_called_once() + + token_response = httpx.Response( + 200, + content=b'{"access_token": "reauth_access_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer reauth_access_token" + + try: + await auth_flow.asend(httpx.Response(200, request=final_request)) + except StopAsyncIteration: + pass