diff --git a/client/src/__tests__/proxyFetchEndpoint.test.ts b/client/src/__tests__/proxyFetchEndpoint.test.ts index 4be75f6c4..0eb2de6fd 100644 --- a/client/src/__tests__/proxyFetchEndpoint.test.ts +++ b/client/src/__tests__/proxyFetchEndpoint.test.ts @@ -1,5 +1,5 @@ /** - * Tests for the proxy server's POST /fetch endpoint. + * Tests for the proxy server's HTTP endpoints. * Spawns the server and hits it like any other HTTP client would. */ import { spawn, type ChildProcess } from "child_process"; @@ -61,7 +61,7 @@ async function withLocalUpstream( } } -describe("POST /fetch endpoint", () => { +describe("Inspector proxy endpoints", () => { let server: ChildProcess; const baseUrl = `http://localhost:${TEST_PORT}`; @@ -231,4 +231,40 @@ describe("POST /fetch endpoint", () => { }, ); }); + + it("forwards upstream SSE 401 challenge details to the browser", async () => { + const challenge = + 'Bearer resource_metadata="http://127.0.0.1/.well-known/oauth-protected-resource/mcp"'; + const upstreamPayload = JSON.stringify({ + error: "unauthorized", + error_description: "Authentication required", + }); + + await withLocalUpstream( + (req, res) => { + res.writeHead(401, { + "Content-Type": "application/json", + "WWW-Authenticate": challenge, + }); + res.end(upstreamPayload); + }, + async (origin) => { + const proxyUrl = new URL(`${baseUrl}/sse`); + proxyUrl.searchParams.set("transportType", "sse"); + proxyUrl.searchParams.set("url", `${origin}/events`); + + const res = await fetch(proxyUrl, { + headers: { + "X-MCP-Proxy-Auth": `Bearer ${TEST_TOKEN}`, + }, + }); + + const body = await res.text(); + expect(res.status).toBe(401); + expect(res.headers.get("www-authenticate")).toBe(challenge); + expect(res.headers.get("content-type")).toMatch(/application\/json/i); + expect(body).toBe(upstreamPayload); + }, + ); + }); }); diff --git a/client/src/components/OAuthCallback.tsx b/client/src/components/OAuthCallback.tsx index ccfd6d928..bfff5bcc8 100644 --- a/client/src/components/OAuthCallback.tsx +++ b/client/src/components/OAuthCallback.tsx @@ -7,6 +7,7 @@ import { generateOAuthErrorDescription, parseOAuthCallbackParams, } from "@/utils/oauthUtils.ts"; +import { getResourceMetadataUrlFromSessionStorage } from "@/lib/oauth-resource-metadata"; interface OAuthCallbackProps { onConnect: (serverUrl: string) => void; @@ -49,6 +50,8 @@ const OAuthCallback = ({ onConnect }: OAuthCallbackProps) => { result = await auth(serverAuthProvider, { serverUrl, authorizationCode: params.code, + resourceMetadataUrl: + getResourceMetadataUrlFromSessionStorage(serverUrl), }); } catch (error) { console.error("OAuth callback error:", error); diff --git a/client/src/components/__tests__/AuthDebugger.test.tsx b/client/src/components/__tests__/AuthDebugger.test.tsx index 71eec04aa..5ae037384 100644 --- a/client/src/components/__tests__/AuthDebugger.test.tsx +++ b/client/src/components/__tests__/AuthDebugger.test.tsx @@ -44,6 +44,7 @@ jest.mock("@modelcontextprotocol/sdk/client/auth.js", () => ({ exchangeAuthorization: jest.fn(), discoverOAuthProtectedResourceMetadata: jest.fn(), selectResourceURL: jest.fn(), + extractWWWAuthenticateParams: jest.fn(() => ({})), })); // Import the functions to get their types @@ -64,10 +65,17 @@ jest.mock("../../lib/auth", () => ({ tokens: jest.fn().mockImplementation(() => Promise.resolve(undefined)), clear: jest.fn().mockImplementation(() => { // Mock the real clear() behavior which removes items from sessionStorage - sessionStorage.removeItem("[https://example.com/mcp] mcp_tokens"); - sessionStorage.removeItem("[https://example.com/mcp] mcp_client_info"); sessionStorage.removeItem( - "[https://example.com/mcp] mcp_server_metadata", + `[https://example.com/mcp] ${SESSION_KEYS.CLIENT_INFORMATION}`, + ); + sessionStorage.removeItem( + `[https://example.com/mcp] ${SESSION_KEYS.TOKENS}`, + ); + sessionStorage.removeItem( + `[https://example.com/mcp] ${SESSION_KEYS.CODE_VERIFIER}`, + ); + sessionStorage.removeItem( + `[https://example.com/mcp] ${SESSION_KEYS.RESOURCE_METADATA_URL}`, ); }), redirectUrl: "http://localhost:3000/oauth/callback/debug", @@ -155,6 +163,11 @@ describe("AuthDebugger", () => { beforeEach(() => { jest.clearAllMocks(); sessionStorageMock.getItem.mockReturnValue(null); + global.fetch = jest.fn().mockResolvedValue( + new Response("", { + status: 404, + }), + ); // Suppress console errors in tests to avoid JSDOM navigation noise jest.spyOn(console, "error").mockImplementation(() => {}); @@ -403,6 +416,9 @@ describe("AuthDebugger", () => { // Verify session storage was cleared expect(sessionStorageMock.removeItem).toHaveBeenCalled(); + expect(sessionStorageMock.removeItem).toHaveBeenCalledWith( + `[https://example.com/mcp] ${SESSION_KEYS.RESOURCE_METADATA_URL}`, + ); }); }); diff --git a/client/src/components/__tests__/OAuthCallback.test.tsx b/client/src/components/__tests__/OAuthCallback.test.tsx new file mode 100644 index 000000000..ee56a8cab --- /dev/null +++ b/client/src/components/__tests__/OAuthCallback.test.tsx @@ -0,0 +1,87 @@ +import { render, waitFor } from "@testing-library/react"; +import OAuthCallback from "../OAuthCallback"; +import { auth } from "@modelcontextprotocol/sdk/client/auth.js"; +import { getServerSpecificKey, SESSION_KEYS } from "../../lib/constants"; + +jest.mock("@modelcontextprotocol/sdk/client/auth.js", () => ({ + auth: jest.fn(), + extractWWWAuthenticateParams: jest.fn(() => ({})), +})); + +jest.mock("../../lib/auth", () => ({ + InspectorOAuthClientProvider: jest.fn().mockImplementation(() => ({})), +})); + +jest.mock("@/lib/hooks/useToast", () => ({ + useToast: () => ({ + toast: jest.fn(), + }), +})); + +const mockAuth = auth as jest.MockedFunction; + +describe("OAuthCallback", () => { + beforeEach(() => { + jest.clearAllMocks(); + mockAuth.mockResolvedValue("AUTHORIZED"); + + sessionStorage.clear(); + sessionStorage.setItem( + SESSION_KEYS.SERVER_URL, + "http://localhost:8080/jenkins/mcp-server/mcp", + ); + sessionStorage.setItem( + getServerSpecificKey( + SESSION_KEYS.RESOURCE_METADATA_URL, + "http://localhost:8080/jenkins/mcp-server/mcp", + ), + "http://localhost:8080/jenkins/.well-known/oauth-protected-resource/mcp-server/mcp", + ); + + window.history.pushState({}, "", "/oauth/callback?code=test-code"); + jest.spyOn(window.history, "replaceState").mockImplementation(() => {}); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + it("passes persisted resource metadata URL when exchanging authorization code", async () => { + render(); + + await waitFor(() => { + expect(mockAuth).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + serverUrl: "http://localhost:8080/jenkins/mcp-server/mcp", + authorizationCode: "test-code", + resourceMetadataUrl: new URL( + "http://localhost:8080/jenkins/.well-known/oauth-protected-resource/mcp-server/mcp", + ), + }), + ); + }); + }); + + it("continues without resource metadata URL when none was persisted", async () => { + sessionStorage.removeItem( + getServerSpecificKey( + SESSION_KEYS.RESOURCE_METADATA_URL, + "http://localhost:8080/jenkins/mcp-server/mcp", + ), + ); + + render(); + + await waitFor(() => { + expect(mockAuth).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + serverUrl: "http://localhost:8080/jenkins/mcp-server/mcp", + authorizationCode: "test-code", + resourceMetadataUrl: undefined, + }), + ); + }); + }); +}); diff --git a/client/src/lib/__tests__/auth.test.ts b/client/src/lib/__tests__/auth.test.ts index 03c503d81..4d526aa65 100644 --- a/client/src/lib/__tests__/auth.test.ts +++ b/client/src/lib/__tests__/auth.test.ts @@ -1,5 +1,6 @@ -import { discoverScopes } from "../auth"; +import { discoverScopes, InspectorOAuthClientProvider } from "../auth"; import { discoverAuthorizationServerMetadata } from "@modelcontextprotocol/sdk/client/auth.js"; +import { getServerSpecificKey, SESSION_KEYS } from "../constants"; jest.mock("@modelcontextprotocol/sdk/client/auth.js", () => ({ discoverAuthorizationServerMetadata: jest.fn(), @@ -156,3 +157,25 @@ describe("discoverScopes", () => { }, ); }); + +describe("InspectorOAuthClientProvider", () => { + beforeEach(() => { + sessionStorage.clear(); + }); + + it("clears the server-specific resource metadata URL", () => { + const serverUrl = "https://example.com/mcp"; + const resourceMetadataKey = getServerSpecificKey( + SESSION_KEYS.RESOURCE_METADATA_URL, + serverUrl, + ); + sessionStorage.setItem( + resourceMetadataKey, + "https://example.com/.well-known/oauth-protected-resource/mcp", + ); + + new InspectorOAuthClientProvider(serverUrl).clear(); + + expect(sessionStorage.getItem(resourceMetadataKey)).toBeNull(); + }); +}); diff --git a/client/src/lib/__tests__/oauth-resource-metadata.test.ts b/client/src/lib/__tests__/oauth-resource-metadata.test.ts new file mode 100644 index 000000000..0b7b69fc2 --- /dev/null +++ b/client/src/lib/__tests__/oauth-resource-metadata.test.ts @@ -0,0 +1,212 @@ +import { + clearResourceMetadataUrlFromSessionStorage, + discoverResourceMetadataUrlFromServer, + extractResourceMetadataUrlFromAuthError, + extractResourceMetadataUrlFromWWWAuthenticate, + getResourceMetadataUrlFromSessionStorage, + saveResourceMetadataUrlToSessionStorage, +} from "../oauth-resource-metadata"; + +// The SDK auth module imports PKCE generation eagerly, but these tests only +// exercise its WWW-Authenticate parser. +jest.mock("pkce-challenge", () => jest.fn(), { virtual: true }); + +describe("oauth-resource-metadata", () => { + beforeEach(() => { + sessionStorage.clear(); + }); + + it("extracts resource_metadata from WWW-Authenticate", () => { + expect( + extractResourceMetadataUrlFromWWWAuthenticate( + 'Bearer realm="mcp", resource_metadata="https://example.com/.well-known/oauth-protected-resource/mcp"', + ), + ).toEqual( + new URL("https://example.com/.well-known/oauth-protected-resource/mcp"), + ); + }); + + it("extracts resource_metadata alongside other Bearer challenge parameters", () => { + expect( + extractResourceMetadataUrlFromWWWAuthenticate( + 'Bearer realm="mcp", resource_metadata="https://example.com/.well-known/oauth-protected-resource/mcp", scope="read write"', + ), + ).toEqual( + new URL("https://example.com/.well-known/oauth-protected-resource/mcp"), + ); + }); + + it.each([ + ["a missing resource_metadata parameter", 'Bearer realm="mcp"'], + [ + "an invalid resource_metadata URL", + 'Bearer resource_metadata="not a URL"', + ], + [ + "a non-Bearer challenge", + 'Basic resource_metadata="https://example.com/.well-known/oauth-protected-resource/mcp"', + ], + ])("ignores %s", (_description, wwwAuthenticate) => { + expect( + extractResourceMetadataUrlFromWWWAuthenticate(wwwAuthenticate), + ).toBeUndefined(); + }); + + it("extracts resource_metadata from a proxy upstream401 snapshot", () => { + expect( + extractResourceMetadataUrlFromAuthError({ + data: { + upstream401: { + wwwAuthenticate: + 'Bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource/mcp"', + }, + }, + }), + ).toEqual( + new URL("https://example.com/.well-known/oauth-protected-resource/mcp"), + ); + }); + + it.each([ + ["missing data", {}], + ["null data", { data: null }], + ["array data", { data: [] }], + ["missing upstream401", { data: {} }], + ["null upstream401", { data: { upstream401: null } }], + [ + "non-string WWW-Authenticate", + { data: { upstream401: { wwwAuthenticate: 123 } } }, + ], + [ + "invalid resource_metadata URL", + { + data: { + upstream401: { + wwwAuthenticate: 'Bearer resource_metadata="not a URL"', + }, + }, + }, + ], + ])("ignores proxy upstream401 snapshot with %s", (_description, error) => { + expect(extractResourceMetadataUrlFromAuthError(error)).toBeUndefined(); + }); + + it("persists and clears resource metadata URL by server", () => { + const serverUrl = "https://example.com/tenant-a/mcp"; + const resourceMetadataUrl = new URL( + "https://example.com/tenant-a/.well-known/oauth-protected-resource/mcp", + ); + + saveResourceMetadataUrlToSessionStorage(serverUrl, resourceMetadataUrl); + + expect(getResourceMetadataUrlFromSessionStorage(serverUrl)).toEqual( + resourceMetadataUrl, + ); + expect( + getResourceMetadataUrlFromSessionStorage( + "https://example.com/tenant-b/mcp", + ), + ).toBeUndefined(); + + clearResourceMetadataUrlFromSessionStorage(serverUrl); + + expect(getResourceMetadataUrlFromSessionStorage(serverUrl)).toBeUndefined(); + }); + + it("ignores invalid stored URLs", () => { + sessionStorage.setItem( + "[https://example.com/mcp] mcp_resource_metadata_url", + "not a URL", + ); + + expect( + getResourceMetadataUrlFromSessionStorage("https://example.com/mcp"), + ).toBeUndefined(); + }); + + it("discovers resource metadata URL from a 401 challenge", async () => { + const resourceMetadataUrl = new URL( + "https://example.com/.well-known/oauth-protected-resource/mcp", + ); + const fetchFn = jest.fn().mockResolvedValue( + new Response("{}", { + status: 401, + headers: { + "WWW-Authenticate": `Bearer resource_metadata="${resourceMetadataUrl.href}"`, + }, + }), + ) as jest.MockedFunction; + + await expect( + discoverResourceMetadataUrlFromServer("https://example.com/mcp", fetchFn), + ).resolves.toEqual(resourceMetadataUrl); + expect(fetchFn).toHaveBeenCalledWith("https://example.com/mcp", { + headers: { Accept: "application/json, text/event-stream" }, + }); + }); + + it("discovers resource metadata URL from a 403 challenge", async () => { + const resourceMetadataUrl = new URL( + "https://example.com/.well-known/oauth-protected-resource/mcp", + ); + const fetchFn = jest.fn().mockResolvedValue( + new Response("{}", { + status: 403, + headers: { + "WWW-Authenticate": `Bearer error="insufficient_scope", resource_metadata="${resourceMetadataUrl.href}"`, + }, + }), + ) as jest.MockedFunction; + + await expect( + discoverResourceMetadataUrlFromServer("https://example.com/mcp", fetchFn), + ).resolves.toEqual(resourceMetadataUrl); + }); + + it("ignores resource metadata URL on non-401/403 responses", async () => { + const fetchFn = jest.fn().mockResolvedValue( + new Response("{}", { + status: 200, + headers: { + "WWW-Authenticate": + 'Bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource/mcp"', + }, + }), + ) as jest.MockedFunction; + + await expect( + discoverResourceMetadataUrlFromServer("https://example.com/mcp", fetchFn), + ).resolves.toBeUndefined(); + }); + + it("ignores network failures during resource metadata discovery", async () => { + const fetchFn = jest + .fn() + .mockRejectedValue(new Error("network failed")) as jest.MockedFunction< + typeof fetch + >; + + await expect( + discoverResourceMetadataUrlFromServer("https://example.com/mcp", fetchFn), + ).resolves.toBeUndefined(); + }); + + it("keeps discovered resource metadata URL when response body cancellation fails", async () => { + const resourceMetadataUrl = new URL( + "https://example.com/.well-known/oauth-protected-resource/mcp", + ); + const fetchFn = jest.fn().mockResolvedValue({ + status: 401, + headers: new Headers({ + "WWW-Authenticate": `Bearer resource_metadata="${resourceMetadataUrl.href}"`, + }), + body: { + cancel: jest.fn().mockRejectedValue(new Error("cancel failed")), + }, + } as unknown as Response); + + await expect( + discoverResourceMetadataUrlFromServer("https://example.com/mcp", fetchFn), + ).resolves.toEqual(resourceMetadataUrl); + }); +}); diff --git a/client/src/lib/__tests__/oauth-state-machine.test.ts b/client/src/lib/__tests__/oauth-state-machine.test.ts new file mode 100644 index 000000000..1c2e479b4 --- /dev/null +++ b/client/src/lib/__tests__/oauth-state-machine.test.ts @@ -0,0 +1,131 @@ +import { OAuthStateMachine } from "../oauth-state-machine"; +import { EMPTY_DEBUGGER_STATE } from "../auth-types"; +import { getServerSpecificKey, SESSION_KEYS } from "../constants"; +import { + discoverAuthorizationServerMetadata, + discoverOAuthProtectedResourceMetadata, + extractWWWAuthenticateParams, + selectResourceURL, +} from "@modelcontextprotocol/sdk/client/auth.js"; + +jest.mock("@modelcontextprotocol/sdk/client/auth.js", () => ({ + discoverAuthorizationServerMetadata: jest.fn(), + registerClient: jest.fn(), + startAuthorization: jest.fn(), + exchangeAuthorization: jest.fn(), + discoverOAuthProtectedResourceMetadata: jest.fn(), + extractWWWAuthenticateParams: jest.fn(), + selectResourceURL: jest.fn(), +})); + +const mockDiscoverAuthorizationServerMetadata = + discoverAuthorizationServerMetadata as jest.MockedFunction< + typeof discoverAuthorizationServerMetadata + >; +const mockDiscoverOAuthProtectedResourceMetadata = + discoverOAuthProtectedResourceMetadata as jest.MockedFunction< + typeof discoverOAuthProtectedResourceMetadata + >; +const mockSelectResourceURL = selectResourceURL as jest.MockedFunction< + typeof selectResourceURL +>; +const mockExtractWWWAuthenticateParams = + extractWWWAuthenticateParams as jest.MockedFunction< + typeof extractWWWAuthenticateParams + >; + +describe("OAuthStateMachine", () => { + beforeEach(() => { + jest.clearAllMocks(); + sessionStorage.clear(); + + mockDiscoverAuthorizationServerMetadata.mockResolvedValue({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + grant_types_supported: ["authorization_code"], + }); + mockSelectResourceURL.mockResolvedValue( + new URL("http://localhost:8080/jenkins/mcp-server/mcp"), + ); + mockExtractWWWAuthenticateParams.mockReturnValue({}); + }); + + it("uses resource_metadata from the current server challenge", async () => { + const serverUrl = "http://localhost:8080/jenkins/mcp-server/mcp"; + const resourceMetadataUrl = new URL( + "http://localhost:8080/jenkins/.well-known/oauth-protected-resource/mcp-server/mcp", + ); + const resourceMetadata = { + resource: serverUrl, + authorization_servers: ["https://auth.example.com"], + }; + const fetchFn = jest.fn().mockResolvedValue( + new Response("{}", { + status: 401, + headers: { + "WWW-Authenticate": `Bearer resource_metadata="${resourceMetadataUrl.href}"`, + }, + }), + ) as jest.MockedFunction; + + mockDiscoverOAuthProtectedResourceMetadata.mockResolvedValue( + resourceMetadata, + ); + mockExtractWWWAuthenticateParams.mockReturnValue({ + resourceMetadataUrl, + }); + + await new OAuthStateMachine(serverUrl, jest.fn(), fetchFn).executeStep({ + ...EMPTY_DEBUGGER_STATE, + oauthStep: "metadata_discovery", + }); + + expect(fetchFn).toHaveBeenCalledWith(serverUrl, expect.any(Object)); + expect(mockDiscoverOAuthProtectedResourceMetadata).toHaveBeenCalledWith( + serverUrl, + { resourceMetadataUrl }, + fetchFn, + ); + expect( + sessionStorage.getItem( + getServerSpecificKey(SESSION_KEYS.RESOURCE_METADATA_URL, serverUrl), + ), + ).toBe(resourceMetadataUrl.href); + }); + + it("does not reuse stored resource_metadata without a current challenge", async () => { + const serverUrl = "http://localhost:8080/current/mcp"; + const storageKey = getServerSpecificKey( + SESSION_KEYS.RESOURCE_METADATA_URL, + serverUrl, + ); + const fetchFn = jest.fn().mockResolvedValue( + new Response("{}", { + status: 401, + }), + ) as jest.MockedFunction; + + sessionStorage.setItem( + storageKey, + "http://localhost:8080/previous/.well-known/oauth-protected-resource", + ); + mockDiscoverOAuthProtectedResourceMetadata.mockResolvedValue({ + resource: serverUrl, + authorization_servers: ["https://auth.example.com"], + }); + + await new OAuthStateMachine(serverUrl, jest.fn(), fetchFn).executeStep({ + ...EMPTY_DEBUGGER_STATE, + oauthStep: "metadata_discovery", + }); + + expect(mockDiscoverOAuthProtectedResourceMetadata).toHaveBeenCalledWith( + serverUrl, + {}, + fetchFn, + ); + expect(sessionStorage.getItem(storageKey)).toBeNull(); + }); +}); diff --git a/client/src/lib/auth.ts b/client/src/lib/auth.ts index f0fc2fc4b..1fdbb3672 100644 --- a/client/src/lib/auth.ts +++ b/client/src/lib/auth.ts @@ -12,6 +12,7 @@ import { discoverAuthorizationServerMetadata } from "@modelcontextprotocol/sdk/c import { SESSION_KEYS, getServerSpecificKey } from "./constants"; import { generateOAuthState } from "@/utils/oauthUtils"; import { validateRedirectUrl } from "@/utils/urlValidation"; +import { clearResourceMetadataUrlFromSessionStorage } from "./oauth-resource-metadata"; /** * Discovers OAuth scopes from server metadata, with preference for resource metadata scopes @@ -262,6 +263,7 @@ export class InspectorOAuthClientProvider implements OAuthClientProvider { sessionStorage.removeItem( getServerSpecificKey(SESSION_KEYS.CODE_VERIFIER, this.serverUrl), ); + clearResourceMetadataUrlFromSessionStorage(this.serverUrl); } } diff --git a/client/src/lib/constants.ts b/client/src/lib/constants.ts index d986d3802..9df0137d3 100644 --- a/client/src/lib/constants.ts +++ b/client/src/lib/constants.ts @@ -17,6 +17,7 @@ export const SESSION_KEYS = { PREREGISTERED_CLIENT_INFORMATION: "mcp_preregistered_client_information", SERVER_METADATA: "mcp_server_metadata", AUTH_DEBUGGER_STATE: "mcp_auth_debugger_state", + RESOURCE_METADATA_URL: "mcp_resource_metadata_url", SCOPE: "mcp_scope", } as const; diff --git a/client/src/lib/hooks/__tests__/useConnection.test.tsx b/client/src/lib/hooks/__tests__/useConnection.test.tsx index 875c9e387..66db245d9 100644 --- a/client/src/lib/hooks/__tests__/useConnection.test.tsx +++ b/client/src/lib/hooks/__tests__/useConnection.test.tsx @@ -15,6 +15,8 @@ import { DEFAULT_INSPECTOR_CONFIG, CLIENT_IDENTITY, MCP_PROXY_TRANSPORT_ERROR_CODE, + SESSION_KEYS, + getServerSpecificKey, } from "../../constants"; import { SSEClientTransportOptions, @@ -24,7 +26,10 @@ import { ElicitResult, ElicitRequest, } from "@modelcontextprotocol/sdk/types.js"; -import { auth } from "@modelcontextprotocol/sdk/client/auth.js"; +import { + auth, + extractWWWAuthenticateParams, +} from "@modelcontextprotocol/sdk/client/auth.js"; import { discoverScopes } from "../../auth"; import { CustomHeaders } from "../../types/customHeaders"; @@ -66,7 +71,9 @@ const mockSSETransport: { const mockStreamableHTTPTransport: { start: jest.Mock; url: URL | undefined; - options: SSEClientTransportOptions | undefined; + options: + | import("@modelcontextprotocol/sdk/client/streamableHttp.js").StreamableHTTPClientTransportOptions + | undefined; } = { start: jest.fn(), url: undefined, @@ -129,6 +136,8 @@ jest.mock("@modelcontextprotocol/sdk/client/auth.js", () => { return { UnauthorizedError, auth: jest.fn().mockResolvedValue("AUTHORIZED"), + discoverOAuthProtectedResourceMetadata: jest.fn(), + extractWWWAuthenticateParams: jest.fn(), }; }); @@ -154,6 +163,10 @@ jest.mock("../../auth", () => ({ })); const mockAuth = auth as jest.MockedFunction; +const mockExtractWWWAuthenticateParams = + extractWWWAuthenticateParams as jest.MockedFunction< + typeof extractWWWAuthenticateParams + >; const mockDiscoverScopes = discoverScopes as jest.MockedFunction< typeof discoverScopes >; @@ -1216,6 +1229,66 @@ describe("useConnection", () => { mockStreamableHTTPTransport.options?.requestInit?.headers, ).toHaveProperty("X-MCP-Proxy-Auth", "Bearer test-proxy-token"); }); + + test("preserves streamable-http per-request headers when adding proxy auth", async () => { + const fetchMock = global.fetch as jest.Mock; + fetchMock.mockClear(); + + const propsWithStreamableHttp = { + ...defaultProps, + transportType: "streamable-http" as const, + config: { + ...DEFAULT_INSPECTOR_CONFIG, + MCP_PROXY_AUTH_TOKEN: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_PROXY_AUTH_TOKEN, + value: "test-proxy-token", + }, + }, + customHeaders: [ + { + name: "X-Tenant-ID", + value: "acme", + enabled: true, + }, + ], + }; + + const { result } = renderHook(() => + useConnection(propsWithStreamableHttp), + ); + + await act(async () => { + await result.current.connect(); + }); + + await mockStreamableHTTPTransport.options?.fetch?.( + "http://localhost:6277/mcp", + { + method: "POST", + headers: new Headers({ + accept: "application/json, text/event-stream", + "content-type": "application/json", + "mcp-session-id": "session-123", + }), + body: "{}", + }, + ); + + const proxiedCall = fetchMock.mock.calls.find( + (call) => call[0] === "http://localhost:6277/mcp", + ); + expect(proxiedCall).toBeDefined(); + const forwardedHeaders = new Headers(proxiedCall[1].headers); + expect(forwardedHeaders.get("accept")).toBe( + "application/json, text/event-stream", + ); + expect(forwardedHeaders.get("content-type")).toBe("application/json"); + expect(forwardedHeaders.get("mcp-session-id")).toBe("session-123"); + expect(forwardedHeaders.get("X-MCP-Proxy-Auth")).toBe( + "Bearer test-proxy-token", + ); + expect(forwardedHeaders.get("X-Tenant-ID")).toBe("acme"); + }); }); describe("Custom Headers", () => { @@ -1387,6 +1460,338 @@ describe("useConnection", () => { }); }); + describe("OAuth resource metadata persistence", () => { + const serverUrl = "http://localhost:8080/jenkins/mcp-server/mcp"; + const resourceMetadataUrl = new URL( + "http://localhost:8080/jenkins/.well-known/oauth-protected-resource/mcp-server/mcp", + ); + const resourceMetadataKey = getServerSpecificKey( + SESSION_KEYS.RESOURCE_METADATA_URL, + serverUrl, + ); + const healthyProxyResponse = () => ({ + json: () => Promise.resolve({ status: "ok" }), + headers: { + get: jest.fn().mockReturnValue(null), + }, + }); + + beforeEach(() => { + jest.clearAllMocks(); + sessionStorage.clear(); + (global.fetch as jest.Mock) + .mockReset() + .mockResolvedValue(healthyProxyResponse()); + mockClient.connect.mockResolvedValue(undefined); + mockAuth.mockResolvedValue("AUTHORIZED"); + mockDiscoverScopes.mockResolvedValue(undefined); + mockExtractWWWAuthenticateParams.mockReturnValue({}); + mockSSETransport.url = undefined; + mockSSETransport.options = undefined; + mockStreamableHTTPTransport.url = undefined; + mockStreamableHTTPTransport.options = undefined; + }); + + it("persists resource_metadata from direct transport responses", async () => { + (global.fetch as jest.Mock).mockResolvedValueOnce( + new Response("{}", { + status: 401, + headers: { + "WWW-Authenticate": `Bearer resource_metadata="${resourceMetadataUrl.href}"`, + }, + }), + ); + mockExtractWWWAuthenticateParams.mockReturnValueOnce({ + resourceMetadataUrl, + }); + + const { result } = renderHook(() => + useConnection({ + ...defaultProps, + connectionType: "direct", + sseUrl: serverUrl, + }), + ); + + await act(async () => { + await result.current.connect(); + }); + await mockSSETransport.options?.fetch?.(serverUrl); + + expect(sessionStorage.getItem(resourceMetadataKey)).toBe( + resourceMetadataUrl.href, + ); + }); + + it("persists resource_metadata from direct streamable HTTP responses", async () => { + (global.fetch as jest.Mock).mockResolvedValueOnce( + new Response("{}", { + status: 401, + headers: { + "WWW-Authenticate": `Bearer resource_metadata="${resourceMetadataUrl.href}"`, + }, + }), + ); + mockExtractWWWAuthenticateParams.mockReturnValueOnce({ + resourceMetadataUrl, + }); + + const { result } = renderHook(() => + useConnection({ + ...defaultProps, + connectionType: "direct", + transportType: "streamable-http", + sseUrl: serverUrl, + }), + ); + + await act(async () => { + await result.current.connect(); + }); + await mockStreamableHTTPTransport.options?.fetch?.(serverUrl); + + expect(sessionStorage.getItem(resourceMetadataKey)).toBe( + resourceMetadataUrl.href, + ); + }); + + it("persists resource_metadata from proxy transport responses", async () => { + (global.fetch as jest.Mock) + .mockResolvedValueOnce(healthyProxyResponse()) + .mockResolvedValueOnce( + new Response("{}", { + status: 401, + headers: { + "WWW-Authenticate": `Bearer resource_metadata="${resourceMetadataUrl.href}"`, + }, + }), + ); + mockExtractWWWAuthenticateParams.mockReturnValueOnce({ + resourceMetadataUrl, + }); + + const { result } = renderHook(() => + useConnection({ + ...defaultProps, + sseUrl: serverUrl, + }), + ); + + await act(async () => { + await result.current.connect(); + }); + await mockSSETransport.options?.eventSourceInit?.fetch?.( + "http://localhost:6277/sse", + ); + + expect(sessionStorage.getItem(resourceMetadataKey)).toBe( + resourceMetadataUrl.href, + ); + }); + + it("does not delegate proxy SSE OAuth recovery to the SDK transport URL", async () => { + const { result } = renderHook(() => + useConnection({ + ...defaultProps, + sseUrl: serverUrl, + }), + ); + + await act(async () => { + await result.current.connect(); + }); + + expect(mockSSETransport.options?.authProvider).toBeUndefined(); + }); + + it("persists resource_metadata from proxy streamable HTTP responses", async () => { + (global.fetch as jest.Mock) + .mockResolvedValueOnce(healthyProxyResponse()) + .mockResolvedValueOnce( + new Response("{}", { + status: 401, + headers: { + "WWW-Authenticate": `Bearer resource_metadata="${resourceMetadataUrl.href}"`, + }, + }), + ); + mockExtractWWWAuthenticateParams.mockReturnValueOnce({ + resourceMetadataUrl, + }); + + const { result } = renderHook(() => + useConnection({ + ...defaultProps, + transportType: "streamable-http", + sseUrl: serverUrl, + }), + ); + + await act(async () => { + await result.current.connect(); + }); + await mockStreamableHTTPTransport.options?.fetch?.( + "http://localhost:6277/mcp", + ); + + expect(sessionStorage.getItem(resourceMetadataKey)).toBe( + resourceMetadataUrl.href, + ); + }); + + it("does not delegate proxy streamable HTTP OAuth recovery to the SDK transport URL", async () => { + const { result } = renderHook(() => + useConnection({ + ...defaultProps, + transportType: "streamable-http", + sseUrl: serverUrl, + }), + ); + + await act(async () => { + await result.current.connect(); + }); + + expect(mockStreamableHTTPTransport.options?.authProvider).toBeUndefined(); + }); + + it("ignores resource_metadata from successful transport responses", async () => { + (global.fetch as jest.Mock).mockResolvedValueOnce( + new Response("{}", { + status: 200, + headers: { + "WWW-Authenticate": `Bearer resource_metadata="${resourceMetadataUrl.href}"`, + }, + }), + ); + mockExtractWWWAuthenticateParams.mockReturnValueOnce({ + resourceMetadataUrl, + }); + + const { result } = renderHook(() => + useConnection({ + ...defaultProps, + connectionType: "direct", + sseUrl: serverUrl, + }), + ); + + await act(async () => { + await result.current.connect(); + }); + await mockSSETransport.options?.fetch?.(serverUrl); + + expect(sessionStorage.getItem(resourceMetadataKey)).toBeNull(); + }); + + it("passes resource_metadata from proxy upstream401 to auth", async () => { + mockClient.connect.mockRejectedValueOnce( + new McpError(MCP_PROXY_TRANSPORT_ERROR_CODE, "proxy transport", { + upstream401: { + wwwAuthenticate: `Bearer resource_metadata="${resourceMetadataUrl.href}"`, + body: "{}", + contentType: "application/json", + }, + }), + ); + mockExtractWWWAuthenticateParams.mockReturnValueOnce({ + resourceMetadataUrl, + }); + + const { result } = renderHook(() => + useConnection({ + ...defaultProps, + sseUrl: serverUrl, + }), + ); + + await act(async () => { + await result.current.connect(); + }); + + expect(mockAuth).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ resourceMetadataUrl }), + ); + }); + + it("passes resource_metadata captured from a proxy transport response to auth recovery", async () => { + (global.fetch as jest.Mock).mockImplementation((url) => { + if (String(url).includes("/health")) { + return Promise.resolve(healthyProxyResponse()); + } + return Promise.resolve( + new Response("{}", { + status: 401, + headers: { + "WWW-Authenticate": `Bearer resource_metadata="${resourceMetadataUrl.href}"`, + }, + }), + ); + }); + mockExtractWWWAuthenticateParams.mockReturnValueOnce({ + resourceMetadataUrl, + }); + mockClient.connect.mockImplementationOnce(async () => { + await mockSSETransport.options?.eventSourceInit?.fetch?.( + "http://localhost:6277/sse", + ); + throw new SseError( + 401, + "Unauthorized", + new ErrorEvent("error", { message: "Unauthorized" }), + ); + }); + + const { result } = renderHook(() => + useConnection({ + ...defaultProps, + sseUrl: serverUrl, + }), + ); + + await act(async () => { + await result.current.connect(); + }); + + expect(mockAuth).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ resourceMetadataUrl }), + ); + }); + + it("clears stale resource_metadata before a fresh connect attempt", async () => { + sessionStorage.setItem( + resourceMetadataKey, + "http://localhost:8080/stale/.well-known/oauth-protected-resource", + ); + mockClient.connect.mockRejectedValueOnce( + new McpError(MCP_PROXY_TRANSPORT_ERROR_CODE, "proxy transport", { + upstream401: { body: "{}", contentType: "application/json" }, + }), + ); + + const { result } = renderHook(() => + useConnection({ + ...defaultProps, + sseUrl: serverUrl, + }), + ); + + await act(async () => { + await result.current.connect(); + }); + + expect(sessionStorage.getItem(resourceMetadataKey)).toBeNull(); + expect(mockAuth).toHaveBeenCalledWith( + expect.any(Object), + expect.not.objectContaining({ + resourceMetadataUrl: expect.anything(), + }), + ); + }); + }); + describe("Connection URL Verification", () => { beforeEach(() => { jest.clearAllMocks(); diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index 9694b891f..2d5b7af94 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -77,6 +77,13 @@ import { InspectorConfig } from "../configurationTypes"; import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; import { CustomHeaders } from "../types/customHeaders"; import { resolveRefsInMessage } from "@/utils/schemaUtils"; +import { + clearResourceMetadataUrlFromSessionStorage, + extractResourceMetadataUrlFromAuthError, + extractResourceMetadataUrlFromWWWAuthenticate, + getResourceMetadataUrlFromSessionStorage, + saveResourceMetadataUrlToSessionStorage, +} from "../oauth-resource-metadata"; interface UseConnectionOptions { transportType: "stdio" | "sse" | "streamable-http"; @@ -392,6 +399,12 @@ export function useConnection({ let scope = oauthScope?.trim(); const fetchFn = connectionType === "proxy" ? createProxyFetch(config) : undefined; + const resourceMetadataUrl = + extractResourceMetadataUrlFromAuthError(error) ?? + getResourceMetadataUrlFromSessionStorage(sseUrl); + if (resourceMetadataUrl) { + saveResourceMetadataUrlToSessionStorage(sseUrl, resourceMetadataUrl); + } if (!scope) { // Only discover resource metadata when we need to discover scopes @@ -399,7 +412,7 @@ export function useConnection({ try { resourceMetadata = await discoverOAuthProtectedResourceMetadata( new URL("/", sseUrl), - {}, + resourceMetadataUrl ? { resourceMetadataUrl } : {}, fetchFn, ); } catch { @@ -415,6 +428,7 @@ export function useConnection({ const result = await auth(serverAuthProvider, { serverUrl: sseUrl, scope, + ...(resourceMetadataUrl && { resourceMetadataUrl }), ...(fetchFn && { fetchFn }), }); return result === "AUTHORIZED"; @@ -436,15 +450,28 @@ export function useConnection({ const captureResponseHeaders = (response: Response): void => { const sessionId = response.headers.get("mcp-session-id"); const protocolVersion = response.headers.get("mcp-protocol-version"); + const resourceMetadataUrl = + response.status === 401 || response.status === 403 + ? extractResourceMetadataUrlFromWWWAuthenticate( + response.headers.get("WWW-Authenticate") ?? undefined, + ) + : undefined; if (sessionId && sessionId !== mcpSessionId) { setMcpSessionId(sessionId); } if (protocolVersion && protocolVersion !== mcpProtocolVersion) { setMcpProtocolVersion(protocolVersion); } + if (resourceMetadataUrl) { + saveResourceMetadataUrlToSessionStorage(sseUrl, resourceMetadataUrl); + } }; const connect = async (_e?: unknown, retryCount: number = 0) => { + if (retryCount === 0) { + clearResourceMetadataUrlFromSessionStorage(sseUrl); + } + const clientCapabilities = { capabilities: { sampling: {}, @@ -663,16 +690,18 @@ export function useConnection({ ); } transportOptions = { - authProvider: serverAuthProvider, eventSourceInit: { - fetch: ( + fetch: async ( url: string | URL | globalThis.Request, init?: RequestInit, - ) => - fetch(url, { + ) => { + const response = await fetch(url, { ...init, headers: { ...headers, ...proxyHeaders }, - }), + }); + captureResponseHeaders(response); + return response; + }, }, requestInit: { headers: { ...headers, ...proxyHeaders }, @@ -694,16 +723,18 @@ export function useConnection({ ); } transportOptions = { - authProvider: serverAuthProvider, eventSourceInit: { - fetch: ( + fetch: async ( url: string | URL | globalThis.Request, init?: RequestInit, - ) => - fetch(url, { + ) => { + const response = await fetch(url, { ...init, headers: { ...headers, ...proxyHeaders }, - }), + }); + captureResponseHeaders(response); + return response; + }, }, requestInit: { headers: { ...headers, ...proxyHeaders }, @@ -716,16 +747,25 @@ export function useConnection({ mcpProxyServerUrl = new URL(`${getMCPProxyAddress(config)}/mcp`); mcpProxyServerUrl.searchParams.append("url", sseUrl); transportOptions = { - authProvider: serverAuthProvider, - eventSourceInit: { - fetch: ( - url: string | URL | globalThis.Request, - init?: RequestInit, - ) => - fetch(url, { - ...init, - headers: { ...headers, ...proxyHeaders }, - }), + fetch: async ( + url: string | URL | globalThis.Request, + init?: RequestInit, + ) => { + const requestHeaders = new Headers(init?.headers); + Object.entries(headers).forEach(([headerName, headerValue]) => { + requestHeaders.set(headerName, headerValue); + }); + Object.entries(proxyHeaders).forEach( + ([headerName, headerValue]) => { + requestHeaders.set(headerName, String(headerValue)); + }, + ); + const response = await fetch(url, { + ...init, + headers: requestHeaders, + }); + captureResponseHeaders(response); + return response; }, requestInit: { headers: { ...headers, ...proxyHeaders }, diff --git a/client/src/lib/oauth-resource-metadata.ts b/client/src/lib/oauth-resource-metadata.ts new file mode 100644 index 000000000..59b156113 --- /dev/null +++ b/client/src/lib/oauth-resource-metadata.ts @@ -0,0 +1,109 @@ +import { extractWWWAuthenticateParams } from "@modelcontextprotocol/sdk/client/auth.js"; +import { getServerSpecificKey, SESSION_KEYS } from "./constants"; + +function parseResourceMetadataUrl(value: string | null): URL | undefined { + if (!value) { + return undefined; + } + + try { + return new URL(value); + } catch { + return undefined; + } +} + +export function extractResourceMetadataUrlFromWWWAuthenticate( + wwwAuthenticate: string | undefined, +): URL | undefined { + if (!wwwAuthenticate) { + return undefined; + } + + const response = new Response(null, { + headers: { "WWW-Authenticate": wwwAuthenticate }, + }); + return extractWWWAuthenticateParams(response).resourceMetadataUrl; +} + +export function extractResourceMetadataUrlFromAuthError( + error: unknown, +): URL | undefined { + const data = (error as { data?: unknown })?.data; + if (typeof data !== "object" || data === null || Array.isArray(data)) { + return undefined; + } + + const upstream401 = (data as { upstream401?: unknown }).upstream401; + if ( + typeof upstream401 !== "object" || + upstream401 === null || + Array.isArray(upstream401) + ) { + return undefined; + } + + const wwwAuthenticate = (upstream401 as { wwwAuthenticate?: unknown }) + .wwwAuthenticate; + return typeof wwwAuthenticate === "string" + ? extractResourceMetadataUrlFromWWWAuthenticate(wwwAuthenticate) + : undefined; +} + +export async function discoverResourceMetadataUrlFromServer( + serverUrl: string, + fetchFn?: typeof fetch, +): Promise { + const effectiveFetch = fetchFn ?? globalThis.fetch; + if (!effectiveFetch) { + return undefined; + } + + try { + const response = await effectiveFetch(serverUrl, { + headers: { Accept: "application/json, text/event-stream" }, + }); + const resourceMetadataUrl = + response.status === 401 || response.status === 403 + ? extractResourceMetadataUrlFromWWWAuthenticate( + response.headers.get("WWW-Authenticate") ?? undefined, + ) + : undefined; + try { + await response.body?.cancel(); + } catch { + // Best-effort cleanup must not discard an already discovered URL. + } + return resourceMetadataUrl; + } catch { + return undefined; + } +} + +export function saveResourceMetadataUrlToSessionStorage( + serverUrl: string, + resourceMetadataUrl: URL, +): void { + sessionStorage.setItem( + getServerSpecificKey(SESSION_KEYS.RESOURCE_METADATA_URL, serverUrl), + resourceMetadataUrl.toString(), + ); +} + +export function getResourceMetadataUrlFromSessionStorage( + serverUrl: string, +): URL | undefined { + return parseResourceMetadataUrl( + sessionStorage.getItem( + getServerSpecificKey(SESSION_KEYS.RESOURCE_METADATA_URL, serverUrl), + ), + ); +} + +export function clearResourceMetadataUrlFromSessionStorage( + serverUrl: string, +): void { + sessionStorage.removeItem( + getServerSpecificKey(SESSION_KEYS.RESOURCE_METADATA_URL, serverUrl), + ); +} diff --git a/client/src/lib/oauth-state-machine.ts b/client/src/lib/oauth-state-machine.ts index 6628b9ad5..3916e1429 100644 --- a/client/src/lib/oauth-state-machine.ts +++ b/client/src/lib/oauth-state-machine.ts @@ -13,6 +13,11 @@ import { OAuthProtectedResourceMetadata, } from "@modelcontextprotocol/sdk/shared/auth.js"; import { generateOAuthState } from "@/utils/oauthUtils"; +import { + clearResourceMetadataUrlFromSessionStorage, + discoverResourceMetadataUrlFromServer, + saveResourceMetadataUrlToSessionStorage, +} from "./oauth-resource-metadata"; export interface StateMachineContext { state: AuthDebuggerState; @@ -36,10 +41,16 @@ export const oauthTransitions: Record = { let authServerUrl = new URL("/", context.serverUrl); let resourceMetadata: OAuthProtectedResourceMetadata | null = null; let resourceMetadataError: Error | null = null; + clearResourceMetadataUrlFromSessionStorage(context.serverUrl); + const resourceMetadataUrl = + (await discoverResourceMetadataUrlFromServer( + context.serverUrl, + context.fetchFn, + )) ?? null; try { resourceMetadata = await discoverOAuthProtectedResourceMetadata( context.serverUrl, - {}, + resourceMetadataUrl ? { resourceMetadataUrl } : {}, context.fetchFn, ); if (resourceMetadata?.authorization_servers?.length) { @@ -69,6 +80,12 @@ export const oauthTransitions: Record = { } const parsedMetadata = await OAuthMetadataSchema.parseAsync(metadata); context.provider.saveServerMetadata(parsedMetadata); + if (resourceMetadataUrl) { + saveResourceMetadataUrlToSessionStorage( + context.serverUrl, + resourceMetadataUrl, + ); + } context.updateState({ resourceMetadata, resource, diff --git a/server/src/index.ts b/server/src/index.ts index bdfe49019..1f9fc3f71 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -425,6 +425,7 @@ const createCustomFetch = (headerHolder: ProxyHeaderHolder) => { const createTransport = async ( req: express.Request, + onHeaderHolder?: (headerHolder: ProxyHeaderHolder) => void, ): Promise<{ transport: Transport; headerHolder?: ProxyHeaderHolder; @@ -472,6 +473,7 @@ const createTransport = async ( headers: headerHolder.headers, }, }); + onHeaderHolder?.(headerHolder); await transport.start(); return { transport, headerHolder }; } else if (transportType === "streamable-http") { @@ -486,6 +488,7 @@ const createTransport = async ( fetch: createCustomFetch(headerHolder), }, ); + onHeaderHolder?.(headerHolder); await transport.start(); return { transport, headerHolder }; } else { @@ -565,8 +568,9 @@ app.post( let streamableHeaderHolder: ProxyHeaderHolder | undefined; try { const { transport: serverTransport, headerHolder } = - await createTransport(req); - streamableHeaderHolder = headerHolder; + await createTransport(req, (holder) => { + streamableHeaderHolder = holder; + }); const webAppTransport = new StreamableHTTPServerTransport({ sessionIdGenerator: randomUUID, @@ -765,8 +769,9 @@ app.get( "New SSE connection request. NOTE: The SSE transport is deprecated and has been replaced by StreamableHttp", ); const { transport: serverTransport, headerHolder } = - await createTransport(req); - sseHeaderHolder = headerHolder; + await createTransport(req, (holder) => { + sseHeaderHolder = holder; + }); const proxyFullAddress = (req.query.proxyFullAddress as string) || ""; const prefix = proxyFullAddress || "";