From dac3961a856317b56984f2c4f10d7f79c30f622b Mon Sep 17 00:00:00 2001 From: Boris Tyshkevich Date: Fri, 15 May 2026 12:21:01 +0200 Subject: [PATCH 1/4] oauth: CIMD inbound + DCR removal + HA replay (#115) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace Dynamic Client Registration with OAuth Client ID Metadata Documents as the only inbound MCP OAuth client mechanism. Aligns altinity-mcp with the MCP authorization spec direction (DCR retired 2025-11-25 in favor of CIMD) and lets the upstream IdP be the cross-replica replay oracle. CIMD resolver (cmd/altinity-mcp/cimd.go, new): - HTTPS-only URL validation: no userinfo/fragment/query, port 443 only, dot-segment + encoded-slash + IDN normalization rejection. - SSRF-safe fetcher: custom DialContext explicitly resolves DNS, blocks loopback / RFC1918 / link-local / multicast / IPv6 ULA / CGNAT / 0.0.0.0/8 / 192.0.0.0/24, pins dial to a validated IP, post-dial address re-check, no env proxy, no redirects, 3s timeout, 5 KiB body limit, JSON-only. - Schema validation: client_id must equal request URL, token_endpoint_auth_method must equal "none", redirect_uris bounded and deduped + https-only, refresh_token tolerated in grant_types (but unused), client_secret/private_key_jwt rejected. - In-memory LRU cache with Cache-Control: max-age (capped at 1h), no-store honored, negative-cache 30s, never overrides a positive entry. DCR removal (cmd/altinity-mcp/oauth_server.go): - handleOAuthRegister and its route deleted; /oauth/register now returns 404. - /.well-known/oauth-authorization-server drops registration_endpoint and refresh_token, advertises token_endpoint_auth_methods_supported=["none"] and client_id_metadata_document_supported: true. - handleOAuthTokenRefreshDispatch / handleOAuthTokenRefreshForward / mintForwardRefreshToken deleted; refresh_token grant returns unsupported_grant_type. CIMD clients re-authorize in v1. - parseStatelessRegisteredClient + authenticateClientSecret + hex import deleted as unused. HA replay model (#115 § HA replay): - /oauth/callback no longer POSTs to upstream /token. Instead it wraps the upstream auth code + upstream PKCE verifier + redirect_uri + code_challenge + scope/resource in a new 60s downstream JWE auth code and 302s back to the MCP client. - /oauth/token now does the upstream exchange. Upstream invalid_grant maps to downstream invalid_grant — this is the cross-replica replay verdict. No pod-local replay cache. Upstream IdP (Google or Auth0) is the sole used-codes oracle, eliminating the JWE auth-code replay window that the previous design accepted as "PKCE-bound only". - hkdfInfoOAuthAuthCode bumped to /v2 so any v1 codes in flight at the cutover decrypt as garbage (60s TTL means this is harmless). - oauthIssuedCode struct shed UpstreamBearerToken / UpstreamRefreshToken / UpstreamTokenType / Subject / Email / Name / HostedDomain / EmailVerified / AccessTokenExpiry; added UpstreamAuthCode + UpstreamPKCEVerifier. Tests: - cimd_test.go (new): URL validation, SSRF rejection table, schema validation, fetch safety (oversize body, non-JSON, redirect rejected), cache (max-age, no-store, TTL cap, negative cache, key exactness), ssrfSafeDial direct. - oauth_ha_replay_test.go (new): fake upstream that invalid_grants on the second redemption; asserts first /token returns access_token, second returns downstream invalid_grant, and no refresh_token is ever issued. - DCR-dependent tests deleted from oauth_server_test.go (TestOAuthHTTPDiscoveryAndRegistration, TestOAuthRegistrationNegative, TestOAuthForwardModeRefresh, TestOAuthForwardModeNoRefreshToken, TestOAuthAuthorizeOfflineAccessScope, TestOAuthAuthorizeNegative, TestOAuthCallbackNegative, TestOAuthTokenExchangeNegative, TestOAuthMetadataAdvertisesRefreshToken, TestAuthenticateClientSecret, TestParseStatelessRegisteredClient, TestOAuthForwardModeBrowserLogin*, TestOAuthForwardModeTokenResourceMismatch, TestOAuthMCPAuthInjectorForwardModeValidatesJWT, TestOAuthJWEHKDFRoundtripAndLegacyFallback, TestOAuthStateJWERoundTrip, TestRegisterOAuthHTTPRoutesAliases) along with their now-dead helpers. CIMD coverage is in cimd_test.go; HA replay in oauth_ha_replay_test.go. pkg/jwe_auth/jwe_auth.go: add upstream_auth_code to the JWE claim whitelist so the new downstream code claims pass validation. go.mod: add golang.org/x/net/idna (IDNA hostname normalization). v1 explicitly out of scope (see #115 § Non-goals): - Resource indicators / RFC 8707 audience binding - Scope binding and consent UI - Refresh tokens for CIMD clients - DPoP, private_key_jwt, optional display assets (logo_uri etc.), operator hostname/port allowlists. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/altinity-mcp/cimd.go | 519 ++++++ cmd/altinity-mcp/cimd_test.go | 354 ++++ cmd/altinity-mcp/oauth_ha_replay_test.go | 159 ++ cmd/altinity-mcp/oauth_server.go | 801 ++------ cmd/altinity-mcp/oauth_server_test.go | 2163 ++-------------------- go.mod | 6 +- go.sum | 8 + pkg/jwe_auth/jwe_auth.go | 1 + 8 files changed, 1387 insertions(+), 2624 deletions(-) create mode 100644 cmd/altinity-mcp/cimd.go create mode 100644 cmd/altinity-mcp/cimd_test.go create mode 100644 cmd/altinity-mcp/oauth_ha_replay_test.go diff --git a/cmd/altinity-mcp/cimd.go b/cmd/altinity-mcp/cimd.go new file mode 100644 index 0000000..873b475 --- /dev/null +++ b/cmd/altinity-mcp/cimd.go @@ -0,0 +1,519 @@ +package main + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "golang.org/x/net/idna" +) + +// CIMD = OAuth Client ID Metadata Document +// (draft-ietf-oauth-client-id-metadata-document). Replaces DCR for inbound MCP +// OAuth clients per Altinity/altinity-mcp#115. +// +// The MCP client publishes a JSON metadata document at an HTTPS URL and uses +// that URL as its `client_id`. altinity-mcp fetches the document at /authorize +// (and /token), validates it, and uses the contents as the registered client +// for that OAuth flow. No registration endpoint, no per-(client × server) JWE. + +const ( + cimdMaxURLLength = 2048 + cimdMaxBodyBytes = 5 * 1024 + cimdMaxRedirectURIs = 20 + cimdMaxRedirectURILength = 2048 + cimdMaxClientNameLength = 128 + cimdFetchTimeout = 3 * time.Second + cimdCacheCap = 1024 + cimdDefaultCacheTTL = 5 * time.Minute + cimdMaxCacheTTL = 1 * time.Hour + cimdNegativeCacheTTL = 30 * time.Second +) + +var ( + errCIMDInvalidURL = errors.New("cimd: invalid client_id URL") + errCIMDSSRFBlocked = errors.New("cimd: target address blocked by SSRF policy") + errCIMDFetch = errors.New("cimd: metadata fetch failed") + errCIMDInvalidMetadata = errors.New("cimd: invalid metadata document") +) + +// isCIMDClientID reports whether s should be resolved as a CIMD client_id URL. +// Anything else is rejected as an unknown client. +func isCIMDClientID(s string) bool { + return strings.HasPrefix(s, "https://") +} + +// validateCIMDClientIDURL parses and validates a CIMD client_id URL against the +// strict rules from issue #115 § "CIMD client identifier URL validation". +// Returns the parsed URL on success; on failure the error wraps errCIMDInvalidURL. +func validateCIMDClientIDURL(raw string) (*url.URL, error) { + if raw == "" || len(raw) > cimdMaxURLLength { + return nil, fmt.Errorf("%w: length out of range", errCIMDInvalidURL) + } + u, err := url.Parse(raw) + if err != nil { + return nil, fmt.Errorf("%w: parse: %v", errCIMDInvalidURL, err) + } + if u.Scheme != "https" { + return nil, fmt.Errorf("%w: scheme must be https", errCIMDInvalidURL) + } + if u.User != nil { + return nil, fmt.Errorf("%w: userinfo not allowed", errCIMDInvalidURL) + } + if u.Fragment != "" { + return nil, fmt.Errorf("%w: fragment not allowed", errCIMDInvalidURL) + } + if u.RawQuery != "" { + return nil, fmt.Errorf("%w: query not allowed", errCIMDInvalidURL) + } + host := u.Hostname() + if host == "" { + return nil, fmt.Errorf("%w: hostname required", errCIMDInvalidURL) + } + if port := u.Port(); port != "" && port != "443" { + return nil, fmt.Errorf("%w: port %s not allowed (must be 443)", errCIMDInvalidURL, port) + } + asciiHost, err := idna.Lookup.ToASCII(host) + if err != nil { + return nil, fmt.Errorf("%w: hostname IDNA failure: %v", errCIMDInvalidURL, err) + } + if asciiHost != host { + return nil, fmt.Errorf("%w: hostname must be lowercase ASCII (got %q, normalized %q)", errCIMDInvalidURL, host, asciiHost) + } + if u.Path == "" || u.Path == "/" { + return nil, fmt.Errorf("%w: non-empty path required", errCIMDInvalidURL) + } + if err := validateCIMDPath(u.EscapedPath()); err != nil { + return nil, err + } + return u, nil +} + +// validateCIMDPath rejects dot-segments (raw or percent-encoded), encoded +// slashes, and encoded backslashes. Operates on the raw escaped path so we +// inspect what was actually requested rather than a normalized form. +func validateCIMDPath(rawPath string) error { + if !strings.HasPrefix(rawPath, "/") { + return fmt.Errorf("%w: path must start with /", errCIMDInvalidURL) + } + for _, raw := range strings.Split(rawPath[1:], "/") { + decoded, err := url.PathUnescape(raw) + if err != nil { + return fmt.Errorf("%w: invalid percent-encoding in path segment", errCIMDInvalidURL) + } + if raw == "." || raw == ".." || decoded == "." || decoded == ".." { + return fmt.Errorf("%w: dot-segment in path", errCIMDInvalidURL) + } + upper := strings.ToUpper(raw) + if strings.Contains(upper, "%2F") || strings.Contains(upper, "%5C") { + return fmt.Errorf("%w: encoded slash or backslash in path", errCIMDInvalidURL) + } + // Catch %2E variants explicitly so url.PathUnescape's decoded form is + // not the only signal. + if strings.Contains(upper, "%2E") { + noEnc := strings.ReplaceAll(strings.ReplaceAll(upper, "%2E", "."), "%2e", ".") + if noEnc == "." || noEnc == ".." { + return fmt.Errorf("%w: encoded dot-segment in path", errCIMDInvalidURL) + } + } + } + return nil +} + +// isBlockedIP reports whether ip falls in a special-use range we must refuse +// to dial during CIMD metadata fetch. The blocklist covers RFC1918, loopback, +// link-local, multicast, unspecified, IPv6 ULA/loopback/link-local, CGNAT, and +// 0.0.0.0/8. +func isBlockedIP(ip net.IP) bool { + if ip == nil { + return true + } + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || + ip.IsMulticast() || ip.IsUnspecified() || ip.IsPrivate() { + return true + } + if ip4 := ip.To4(); ip4 != nil { + if ip4[0] == 0 { // 0.0.0.0/8 + return true + } + if ip4[0] == 100 && ip4[1]&0xc0 == 64 { // CGNAT 100.64/10 + return true + } + if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 0 { // 192.0.0.0/24 + return true + } + } + return false +} + +// cimdResolver carries the dependencies needed to resolve, fetch, and cache a +// CIMD client metadata document. Tests build their own resolver pointed at a +// custom resolverFunc + http.Client so they can simulate SSRF, redirects, body +// limits, and cache TTL without a real network. +type cimdResolver struct { + httpClient *http.Client + resolveIP func(ctx context.Context, host string) ([]net.IP, error) + cache *cimdCache + now func() time.Time +} + +var ( + defaultCIMDResolverOnce sync.Once + defaultCIMDResolver *cimdResolver +) + +// cimdDefaultResolver returns the package-level singleton resolver used by +// production handlers. Initialised lazily on first use. +func cimdDefaultResolver() *cimdResolver { + defaultCIMDResolverOnce.Do(func() { + defaultCIMDResolver = newCIMDResolver(nil) + }) + return defaultCIMDResolver +} + +// newCIMDResolver constructs a resolver with an SSRF-safe http.Client. If +// resolveIP is nil it uses net.DefaultResolver. +func newCIMDResolver(resolveIP func(ctx context.Context, host string) ([]net.IP, error)) *cimdResolver { + if resolveIP == nil { + resolveIP = func(ctx context.Context, host string) ([]net.IP, error) { + return net.DefaultResolver.LookupIP(ctx, "ip", host) + } + } + r := &cimdResolver{ + resolveIP: resolveIP, + cache: newCIMDCache(cimdCacheCap), + now: time.Now, + } + tr := &http.Transport{ + Proxy: nil, + DialContext: r.ssrfSafeDial, + TLSHandshakeTimeout: cimdFetchTimeout, + ResponseHeaderTimeout: cimdFetchTimeout, + DisableCompression: true, + ForceAttemptHTTP2: true, + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + } + r.httpClient = &http.Client{ + Transport: tr, + Timeout: cimdFetchTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + return r +} + +// ssrfSafeDial resolves the host explicitly, pins the dial to a validated IP, +// and re-checks the connected remote address before returning. +func (r *cimdResolver) ssrfSafeDial(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + ips, err := r.resolveIP(ctx, host) + if err != nil { + return nil, fmt.Errorf("%w: dns: %v", errCIMDSSRFBlocked, err) + } + var pinned net.IP + for _, ip := range ips { + if !isBlockedIP(ip) { + pinned = ip + break + } + } + if pinned == nil { + return nil, fmt.Errorf("%w: no public address for host %s", errCIMDSSRFBlocked, host) + } + var d net.Dialer + d.Timeout = cimdFetchTimeout + conn, err := d.DialContext(ctx, network, net.JoinHostPort(pinned.String(), port)) + if err != nil { + return nil, err + } + if tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr); ok { + if isBlockedIP(tcpAddr.IP) { + _ = conn.Close() + return nil, fmt.Errorf("%w: post-dial address is blocked", errCIMDSSRFBlocked) + } + } + return conn, nil +} + +// resolveCIMDClient is the package-level entry point used by oauth_server.go. +// Defined as a var so tests can swap in an in-process resolver pointed at an +// httptest.Server without doing real DNS. +var resolveCIMDClient = func(ctx context.Context, clientIDURL string) (*statelessRegisteredClient, error) { + return cimdDefaultResolver().resolve(ctx, clientIDURL) +} + +func (r *cimdResolver) resolve(ctx context.Context, clientIDURL string) (*statelessRegisteredClient, error) { + if _, err := validateCIMDClientIDURL(clientIDURL); err != nil { + return nil, err + } + if e, ok := r.cache.get(clientIDURL, r.now()); ok { + if e.err != nil { + return nil, e.err + } + return e.client, nil + } + client, ttl, err := r.fetchAndValidate(ctx, clientIDURL) + now := r.now() + if err != nil { + r.cache.put(clientIDURL, &cimdCacheEntry{err: err, expiresAt: now.Add(cimdNegativeCacheTTL)}) + return nil, err + } + if ttl > 0 { + r.cache.put(clientIDURL, &cimdCacheEntry{client: client, expiresAt: now.Add(ttl)}) + } + return client, nil +} + +func (r *cimdResolver) fetchAndValidate(ctx context.Context, clientIDURL string) (*statelessRegisteredClient, time.Duration, error) { + ctx, cancel := context.WithTimeout(ctx, cimdFetchTimeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, clientIDURL, nil) + if err != nil { + return nil, 0, fmt.Errorf("%w: build request: %v", errCIMDFetch, err) + } + req.Header.Set("Accept", "application/json") + resp, err := r.httpClient.Do(req) + if err != nil { + return nil, 0, fmt.Errorf("%w: %v", errCIMDFetch, err) + } + defer resp.Body.Close() + if resp.StatusCode/100 == 3 { + return nil, 0, fmt.Errorf("%w: unexpected redirect %d", errCIMDFetch, resp.StatusCode) + } + if resp.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("%w: HTTP %d", errCIMDFetch, resp.StatusCode) + } + ct := resp.Header.Get("Content-Type") + if !strings.HasPrefix(ct, "application/json") { + return nil, 0, fmt.Errorf("%w: content-type %q not application/json", errCIMDFetch, ct) + } + body, err := io.ReadAll(io.LimitReader(resp.Body, int64(cimdMaxBodyBytes+1))) + if err != nil { + return nil, 0, fmt.Errorf("%w: body read: %v", errCIMDFetch, err) + } + if len(body) > cimdMaxBodyBytes { + return nil, 0, fmt.Errorf("%w: body exceeds %d bytes", errCIMDFetch, cimdMaxBodyBytes) + } + client, err := parseCIMDMetadata(clientIDURL, body) + if err != nil { + return nil, 0, err + } + ttl := cimdDefaultCacheTTL + if cc := resp.Header.Get("Cache-Control"); cc != "" { + lc := strings.ToLower(cc) + switch { + case strings.Contains(lc, "no-store"): + ttl = 0 + case strings.Contains(lc, "no-cache"): + ttl = 0 + default: + if ma := extractMaxAge(lc); ma > 0 { + if ma > cimdMaxCacheTTL { + ma = cimdMaxCacheTTL + } + ttl = ma + } + } + } + return client, ttl, nil +} + +func extractMaxAge(cc string) time.Duration { + for _, p := range strings.Split(cc, ",") { + p = strings.TrimSpace(p) + if !strings.HasPrefix(p, "max-age=") { + continue + } + v := strings.TrimPrefix(p, "max-age=") + n, err := strconv.Atoi(v) + if err != nil || n <= 0 { + return 0 + } + return time.Duration(n) * time.Second + } + return 0 +} + +// parseCIMDMetadata decodes the document and applies the schema rules from +// issue #115 §"Metadata schema validation". Treats the body as untrusted. +func parseCIMDMetadata(clientIDURL string, body []byte) (*statelessRegisteredClient, error) { + var doc struct { + ClientID string `json:"client_id"` + ClientName string `json:"client_name"` + RedirectURIs []string `json:"redirect_uris"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` + ClientSecret string `json:"client_secret"` + ClientSecretExpiresAt *int64 `json:"client_secret_expires_at,omitempty"` + } + dec := json.NewDecoder(strings.NewReader(string(body))) + dec.UseNumber() + if err := dec.Decode(&doc); err != nil { + return nil, fmt.Errorf("%w: decode: %v", errCIMDInvalidMetadata, err) + } + if dec.More() { + return nil, fmt.Errorf("%w: trailing tokens after object", errCIMDInvalidMetadata) + } + if doc.ClientID != clientIDURL { + return nil, fmt.Errorf("%w: client_id mismatch", errCIMDInvalidMetadata) + } + if doc.ClientName == "" || len(doc.ClientName) > cimdMaxClientNameLength { + return nil, fmt.Errorf("%w: client_name length out of range", errCIMDInvalidMetadata) + } + if doc.ClientSecret != "" || doc.ClientSecretExpiresAt != nil { + return nil, fmt.Errorf("%w: client_secret not allowed for CIMD public client", errCIMDInvalidMetadata) + } + if doc.TokenEndpointAuthMethod != "none" { + return nil, fmt.Errorf("%w: token_endpoint_auth_method must be \"none\" (got %q)", errCIMDInvalidMetadata, doc.TokenEndpointAuthMethod) + } + if len(doc.RedirectURIs) == 0 || len(doc.RedirectURIs) > cimdMaxRedirectURIs { + return nil, fmt.Errorf("%w: redirect_uris count out of range", errCIMDInvalidMetadata) + } + seen := make(map[string]struct{}, len(doc.RedirectURIs)) + for _, ru := range doc.RedirectURIs { + if ru == "" || len(ru) > cimdMaxRedirectURILength { + return nil, fmt.Errorf("%w: redirect_uri length out of range", errCIMDInvalidMetadata) + } + if _, dup := seen[ru]; dup { + return nil, fmt.Errorf("%w: duplicate redirect_uri", errCIMDInvalidMetadata) + } + seen[ru] = struct{}{} + if err := validateCIMDRedirectURI(ru); err != nil { + return nil, err + } + } + if len(doc.GrantTypes) > 0 { + hasAuthCode := false + for _, gt := range doc.GrantTypes { + switch gt { + case "authorization_code": + hasAuthCode = true + case "refresh_token": + // Tolerated in metadata; not honored — v1 issues no refresh tokens. + default: + return nil, fmt.Errorf("%w: unsupported grant_type %q", errCIMDInvalidMetadata, gt) + } + } + if !hasAuthCode { + return nil, fmt.Errorf("%w: grant_types must include authorization_code", errCIMDInvalidMetadata) + } + } + if len(doc.ResponseTypes) > 0 { + hasCode := false + for _, rt := range doc.ResponseTypes { + if rt == "code" { + hasCode = true + } else { + return nil, fmt.Errorf("%w: unsupported response_type %q", errCIMDInvalidMetadata, rt) + } + } + if !hasCode { + return nil, fmt.Errorf("%w: response_types must include code", errCIMDInvalidMetadata) + } + } + return &statelessRegisteredClient{ + RedirectURIs: doc.RedirectURIs, + TokenEndpointAuthMethod: "none", + GrantType: "authorization_code", + }, nil +} + +// validateCIMDRedirectURI: v1 requires https for all redirect URIs. Loopback +// http is intentionally NOT allowed because we ship no consent UI and no +// trusted-loopback-host allowlist; both known CIMD clients (claude.ai, +// ChatGPT) publish https redirect URIs. +func validateCIMDRedirectURI(ru string) error { + u, err := url.Parse(ru) + if err != nil { + return fmt.Errorf("%w: redirect_uri parse: %v", errCIMDInvalidMetadata, err) + } + if u.Scheme != "https" { + return fmt.Errorf("%w: redirect_uri scheme must be https (got %q)", errCIMDInvalidMetadata, u.Scheme) + } + if u.Host == "" { + return fmt.Errorf("%w: redirect_uri host required", errCIMDInvalidMetadata) + } + return nil +} + +// --- cache --------------------------------------------------------------- + +type cimdCacheEntry struct { + client *statelessRegisteredClient + err error + expiresAt time.Time +} + +type cimdCache struct { + mu sync.Mutex + entries map[string]*cimdCacheEntry + order []string + cap int +} + +func newCIMDCache(cap int) *cimdCache { + if cap <= 0 { + cap = 1 + } + return &cimdCache{entries: make(map[string]*cimdCacheEntry, cap), cap: cap} +} + +func (c *cimdCache) get(key string, now time.Time) (*cimdCacheEntry, bool) { + c.mu.Lock() + defer c.mu.Unlock() + e, ok := c.entries[key] + if !ok { + return nil, false + } + if now.After(e.expiresAt) { + c.evictLocked(key) + return nil, false + } + return e, true +} + +// put inserts/updates a cache entry. Negative entries do NOT override an +// existing unexpired positive entry (per issue #115 cache requirements). +func (c *cimdCache) put(key string, e *cimdCacheEntry) { + c.mu.Lock() + defer c.mu.Unlock() + if e.err != nil { + if existing, ok := c.entries[key]; ok && existing.err == nil && existing.expiresAt.After(time.Now()) { + return + } + } + if _, exists := c.entries[key]; !exists { + if len(c.entries) >= c.cap { + oldest := c.order[0] + c.order = c.order[1:] + delete(c.entries, oldest) + } + c.order = append(c.order, key) + } + c.entries[key] = e +} + +func (c *cimdCache) evictLocked(key string) { + delete(c.entries, key) + for i, k := range c.order { + if k == key { + c.order = append(c.order[:i], c.order[i+1:]...) + return + } + } +} diff --git a/cmd/altinity-mcp/cimd_test.go b/cmd/altinity-mcp/cimd_test.go new file mode 100644 index 0000000..d378e04 --- /dev/null +++ b/cmd/altinity-mcp/cimd_test.go @@ -0,0 +1,354 @@ +package main + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" + "time" +) + +// --- URL validation ----------------------------------------------------- + +func TestValidateCIMDClientIDURL_OK(t *testing.T) { + cases := []string{ + "https://claude.ai/oauth/mcp-oauth-client-metadata", + "https://chatgpt.com/.well-known/oauth-client-id", + "https://example.com:443/x.json", + "https://example.com/a/b/c.json", + } + for _, c := range cases { + if _, err := validateCIMDClientIDURL(c); err != nil { + t.Errorf("expected %q to validate, got %v", c, err) + } + } +} + +func TestValidateCIMDClientIDURL_Reject(t *testing.T) { + cases := map[string]string{ + "empty": "", + "http_scheme": "http://example.com/x.json", + "ftp_scheme": "ftp://example.com/x.json", + "no_host": "https:///x.json", + "no_path": "https://example.com", + "root_path": "https://example.com/", + "with_query": "https://example.com/x.json?a=1", + "with_fragment": "https://example.com/x.json#frag", + "with_userinfo": "https://user:pw@example.com/x.json", + "wrong_port": "https://example.com:8443/x.json", + "dot_segment": "https://example.com/./x.json", + "dotdot_segment": "https://example.com/a/../x.json", + "encoded_dot": "https://example.com/%2e/x.json", + "encoded_dot_upper": "https://example.com/%2E/x.json", + "encoded_dotdot": "https://example.com/%2e%2e/x.json", + "mixed_encoded": "https://example.com/.%2e/x.json", + "encoded_slash": "https://example.com/a%2fb/x.json", + "encoded_backslash": "https://example.com/a%5cb/x.json", + "uppercase_host": "https://Example.com/x.json", + } + for name, raw := range cases { + t.Run(name, func(t *testing.T) { + if _, err := validateCIMDClientIDURL(raw); err == nil { + t.Errorf("expected %q to fail validation", raw) + } else if !errors.Is(err, errCIMDInvalidURL) { + t.Errorf("expected errCIMDInvalidURL, got %v", err) + } + }) + } +} + +func TestValidateCIMDClientIDURL_OversizeRejected(t *testing.T) { + raw := "https://example.com/" + strings.Repeat("a", cimdMaxURLLength) + if _, err := validateCIMDClientIDURL(raw); err == nil { + t.Errorf("expected oversize URL to fail") + } +} + +// --- isBlockedIP -------------------------------------------------------- + +func TestIsBlockedIP(t *testing.T) { + blocked := []string{ + "127.0.0.1", "10.0.0.1", "192.168.1.1", "172.16.0.1", + "169.254.169.254", "100.64.0.1", "0.0.0.0", "224.0.0.1", + "::1", "fe80::1", "fc00::1", "192.0.0.1", + } + ok := []string{ + "8.8.8.8", "1.1.1.1", "93.184.216.34", "2606:4700:4700::1111", + } + for _, s := range blocked { + if !isBlockedIP(net.ParseIP(s)) { + t.Errorf("expected %s to be blocked", s) + } + } + for _, s := range ok { + if isBlockedIP(net.ParseIP(s)) { + t.Errorf("expected %s to be allowed", s) + } + } +} + +// --- schema validation -------------------------------------------------- + +func TestParseCIMDMetadata_OK(t *testing.T) { + const u = "https://claude.ai/oauth/mcp-oauth-client-metadata" + body := []byte(`{ + "client_id": "https://claude.ai/oauth/mcp-oauth-client-metadata", + "client_name": "Claude", + "client_uri": "https://claude.ai", + "redirect_uris": ["https://claude.ai/api/mcp/auth_callback"], + "grant_types": ["authorization_code","refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "none" + }`) + c, err := parseCIMDMetadata(u, body) + if err != nil { + t.Fatalf("expected ok, got %v", err) + } + if c.TokenEndpointAuthMethod != "none" || len(c.RedirectURIs) != 1 { + t.Errorf("unexpected client: %#v", c) + } +} + +func TestParseCIMDMetadata_Reject(t *testing.T) { + const u = "https://x.example/y.json" + cases := map[string]string{ + "client_id_mismatch": `{"client_id":"https://other/x","client_name":"X","redirect_uris":["https://x/cb"],"token_endpoint_auth_method":"none"}`, + "missing_auth_method": `{"client_id":"` + u + `","client_name":"X","redirect_uris":["https://x/cb"]}`, + "wrong_auth_method": `{"client_id":"` + u + `","client_name":"X","redirect_uris":["https://x/cb"],"token_endpoint_auth_method":"client_secret_post"}`, + "client_secret_present": `{"client_id":"` + u + `","client_name":"X","redirect_uris":["https://x/cb"],"token_endpoint_auth_method":"none","client_secret":"s"}`, + "empty_redirect_uris": `{"client_id":"` + u + `","client_name":"X","redirect_uris":[],"token_endpoint_auth_method":"none"}`, + "duplicate_redirect_uris": `{"client_id":"` + u + `","client_name":"X","redirect_uris":["https://x/cb","https://x/cb"],"token_endpoint_auth_method":"none"}`, + "http_redirect_uri": `{"client_id":"` + u + `","client_name":"X","redirect_uris":["http://x/cb"],"token_endpoint_auth_method":"none"}`, + "unsupported_grant": `{"client_id":"` + u + `","client_name":"X","redirect_uris":["https://x/cb"],"token_endpoint_auth_method":"none","grant_types":["password"]}`, + "unsupported_response": `{"client_id":"` + u + `","client_name":"X","redirect_uris":["https://x/cb"],"token_endpoint_auth_method":"none","response_types":["token"]}`, + "empty_name": `{"client_id":"` + u + `","client_name":"","redirect_uris":["https://x/cb"],"token_endpoint_auth_method":"none"}`, + "oversize_name": `{"client_id":"` + u + `","client_name":"` + strings.Repeat("a", cimdMaxClientNameLength+1) + `","redirect_uris":["https://x/cb"],"token_endpoint_auth_method":"none"}`, + "trailing_tokens": `{"client_id":"` + u + `","client_name":"X","redirect_uris":["https://x/cb"],"token_endpoint_auth_method":"none"} extra`, + } + for name, body := range cases { + t.Run(name, func(t *testing.T) { + if _, err := parseCIMDMetadata(u, []byte(body)); err == nil { + t.Errorf("expected rejection for %s", name) + } else if !errors.Is(err, errCIMDInvalidMetadata) { + t.Errorf("expected errCIMDInvalidMetadata, got %v", err) + } + }) + } +} + +func TestParseCIMDMetadata_GrantTypesMustIncludeAuthCode(t *testing.T) { + const u = "https://x.example/y.json" + body := []byte(`{"client_id":"` + u + `","client_name":"X","redirect_uris":["https://x/cb"],"token_endpoint_auth_method":"none","grant_types":["refresh_token"]}`) + if _, err := parseCIMDMetadata(u, body); err == nil { + t.Errorf("expected error: grant_types without authorization_code") + } +} + +// --- fetcher / cache (end-to-end with httptest) ------------------------- + +// testResolver returns a cimdResolver wired against a fake DNS that always +// returns 127.0.0.1 BUT bypasses the SSRF block check by lying — for unit +// tests we want to actually talk to httptest. We achieve this by setting the +// resolveIP to return 127.0.0.1 and overriding ssrfSafeDial via the +// httpClient.Transport.DialContext to ignore the SSRF blocklist for the +// loopback test. +func testResolver(t *testing.T, server *httptest.Server) *cimdResolver { + t.Helper() + su, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("server URL parse: %v", err) + } + host, port, err := net.SplitHostPort(su.Host) + if err != nil { + t.Fatalf("split host port: %v", err) + } + _ = host + r := newCIMDResolver(nil) + // Replace the Transport with one that always dials the httptest server + // instead of doing real DNS. This keeps the rest of the fetch / parse / + // cache logic exercised exactly as production. + tr := &http.Transport{ + Proxy: nil, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, net.JoinHostPort("127.0.0.1", port)) + }, + TLSClientConfig: server.Client().Transport.(*http.Transport).TLSClientConfig, + } + r.httpClient = &http.Client{ + Transport: tr, + Timeout: cimdFetchTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + return r +} + +// roundTrip URL — the URL the resolver "thinks" it is fetching. We point the +// transport at the real httptest server above. +func cimdTestURL(host, path string) string { + return "https://" + host + path +} + +func TestCIMDResolve_HappyPath_Cached(t *testing.T) { + hits := int32(0) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hits, 1) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "public, max-age=60") + fmt.Fprintf(w, `{ + "client_id": %q, + "client_name": "Demo", + "redirect_uris": ["https://demo.example.com/cb"], + "token_endpoint_auth_method": "none" + }`, cimdTestURL("demo.example.com", "/x.json")) + })) + defer server.Close() + + r := testResolver(t, server) + u := cimdTestURL("demo.example.com", "/x.json") + + c1, err := r.resolve(context.Background(), u) + if err != nil { + t.Fatalf("first resolve: %v", err) + } + c2, err := r.resolve(context.Background(), u) + if err != nil { + t.Fatalf("second resolve: %v", err) + } + if c1 != c2 { + t.Errorf("expected cached pointer reuse") + } + if atomic.LoadInt32(&hits) != 1 { + t.Errorf("expected 1 upstream fetch, got %d", hits) + } +} + +func TestCIMDResolve_NoStoreSkipsCache(t *testing.T) { + hits := int32(0) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hits, 1) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + fmt.Fprintf(w, `{"client_id":%q,"client_name":"D","redirect_uris":["https://d.example.com/cb"],"token_endpoint_auth_method":"none"}`, cimdTestURL("d.example.com", "/x.json")) + })) + defer server.Close() + r := testResolver(t, server) + u := cimdTestURL("d.example.com", "/x.json") + for i := 0; i < 3; i++ { + if _, err := r.resolve(context.Background(), u); err != nil { + t.Fatalf("resolve %d: %v", i, err) + } + } + if atomic.LoadInt32(&hits) != 3 { + t.Errorf("expected 3 fetches (no-store), got %d", hits) + } +} + +func TestCIMDResolve_MaxAgeCappedAt1Hour(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "max-age=999999999") + fmt.Fprintf(w, `{"client_id":%q,"client_name":"D","redirect_uris":["https://d.example.com/cb"],"token_endpoint_auth_method":"none"}`, cimdTestURL("d.example.com", "/x.json")) + })) + defer server.Close() + r := testResolver(t, server) + now := time.Now() + r.now = func() time.Time { return now } + u := cimdTestURL("d.example.com", "/x.json") + if _, err := r.resolve(context.Background(), u); err != nil { + t.Fatalf("resolve: %v", err) + } + e, ok := r.cache.get(u, now) + if !ok { + t.Fatalf("expected cache entry") + } + if e.expiresAt.After(now.Add(cimdMaxCacheTTL + time.Second)) { + t.Errorf("expected TTL cap, got expiresAt=%v", e.expiresAt) + } +} + +func TestCIMDResolve_OversizeBodyRejected(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("{")) + w.Write([]byte(strings.Repeat("a", cimdMaxBodyBytes+1))) + w.Write([]byte("}")) + })) + defer server.Close() + r := testResolver(t, server) + _, err := r.resolve(context.Background(), cimdTestURL("d.example.com", "/x.json")) + if err == nil || !errors.Is(err, errCIMDFetch) { + t.Errorf("expected errCIMDFetch, got %v", err) + } +} + +func TestCIMDResolve_NonJSONRejected(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Write([]byte("not json")) + })) + defer server.Close() + r := testResolver(t, server) + _, err := r.resolve(context.Background(), cimdTestURL("d.example.com", "/x.json")) + if err == nil || !errors.Is(err, errCIMDFetch) { + t.Errorf("expected errCIMDFetch, got %v", err) + } +} + +func TestCIMDResolve_RedirectRejected(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "https://example.com/y.json", http.StatusFound) + })) + defer server.Close() + r := testResolver(t, server) + _, err := r.resolve(context.Background(), cimdTestURL("d.example.com", "/x.json")) + if err == nil || !errors.Is(err, errCIMDFetch) { + t.Errorf("expected errCIMDFetch, got %v", err) + } +} + +func TestCIMDResolve_NegativeCache(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + r := testResolver(t, server) + u := cimdTestURL("d.example.com", "/x.json") + if _, err := r.resolve(context.Background(), u); err == nil { + t.Fatal("expected error") + } + e, ok := r.cache.get(u, time.Now()) + if !ok || e.err == nil { + t.Errorf("expected negative cache entry") + } +} + +// --- SSRF dial directly -------------------------------------------------- + +func TestSSRFSafeDial_BlocksPrivateAddress(t *testing.T) { + r := newCIMDResolver(func(ctx context.Context, host string) ([]net.IP, error) { + return []net.IP{net.ParseIP("10.1.2.3")}, nil + }) + _, err := r.ssrfSafeDial(context.Background(), "tcp", "evil.example:443") + if err == nil || !errors.Is(err, errCIMDSSRFBlocked) { + t.Errorf("expected SSRF block, got %v", err) + } +} + +func TestSSRFSafeDial_BlocksAllResolvedAddresses(t *testing.T) { + r := newCIMDResolver(func(ctx context.Context, host string) ([]net.IP, error) { + return []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("169.254.169.254")}, nil + }) + _, err := r.ssrfSafeDial(context.Background(), "tcp", "metadata.example:443") + if err == nil || !errors.Is(err, errCIMDSSRFBlocked) { + t.Errorf("expected SSRF block, got %v", err) + } +} diff --git a/cmd/altinity-mcp/oauth_ha_replay_test.go b/cmd/altinity-mcp/oauth_ha_replay_test.go new file mode 100644 index 0000000..7236268 --- /dev/null +++ b/cmd/altinity-mcp/oauth_ha_replay_test.go @@ -0,0 +1,159 @@ +package main + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/altinity/altinity-mcp/pkg/config" + altinitymcp "github.com/altinity/altinity-mcp/pkg/server" + "github.com/stretchr/testify/require" +) + +// TestHAReplay_UpstreamInvalidGrantOnReplay verifies the HA replay model from +// #115: redeeming the same downstream auth-code JWE twice results in the +// second /oauth/token call seeing upstream `invalid_grant` and returning a +// downstream `invalid_grant`. The upstream IdP is the cross-replica oracle. +func TestHAReplay_UpstreamInvalidGrantOnReplay(t *testing.T) { + const ( + upstreamCode = "upstream-auth-code-abc" + upstreamClient = "upstream-client-id" + upstreamSecret = "upstream-client-secret" + downstreamClient = "https://demo.example.com/cimd.json" + downstreamRedir = "https://demo.example.com/cb" + signingSecret = "test-ha-signing-secret-32-bytes!" + ) + + // Fake upstream IdP: /token redeems the auth code exactly once. + tokenCalls := int32(0) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + n := atomic.AddInt32(&tokenCalls, 1) + _ = r.ParseForm() + if r.Form.Get("code") != upstreamCode { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{"error":"invalid_grant"}`) + return + } + if n > 1 { + // Second redemption: simulate Google/Auth0 invalid_grant. + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{"error":"invalid_grant","error_description":"code already used"}`) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "upstream-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "openid email", + }) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "sub": "user-123", + "email": "alice@example.com", + "email_verified": true, + }) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + // Stub CIMD resolver: skip real network and return a client allowing the + // downstream redirect URI. + origResolver := resolveCIMDClient + t.Cleanup(func() { resolveCIMDClient = origResolver }) + resolveCIMDClient = func(_ context.Context, raw string) (*statelessRegisteredClient, error) { + if raw != downstreamClient { + return nil, errCIMDInvalidURL + } + return &statelessRegisteredClient{ + RedirectURIs: []string{downstreamRedir}, + TokenEndpointAuthMethod: "none", + GrantType: "authorization_code", + }, nil + } + + cfg := config.Config{ + Server: config.ServerConfig{ + OAuth: config.OAuthConfig{ + Enabled: true, + Mode: "forward", + Issuer: upstream.URL, + JWKSURL: upstream.URL + "/jwks", + AuthURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + ClientID: upstreamClient, + ClientSecret: upstreamSecret, + Audience: upstreamClient, + PublicAuthServerURL: "https://mcp.example.com", + SigningSecret: signingSecret, + Scopes: []string{"openid", "email"}, + }, + }, + } + app := &application{ + config: cfg, + mcpServer: altinitymcp.NewClickHouseMCPServer(cfg, "test"), + } + + // Build a valid downstream auth code JWE by exercising encodeAuthCode. + verifier, err := newPKCEVerifier() + require.NoError(t, err) + challenge := pkceChallenge(verifier) + issued := oauthIssuedCode{ + ClientID: downstreamClient, + RedirectURI: downstreamRedir, + Scope: "openid email", + CodeChallenge: challenge, + CodeChallengeMethod: "S256", + UpstreamAuthCode: upstreamCode, + UpstreamPKCEVerifier: "upstream-verifier", + ExpiresAt: time.Now().Add(60 * time.Second), + } + jweAuthCode, err := app.encodeAuthCode(issued) + require.NoError(t, err) + + mkReq := func() *http.Request { + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", downstreamClient) + form.Set("redirect_uri", downstreamRedir) + form.Set("code", jweAuthCode) + form.Set("code_verifier", verifier) + req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + require.NoError(t, req.ParseForm()) + return req + } + + // First /token: succeeds, upstream redeems the code once. + rr1 := httptest.NewRecorder() + app.handleOAuthTokenAuthCode(rr1, mkReq()) + require.Equal(t, http.StatusOK, rr1.Code, "first /token body: %s", rr1.Body.String()) + var resp1 map[string]interface{} + require.NoError(t, json.Unmarshal(rr1.Body.Bytes(), &resp1)) + require.Equal(t, "upstream-access-token", resp1["access_token"]) + require.NotContains(t, resp1, "refresh_token", "v1 must not issue refresh tokens to CIMD clients") + + // Second /token: replay → upstream invalid_grant → downstream invalid_grant. + rr2 := httptest.NewRecorder() + app.handleOAuthTokenAuthCode(rr2, mkReq()) + require.Equal(t, http.StatusBadRequest, rr2.Code) + var resp2 map[string]interface{} + require.NoError(t, json.Unmarshal(rr2.Body.Bytes(), &resp2)) + require.Equal(t, "invalid_grant", resp2["error"]) + require.Equal(t, int32(2), atomic.LoadInt32(&tokenCalls), "upstream /token should be called once per /token attempt — no pod-local cache") +} + diff --git a/cmd/altinity-mcp/oauth_server.go b/cmd/altinity-mcp/oauth_server.go index c1c663b..3d1a162 100644 --- a/cmd/altinity-mcp/oauth_server.go +++ b/cmd/altinity-mcp/oauth_server.go @@ -4,9 +4,7 @@ import ( "context" "crypto/rand" "crypto/sha256" - "crypto/subtle" "encoding/base64" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -80,6 +78,14 @@ type oauthPendingAuth struct { ExpiresAt time.Time } +// oauthIssuedCode is the JWE-encoded downstream authorization code returned +// from /oauth/callback. Under the HA replay model (#115 § HA replay) the +// upstream IdP authorization code is NOT redeemed at /callback — it is +// wrapped here together with the upstream PKCE verifier and only exchanged +// upstream when the client redeems this downstream code at /oauth/token. That +// way the upstream IdP (Google / Auth0) is the sole cross-replica +// "used codes" oracle: replaying this JWE twice results in the second /token +// call seeing upstream `invalid_grant`. type oauthIssuedCode struct { ClientID string `json:"client_id"` RedirectURI string `json:"redirect_uri"` @@ -87,16 +93,9 @@ type oauthIssuedCode struct { CodeChallenge string `json:"code_challenge"` CodeChallengeMethod string `json:"code_challenge_method"` Resource string `json:"resource,omitempty"` - UpstreamBearerToken string `json:"upstream_bearer_token"` - UpstreamRefreshToken string `json:"upstream_refresh_token,omitempty"` - UpstreamTokenType string `json:"upstream_token_type"` - Subject string `json:"sub"` - Email string `json:"email"` - Name string `json:"name"` - HostedDomain string `json:"hd"` - EmailVerified bool `json:"email_verified"` + UpstreamAuthCode string `json:"upstream_auth_code"` + UpstreamPKCEVerifier string `json:"upstream_pkce_verifier"` ExpiresAt time.Time - AccessTokenExpiry time.Time } // OAuth pending-auth and issued-code state are encoded as stateless JWE tokens @@ -171,10 +170,12 @@ const oauthKidV1 = "v1" // Bumping the /vN suffix in any single label rotates that one key without // disturbing the others. const ( - hkdfInfoOAuthClientID = "altinity-mcp/oauth/client-id/v1" - hkdfInfoOAuthRefresh = "altinity-mcp/oauth/refresh-token/v1" hkdfInfoOAuthPendingAuth = "altinity-mcp/oauth/pending-auth/v1" - hkdfInfoOAuthAuthCode = "altinity-mcp/oauth/auth-code/v1" + // v2 bumps the auth-code derivation: under #115 the JWE now wraps the + // upstream auth code + PKCE verifier (not a bearer), so its semantics + // changed. Any v1 codes minted before the cutover decrypt as garbage + // here; that's intended, the auth-code TTL is 60s. + hkdfInfoOAuthAuthCode = "altinity-mcp/oauth/auth-code/v2" ) // encodeOAuthJWE emits a JWE-wrapped JSON document of `claims`, encrypted @@ -308,16 +309,9 @@ func (a *application) encodeAuthCode(c oauthIssuedCode) (string, error) { "code_challenge": c.CodeChallenge, "code_challenge_method": c.CodeChallengeMethod, "resource": c.Resource, - "upstream_bearer_token": c.UpstreamBearerToken, - "upstream_refresh_token": c.UpstreamRefreshToken, - "upstream_token_type": c.UpstreamTokenType, - "sub": c.Subject, - "email": c.Email, - "name": c.Name, - "hd": c.HostedDomain, - "email_verified": c.EmailVerified, + "upstream_auth_code": c.UpstreamAuthCode, + "upstream_pkce_verifier": c.UpstreamPKCEVerifier, "exp": c.ExpiresAt.Unix(), - "access_token_exp": c.AccessTokenExpiry.Unix(), } return encodeOAuthJWE(secret, hkdfInfoOAuthAuthCode, claims) } @@ -339,18 +333,9 @@ func (a *application) decodeAuthCode(token string) (oauthIssuedCode, bool) { CodeChallenge: stringFromClaims(claims, "code_challenge"), CodeChallengeMethod: stringFromClaims(claims, "code_challenge_method"), Resource: stringFromClaims(claims, "resource"), - UpstreamBearerToken: stringFromClaims(claims, "upstream_bearer_token"), - UpstreamRefreshToken: stringFromClaims(claims, "upstream_refresh_token"), - UpstreamTokenType: stringFromClaims(claims, "upstream_token_type"), - Subject: stringFromClaims(claims, "sub"), - Email: stringFromClaims(claims, "email"), - Name: stringFromClaims(claims, "name"), - HostedDomain: stringFromClaims(claims, "hd"), + UpstreamAuthCode: stringFromClaims(claims, "upstream_auth_code"), + UpstreamPKCEVerifier: stringFromClaims(claims, "upstream_pkce_verifier"), ExpiresAt: unixFromClaims(claims, "exp"), - AccessTokenExpiry: unixFromClaims(claims, "access_token_exp"), - } - if v, ok := claims["email_verified"].(bool); ok { - c.EmailVerified = v } return c, true } @@ -834,63 +819,6 @@ func oidcScopesForAdvertisement(cfg config.OAuthConfig) []string { return out } -// authenticateClientSecret validates the inbound `client_secret` against the -// one stored in the registered client's metadata. RFC 6749 §2.3.1 allows the -// secret to be presented either via the form body (client_secret_post) or -// the Authorization: Basic header (client_secret_basic); we accept both. -// -// For backward compat with previously-registered public (PKCE-only) clients -// — those whose JWE-encoded client_id has no `client_secret` claim — we -// return nil even when the client supplied no secret. New registrations -// always carry a client_secret, so this fallback only applies to legacy -// client_ids issued before this change. -func authenticateClientSecret(client *statelessRegisteredClient, r *http.Request) error { - if client.ClientSecret == "" { - return nil - } - got := r.Form.Get("client_secret") - if got == "" { - if user, pass, ok := r.BasicAuth(); ok && user != "" { - got = pass - } - } - if got == "" { - return fmt.Errorf("client_secret is required") - } - if subtle.ConstantTimeCompare([]byte(got), []byte(client.ClientSecret)) != 1 { - return fmt.Errorf("client_secret mismatch") - } - return nil -} - -func parseStatelessRegisteredClient(claims map[string]interface{}) (*statelessRegisteredClient, error) { - client := &statelessRegisteredClient{ - RedirectURIs: decodeStringSlice(claims["redirect_uris"]), - } - if authMethod, ok := claims["token_endpoint_auth_method"].(string); ok { - client.TokenEndpointAuthMethod = authMethod - } - if grantType, ok := claims["grant_type"].(string); ok { - client.GrantType = grantType - } - if exp, ok := claims["exp"].(float64); ok { - client.ExpiresAt = int64(exp) - } - if cs, ok := claims["client_secret"].(string); ok { - client.ClientSecret = cs - } - if client.TokenEndpointAuthMethod == "" { - client.TokenEndpointAuthMethod = "none" - } - if len(client.RedirectURIs) == 0 { - return nil, fmt.Errorf("missing redirect URIs") - } - if client.GrantType == "" { - client.GrantType = "authorization_code" - } - return client, nil -} - func oauthClaimsFromUserInfo(raw map[string]interface{}) *altinitymcp.OAuthClaims { claims := &altinitymcp.OAuthClaims{Extra: make(map[string]interface{})} if sub, ok := raw["sub"].(string); ok { @@ -1055,15 +983,15 @@ func (a *application) handleOAuthAuthorizationServerMetadata(w http.ResponseWrit // validateOAuthClaims still normalises slashes defensively. issuer := strings.TrimRight(baseURL, "/") resp := map[string]interface{}{ - "issuer": issuer, - "authorization_endpoint": joinURLPath(baseURL, a.oauthAuthorizationPath()), - "token_endpoint": joinURLPath(baseURL, a.oauthTokenPath()), - "registration_endpoint": joinURLPath(baseURL, a.oauthRegistrationPath()), - "scopes_supported": oidcScopesForAdvertisement(a.GetCurrentConfig().Server.OAuth), - "response_types_supported": []string{"code"}, - "grant_types_supported": []string{"authorization_code", "refresh_token"}, - "token_endpoint_auth_methods_supported": []string{"client_secret_post", "client_secret_basic", "none"}, - "code_challenge_methods_supported": []string{"S256"}, + "issuer": issuer, + "authorization_endpoint": joinURLPath(baseURL, a.oauthAuthorizationPath()), + "token_endpoint": joinURLPath(baseURL, a.oauthTokenPath()), + "scopes_supported": oidcScopesForAdvertisement(a.GetCurrentConfig().Server.OAuth), + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code"}, + "token_endpoint_auth_methods_supported": []string{"none"}, + "code_challenge_methods_supported": []string{"S256"}, + "client_id_metadata_document_supported": true, } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(resp) @@ -1077,15 +1005,15 @@ func (a *application) handleOAuthOpenIDConfiguration(w http.ResponseWriter, r *h baseURL := a.oauthAuthorizationServerBaseURL(r) issuer := strings.TrimRight(baseURL, "/") resp := map[string]interface{}{ - "issuer": issuer, - "authorization_endpoint": joinURLPath(baseURL, a.oauthAuthorizationPath()), - "token_endpoint": joinURLPath(baseURL, a.oauthTokenPath()), - "registration_endpoint": joinURLPath(baseURL, a.oauthRegistrationPath()), - "scopes_supported": oidcScopesForAdvertisement(a.GetCurrentConfig().Server.OAuth), - "response_types_supported": []string{"code"}, - "grant_types_supported": []string{"authorization_code", "refresh_token"}, - "token_endpoint_auth_methods_supported": []string{"client_secret_post", "client_secret_basic", "none"}, - "code_challenge_methods_supported": []string{"S256"}, + "issuer": issuer, + "authorization_endpoint": joinURLPath(baseURL, a.oauthAuthorizationPath()), + "token_endpoint": joinURLPath(baseURL, a.oauthTokenPath()), + "scopes_supported": oidcScopesForAdvertisement(a.GetCurrentConfig().Server.OAuth), + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code"}, + "token_endpoint_auth_methods_supported": []string{"none"}, + "code_challenge_methods_supported": []string{"S256"}, + "client_id_metadata_document_supported": true, } if !a.oauthForwardMode() { resp["subject_types_supported"] = []string{"public"} @@ -1095,135 +1023,6 @@ func (a *application) handleOAuthOpenIDConfiguration(w http.ResponseWriter, r *h _ = json.NewEncoder(w).Encode(resp) } -func (a *application) handleOAuthRegister(w http.ResponseWriter, r *http.Request) { - if !a.oauthEnabled() { - http.NotFound(w, r) - return - } - if r.Method != http.MethodPost { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - var req struct { - RedirectURIs []string `json:"redirect_uris"` - TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Invalid registration payload", http.StatusBadRequest) - return - } - if len(req.RedirectURIs) == 0 { - http.Error(w, "redirect_uris is required", http.StatusBadRequest) - return - } - for _, uri := range req.RedirectURIs { - parsed, err := url.Parse(uri) - if err != nil || parsed.Host == "" { - http.Error(w, "invalid redirect URI", http.StatusBadRequest) - return - } - switch parsed.Scheme { - case "https": - // always allowed - case "http": - host := parsed.Hostname() - if host != "localhost" && host != "127.0.0.1" && host != "::1" { - http.Error(w, "http redirect URIs are only allowed for localhost", http.StatusBadRequest) - return - } - default: - http.Error(w, "redirect URI must use https (or http for localhost)", http.StatusBadRequest) - return - } - } - // We register every new client as confidential (client_secret_post). The - // stored secret lives inside the JWE-encoded client_id, so the server - // remains stateless. Anthropic's `mcp_servers`-via-URL flow requires a - // confidential AS (it has no browser to perform PKCE on); leaving the - // "none" path as the only option silently 401s every artifact-side call. - // Public-client (PKCE-only) registrations from clients that explicitly ask - // for token_endpoint_auth_method:none are still honoured for back-compat - // with first-party apps that use only the browser auth-code path. - authMethod := req.TokenEndpointAuthMethod - if authMethod == "" { - authMethod = "client_secret_post" - } - switch authMethod { - case "client_secret_post", "client_secret_basic", "none": - default: - http.Error(w, "Unsupported token_endpoint_auth_method", http.StatusBadRequest) - return - } - - secret, err := a.mustJWESecret() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - var clientSecret string - if authMethod != "none" { - var raw [32]byte - if _, err := rand.Read(raw[:]); err != nil { - http.Error(w, "Failed to generate client_secret", http.StatusInternalServerError) - return - } - clientSecret = hex.EncodeToString(raw[:]) - } - - clientIDClaims := map[string]interface{}{ - "redirect_uris": req.RedirectURIs, - "token_endpoint_auth_method": authMethod, - "grant_type": "authorization_code", - "exp": time.Now().Add(30 * 24 * time.Hour).Unix(), - } - if clientSecret != "" { - // Embed the secret inside the JWE so the token endpoint can compare - // it against the inbound form parameter without server-side state. - clientIDClaims["client_secret"] = clientSecret - } - clientID, err := encodeOAuthJWE(secret, hkdfInfoOAuthClientID, clientIDClaims) - if err != nil { - http.Error(w, "Failed to create stateless client registration", http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusCreated) - expAt := time.Now().Add(30 * 24 * time.Hour).Unix() - // grant_types must include every grant the server will accept from this - // client. Per RFC 7591 §3.2.1 clients treat this list as authoritative, - // so omitting refresh_token here causes strict clients (e.g. Claude.ai) - // to skip grant_type=refresh_token even though /oauth/token would - // accept it and /.well-known/oauth-authorization-server advertises it - // via grant_types_supported. - // scope advertised on the DCR response goes through the same allowlist as - // metadata + WWW-Authenticate so DCR clients never see URI-form scopes or - // non-identity scopes (mcp:*, calendar, …). See oidcScopesForAdvertisement. - cfgOAuth := a.GetCurrentConfig().Server.OAuth - scopes := oidcScopesForAdvertisement(cfgOAuth) - if len(scopes) == 0 { - scopes = oidcScopesForAdvertisement(config.OAuthConfig{Scopes: cfgOAuth.RequiredScopes}) - } - resp := map[string]interface{}{ - "client_id": clientID, - "client_id_issued_at": time.Now().Unix(), - "redirect_uris": req.RedirectURIs, - "grant_types": []string{"authorization_code", "refresh_token"}, - "response_types": []string{"code"}, - "token_endpoint_auth_method": authMethod, - } - if len(scopes) > 0 { - resp["scope"] = strings.Join(scopes, " ") - } - if clientSecret != "" { - resp["client_secret"] = clientSecret - // RFC 7591 §3.2.1: client_secret_expires_at is REQUIRED when a secret - // is issued. The JWE client_id embeds the same exp, so use it here too. - resp["client_secret_expires_at"] = expAt - } - _ = json.NewEncoder(w).Encode(resp) -} - func (a *application) handleOAuthAuthorize(w http.ResponseWriter, r *http.Request) { if !a.oauthEnabled() { http.NotFound(w, r) @@ -1240,18 +1039,22 @@ func (a *application) handleOAuthAuthorize(w http.ResponseWriter, r *http.Reques http.Error(w, "Invalid authorization request", http.StatusBadRequest) return } - secret, err := a.mustJWESecret() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + // CIMD inbound (#115): client_id is the HTTPS URL of the MCP client's + // metadata document. resolveCIMDClient validates the URL, fetches the + // document under SSRF-safe constraints, and synthesises the registered + // client. DCR was removed in the same change; non-https client_ids are + // rejected as unknown. + if !isCIMDClientID(clientID) { + http.Error(w, "Unknown OAuth client", http.StatusBadRequest) return } - clientClaims, err := decodeOAuthJWE(secret, hkdfInfoOAuthClientID, clientID) + client, err := resolveCIMDClient(r.Context(), clientID) if err != nil { + log.Debug().Err(err).Str("client_id", clientID).Msg("OAuth /authorize rejected: CIMD resolution failed") http.Error(w, "Unknown OAuth client", http.StatusBadRequest) return } - client, err := parseStatelessRegisteredClient(clientClaims) - if err != nil || time.Now().Unix() > client.ExpiresAt || !slices.Contains(client.RedirectURIs, redirectURI) { + if !slices.Contains(client.RedirectURIs, redirectURI) { http.Error(w, "Unknown OAuth client", http.StatusBadRequest) return } @@ -1346,146 +1149,22 @@ func (a *application) handleOAuthCallback(w http.ResponseWriter, r *http.Request return } - cfg := a.GetCurrentConfig() - callbackURL := joinURLPath(a.oauthAuthorizationServerBaseURL(r), a.oauthCallbackPath()) - form := url.Values{} - form.Set("grant_type", "authorization_code") - form.Set("code", code) - form.Set("client_id", cfg.Server.OAuth.ClientID) - form.Set("client_secret", cfg.Server.OAuth.ClientSecret) - form.Set("redirect_uri", callbackURL) - // Replay our upstream PKCE verifier (set during /authorize) per RFC 7636 - // §4.5. Skipped only for legacy pending entries that predate the PKCE - // upgrade — those expire within 10 minutes and stop appearing. - if pending.UpstreamPKCEVerifier != "" { - form.Set("code_verifier", pending.UpstreamPKCEVerifier) - } - - tokenURL, err := a.resolveUpstreamTokenURL() - if err != nil { - http.Error(w, "Failed to resolve upstream token endpoint", http.StatusBadGateway) - return - } - resp, err := (&http.Client{Timeout: 10 * time.Second}).PostForm(tokenURL, form) - if err != nil { - log.Error().Err(err).Str("token_url", tokenURL).Msg("Upstream OAuth token exchange request failed") - http.Error(w, "Failed to exchange upstream auth code", http.StatusBadGateway) - return - } - defer func() { - if closeErr := resp.Body.Close(); closeErr != nil { - log.Error().Err(closeErr).Msgf("can't close %s response body", tokenURL) - } - }() - body, err := io.ReadAll(io.LimitReader(resp.Body, maxOAuthResponseBytes)) - if err != nil { - http.Error(w, "Failed to read upstream token response", http.StatusBadGateway) - return - } - if resp.StatusCode >= 300 { - errCode, bodyLen := safeUpstreamErrorFields(body) - log.Error().Int("status", resp.StatusCode).Str("error_code", errCode).Int("body_len", bodyLen).Msg("Upstream OAuth token exchange failed") - http.Error(w, "Failed to exchange upstream auth code", http.StatusBadGateway) - return - } - var tokenResp struct { - AccessToken string `json:"access_token"` - IDToken string `json:"id_token"` - RefreshToken string `json:"refresh_token"` - TokenType string `json:"token_type"` - ExpiresIn int64 `json:"expires_in"` - Scope string `json:"scope"` - } - if err := json.Unmarshal(body, &tokenResp); err != nil || (tokenResp.AccessToken == "" && tokenResp.IDToken == "") { - log.Error(). - Err(err). - Bool("has_access_token", tokenResp.AccessToken != ""). - Bool("has_id_token", tokenResp.IDToken != ""). - Msg("Upstream token response missing usable token") - http.Error(w, "Missing upstream token", http.StatusBadGateway) - return - } - log.Info(). - Bool("has_access_token", tokenResp.AccessToken != ""). - Bool("has_id_token", tokenResp.IDToken != ""). - Bool("has_refresh_token", tokenResp.RefreshToken != ""). - Bool("forward_mode", a.oauthForwardMode()). - Bool("upstream_offline_access", cfg.Server.OAuth.UpstreamOfflineAccess). - Str("scope", tokenResp.Scope). - Int64("expires_in", tokenResp.ExpiresIn). - Msg("Upstream OAuth token exchange succeeded") - - var identityClaims *altinitymcp.OAuthClaims - if tokenResp.IDToken != "" { - identityClaims, err = a.mcpServer.ValidateUpstreamIdentityToken(tokenResp.IDToken, cfg.Server.OAuth.ClientID) - if err != nil { - log.Error().Err(err).Msg("Upstream identity token validation failed") - http.Error(w, "Failed to validate upstream identity token", http.StatusBadGateway) - return - } - } else if tokenResp.AccessToken != "" { - identityClaims, err = a.fetchUserInfo(tokenResp.AccessToken) - if err != nil { - log.Error().Err(err).Msg("Upstream userinfo validation failed") - http.Error(w, "Failed to validate upstream identity", http.StatusBadGateway) - return - } - } else { - http.Error(w, "Missing upstream token", http.StatusBadGateway) - return - } - if tokenResp.Scope == "" { - tokenResp.Scope = pending.Scope - } - if tokenResp.Scope == "" { - tokenResp.Scope = strings.Join(cfg.Server.OAuth.Scopes, " ") - } - tokenType := tokenResp.TokenType - if tokenType == "" { - tokenType = "Bearer" - } - bearerToken := tokenResp.IDToken - if bearerToken == "" { - bearerToken = tokenResp.AccessToken - } - // The bearer we forward to ClickHouse is the ID token when present, else - // the access_token. Auth0 (and other IdPs) routinely return different - // lifetimes for the two — e.g. expires_in=86400 for the access_token while - // the id_token's own exp is iat+3600. We must report expires_in matching - // the actual bearer the client receives, otherwise downstream MCP clients - // (Claude.ai) won't refresh in time and the user-visible session breaks - // at the bearer's real expiry. - var accessTokenExpiry int64 - if tokenResp.IDToken != "" && identityClaims != nil && identityClaims.ExpiresAt > 0 { - accessTokenExpiry = identityClaims.ExpiresAt - } else if tokenResp.ExpiresIn > 0 { - accessTokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Unix() - } else { - accessTokenExpiry = time.Now().Add(time.Hour).Unix() - } - // /callback only runs in forward mode now (#109): under gating, /callback - // is not registered and clients redirect directly to the upstream IdP. - // Wrap the upstream tokens in our short-lived issued code; /token unwraps - // them in handleOAuthTokenAuthCode. + // HA replay model (#115): the upstream auth code is NOT redeemed here. + // We wrap it (plus the upstream PKCE verifier captured at /authorize and + // the pending-auth fields) into a 60s downstream JWE and let /oauth/token + // perform the upstream exchange. That way the upstream IdP — Google or + // Auth0 — is the sole cross-replica "used codes" oracle: a replayed + // downstream code hits upstream `invalid_grant` and fails. issuedCode := oauthIssuedCode{ - ClientID: pending.ClientID, - RedirectURI: pending.RedirectURI, - Scope: tokenResp.Scope, - CodeChallenge: pending.CodeChallenge, - CodeChallengeMethod: pending.CodeChallengeMethod, - Resource: pending.Resource, - ExpiresAt: time.Now().Add(time.Duration(defaultAuthCodeTTLSeconds) * time.Second), - UpstreamBearerToken: bearerToken, - UpstreamTokenType: tokenType, - AccessTokenExpiry: time.Unix(accessTokenExpiry, 0), - } - if cfg.Server.OAuth.UpstreamOfflineAccess { - issuedCode.UpstreamRefreshToken = tokenResp.RefreshToken - if tokenResp.RefreshToken == "" { - log.Warn(). - Str("scope", tokenResp.Scope). - Msg("upstream_offline_access=true but upstream did not return a refresh_token; check IdP application config (offline_access scope, refresh_token grant, audience)") - } + ClientID: pending.ClientID, + RedirectURI: pending.RedirectURI, + Scope: pending.Scope, + CodeChallenge: pending.CodeChallenge, + CodeChallengeMethod: pending.CodeChallengeMethod, + Resource: pending.Resource, + UpstreamAuthCode: code, + UpstreamPKCEVerifier: pending.UpstreamPKCEVerifier, + ExpiresAt: time.Now().Add(time.Duration(defaultAuthCodeTTLSeconds) * time.Second), } authCode, err := a.encodeAuthCode(issuedCode) if err != nil { @@ -1494,6 +1173,11 @@ func (a *application) handleOAuthCallback(w http.ResponseWriter, r *http.Request return } + log.Info(). + Str("client_id", pending.ClientID). + Bool("forward_mode", a.oauthForwardMode()). + Msg("OAuth /callback wrapped upstream auth code in downstream JWE; awaiting /token redemption") + redirect, err := url.Parse(pending.RedirectURI) if err != nil { http.Error(w, "Invalid redirect URI", http.StatusBadGateway) @@ -1508,25 +1192,6 @@ func (a *application) handleOAuthCallback(w http.ResponseWriter, r *http.Request http.Redirect(w, r, redirect.String(), http.StatusFound) } -// mintForwardRefreshToken wraps an upstream IdP refresh token in a stateless JWE. -func (a *application) mintForwardRefreshToken(secret []byte, upstreamRefresh, upstreamTokenType, scope, clientID, issuer string) (string, error) { - cfg := a.GetCurrentConfig() - now := time.Now() - tokenType := upstreamTokenType - if tokenType == "" { - tokenType = "Bearer" - } - return encodeOAuthJWE(secret, hkdfInfoOAuthRefresh, map[string]interface{}{ - "upstream_refresh_token": upstreamRefresh, - "upstream_token_type": tokenType, - "scope": scope, - "client_id": clientID, - "iss": strings.TrimSuffix(issuer, "/"), - "iat": now.Unix(), - "exp": now.Add(time.Duration(ttlSeconds(cfg.Server.OAuth.RefreshTokenTTLSeconds, defaultRefreshTokenTTLSeconds)) * time.Second).Unix(), - }) -} - func (a *application) handleOAuthToken(w http.ResponseWriter, r *http.Request) { if !a.oauthEnabled() { http.NotFound(w, r) @@ -1549,132 +1214,64 @@ func (a *application) handleOAuthToken(w http.ResponseWriter, r *http.Request) { switch grantType { case "authorization_code": a.handleOAuthTokenAuthCode(w, r) - case "refresh_token": - a.handleOAuthTokenRefreshDispatch(w, r) default: + // refresh_token grant is intentionally not supported in v1 (#115): + // CIMD clients re-authorize instead of refreshing. This keeps the + // downstream JWE footprint small and avoids issuing long-lived + // credentials to public clients without rotation/reuse detection. writeOAuthTokenError(w, http.StatusBadRequest, "unsupported_grant_type", "unsupported grant type") } } -// handleOAuthTokenRefreshDispatch validates the refresh request's client -// authentication and refresh-token JWE, then delegates to the forward-mode -// upstream-refresh path. Under #109, gating mode no longer mints refresh -// tokens — clients refresh directly against the upstream IdP — so this -// dispatcher only ever runs in forward mode. -func (a *application) handleOAuthTokenRefreshDispatch(w http.ResponseWriter, r *http.Request) { - log.Info(). - Bool("forward_mode", a.oauthForwardMode()). - Msg("OAuth refresh_token grant: handler entered") - secret, err := a.mustJWESecret() - if err != nil { - writeOAuthTokenError(w, http.StatusInternalServerError, "server_error", err.Error()) - return - } - +func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Request) { clientID := r.Form.Get("client_id") - clientClaims, err := decodeOAuthJWE(secret, hkdfInfoOAuthClientID, clientID) - if err != nil { + if !isCIMDClientID(clientID) { writeOAuthTokenError(w, http.StatusUnauthorized, "invalid_client", "unknown OAuth client") return } - client, err := parseStatelessRegisteredClient(clientClaims) - if err != nil || time.Now().Unix() > client.ExpiresAt { - writeOAuthTokenError(w, http.StatusUnauthorized, "invalid_client", "unknown OAuth client") - return - } - if err := authenticateClientSecret(client, r); err != nil { - log.Debug().Err(err).Msg("OAuth refresh request rejected: client_secret authentication failed") - writeOAuthTokenError(w, http.StatusUnauthorized, "invalid_client", "client authentication failed") - return - } - - refreshTokenStr := r.Form.Get("refresh_token") - if refreshTokenStr == "" { - writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "missing refresh token") - return - } - claims, err := decodeOAuthJWE(secret, hkdfInfoOAuthRefresh, refreshTokenStr) - if err != nil { - log.Warn().Err(err).Msg("OAuth refresh_token grant: JWE decode failed") - writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "invalid refresh token") - return - } - jweUpstreamRefresh, _ := claims["upstream_refresh_token"].(string) - log.Info(). - Bool("has_upstream_refresh_token", jweUpstreamRefresh != ""). - Msg("OAuth refresh_token grant: JWE decoded successfully") - - tokenClientID, _ := claims["client_id"].(string) - if tokenClientID != clientID { - log.Debug(). - Str("token_client_id", tokenClientID). - Str("request_client_id", clientID). - Msg("OAuth refresh rejected: client_id mismatch") - writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "refresh token was not issued to this client") - return - } - - a.handleOAuthTokenRefreshForward(w, r, secret, clientID, claims) -} - -func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Request) { - secret, err := a.mustJWESecret() - if err != nil { - writeOAuthTokenError(w, http.StatusInternalServerError, "server_error", err.Error()) + // Public CIMD clients reject any client_secret / client_assertion on /token + // per RFC 7591 token_endpoint_auth_method=none + CIMD spec. + if r.Form.Get("client_secret") != "" || r.Form.Get("client_assertion") != "" { + writeOAuthTokenError(w, http.StatusUnauthorized, "invalid_client", "client authentication not supported for public CIMD clients") return } - clientID := r.Form.Get("client_id") - clientClaims, err := decodeOAuthJWE(secret, hkdfInfoOAuthClientID, clientID) + client, err := resolveCIMDClient(r.Context(), clientID) if err != nil { + log.Debug().Err(err).Str("client_id", clientID).Msg("OAuth /token rejected: CIMD resolution failed") writeOAuthTokenError(w, http.StatusUnauthorized, "invalid_client", "unknown OAuth client") return } - client, err := parseStatelessRegisteredClient(clientClaims) - if err != nil || time.Now().Unix() > client.ExpiresAt { - log.Debug(). - Err(err). - Int64("client_expires_at", client.ExpiresAt). - Str("token_endpoint_auth_method", client.TokenEndpointAuthMethod). - Msg("OAuth token request rejected: invalid client metadata") - writeOAuthTokenError(w, http.StatusUnauthorized, "invalid_client", "unknown OAuth client") - return - } - if err := authenticateClientSecret(client, r); err != nil { - log.Debug().Err(err).Msg("OAuth token request rejected: client_secret authentication failed") - writeOAuthTokenError(w, http.StatusUnauthorized, "invalid_client", "client authentication failed") + requestRedirect := r.Form.Get("redirect_uri") + if !slices.Contains(client.RedirectURIs, requestRedirect) { + writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "redirect_uri not registered for this client") return } issued, ok := a.decodeAuthCode(r.Form.Get("code")) if !ok { - log.Debug().Msg("OAuth token request rejected: unknown or expired authorization code") + log.Debug().Msg("OAuth /token rejected: unknown or expired authorization code") writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "invalid authorization code") return } - if issued.ClientID != clientID || issued.RedirectURI != r.Form.Get("redirect_uri") { + if issued.ClientID != clientID || issued.RedirectURI != requestRedirect { log.Debug(). Time("code_expires_at", issued.ExpiresAt). Str("issued_client_id", issued.ClientID). Str("request_client_id", clientID). Str("issued_redirect_uri", issued.RedirectURI). - Str("request_redirect_uri", r.Form.Get("redirect_uri")). - Msg("OAuth token request rejected: authorization code mismatch") + Str("request_redirect_uri", requestRedirect). + Msg("OAuth /token rejected: authorization code mismatch") writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "invalid authorization code") return } - if issued.CodeChallenge != "" { - if pkceChallenge(r.Form.Get("code_verifier")) != issued.CodeChallenge { - log.Debug().Msg("OAuth token request rejected: invalid PKCE verifier") - writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "invalid PKCE verifier") - return - } + if issued.CodeChallenge == "" || pkceChallenge(r.Form.Get("code_verifier")) != issued.CodeChallenge { + log.Debug().Msg("OAuth /token rejected: invalid PKCE verifier") + writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "invalid PKCE verifier") + return } - // RFC 8707 §2.2: clients MAY also send `resource` on /token. When the same - // resource was already pinned at /authorize, both must agree; if /authorize - // omitted it but /token includes it, accept and use the latter. Enforced in - // both gating and forward modes — in forward mode the value is only used - // for the rejection check (the response carries the upstream bearer token - // which has its own `aud`). + // RFC 8707 §2.2: when `resource` was pinned at /authorize, /token must + // match. When /authorize omitted it but /token includes one, accept and + // use the latter for downstream advisory only. resource := issued.Resource if formResource := r.Form.Get("resource"); formResource != "" { if resource == "" { @@ -1684,122 +1281,62 @@ func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Re return } } - - // /oauth/token only runs in forward mode now (#109): under gating, /token - // is not registered and clients hit the upstream IdP directly. The issued - // authorization_code wraps an upstream bearer token captured in /callback; - // forward it back to the client unchanged, mint a forward-mode refresh - // JWE around the upstream refresh if offline_access is on. _ = resource - bearerToken := issued.UpstreamBearerToken - if bearerToken == "" { - writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "invalid authorization code") - return - } - expiresIn := int64(0) - if !issued.AccessTokenExpiry.IsZero() { - expiresIn = int64(time.Until(issued.AccessTokenExpiry).Seconds()) - if expiresIn < 0 { - expiresIn = 0 - } - } - response := map[string]interface{}{ - "access_token": bearerToken, - "token_type": issued.UpstreamTokenType, - "expires_in": expiresIn, - } - // Normalize Google URI-form scopes back to OIDC standard names before - // echoing to the MCP client. ChatGPT compares request vs response shape - // and warns when they differ; round-tripping standard names eliminates - // the cosmetic "permissions not granted" warning. Omit when empty — - // RFC 6749 §5.1 makes scope OPTIONAL when identical to the request. - if s := normalizeUpstreamScopeForClient(issued.Scope); s != "" { - response["scope"] = s - } - if issued.UpstreamRefreshToken != "" { - refreshToken, err := a.mintForwardRefreshToken(secret, issued.UpstreamRefreshToken, issued.UpstreamTokenType, issued.Scope, clientID, a.oauthAuthorizationServerBaseURL(r)) - if err != nil { - log.Error().Err(err).Msg("Failed to mint forward-mode refresh token") - writeOAuthTokenError(w, http.StatusInternalServerError, "server_error", err.Error()) - return - } - response["refresh_token"] = refreshToken - log.Info(). - Str("client_id", clientID). - Int("jwe_len", len(refreshToken)). - Msg("Forward-mode auth-code response includes refresh_token (JWE wrapping upstream refresh)") - } else { - log.Info(). - Str("client_id", clientID). - Msg("Forward-mode auth-code response WITHOUT refresh_token (no upstream refresh captured)") - } - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(response) -} -// handleOAuthTokenRefreshForward implements the forward-mode refresh flow. -// The decrypted JWE carries the upstream IdP refresh token; we exchange it -// upstream for a fresh ID token + (rotated) refresh token, re-validate the -// new ID token, and mint a new JWE wrapping the rotated upstream refresh. -// -// RFC 8707 §2.2 note: this path does not validate the optional `resource` -// form parameter. The forward refresh JWE (mintForwardRefreshToken) does not -// embed `aud`, so there is nothing to compare against. Audience enforcement -// in forward mode is delegated to the upstream IdP, which re-issues the ID -// token with its own `aud` claim. Closing this gap requires embedding `aud` -// in the forward refresh JWE — deliberately deferred to keep this change -// small; see the "Out of scope" note in the branch's review-fix plan. -func (a *application) handleOAuthTokenRefreshForward(w http.ResponseWriter, r *http.Request, secret []byte, clientID string, claims map[string]interface{}) { - upstreamRefresh, _ := claims["upstream_refresh_token"].(string) - if upstreamRefresh == "" { - writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "refresh token is not valid for forward mode") + // HA replay model: redeem the upstream auth code with the upstream IdP + // *now*, not at /callback. The upstream IdP's `invalid_grant` on a second + // redemption is our cross-replica replay verdict — see #115 § HA replay. + if issued.UpstreamAuthCode == "" || issued.UpstreamPKCEVerifier == "" { + writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "invalid authorization code") return } - upstreamTokenType, _ := claims["upstream_token_type"].(string) - scope, _ := claims["scope"].(string) - cfg := a.GetCurrentConfig() + callbackURL := joinURLPath(a.oauthAuthorizationServerBaseURL(r), a.oauthCallbackPath()) tokenURL, err := a.resolveUpstreamTokenURL() if err != nil { + log.Error().Err(err).Msg("OAuth /token: failed to resolve upstream token endpoint") writeOAuthTokenError(w, http.StatusBadGateway, "server_error", "failed to resolve upstream token endpoint") return } form := url.Values{} - form.Set("grant_type", "refresh_token") - form.Set("refresh_token", upstreamRefresh) + form.Set("grant_type", "authorization_code") + form.Set("code", issued.UpstreamAuthCode) form.Set("client_id", cfg.Server.OAuth.ClientID) if cfg.Server.OAuth.ClientSecret != "" { form.Set("client_secret", cfg.Server.OAuth.ClientSecret) } - if scope != "" { - form.Set("scope", scope) - } + form.Set("redirect_uri", callbackURL) + form.Set("code_verifier", issued.UpstreamPKCEVerifier) - log.Info().Str("token_url", tokenURL).Msg("Forward-mode refresh: calling upstream /oauth/token") - resp, err := (&http.Client{Timeout: 10 * time.Second}).PostForm(tokenURL, form) + upstreamResp, err := (&http.Client{Timeout: 10 * time.Second}).PostForm(tokenURL, form) if err != nil { - log.Error().Err(err).Str("token_url", tokenURL).Msg("Upstream OAuth refresh request failed") - writeOAuthTokenError(w, http.StatusBadGateway, "server_error", "upstream refresh failed") + log.Error().Err(err).Str("token_url", tokenURL).Msg("OAuth /token: upstream code exchange transport error") + writeOAuthTokenError(w, http.StatusBadGateway, "server_error", "upstream code exchange failed") return } defer func() { - if closeErr := resp.Body.Close(); closeErr != nil { + if closeErr := upstreamResp.Body.Close(); closeErr != nil { log.Error().Err(closeErr).Msgf("can't close %s response body", tokenURL) } }() - body, err := io.ReadAll(io.LimitReader(resp.Body, maxOAuthResponseBytes)) + body, err := io.ReadAll(io.LimitReader(upstreamResp.Body, maxOAuthResponseBytes)) if err != nil { - writeOAuthTokenError(w, http.StatusBadGateway, "server_error", "failed to read upstream refresh response") + writeOAuthTokenError(w, http.StatusBadGateway, "server_error", "failed to read upstream token response") return } - if resp.StatusCode >= 300 { + if upstreamResp.StatusCode >= 300 { errCode, bodyLen := safeUpstreamErrorFields(body) - log.Error().Int("status", resp.StatusCode).Str("error_code", errCode).Int("body_len", bodyLen).Msg("Upstream OAuth refresh rejected") - writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "upstream rejected the refresh token") + log.Warn(). + Int("status", upstreamResp.StatusCode). + Str("upstream_error", errCode). + Int("body_len", bodyLen). + Str("client_id", clientID). + Msg("OAuth /token: upstream code exchange rejected — likely replay") + // Map upstream invalid_grant (replay-detected, expired, already used) + // to a downstream invalid_grant per RFC 6749 §5.2. + writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "upstream rejected the authorization code") return } - log.Info().Int("status", resp.StatusCode).Msg("Forward-mode refresh: upstream /oauth/token returned 2xx") - var tokenResp struct { AccessToken string `json:"access_token"` IDToken string `json:"id_token"` @@ -1809,87 +1346,78 @@ func (a *application) handleOAuthTokenRefreshForward(w http.ResponseWriter, r *h Scope string `json:"scope"` } if err := json.Unmarshal(body, &tokenResp); err != nil || (tokenResp.AccessToken == "" && tokenResp.IDToken == "") { - log.Error().Err(err).Msg("Upstream refresh response missing usable token") - writeOAuthTokenError(w, http.StatusBadGateway, "server_error", "missing upstream token") + log.Error(). + Err(err). + Bool("has_access_token", tokenResp.AccessToken != ""). + Bool("has_id_token", tokenResp.IDToken != ""). + Msg("OAuth /token: upstream response missing usable token") + writeOAuthTokenError(w, http.StatusBadGateway, "server_error", "upstream returned no usable token") return } + log.Info(). + Bool("has_access_token", tokenResp.AccessToken != ""). + Bool("has_id_token", tokenResp.IDToken != ""). + Bool("forward_mode", a.oauthForwardMode()). + Str("scope", tokenResp.Scope). + Int64("expires_in", tokenResp.ExpiresIn). + Str("client_id", clientID). + Msg("OAuth /token: upstream code exchange succeeded") - bearerToken := tokenResp.IDToken - if bearerToken == "" { - bearerToken = tokenResp.AccessToken - } - // Re-run identity policy on the rotated upstream token before issuing it. - // Mirror handleOAuthCallback's preference: validate id_token via JWKS when - // present, otherwise fall back to the upstream userinfo endpoint with the - // access_token (which also runs identity-policy checks). var identityClaims *altinitymcp.OAuthClaims if tokenResp.IDToken != "" { identityClaims, err = a.mcpServer.ValidateUpstreamIdentityToken(tokenResp.IDToken, cfg.Server.OAuth.ClientID) if err != nil { - log.Error().Err(err).Msg("Upstream identity token validation failed on refresh") - writeOAuthTokenError(w, http.StatusForbidden, "access_denied", err.Error()) + log.Error().Err(err).Msg("OAuth /token: upstream identity token validation failed") + writeOAuthTokenError(w, http.StatusBadGateway, "server_error", "failed to validate upstream identity token") return } } else if tokenResp.AccessToken != "" { - if _, err := a.fetchUserInfo(tokenResp.AccessToken); err != nil { - log.Error().Err(err).Msg("Upstream userinfo validation failed on refresh") - writeOAuthTokenError(w, http.StatusForbidden, "access_denied", err.Error()) + identityClaims, err = a.fetchUserInfo(tokenResp.AccessToken) + if err != nil { + log.Error().Err(err).Msg("OAuth /token: upstream userinfo validation failed") + writeOAuthTokenError(w, http.StatusBadGateway, "server_error", "failed to validate upstream identity") return } } + _ = identityClaims - rotatedUpstream := tokenResp.RefreshToken - if rotatedUpstream == "" { - // IdP did not rotate; keep the existing upstream refresh. - rotatedUpstream = upstreamRefresh - } - newTokenType := tokenResp.TokenType - if newTokenType == "" { - newTokenType = upstreamTokenType + if tokenResp.Scope == "" { + tokenResp.Scope = issued.Scope } - if newTokenType == "" { - newTokenType = "Bearer" + if tokenResp.Scope == "" { + tokenResp.Scope = strings.Join(cfg.Server.OAuth.Scopes, " ") } - newScope := tokenResp.Scope - if newScope == "" { - newScope = scope + tokenType := tokenResp.TokenType + if tokenType == "" { + tokenType = "Bearer" } - newRefreshJWE, err := a.mintForwardRefreshToken(secret, rotatedUpstream, newTokenType, newScope, clientID, a.oauthAuthorizationServerBaseURL(r)) - if err != nil { - log.Error().Err(err).Msg("Failed to mint rotated forward-mode refresh token") - writeOAuthTokenError(w, http.StatusInternalServerError, "server_error", err.Error()) - return + bearerToken := tokenResp.IDToken + if bearerToken == "" { + bearerToken = tokenResp.AccessToken } - // Match expires_in to the actual bearer we forward (id_token when present), - // not tokenResp.ExpiresIn which describes the access_token's lifetime — - // IdPs often return divergent lifetimes (e.g. Auth0: id_token exp = iat+3600, - // access_token expires_in = 86400). See handleOAuthCallback for the same fix. var expiresIn int64 if tokenResp.IDToken != "" && identityClaims != nil && identityClaims.ExpiresAt > 0 { expiresIn = identityClaims.ExpiresAt - time.Now().Unix() } else if tokenResp.ExpiresIn > 0 { expiresIn = tokenResp.ExpiresIn } else { - expiresIn = int64(time.Hour.Seconds()) + expiresIn = int64(defaultAccessTokenTTLSeconds) } if expiresIn < 0 { expiresIn = 0 } - refreshResp := map[string]interface{}{ - "access_token": bearerToken, - "refresh_token": newRefreshJWE, - "token_type": newTokenType, - "expires_in": expiresIn, + response := map[string]interface{}{ + "access_token": bearerToken, + "token_type": tokenType, + "expires_in": expiresIn, } - // Mirror handleOAuthTokenAuthCode: normalize URI-form scopes to OIDC - // standard names for the client; the upstream-stored newScope (now in the - // rotated refresh JWE) keeps Google's original form for the next upstream - // refresh call. - if s := normalizeUpstreamScopeForClient(newScope); s != "" { - refreshResp["scope"] = s + if s := normalizeUpstreamScopeForClient(tokenResp.Scope); s != "" { + response["scope"] = s } + // v1 deliberately drops refresh_token from the response. CIMD clients + // re-authorize. See #115 § Refresh-token policy. w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(refreshResp) + _ = json.NewEncoder(w).Encode(response) } func truncateForLog(value string, max int) string { @@ -1930,9 +1458,10 @@ func (a *application) registerOAuthHTTPRoutes(mux *http.ServeMux) { mux.HandleFunc(path, a.handleOAuthOpenIDConfiguration) } - for _, path := range uniquePaths(a.oauthRegistrationPath(), defaultRegistrationPath) { - mux.HandleFunc(path, a.handleOAuthRegister) - } + // /oauth/register is intentionally NOT mounted: DCR was removed in + // favour of CIMD per #115. Old clients calling /oauth/register get the + // mux's default 404. The .well-known metadata no longer advertises + // registration_endpoint either. for _, path := range uniquePaths(a.oauthAuthorizationPath(), defaultAuthorizationPath) { mux.HandleFunc(path, a.handleOAuthAuthorize) } diff --git a/cmd/altinity-mcp/oauth_server_test.go b/cmd/altinity-mcp/oauth_server_test.go index cdebb43..93d9114 100644 --- a/cmd/altinity-mcp/oauth_server_test.go +++ b/cmd/altinity-mcp/oauth_server_test.go @@ -1,17 +1,12 @@ package main import ( - "bytes" - "crypto/rand" - "crypto/rsa" "encoding/base64" "encoding/json" - "fmt" "net/http" "net/http/httptest" "net/url" "strings" - "sync" "testing" "time" @@ -41,1919 +36,213 @@ func decodeJWTSegment(seg string) ([]byte, error) { // 3. Legacy artifacts (no kid, single SHA256(secret) key) still decrypt and // verify, so existing refresh tokens / client_ids minted before the // cutover keep working through the rotation window. -func TestOAuthJWEHKDFRoundtripAndLegacyFallback(t *testing.T) { - t.Parallel() - - secret := []byte("test-signing-secret-32-byte-key!!") - - t.Run("v1_artifact_carries_kid_header", func(t *testing.T) { - t.Parallel() - token, err := encodeOAuthJWE(secret, hkdfInfoOAuthClientID, map[string]interface{}{ - "sub": "user-1", - "exp": time.Now().Add(time.Hour).Unix(), - }) - require.NoError(t, err) - // JWE compact serialisation: 5 dot-separated parts (header.cek.iv.ct.tag). - parts := strings.Split(token, ".") - require.Len(t, parts, 5) - header, err := decodeJWTSegment(parts[0]) - require.NoError(t, err) - var hdr map[string]interface{} - require.NoError(t, json.Unmarshal(header, &hdr)) - require.Equal(t, oauthKidV1, hdr["kid"], "newly-issued JWE must carry kid=v1") - }) - - t.Run("v1_roundtrip", func(t *testing.T) { - t.Parallel() - original := map[string]interface{}{ - "sub": "user-1", - "exp": float64(time.Now().Add(time.Hour).Unix()), - "scope": "openid email", - "email": "u@example.com", - "client_id": "test-client", - } - token, err := encodeOAuthJWE(secret, hkdfInfoOAuthRefresh, original) - require.NoError(t, err) - decrypted, err := decodeOAuthJWE(secret, hkdfInfoOAuthRefresh, token) - require.NoError(t, err) - require.Equal(t, original["sub"], decrypted["sub"]) - require.Equal(t, original["scope"], decrypted["scope"]) - }) - - t.Run("v1_domain_separation_blocks_cross_context_decrypt", func(t *testing.T) { - // A refresh token's JWE MUST NOT decrypt against the client_id key, - // even though both are minted from the same shared secret. This is - // the core HKDF benefit (RFC 5869 §3.2): different info → independent - // keys. - t.Parallel() - token, err := encodeOAuthJWE(secret, hkdfInfoOAuthRefresh, map[string]interface{}{ - "sub": "user-1", - "exp": time.Now().Add(time.Hour).Unix(), - }) - require.NoError(t, err) - _, err = decodeOAuthJWE(secret, hkdfInfoOAuthClientID, token) - require.Error(t, err, "decryption with the wrong info label MUST fail") - }) - - t.Run("legacy_artifact_decrypts_via_fallback", func(t *testing.T) { - // Mint a JWE the way the pre-Step-2 server did: jwe_auth.GenerateJWEToken - // with the raw secret, no kid header, JWT-signed inner content. - t.Parallel() - legacy, err := jwe_auth.GenerateJWEToken(map[string]interface{}{ - "sub": "user-legacy", - "exp": time.Now().Add(time.Hour).Unix(), - "scope": "openid", - }, secret, secret) - require.NoError(t, err) - // Sanity: legacy artifacts have no kid (or empty) in the protected header. - parts := strings.Split(legacy, ".") - header, err := decodeJWTSegment(parts[0]) - require.NoError(t, err) - var hdr map[string]interface{} - require.NoError(t, json.Unmarshal(header, &hdr)) - _, hasKid := hdr["kid"] - require.False(t, hasKid, "legacy artifact must not carry kid") - - // Now decode it via the new path — should succeed via the legacy - // fallback branch, regardless of which info label we ask for (the - // fallback ignores info because the legacy SHA256(secret) key is - // shared across contexts). - decoded, err := decodeOAuthJWE(secret, hkdfInfoOAuthRefresh, legacy) - require.NoError(t, err, "legacy JWE must remain decryptable during the rotation window") - require.Equal(t, "user-legacy", decoded["sub"]) - }) - -} - -func TestOAuthHTTPDiscoveryAndRegistration(t *testing.T) { - t.Parallel() - app := &application{ - config: config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Enabled: true, - Issuer: "https://mcp.example.com/oauth", - Audience: "https://mcp.example.com", - PublicResourceURL: "https://mcp.example.com", - PublicAuthServerURL: "https://mcp.example.com/oauth", - SigningSecret: "test-gating-secret-32-byte-key!!", - Scopes: []string{"openid", "email"}, - AuthURL: "https://accounts.google.com/o/oauth2/v2/auth", - TokenURL: "https://oauth2.googleapis.com/token", - ClientID: "google-client-id", - ClientSecret: "google-client-secret", - }, - }, - }, - } - - // NOTE: subtests are NOT parallel — custom_public_urls_and_paths mutates shared app.config - t.Run("protected_resource_metadata", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/.well-known/oauth-protected-resource", nil) - rr := httptest.NewRecorder() - app.handleOAuthProtectedResource(rr, req) - require.Equal(t, http.StatusOK, rr.Code) - - var body map[string]interface{} - require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) - require.Equal(t, "https://mcp.example.com/", body["resource"]) - require.Equal(t, []interface{}{"https://mcp.example.com/oauth"}, body["authorization_servers"]) - }) - - t.Run("authorization_server_metadata", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/.well-known/oauth-authorization-server", nil) - rr := httptest.NewRecorder() - app.handleOAuthAuthorizationServerMetadata(rr, req) - require.Equal(t, http.StatusOK, rr.Code) - require.Contains(t, rr.Body.String(), "\"authorization_endpoint\":\"https://mcp.example.com/oauth/oauth/authorize\"") - require.Contains(t, rr.Body.String(), "\"registration_endpoint\":\"https://mcp.example.com/oauth/oauth/register\"") - }) - - t.Run("openid_configuration_aliases", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/.well-known/openid-configuration/oauth", nil) - rr := httptest.NewRecorder() - app.handleOAuthOpenIDConfiguration(rr, req) - require.Equal(t, http.StatusOK, rr.Code) - require.Contains(t, rr.Body.String(), "\"issuer\":\"https://mcp.example.com/oauth\"") - require.Contains(t, rr.Body.String(), "\"token_endpoint\":\"https://mcp.example.com/oauth/oauth/token\"") - }) - - t.Run("dynamic_client_registration", func(t *testing.T) { - body := bytes.NewBufferString(`{"redirect_uris":["http://127.0.0.1:3334/callback"],"token_endpoint_auth_method":"none"}`) - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/register", body) - rr := httptest.NewRecorder() - app.handleOAuthRegister(rr, req) - require.Equal(t, http.StatusCreated, rr.Code) - require.Contains(t, rr.Body.String(), "\"client_id\"") - require.Contains(t, rr.Body.String(), "\"token_endpoint_auth_method\":\"none\"") - - var reg map[string]interface{} - require.NoError(t, json.Unmarshal(rr.Body.Bytes(), ®)) - clientID, ok := reg["client_id"].(string) - require.True(t, ok) - require.NotEmpty(t, clientID) - - // Registration response must echo every grant the client is - // permitted to use. Per RFC 7591 strict clients (Claude.ai) treat - // an omitted grant as forbidden and never attempt it, which - // silently disables the refresh flow. - grants, ok := reg["grant_types"].([]interface{}) - require.True(t, ok, "grant_types missing or wrong type in registration response") - require.ElementsMatch(t, []interface{}{"authorization_code", "refresh_token"}, grants, - "registration response must advertise both authorization_code and refresh_token") - - authReq := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/authorize?response_type=code&client_id="+url.QueryEscape(clientID)+"&redirect_uri="+url.QueryEscape("http://127.0.0.1:3334/callback")+"&scope=openid+email&state=test-state&code_challenge=test-challenge&code_challenge_method=S256", nil) - authRR := httptest.NewRecorder() - app.handleOAuthAuthorize(authRR, authReq) - require.Equal(t, http.StatusFound, authRR.Code) - require.Contains(t, authRR.Header().Get("Location"), "https://accounts.google.com/o/oauth2/v2/auth") - }) - - t.Run("authorize_resource_indicator_accepted_when_matches_advertised_resource", func(t *testing.T) { - // RFC 8707 / MCP authorization spec: client passes `resource=` - // on /authorize. We accept either trailing-slash form, but the bare-host - // form is the canonical advertised resource here (PublicResourceURL - // "https://mcp.example.com" — slashes get appended where needed). - regBody := bytes.NewBufferString(`{"redirect_uris":["http://127.0.0.1:3334/callback"],"token_endpoint_auth_method":"none"}`) - regReq := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/register", regBody) - regRR := httptest.NewRecorder() - app.handleOAuthRegister(regRR, regReq) - var reg map[string]interface{} - require.NoError(t, json.Unmarshal(regRR.Body.Bytes(), ®)) - clientID, _ := reg["client_id"].(string) - - base := "https://mcp.example.com/oauth/authorize?response_type=code&client_id=" + url.QueryEscape(clientID) + - "&redirect_uri=" + url.QueryEscape("http://127.0.0.1:3334/callback") + - "&scope=openid+email&state=s&code_challenge=c&code_challenge_method=S256" - - // (a) resource present and matches advertised resource (trailing-slash form): 302 - authReq := httptest.NewRequest(http.MethodGet, base+"&resource="+url.QueryEscape("https://mcp.example.com/"), nil) - authRR := httptest.NewRecorder() - app.handleOAuthAuthorize(authRR, authReq) - require.Equal(t, http.StatusFound, authRR.Code, "valid resource indicator must be accepted (slash form)") - - // PKCE on the upstream-IdP leg: the redirect to upstream MUST include - // code_challenge + code_challenge_method=S256 (OAuth 2.1 §7.5.2). - // Without this, an attacker who intercepts the upstream auth code - // (e.g., via referrer or proxy logs between IdP and our /callback) - // could redeem it even though we hold the upstream client_secret. - upstreamRedirect, parseErr := url.Parse(authRR.Header().Get("Location")) - require.NoError(t, parseErr) - require.NotEmpty(t, upstreamRedirect.Query().Get("code_challenge"), - "upstream /authorize redirect must carry code_challenge (RFC 7636 / OAuth 2.1)") - require.Equal(t, "S256", upstreamRedirect.Query().Get("code_challenge_method"), - "upstream PKCE method must be S256 per OAuth 2.1 §4.1.1") - - // (b) resource present and matches advertised resource (bare host form): 302 - authReq = httptest.NewRequest(http.MethodGet, base+"&resource="+url.QueryEscape("https://mcp.example.com"), nil) - authRR = httptest.NewRecorder() - app.handleOAuthAuthorize(authRR, authReq) - require.Equal(t, http.StatusFound, authRR.Code, "valid resource indicator must be accepted (bare host form)") - - // (c) resource present but identifies a different host: 400 - authReq = httptest.NewRequest(http.MethodGet, base+"&resource="+url.QueryEscape("https://attacker.example/"), nil) - authRR = httptest.NewRecorder() - app.handleOAuthAuthorize(authRR, authReq) - require.Equal(t, http.StatusBadRequest, authRR.Code, "mismatched resource indicator must be rejected") - - // (d) resource absent (legacy clients): 302 (back-compat — RFC 8707 says SHOULD, not MUST) - authReq = httptest.NewRequest(http.MethodGet, base, nil) - authRR = httptest.NewRecorder() - app.handleOAuthAuthorize(authRR, authReq) - require.Equal(t, http.StatusFound, authRR.Code, "missing resource indicator must still authorize (legacy clients)") - }) - - t.Run("dynamic_client_registration_default_is_confidential", func(t *testing.T) { - // When the client doesn't ask for a specific auth method, we now - // register it as confidential (client_secret_post). This unblocks - // Anthropic's mcp_servers-via-URL flow, which has no browser session - // for PKCE and needs server-to-server token-endpoint auth. - body := bytes.NewBufferString(`{"redirect_uris":["http://127.0.0.1:3334/callback"]}`) - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/register", body) - rr := httptest.NewRecorder() - app.handleOAuthRegister(rr, req) - require.Equal(t, http.StatusCreated, rr.Code) - - var reg map[string]interface{} - require.NoError(t, json.Unmarshal(rr.Body.Bytes(), ®)) - require.Equal(t, "client_secret_post", reg["token_endpoint_auth_method"]) - cs, _ := reg["client_secret"].(string) - require.NotEmpty(t, cs, "confidential registration must include client_secret") - require.Len(t, cs, 64, "client_secret should be 32 random bytes hex-encoded") - _, hasExpiry := reg["client_secret_expires_at"] - require.True(t, hasExpiry, "RFC 7591 §3.2.1: client_secret_expires_at is required when secret is issued") - }) - - t.Run("dynamic_client_registration_explicit_none_still_public", func(t *testing.T) { - // First-party flows that explicitly ask for the legacy public-client - // shape keep getting it — no client_secret in the response. - body := bytes.NewBufferString(`{"redirect_uris":["http://127.0.0.1:3334/callback"],"token_endpoint_auth_method":"none"}`) - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/register", body) - rr := httptest.NewRecorder() - app.handleOAuthRegister(rr, req) - require.Equal(t, http.StatusCreated, rr.Code) - - var reg map[string]interface{} - require.NoError(t, json.Unmarshal(rr.Body.Bytes(), ®)) - require.Equal(t, "none", reg["token_endpoint_auth_method"]) - _, hasSecret := reg["client_secret"] - require.False(t, hasSecret, "public registration must not include client_secret") - }) - - t.Run("authentication_methods_advertised", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/.well-known/oauth-authorization-server", nil) - rr := httptest.NewRecorder() - app.handleOAuthAuthorizationServerMetadata(rr, req) - require.Equal(t, http.StatusOK, rr.Code) - - var meta map[string]interface{} - require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &meta)) - methods, ok := meta["token_endpoint_auth_methods_supported"].([]interface{}) - require.True(t, ok) - require.Contains(t, methods, "client_secret_post") - require.Contains(t, methods, "client_secret_basic") - require.Contains(t, methods, "none") - }) - - t.Run("custom_public_urls_and_paths", func(t *testing.T) { - app.config.Server.OAuth.PublicResourceURL = "https://public.example.com" - app.config.Server.OAuth.PublicAuthServerURL = "https://public.example.com/oauth" - app.config.Server.OAuth.RegistrationPath = "/register" - app.config.Server.OAuth.AuthorizationPath = "/authorize" - app.config.Server.OAuth.CallbackPath = "/callback" - app.config.Server.OAuth.TokenPath = "/token" - - req := httptest.NewRequest(http.MethodGet, "https://internal.example.com/.well-known/oauth-authorization-server", nil) - rr := httptest.NewRecorder() - app.handleOAuthAuthorizationServerMetadata(rr, req) - require.Equal(t, http.StatusOK, rr.Code) - require.Contains(t, rr.Body.String(), "\"issuer\":\"https://public.example.com/oauth\"") - require.Contains(t, rr.Body.String(), "\"authorization_endpoint\":\"https://public.example.com/oauth/authorize\"") - require.Contains(t, rr.Body.String(), "\"registration_endpoint\":\"https://public.example.com/oauth/register\"") - - req = httptest.NewRequest(http.MethodGet, "https://internal.example.com/.well-known/oauth-protected-resource", nil) - rr = httptest.NewRecorder() - app.handleOAuthProtectedResource(rr, req) - require.Equal(t, http.StatusOK, rr.Code) - require.Contains(t, rr.Body.String(), "\"resource\":\"https://public.example.com/\"") - // In gating mode the protected-resource metadata advertises the upstream - // IdP (cfg.Issuer) as the AS, not MCP's own PublicAuthServerURL. - require.Contains(t, rr.Body.String(), "\"authorization_servers\":[\"https://mcp.example.com/oauth\"]") - }) -} - -func TestOAuthMCPAuthInjector(t *testing.T) { - t.Parallel() - - app := &application{ - config: config.Config{ - Server: config.ServerConfig{ - JWE: config.JWEConfig{ - Enabled: true, - JWESecretKey: "this-is-a-32-byte-secret-key!!", - JWTSecretKey: "jwt-secret", - }, - OAuth: config.OAuthConfig{ - Enabled: true, - Mode: "gating", - Issuer: "https://accounts.example.com", - PublicAuthServerURL: "https://mcp.example.com", - Audience: "https://mcp.example.com", - SigningSecret: "test-gating-secret-32-byte-key!!", - }, - }, - }, - mcpServer: altinitymcp.NewClickHouseMCPServer(config.Config{Server: config.ServerConfig{JWE: config.JWEConfig{Enabled: true, JWESecretKey: "this-is-a-32-byte-secret-key!!", JWTSecretKey: "jwt-secret"}, OAuth: config.OAuthConfig{ - Enabled: true, - Mode: "gating", - Issuer: "https://accounts.example.com", - PublicAuthServerURL: "https://mcp.example.com", - Audience: "https://mcp.example.com", - SigningSecret: "test-gating-secret-32-byte-key!!", - }}}, "test"), - } - - jweToken, err := jwe_auth.GenerateJWEToken(map[string]interface{}{"host": "localhost", "port": 8123, "exp": time.Now().Add(time.Hour).Unix()}, []byte("this-is-a-32-byte-secret-key!!"), []byte("jwt-secret")) - require.NoError(t, err) - jweTokenWithCredentials, err := jwe_auth.GenerateJWEToken(map[string]interface{}{ - "host": "localhost", - "port": 8123, - "username": "default", - "password": "secret", - "exp": time.Now().Add(time.Hour).Unix(), - }, []byte("this-is-a-32-byte-secret-key!!"), []byte("jwt-secret")) - require.NoError(t, err) - - t.Run("missing_oauth_gets_challenge", func(t *testing.T) { - t.Parallel() - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/"+jweToken, nil) - req.SetPathValue("token", jweToken) - rr := httptest.NewRecorder() - handler := app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - handler.ServeHTTP(rr, req) - require.Equal(t, http.StatusUnauthorized, rr.Code) - require.Contains(t, rr.Header().Get("WWW-Authenticate"), "resource_metadata=") - require.Contains(t, rr.Header().Get("WWW-Authenticate"), "error=\"invalid_token\"") - }) - - t.Run("jwe_with_credentials_skips_oauth", func(t *testing.T) { - t.Parallel() - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/"+jweTokenWithCredentials, nil) - req.SetPathValue("token", jweTokenWithCredentials) - rr := httptest.NewRecorder() - called := false - handler := app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called = true - require.Equal(t, jweTokenWithCredentials, r.Context().Value(altinitymcp.JWETokenKey)) - require.Nil(t, r.Context().Value(altinitymcp.OAuthTokenKey)) - w.WriteHeader(http.StatusOK) - })) - handler.ServeHTTP(rr, req) - require.True(t, called) - require.Equal(t, http.StatusOK, rr.Code) - }) -} - -func TestOAuthMCPAuthInjectorForwardModePassesOpaqueBearerToken(t *testing.T) { - t.Parallel() - token := "opaque-access-token" - app := &application{ - config: config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Enabled: true, - Mode: "forward", - }, - }, - }, - mcpServer: altinitymcp.NewClickHouseMCPServer(config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Enabled: true, - Mode: "forward", - }, - }, - }, "test"), - } - - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/", nil) - req.Header.Set("Authorization", "Bearer "+token) - rr := httptest.NewRecorder() - called := false - - handler := app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called = true - require.Equal(t, token, r.Context().Value(altinitymcp.OAuthTokenKey)) - require.Nil(t, r.Context().Value(altinitymcp.OAuthClaimsKey)) - w.WriteHeader(http.StatusOK) - })) - - handler.ServeHTTP(rr, req) - require.True(t, called) - require.Equal(t, http.StatusOK, rr.Code) -} - -// TestOAuthMCPAuthInjectorForwardModeValidatesJWT is the integration check -// for the C-1 fix: forward mode used to skip ValidateOAuthToken entirely, -// so any string in `Authorization: Bearer …` reached the inner handler -// and was forwarded to ClickHouse. After C-1 the auth layer validates JWT -// bearers when Issuer/JWKSURL is configured and rejects bad ones at 401. -func TestOAuthMCPAuthInjectorForwardModeValidatesJWT(t *testing.T) { - t.Parallel() - - provider := newTestForwardModeOIDCProvider(t, nil, nil) - cfg := config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Enabled: true, - Mode: "forward", - Issuer: provider.server.URL, - JWKSURL: provider.server.URL + "/jwks", - Audience: "clickhouse-api", - }, - }, - } - app := &application{ - config: cfg, - mcpServer: altinitymcp.NewClickHouseMCPServer(cfg, "test"), - } - - t.Run("valid_jwt_reaches_handler_with_claims", func(t *testing.T) { - t.Parallel() - token := provider.issueIDToken(t, map[string]interface{}{ - "sub": "user-good", - "iss": provider.server.URL, - "aud": "clickhouse-api", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - }) - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/", nil) - req.Header.Set("Authorization", "Bearer "+token) - rr := httptest.NewRecorder() - called := false - handler := app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called = true - require.Equal(t, token, r.Context().Value(altinitymcp.OAuthTokenKey)) - claims, ok := r.Context().Value(altinitymcp.OAuthClaimsKey).(*altinitymcp.OAuthClaims) - require.True(t, ok, "valid forward-mode JWT must populate OAuthClaims in context") - require.Equal(t, "user-good", claims.Subject) - w.WriteHeader(http.StatusOK) - })) - handler.ServeHTTP(rr, req) - require.True(t, called) - require.Equal(t, http.StatusOK, rr.Code) - }) - - t.Run("jwt_with_wrong_audience_rejected_with_401", func(t *testing.T) { - t.Parallel() - token := provider.issueIDToken(t, map[string]interface{}{ - "sub": "user-bad-aud", - "iss": provider.server.URL, - "aud": "some-other-api", - "exp": time.Now().Add(time.Hour).Unix(), - }) - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/", nil) - req.Header.Set("Authorization", "Bearer "+token) - rr := httptest.NewRecorder() - called := false - handler := app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called = true - w.WriteHeader(http.StatusOK) - })) - handler.ServeHTTP(rr, req) - require.False(t, called, "wrong-aud forward-mode JWT must NOT reach inner handler") - require.Equal(t, http.StatusUnauthorized, rr.Code) - require.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) - require.Contains(t, rr.Header().Get("WWW-Authenticate"), "resource_metadata=") - }) - - t.Run("expired_jwt_rejected_with_401", func(t *testing.T) { - t.Parallel() - token := provider.issueIDToken(t, map[string]interface{}{ - "sub": "user-expired", - "iss": provider.server.URL, - "aud": "clickhouse-api", - "exp": time.Now().Add(-2 * time.Hour).Unix(), - "iat": time.Now().Add(-3 * time.Hour).Unix(), - }) - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/", nil) - req.Header.Set("Authorization", "Bearer "+token) - rr := httptest.NewRecorder() - called := false - handler := app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called = true - w.WriteHeader(http.StatusOK) - })) - handler.ServeHTTP(rr, req) - require.False(t, called, "expired forward-mode JWT must NOT reach inner handler") - require.Equal(t, http.StatusUnauthorized, rr.Code) - }) - - t.Run("opaque_bearer_softpasses_when_jwks_configured", func(t *testing.T) { - t.Parallel() - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/", nil) - req.Header.Set("Authorization", "Bearer not-a-jwt-just-an-opaque-string") - rr := httptest.NewRecorder() - called := false - handler := app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called = true - require.Equal(t, "not-a-jwt-just-an-opaque-string", r.Context().Value(altinitymcp.OAuthTokenKey)) - require.Nil(t, r.Context().Value(altinitymcp.OAuthClaimsKey)) - w.WriteHeader(http.StatusOK) - })) - handler.ServeHTTP(rr, req) - require.True(t, called, "opaque forward-mode bearer soft-passes (deferred to ClickHouse)") - require.Equal(t, http.StatusOK, rr.Code) - }) -} - -func TestRegisterOAuthHTTPRoutesAliases(t *testing.T) { - t.Parallel() - app := &application{ - config: config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Enabled: true, - Mode: "forward", - Issuer: "https://mcp.example.com/oauth", - Audience: "https://mcp.example.com", - Scopes: []string{"openid", "email"}, - }, - }, - }, - } - - mux := http.NewServeMux() - app.registerOAuthHTTPRoutes(mux) - - for _, path := range []string{ - "/.well-known/oauth-authorization-server/oauth", - "/oauth/.well-known/oauth-authorization-server", - "/.well-known/openid-configuration/oauth", - "/oauth/.well-known/openid-configuration", - } { - req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com"+path, nil) - rr := httptest.NewRecorder() - mux.ServeHTTP(rr, req) - require.Equalf(t, http.StatusOK, rr.Code, "expected alias %s to resolve", path) - } - - app.config.Server.OAuth.RegistrationPath = "/register" - app.config.Server.OAuth.AuthorizationPath = "/authorize" - app.config.Server.OAuth.CallbackPath = "/callback" - app.config.Server.OAuth.TokenPath = "/token" - - mux = http.NewServeMux() - app.registerOAuthHTTPRoutes(mux) - - for _, path := range []string{ - "/register", - "/authorize", - "/callback", - "/token", - } { - method := http.MethodGet - if path == "/register" || path == "/token" { - method = http.MethodPost - } - req := httptest.NewRequest(method, "https://mcp.example.com"+path, nil) - rr := httptest.NewRecorder() - mux.ServeHTTP(rr, req) - require.NotEqualf(t, http.StatusNotFound, rr.Code, "expected configured path %s to resolve", path) - } -} - -type testForwardModeOIDCProvider struct { - server *httptest.Server - - privateKey *rsa.PrivateKey - keyID string - - tokenResponse map[string]interface{} - userInfoClaims map[string]interface{} - lastUserInfoAuth string - userInfoCalls int - mu sync.Mutex - - // refreshHandler, if non-nil, handles POST /token requests with - // grant_type=refresh_token. It receives the parsed form and returns - // (status, body). When nil, refresh_token grants fall through to the - // default static tokenResponse behavior. - refreshHandler func(form url.Values) (int, map[string]interface{}) -} - -func newTestForwardModeOIDCProvider(t *testing.T, tokenResponse map[string]interface{}, userInfoClaims map[string]interface{}) *testForwardModeOIDCProvider { - t.Helper() - - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - - provider := &testForwardModeOIDCProvider{ - privateKey: privateKey, - keyID: "test-signing-key", - tokenResponse: tokenResponse, - userInfoClaims: userInfoClaims, - } - - mux := http.NewServeMux() - server := httptest.NewServer(mux) - provider.server = server - t.Cleanup(server.Close) - - mux.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNoContent) - }) - - mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - require.NoError(t, r.ParseForm()) - if provider.refreshHandler != nil && r.Form.Get("grant_type") == "refresh_token" { - status, body := provider.refreshHandler(r.Form) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - require.NoError(t, json.NewEncoder(w).Encode(body)) - return - } - w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(provider.tokenResponse)) - }) - - mux.HandleFunc("/jwks", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - keySet := jose.JSONWebKeySet{ - Keys: []jose.JSONWebKey{{ - Key: &privateKey.PublicKey, - KeyID: provider.keyID, - Use: "sig", - Algorithm: string(jose.RS256), - }}, - } - require.NoError(t, json.NewEncoder(w).Encode(keySet)) - }) - - mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { - provider.mu.Lock() - provider.userInfoCalls++ - provider.lastUserInfoAuth = r.Header.Get("Authorization") - provider.mu.Unlock() - - if provider.userInfoClaims == nil { - http.Error(w, "userinfo not configured", http.StatusNotFound) - return - } - - w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(provider.userInfoClaims)) - }) - - return provider -} - -func (p *testForwardModeOIDCProvider) issueIDToken(t *testing.T, claims map[string]interface{}) string { - t.Helper() - - signer, err := jose.NewSigner(jose.SigningKey{ - Algorithm: jose.RS256, - Key: jose.JSONWebKey{ - Key: p.privateKey, - KeyID: p.keyID, - Use: "sig", - Algorithm: string(jose.RS256), - }, - }, (&jose.SignerOptions{}).WithType("JWT")) - require.NoError(t, err) - - payload, err := json.Marshal(claims) - require.NoError(t, err) - - object, err := signer.Sign(payload) - require.NoError(t, err) - - token, err := object.CompactSerialize() - require.NoError(t, err) - - return token -} - -func (p *testForwardModeOIDCProvider) userInfoRequest() (int, string) { - p.mu.Lock() - defer p.mu.Unlock() - return p.userInfoCalls, p.lastUserInfoAuth -} - -func newForwardModeBrowserLoginTestApp(provider *testForwardModeOIDCProvider) *application { - cfg := config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Enabled: true, - Mode: "forward", - Issuer: provider.server.URL, - JWKSURL: provider.server.URL + "/jwks", - AuthURL: provider.server.URL + "/authorize", - TokenURL: provider.server.URL + "/token", - UserInfoURL: provider.server.URL + "/userinfo", - ClientID: "upstream-client-id", - ClientSecret: "upstream-client-secret", - Scopes: []string{"openid", "email"}, - SigningSecret: "test-gating-secret-32-byte-key!!", - }, - }, - } - - return &application{ - config: cfg, - mcpServer: altinitymcp.NewClickHouseMCPServer(cfg, "test"), - } -} - -func registerOAuthBrowserClient(t *testing.T, app *application, redirectURI string) string { - t.Helper() - - body := bytes.NewBufferString(fmt.Sprintf(`{"redirect_uris":["%s"],"token_endpoint_auth_method":"none"}`, redirectURI)) - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/register", body) - rr := httptest.NewRecorder() - app.handleOAuthRegister(rr, req) - require.Equal(t, http.StatusCreated, rr.Code) - - var reg map[string]interface{} - require.NoError(t, json.Unmarshal(rr.Body.Bytes(), ®)) - - clientID, ok := reg["client_id"].(string) - require.True(t, ok) - require.NotEmpty(t, clientID) - return clientID -} - -func startOAuthBrowserLogin(t *testing.T, app *application, clientID, redirectURI, clientState, codeVerifier string) string { - t.Helper() - - authReq := httptest.NewRequest( - http.MethodGet, - "https://mcp.example.com/oauth/authorize?response_type=code&client_id="+url.QueryEscape(clientID)+ - "&redirect_uri="+url.QueryEscape(redirectURI)+ - "&scope=openid+email&state="+url.QueryEscape(clientState)+ - "&code_challenge="+url.QueryEscape(pkceChallenge(codeVerifier))+ - "&code_challenge_method=S256", - nil, - ) - authRR := httptest.NewRecorder() - app.handleOAuthAuthorize(authRR, authReq) - require.Equal(t, http.StatusFound, authRR.Code) - - location, err := url.Parse(authRR.Header().Get("Location")) - require.NoError(t, err) - - state := location.Query().Get("state") - require.NotEmpty(t, state) - return state -} - -func exchangeOAuthBrowserCode(t *testing.T, app *application, clientID, code, redirectURI, codeVerifier string) *httptest.ResponseRecorder { - t.Helper() - - form := url.Values{} - form.Set("grant_type", "authorization_code") - form.Set("client_id", clientID) - form.Set("code", code) - form.Set("redirect_uri", redirectURI) - form.Set("code_verifier", codeVerifier) - - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/token", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - rr := httptest.NewRecorder() - app.handleOAuthToken(rr, req) - return rr -} - -func TestOAuthForwardModeBrowserLoginUsesUpstreamBearerToken(t *testing.T) { - t.Parallel() - const ( - redirectURI = "http://127.0.0.1:3334/callback" - codeVerifier = "test-code-verifier" - clientState = "client-state" - ) - - t.Run("access_token_and_id_token_prefers_id_token", func(t *testing.T) { - t.Parallel() - provider := newTestForwardModeOIDCProvider(t, map[string]interface{}{ - "access_token": "upstream-access-token", - "token_type": "Bearer", - "expires_in": 1800, - "scope": "openid email profile", - }, nil) - provider.tokenResponse["id_token"] = provider.issueIDToken(t, map[string]interface{}{ - "sub": "user-1", - "iss": provider.server.URL, - "aud": "upstream-client-id", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - "email": "user@example.com", - "email_verified": true, - }) - - app := newForwardModeBrowserLoginTestApp(provider) - clientID := registerOAuthBrowserClient(t, app, redirectURI) - state := startOAuthBrowserLogin(t, app, clientID, redirectURI, clientState, codeVerifier) - - callbackReq := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?code=upstream-auth-code&state="+url.QueryEscape(state), nil) - callbackRR := httptest.NewRecorder() - app.handleOAuthCallback(callbackRR, callbackReq) - require.Equal(t, http.StatusFound, callbackRR.Code) - - redirectLocation, err := url.Parse(callbackRR.Header().Get("Location")) - require.NoError(t, err) - require.Equal(t, clientState, redirectLocation.Query().Get("state")) - - tokenRR := exchangeOAuthBrowserCode(t, app, clientID, redirectLocation.Query().Get("code"), redirectURI, codeVerifier) - require.Equal(t, http.StatusOK, tokenRR.Code) - - var tokenResp map[string]interface{} - require.NoError(t, json.Unmarshal(tokenRR.Body.Bytes(), &tokenResp)) - // In forward mode, the raw upstream token is returned directly - require.Equal(t, provider.tokenResponse["id_token"], tokenResp["access_token"]) - require.Equal(t, "Bearer", tokenResp["token_type"]) - require.Equal(t, "openid email profile", tokenResp["scope"]) - // The bearer we forward is the id_token, so expires_in must reflect - // the id_token's exp (1h), NOT the upstream access_token's expires_in - // (30m). IdPs commonly return divergent lifetimes; using the wrong - // one means downstream MCP clients (Claude.ai) refresh too late and - // the bearer expires under them. - require.Greater(t, tokenResp["expires_in"].(float64), float64(3500)) - require.LessOrEqual(t, tokenResp["expires_in"].(float64), float64(3600)) - - userInfoCalls, userInfoAuth := provider.userInfoRequest() - require.Equal(t, 0, userInfoCalls) - require.Empty(t, userInfoAuth) - }) - - t.Run("access_token_only_uses_userinfo_and_returns_access_token", func(t *testing.T) { - t.Parallel() - provider := newTestForwardModeOIDCProvider(t, map[string]interface{}{ - "access_token": "opaque-access-token", - "token_type": "DPoP", - "expires_in": 900, - "scope": "openid email", - }, map[string]interface{}{ - "sub": "user-2", - "iss": "https://issuer.example.com", - "email": "user2@example.com", - "email_verified": true, - }) - - app := newForwardModeBrowserLoginTestApp(provider) - clientID := registerOAuthBrowserClient(t, app, redirectURI) - state := startOAuthBrowserLogin(t, app, clientID, redirectURI, clientState, codeVerifier) - - callbackReq := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?code=upstream-auth-code&state="+url.QueryEscape(state), nil) - callbackRR := httptest.NewRecorder() - app.handleOAuthCallback(callbackRR, callbackReq) - require.Equal(t, http.StatusFound, callbackRR.Code) - - redirectLocation, err := url.Parse(callbackRR.Header().Get("Location")) - require.NoError(t, err) - - tokenRR := exchangeOAuthBrowserCode(t, app, clientID, redirectLocation.Query().Get("code"), redirectURI, codeVerifier) - require.Equal(t, http.StatusOK, tokenRR.Code) - - var tokenResp map[string]interface{} - require.NoError(t, json.Unmarshal(tokenRR.Body.Bytes(), &tokenResp)) - // In forward mode, the raw upstream access token is returned directly - require.Equal(t, "opaque-access-token", tokenResp["access_token"]) - require.Equal(t, "DPoP", tokenResp["token_type"]) - require.Equal(t, "openid email", tokenResp["scope"]) - require.Greater(t, tokenResp["expires_in"].(float64), float64(0)) - require.LessOrEqual(t, tokenResp["expires_in"].(float64), float64(900)) - - userInfoCalls, userInfoAuth := provider.userInfoRequest() - require.Equal(t, 1, userInfoCalls) - require.Equal(t, "Bearer opaque-access-token", userInfoAuth) - }) - - t.Run("id_token_without_access_token_returns_id_token", func(t *testing.T) { - t.Parallel() - provider := newTestForwardModeOIDCProvider(t, map[string]interface{}{ - "token_type": "Bearer", - "expires_in": 900, - "scope": "openid email", - }, nil) - provider.tokenResponse["id_token"] = provider.issueIDToken(t, map[string]interface{}{ - "sub": "user-3", - "iss": provider.server.URL, - "aud": "upstream-client-id", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - "email": "user3@example.com", - "email_verified": true, - }) - - app := newForwardModeBrowserLoginTestApp(provider) - clientID := registerOAuthBrowserClient(t, app, redirectURI) - state := startOAuthBrowserLogin(t, app, clientID, redirectURI, clientState, codeVerifier) - - callbackReq := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?code=upstream-auth-code&state="+url.QueryEscape(state), nil) - callbackRR := httptest.NewRecorder() - app.handleOAuthCallback(callbackRR, callbackReq) - require.Equal(t, http.StatusFound, callbackRR.Code) - - redirectLocation, err := url.Parse(callbackRR.Header().Get("Location")) - require.NoError(t, err) - - tokenRR := exchangeOAuthBrowserCode(t, app, clientID, redirectLocation.Query().Get("code"), redirectURI, codeVerifier) - require.Equal(t, http.StatusOK, tokenRR.Code) - - var tokenResp map[string]interface{} - require.NoError(t, json.Unmarshal(tokenRR.Body.Bytes(), &tokenResp)) - // In forward mode, the raw upstream id_token is returned directly - require.Equal(t, provider.tokenResponse["id_token"], tokenResp["access_token"]) - require.Equal(t, "Bearer", tokenResp["token_type"]) - require.Equal(t, "openid email", tokenResp["scope"]) - - userInfoCalls, userInfoAuth := provider.userInfoRequest() - require.Equal(t, 0, userInfoCalls) - require.Empty(t, userInfoAuth) - }) -} - -// TestOAuthForwardModeTokenResourceMismatch pins the RFC 8707 §2.2 enforcement -// in forward mode: a /token (auth-code grant) request whose `resource` differs -// from the one already pinned at /authorize must be rejected with -// invalid_target, regardless of which mode we're running in. -func TestOAuthForwardModeTokenResourceMismatch(t *testing.T) { - t.Parallel() - const ( - redirectURI = "http://127.0.0.1:3334/callback" - codeVerifier = "test-code-verifier" - clientState = "client-state" - pinnedResource = "https://mcp.example.com" - otherResource = "https://attacker.example.com" - ) - - provider := newTestForwardModeOIDCProvider(t, map[string]interface{}{ - "access_token": "upstream-access-token", - "token_type": "Bearer", - "expires_in": 1800, - "scope": "openid email", - }, nil) - provider.tokenResponse["id_token"] = provider.issueIDToken(t, map[string]interface{}{ - "sub": "user-1", - "iss": provider.server.URL, - "aud": "upstream-client-id", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - "email": "user@example.com", - "email_verified": true, - }) - - app := newForwardModeBrowserLoginTestApp(provider) - clientID := registerOAuthBrowserClient(t, app, redirectURI) - - authReq := httptest.NewRequest( - http.MethodGet, - "https://mcp.example.com/oauth/authorize?response_type=code&client_id="+url.QueryEscape(clientID)+ - "&redirect_uri="+url.QueryEscape(redirectURI)+ - "&scope=openid+email&state="+url.QueryEscape(clientState)+ - "&code_challenge="+url.QueryEscape(pkceChallenge(codeVerifier))+ - "&code_challenge_method=S256"+ - "&resource="+url.QueryEscape(pinnedResource), - nil, - ) - authRR := httptest.NewRecorder() - app.handleOAuthAuthorize(authRR, authReq) - require.Equal(t, http.StatusFound, authRR.Code, "authorize must accept canonical resource") - - location, err := url.Parse(authRR.Header().Get("Location")) - require.NoError(t, err) - state := location.Query().Get("state") - require.NotEmpty(t, state) - - callbackReq := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?code=upstream-auth-code&state="+url.QueryEscape(state), nil) - callbackRR := httptest.NewRecorder() - app.handleOAuthCallback(callbackRR, callbackReq) - require.Equal(t, http.StatusFound, callbackRR.Code) - - redirectLocation, err := url.Parse(callbackRR.Header().Get("Location")) - require.NoError(t, err) - code := redirectLocation.Query().Get("code") - require.NotEmpty(t, code) - - exchange := func(t *testing.T, formResource string) *httptest.ResponseRecorder { - t.Helper() - form := url.Values{} - form.Set("grant_type", "authorization_code") - form.Set("client_id", clientID) - form.Set("code", code) - form.Set("redirect_uri", redirectURI) - form.Set("code_verifier", codeVerifier) - if formResource != "" { - form.Set("resource", formResource) - } - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/token", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - rr := httptest.NewRecorder() - app.handleOAuthToken(rr, req) - return rr - } - - t.Run("mismatched_resource_rejected", func(t *testing.T) { - rr := exchange(t, otherResource) - require.Equal(t, http.StatusBadRequest, rr.Code, "forward mode must reject /token resource that differs from /authorize") - var body map[string]interface{} - require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) - require.Equal(t, "invalid_target", body["error"]) - }) -} - -func generateOAuthTokenForApp(claims map[string]interface{}) (string, error) { - payload, err := json.Marshal(claims) - if err != nil { - return "", err - } - hashedSecret := jwe_auth.HashSHA256([]byte("test-gating-secret-32-byte-key!!")) - signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: hashedSecret}, (&jose.SignerOptions{}).WithType("JWT")) - if err != nil { - return "", err - } - object, err := signer.Sign(payload) - if err != nil { - return "", err - } - return object.CompactSerialize() -} - -func TestCanonicalResourceURL(t *testing.T) { - t.Parallel() - cases := []struct{ in, want string }{ - {"", ""}, - {" ", ""}, - {"https://mcp.example.com", "https://mcp.example.com/"}, - {"https://mcp.example.com/", "https://mcp.example.com/"}, - {"https://mcp.example.com//", "https://mcp.example.com/"}, - {" https://mcp.example.com ", "https://mcp.example.com/"}, - {"https://mcp.example.com/path", "https://mcp.example.com/path/"}, - {"https://mcp.example.com/path/", "https://mcp.example.com/path/"}, - } - for _, c := range cases { - require.Equal(t, c.want, canonicalResourceURL(c.in), "input=%q", c.in) - } -} - -// newJWEStateTestApp builds a minimal application wired with a SigningSecret -// for exercising the stateless JWE encode/decode helpers in isolation. -func newJWEStateTestApp(secret string) *application { - cfg := config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Enabled: true, - SigningSecret: secret, - }, - }, - } - return &application{config: cfg} -} - -func TestOAuthStateJWERoundTrip(t *testing.T) { - t.Parallel() - const secret = "test-jwe-state-secret-32-byte-key" - - t.Run("pending_auth_round_trip", func(t *testing.T) { - t.Parallel() - app := newJWEStateTestApp(secret) - want := oauthPendingAuth{ - ClientID: "cid", - RedirectURI: "https://client.example/cb", - Scope: "openid email", - ClientState: "abc", - CodeChallenge: "ch", - CodeChallengeMethod: "S256", - Resource: "https://mcp.example/", - UpstreamPKCEVerifier: "verifier", - ExpiresAt: time.Now().Add(10 * time.Minute).Truncate(time.Second), - } - tok, err := app.encodePendingAuth(want) - require.NoError(t, err) - require.NotEmpty(t, tok) - - got, ok := app.decodePendingAuth(tok) - require.True(t, ok) - require.Equal(t, want.ClientID, got.ClientID) - require.Equal(t, want.RedirectURI, got.RedirectURI) - require.Equal(t, want.Scope, got.Scope) - require.Equal(t, want.ClientState, got.ClientState) - require.Equal(t, want.CodeChallenge, got.CodeChallenge) - require.Equal(t, want.CodeChallengeMethod, got.CodeChallengeMethod) - require.Equal(t, want.Resource, got.Resource) - require.Equal(t, want.UpstreamPKCEVerifier, got.UpstreamPKCEVerifier) - require.Equal(t, want.ExpiresAt.Unix(), got.ExpiresAt.Unix()) - }) - - t.Run("auth_code_round_trip", func(t *testing.T) { - t.Parallel() - app := newJWEStateTestApp(secret) - want := oauthIssuedCode{ - ClientID: "cid", - RedirectURI: "https://client.example/cb", - Scope: "openid email", - CodeChallenge: "ch", - CodeChallengeMethod: "S256", - Resource: "https://mcp.example/", - UpstreamBearerToken: "upstream-bearer", - UpstreamRefreshToken: "upstream-refresh", - UpstreamTokenType: "Bearer", - Subject: "user-1", - Email: "u@example.com", - Name: "User", - HostedDomain: "example.com", - EmailVerified: true, - ExpiresAt: time.Now().Add(60 * time.Second).Truncate(time.Second), - AccessTokenExpiry: time.Now().Add(time.Hour).Truncate(time.Second), - } - tok, err := app.encodeAuthCode(want) - require.NoError(t, err) - - got, ok := app.decodeAuthCode(tok) - require.True(t, ok) - require.Equal(t, want, got) - }) - - t.Run("cross_pod_portable_with_shared_secret", func(t *testing.T) { - t.Parallel() - // Simulate two replicas: separate application instances, identical secret. - podA := newJWEStateTestApp(secret) - podB := newJWEStateTestApp(secret) - mintedOnA, err := podA.encodePendingAuth(oauthPendingAuth{ - ClientID: "cid", - ExpiresAt: time.Now().Add(10 * time.Minute), - }) - require.NoError(t, err) - got, ok := podB.decodePendingAuth(mintedOnA) - require.True(t, ok) - require.Equal(t, "cid", got.ClientID) - }) - - t.Run("cross_pod_rejected_with_different_secret", func(t *testing.T) { - t.Parallel() - podA := newJWEStateTestApp(secret) - podB := newJWEStateTestApp("a-different-secret-32-bytes-long!") - mintedOnA, err := podA.encodeAuthCode(oauthIssuedCode{ - ClientID: "cid", - ExpiresAt: time.Now().Add(60 * time.Second), - }) - require.NoError(t, err) - _, ok := podB.decodeAuthCode(mintedOnA) - require.False(t, ok, "JWE minted with a different secret must not decode") - }) - - t.Run("expired_auth_code_rejected", func(t *testing.T) { - t.Parallel() - app := newJWEStateTestApp(secret) - tok, err := app.encodeAuthCode(oauthIssuedCode{ - ClientID: "cid", - ExpiresAt: time.Now().Add(-1 * time.Second), - }) - require.NoError(t, err) - _, ok := app.decodeAuthCode(tok) - require.False(t, ok, "expired auth code must be rejected by JWE exp validation") - }) - - t.Run("expired_pending_auth_rejected", func(t *testing.T) { - t.Parallel() - app := newJWEStateTestApp(secret) - tok, err := app.encodePendingAuth(oauthPendingAuth{ - ClientID: "cid", - ExpiresAt: time.Now().Add(-1 * time.Second), - }) - require.NoError(t, err) - _, ok := app.decodePendingAuth(tok) - require.False(t, ok) - }) - - t.Run("tampered_token_rejected", func(t *testing.T) { - t.Parallel() - app := newJWEStateTestApp(secret) - tok, err := app.encodeAuthCode(oauthIssuedCode{ - ClientID: "cid", - ExpiresAt: time.Now().Add(60 * time.Second), - }) - require.NoError(t, err) - // Flip a byte in the JWE ciphertext. - bs := []byte(tok) - bs[len(bs)/2] ^= 0x01 - _, ok := app.decodeAuthCode(string(bs)) - require.False(t, ok) - }) - - t.Run("decode_missing_secret_fails_cleanly", func(t *testing.T) { - t.Parallel() - app := newJWEStateTestApp("") - _, ok := app.decodePendingAuth("anything") - require.False(t, ok) - _, ok = app.decodeAuthCode("anything") - require.False(t, ok) - }) -} - -// newGatingModeTestApp creates an application configured for gating mode OAuth. -func newGatingModeTestApp(provider *testForwardModeOIDCProvider) *application { - cfg := config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Enabled: true, - Mode: "gating", - Issuer: provider.server.URL, - JWKSURL: provider.server.URL + "/jwks", - AuthURL: provider.server.URL + "/authorize", - TokenURL: provider.server.URL + "/token", - UserInfoURL: provider.server.URL + "/userinfo", - ClientID: "upstream-client-id", - ClientSecret: "upstream-client-secret", - Scopes: []string{"openid", "email"}, - SigningSecret: "test-gating-secret-32-byte-key!!", - AccessTokenTTLSeconds: 300, - RefreshTokenTTLSeconds: 86400, - }, - }, - } - return &application{ - config: cfg, - mcpServer: altinitymcp.NewClickHouseMCPServer(cfg, "test"), - } -} - -// doGatingAuthCodeFlow runs the full authorize→callback→token exchange and -// returns the parsed token response. -func doGatingAuthCodeFlow(t *testing.T, app *application, provider *testForwardModeOIDCProvider, redirectURI, codeVerifier string) map[string]interface{} { - t.Helper() - - clientID := registerOAuthBrowserClient(t, app, redirectURI) - state := startOAuthBrowserLogin(t, app, clientID, redirectURI, "s", codeVerifier) - - callbackReq := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?code=upstream-auth-code&state="+url.QueryEscape(state), nil) - callbackRR := httptest.NewRecorder() - app.handleOAuthCallback(callbackRR, callbackReq) - require.Equal(t, http.StatusFound, callbackRR.Code) - - loc, err := url.Parse(callbackRR.Header().Get("Location")) - require.NoError(t, err) - - tokenRR := exchangeOAuthBrowserCode(t, app, clientID, loc.Query().Get("code"), redirectURI, codeVerifier) - require.Equal(t, http.StatusOK, tokenRR.Code) - - var resp map[string]interface{} - require.NoError(t, json.Unmarshal(tokenRR.Body.Bytes(), &resp)) - resp["_client_id"] = clientID // stash for refresh tests - return resp -} - -func exchangeRefreshToken(t *testing.T, app *application, clientID, refreshToken string) *httptest.ResponseRecorder { - t.Helper() - form := url.Values{} - form.Set("grant_type", "refresh_token") - form.Set("client_id", clientID) - form.Set("refresh_token", refreshToken) - - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/token", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - rr := httptest.NewRecorder() - app.handleOAuthToken(rr, req) - return rr -} - -func TestOAuthForwardModeNoRefreshToken(t *testing.T) { - t.Parallel() - const ( - redirectURI = "http://127.0.0.1:3334/callback" - codeVerifier = "test-code-verifier" - clientState = "cs" - ) - - provider := newTestForwardModeOIDCProvider(t, map[string]interface{}{ - "access_token": "upstream-access-token", - "token_type": "Bearer", - "expires_in": 1800, - "scope": "openid email", - }, nil) - provider.tokenResponse["id_token"] = provider.issueIDToken(t, map[string]interface{}{ - "sub": "user-1", - "iss": provider.server.URL, - "aud": "upstream-client-id", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - "email": "user@example.com", - "email_verified": true, - }) - - app := newForwardModeBrowserLoginTestApp(provider) - clientID := registerOAuthBrowserClient(t, app, redirectURI) - state := startOAuthBrowserLogin(t, app, clientID, redirectURI, clientState, codeVerifier) - - callbackReq := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?code=upstream-auth-code&state="+url.QueryEscape(state), nil) - callbackRR := httptest.NewRecorder() - app.handleOAuthCallback(callbackRR, callbackReq) - require.Equal(t, http.StatusFound, callbackRR.Code) - - loc, err := url.Parse(callbackRR.Header().Get("Location")) - require.NoError(t, err) - - tokenRR := exchangeOAuthBrowserCode(t, app, clientID, loc.Query().Get("code"), redirectURI, codeVerifier) - require.Equal(t, http.StatusOK, tokenRR.Code) - - var tokenResp map[string]interface{} - require.NoError(t, json.Unmarshal(tokenRR.Body.Bytes(), &tokenResp)) - _, hasRefresh := tokenResp["refresh_token"] - require.False(t, hasRefresh, "forward mode should NOT include refresh_token") -} - -// newForwardModeRefreshTestApp configures a forward-mode app with -// UpstreamOfflineAccess enabled, so the auth-code response carries a JWE -// refresh_token wrapping the upstream IdP's refresh token. -func newForwardModeRefreshTestApp(provider *testForwardModeOIDCProvider) *application { - cfg := config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Enabled: true, - Mode: "forward", - Issuer: provider.server.URL, - JWKSURL: provider.server.URL + "/jwks", - AuthURL: provider.server.URL + "/authorize", - TokenURL: provider.server.URL + "/token", - UserInfoURL: provider.server.URL + "/userinfo", - ClientID: "upstream-client-id", - ClientSecret: "upstream-client-secret", - Scopes: []string{"openid", "email"}, - UpstreamOfflineAccess: true, - SigningSecret: "test-gating-secret-32-byte-key!!", - RefreshTokenTTLSeconds: 86400, - }, - }, - } - return &application{ - config: cfg, - mcpServer: altinitymcp.NewClickHouseMCPServer(cfg, "test"), - } -} - -func TestOAuthForwardModeRefresh(t *testing.T) { - t.Parallel() - const ( - redirectURI = "http://127.0.0.1:3334/callback" - codeVerifier = "test-code-verifier-fwd-refresh" - clientState = "cs" - ) - - newProvider := func(t *testing.T) *testForwardModeOIDCProvider { - provider := newTestForwardModeOIDCProvider(t, map[string]interface{}{ - "access_token": "upstream-access-token", - "refresh_token": "upstream-refresh-token-original", - "token_type": "Bearer", - "expires_in": 1800, - "scope": "openid email offline_access", - }, nil) - provider.tokenResponse["id_token"] = provider.issueIDToken(t, map[string]interface{}{ - "sub": "user-1", - "iss": provider.server.URL, - "aud": "upstream-client-id", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - "email": "user@example.com", - "email_verified": true, - }) - - // Stateful refresh handler with strict single-use rotation: each refresh - // invalidates the inbound token and issues a new one. Models an IdP with - // refresh-token reuse detection enabled (e.g. Auth0 default). - validUpstreamRefresh := map[string]bool{"upstream-refresh-token-original": true} - rotation := 0 - var mu sync.Mutex - provider.refreshHandler = func(form url.Values) (int, map[string]interface{}) { - mu.Lock() - defer mu.Unlock() - inbound := form.Get("refresh_token") - if !validUpstreamRefresh[inbound] { - return http.StatusBadRequest, map[string]interface{}{"error": "invalid_grant"} - } - delete(validUpstreamRefresh, inbound) - rotation++ - next := fmt.Sprintf("upstream-refresh-token-rotated-%d", rotation) - validUpstreamRefresh[next] = true - newIDToken := provider.issueIDToken(t, map[string]interface{}{ - "sub": "user-1", - "iss": provider.server.URL, - "aud": "upstream-client-id", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Add(time.Duration(rotation) * time.Second).Unix(), - "email": "user@example.com", - "email_verified": true, - }) - return http.StatusOK, map[string]interface{}{ - "access_token": "upstream-access-token-r" + fmt.Sprint(rotation), - "id_token": newIDToken, - "refresh_token": next, - "token_type": "Bearer", - "expires_in": 1800, - "scope": "openid email offline_access", - } - } - return provider - } - - doInitialFlow := func(t *testing.T, app *application) (string, map[string]interface{}) { - t.Helper() - clientID := registerOAuthBrowserClient(t, app, redirectURI) - state := startOAuthBrowserLogin(t, app, clientID, redirectURI, clientState, codeVerifier) - callbackReq := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?code=upstream-auth-code&state="+url.QueryEscape(state), nil) - callbackRR := httptest.NewRecorder() - app.handleOAuthCallback(callbackRR, callbackReq) - require.Equal(t, http.StatusFound, callbackRR.Code) - loc, err := url.Parse(callbackRR.Header().Get("Location")) - require.NoError(t, err) - tokenRR := exchangeOAuthBrowserCode(t, app, clientID, loc.Query().Get("code"), redirectURI, codeVerifier) - require.Equal(t, http.StatusOK, tokenRR.Code) - var resp map[string]interface{} - require.NoError(t, json.Unmarshal(tokenRR.Body.Bytes(), &resp)) - return clientID, resp - } - - t.Run("auth_code_response_includes_refresh_token", func(t *testing.T) { - t.Parallel() - provider := newProvider(t) - app := newForwardModeRefreshTestApp(provider) - _, resp := doInitialFlow(t, app) - - require.Equal(t, provider.tokenResponse["id_token"], resp["access_token"], "access_token must remain the upstream ID token verbatim") - require.NotEmpty(t, resp["refresh_token"], "forward mode + UpstreamOfflineAccess must issue a refresh_token") - // MCP refresh token is the JWE wrapper, not the raw upstream refresh. - require.NotEqual(t, "upstream-refresh-token-original", resp["refresh_token"]) - // expires_in must reflect the id_token's actual exp (1h here), not the - // upstream access_token's expires_in (1800). MCP clients schedule - // proactive refresh from this value; using the access_token TTL when - // we forward the id_token causes downstream sessions to break at the - // real bearer expiry. - require.Greater(t, resp["expires_in"].(float64), float64(3500)) - require.LessOrEqual(t, resp["expires_in"].(float64), float64(3600)) - }) - - t.Run("refresh_grants_new_upstream_id_token_and_rotates", func(t *testing.T) { - t.Parallel() - provider := newProvider(t) - app := newForwardModeRefreshTestApp(provider) - clientID, resp := doInitialFlow(t, app) - - rr := exchangeRefreshToken(t, app, clientID, resp["refresh_token"].(string)) - require.Equal(t, http.StatusOK, rr.Code, "refresh response body: %s", rr.Body.String()) - - var refreshed map[string]interface{} - require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &refreshed)) - require.NotEmpty(t, refreshed["access_token"]) - // New access_token must be a freshly minted upstream ID token, not the - // original one returned at auth_code exchange. - require.NotEqual(t, resp["access_token"], refreshed["access_token"]) - // Refresh token rotates (new JWE wraps the rotated upstream refresh). - require.NotEmpty(t, refreshed["refresh_token"]) - require.NotEqual(t, resp["refresh_token"], refreshed["refresh_token"]) - require.Equal(t, "Bearer", refreshed["token_type"]) - // expires_in must reflect the rotated id_token's exp (1h), not the - // upstream access_token's expires_in (1800). Same rationale as the - // auth-code path above. - require.Greater(t, refreshed["expires_in"].(float64), float64(3500)) - require.LessOrEqual(t, refreshed["expires_in"].(float64), float64(3600)) - }) - - t.Run("idp_rotation_invalidates_rotated_out_mcp_refresh", func(t *testing.T) { - t.Parallel() - // MCP-side refresh tokens are stateless JWEs with no server-side - // reuse detection — the JWE itself stays decryptable until its exp. - // Security against replay therefore depends on the upstream IdP - // enforcing rotation. This test verifies that when the upstream IdP - // does enforce rotation (the production-recommended Auth0/Okta - // configuration), MCP correctly surfaces upstream's rejection as - // invalid_grant rather than silently issuing new tokens. - provider := newProvider(t) - app := newForwardModeRefreshTestApp(provider) - clientID, resp := doInitialFlow(t, app) - - // First refresh succeeds; upstream rotates the underlying refresh. - rr1 := exchangeRefreshToken(t, app, clientID, resp["refresh_token"].(string)) - require.Equal(t, http.StatusOK, rr1.Code, "first refresh should succeed: %s", rr1.Body.String()) - - // Second refresh with the original (now rotated-out) MCP refresh token: - // MCP decrypts the JWE successfully but the upstream IdP rejects the - // underlying refresh, and MCP must return invalid_grant. - rr2 := exchangeRefreshToken(t, app, clientID, resp["refresh_token"].(string)) - require.Equal(t, http.StatusBadRequest, rr2.Code) - require.Contains(t, rr2.Body.String(), "invalid_grant") - require.Contains(t, rr2.Body.String(), "upstream rejected the refresh token") - }) - - t.Run("malformed_refresh_token_rejected", func(t *testing.T) { - t.Parallel() - provider := newProvider(t) - app := newForwardModeRefreshTestApp(provider) - clientID, _ := doInitialFlow(t, app) - - rr := exchangeRefreshToken(t, app, clientID, "garbage-refresh-token") - require.Equal(t, http.StatusBadRequest, rr.Code) - require.Contains(t, rr.Body.String(), "invalid refresh token") - }) - - t.Run("rotating_gating_secret_invalidates_outstanding_refresh_tokens", func(t *testing.T) { - t.Parallel() - provider := newProvider(t) - app := newForwardModeRefreshTestApp(provider) - clientID, resp := doInitialFlow(t, app) - - // Rotate the symmetric secret used to encrypt the JWE. - app.config.Server.OAuth.SigningSecret = "different-secret-32-bytes-long!!" - app.mcpServer.Config.Server.OAuth.SigningSecret = "different-secret-32-bytes-long!!" - - // client_id is decrypted first in handleOAuthTokenRefresh, so a - // client_id JWE keyed by the prior secret fails before the refresh - // token is even inspected. - rr := exchangeRefreshToken(t, app, clientID, resp["refresh_token"].(string)) - require.Equal(t, http.StatusUnauthorized, rr.Code) - require.Contains(t, rr.Body.String(), "unknown OAuth client") - }) -} - -func TestOAuthAuthorizeOfflineAccessScope(t *testing.T) { - t.Parallel() - const ( - redirectURI = "http://127.0.0.1:3334/callback" - codeVerifier = "v" - ) - - scopeFromRedirect := func(t *testing.T, app *application) []string { - t.Helper() - clientID := registerOAuthBrowserClient(t, app, redirectURI) - authReq := httptest.NewRequest( - http.MethodGet, - "https://mcp.example.com/oauth/authorize?response_type=code&client_id="+url.QueryEscape(clientID)+ - "&redirect_uri="+url.QueryEscape(redirectURI)+ - "&scope=openid+email&state=cs"+ - "&code_challenge="+url.QueryEscape(pkceChallenge(codeVerifier))+ - "&code_challenge_method=S256", - nil, - ) - authRR := httptest.NewRecorder() - app.handleOAuthAuthorize(authRR, authReq) - require.Equal(t, http.StatusFound, authRR.Code) - loc, err := url.Parse(authRR.Header().Get("Location")) - require.NoError(t, err) - return strings.Fields(loc.Query().Get("scope")) - } - - t.Run("forward_mode_with_offline_access_appends_scope", func(t *testing.T) { - t.Parallel() - provider := newTestForwardModeOIDCProvider(t, map[string]interface{}{ - "access_token": "irrelevant", - "token_type": "Bearer", - }, nil) - app := newForwardModeRefreshTestApp(provider) - scopes := scopeFromRedirect(t, app) - require.Contains(t, scopes, "offline_access", "forward mode + UpstreamOfflineAccess must request offline_access upstream") - }) - - t.Run("forward_mode_without_offline_access_omits_scope", func(t *testing.T) { - t.Parallel() - provider := newTestForwardModeOIDCProvider(t, map[string]interface{}{ - "access_token": "irrelevant", - "token_type": "Bearer", - }, nil) - app := newForwardModeBrowserLoginTestApp(provider) - scopes := scopeFromRedirect(t, app) - require.NotContains(t, scopes, "offline_access", "default forward mode must not request offline_access") - }) - -} - -func TestOAuthRegistrationNegative(t *testing.T) { - t.Parallel() - provider := newTestForwardModeOIDCProvider(t, nil, nil) - app := newGatingModeTestApp(provider) - - post := func(body string) *httptest.ResponseRecorder { - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/register", strings.NewReader(body)) - rr := httptest.NewRecorder() - app.handleOAuthRegister(rr, req) - return rr - } - - t.Run("invalid_json", func(t *testing.T) { - t.Parallel() - rr := post("{broken") - require.Equal(t, http.StatusBadRequest, rr.Code) - }) - - t.Run("empty_redirect_uris", func(t *testing.T) { - t.Parallel() - rr := post(`{"redirect_uris":[]}`) - require.Equal(t, http.StatusBadRequest, rr.Code) - }) - - t.Run("http_non_localhost_redirect", func(t *testing.T) { - t.Parallel() - rr := post(`{"redirect_uris":["http://evil.com/cb"]}`) - require.Equal(t, http.StatusBadRequest, rr.Code) - }) - - t.Run("invalid_redirect_uri", func(t *testing.T) { - t.Parallel() - rr := post(`{"redirect_uris":["not-a-url"]}`) - require.Equal(t, http.StatusBadRequest, rr.Code) - }) - - t.Run("unsupported_auth_method", func(t *testing.T) { - t.Parallel() - // client_secret_post / client_secret_basic / none are now supported. - // Anything else (e.g. private_key_jwt) must still be rejected. - rr := post(`{"redirect_uris":["https://example.com/cb"],"token_endpoint_auth_method":"private_key_jwt"}`) - require.Equal(t, http.StatusBadRequest, rr.Code) - }) -} - -func TestOAuthAuthorizeNegative(t *testing.T) { - t.Parallel() - provider := newTestForwardModeOIDCProvider(t, nil, nil) - app := newGatingModeTestApp(provider) - redirectURI := "http://127.0.0.1:3334/callback" - clientID := registerOAuthBrowserClient(t, app, redirectURI) - - get := func(query string) *httptest.ResponseRecorder { - req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/authorize?"+query, nil) - rr := httptest.NewRecorder() - app.handleOAuthAuthorize(rr, req) - return rr - } - - t.Run("missing_client_id", func(t *testing.T) { - t.Parallel() - rr := get("redirect_uri=" + url.QueryEscape(redirectURI) + "&response_type=code&code_challenge=abc&code_challenge_method=S256") - require.Equal(t, http.StatusBadRequest, rr.Code) - }) - - t.Run("missing_redirect_uri", func(t *testing.T) { - t.Parallel() - rr := get("client_id=" + url.QueryEscape(clientID) + "&response_type=code&code_challenge=abc&code_challenge_method=S256") - require.Equal(t, http.StatusBadRequest, rr.Code) - }) - - t.Run("redirect_uri_mismatch", func(t *testing.T) { - t.Parallel() - rr := get("client_id=" + url.QueryEscape(clientID) + "&redirect_uri=" + url.QueryEscape("https://evil.com/cb") + "&response_type=code&code_challenge=abc&code_challenge_method=S256") - require.Equal(t, http.StatusBadRequest, rr.Code) - }) - - t.Run("missing_pkce_challenge", func(t *testing.T) { - t.Parallel() - rr := get("client_id=" + url.QueryEscape(clientID) + "&redirect_uri=" + url.QueryEscape(redirectURI) + "&response_type=code&code_challenge_method=S256") - require.Equal(t, http.StatusBadRequest, rr.Code) - }) - - t.Run("wrong_pkce_method", func(t *testing.T) { - t.Parallel() - rr := get("client_id=" + url.QueryEscape(clientID) + "&redirect_uri=" + url.QueryEscape(redirectURI) + "&response_type=code&code_challenge=abc&code_challenge_method=plain") - require.Equal(t, http.StatusBadRequest, rr.Code) - }) -} - -func TestOAuthCallbackNegative(t *testing.T) { +func TestOAuthMCPAuthInjector(t *testing.T) { t.Parallel() - provider := newTestForwardModeOIDCProvider(t, map[string]interface{}{ - "access_token": "upstream-access-token", - "token_type": "Bearer", - "expires_in": 1800, - "scope": "openid email", - }, nil) - provider.tokenResponse["id_token"] = provider.issueIDToken(t, map[string]interface{}{ - "sub": "user-1", - "iss": provider.server.URL, - "aud": "upstream-client-id", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - "email": "user@example.com", - "email_verified": true, - }) - app := newGatingModeTestApp(provider) - - t.Run("missing_state", func(t *testing.T) { - t.Parallel() - req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?code=some-code", nil) - rr := httptest.NewRecorder() - app.handleOAuthCallback(rr, req) - require.Equal(t, http.StatusBadRequest, rr.Code) - }) - t.Run("missing_code", func(t *testing.T) { - t.Parallel() - req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?state=some-state", nil) - rr := httptest.NewRecorder() - app.handleOAuthCallback(rr, req) - require.Equal(t, http.StatusBadRequest, rr.Code) - }) - - t.Run("unknown_pending_state", func(t *testing.T) { - t.Parallel() - req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?code=some-code&state=random-unknown-state", nil) - rr := httptest.NewRecorder() - app.handleOAuthCallback(rr, req) - require.Equal(t, http.StatusBadRequest, rr.Code) - }) - - t.Run("upstream_token_endpoint_500", func(t *testing.T) { - t.Parallel() - // Create a mock server that returns 500 from its token endpoint - errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/token" { - http.Error(w, "internal server error", http.StatusInternalServerError) - return - } - http.NotFound(w, r) - })) - defer errorServer.Close() - - errorApp := &application{ - config: config.Config{ - Server: config.ServerConfig{ - OAuth: config.OAuthConfig{ - Enabled: true, - Mode: "gating", - Issuer: errorServer.URL, - TokenURL: errorServer.URL + "/token", - AuthURL: errorServer.URL + "/authorize", - ClientID: "upstream-client-id", - ClientSecret: "upstream-client-secret", - Scopes: []string{"openid", "email"}, - SigningSecret: "test-gating-secret-32-byte-key!!", - AccessTokenTTLSeconds: 300, - RefreshTokenTTLSeconds: 86400, - }, + app := &application{ + config: config.Config{ + Server: config.ServerConfig{ + JWE: config.JWEConfig{ + Enabled: true, + JWESecretKey: "this-is-a-32-byte-secret-key!!", + JWTSecretKey: "jwt-secret", + }, + OAuth: config.OAuthConfig{ + Enabled: true, + Mode: "gating", + Issuer: "https://accounts.example.com", + PublicAuthServerURL: "https://mcp.example.com", + Audience: "https://mcp.example.com", + SigningSecret: "test-gating-secret-32-byte-key!!", }, }, - } - errorApp.mcpServer = altinitymcp.NewClickHouseMCPServer(errorApp.config, "test") + }, + mcpServer: altinitymcp.NewClickHouseMCPServer(config.Config{Server: config.ServerConfig{JWE: config.JWEConfig{Enabled: true, JWESecretKey: "this-is-a-32-byte-secret-key!!", JWTSecretKey: "jwt-secret"}, OAuth: config.OAuthConfig{ + Enabled: true, + Mode: "gating", + Issuer: "https://accounts.example.com", + PublicAuthServerURL: "https://mcp.example.com", + Audience: "https://mcp.example.com", + SigningSecret: "test-gating-secret-32-byte-key!!", + }}}, "test"), + } - redirectURI := "http://127.0.0.1:3334/callback" - clientID := registerOAuthBrowserClient(t, errorApp, redirectURI) - state := startOAuthBrowserLogin(t, errorApp, clientID, redirectURI, "s", "verifier") + jweToken, err := jwe_auth.GenerateJWEToken(map[string]interface{}{"host": "localhost", "port": 8123, "exp": time.Now().Add(time.Hour).Unix()}, []byte("this-is-a-32-byte-secret-key!!"), []byte("jwt-secret")) + require.NoError(t, err) + jweTokenWithCredentials, err := jwe_auth.GenerateJWEToken(map[string]interface{}{ + "host": "localhost", + "port": 8123, + "username": "default", + "password": "secret", + "exp": time.Now().Add(time.Hour).Unix(), + }, []byte("this-is-a-32-byte-secret-key!!"), []byte("jwt-secret")) + require.NoError(t, err) - req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?code=upstream-auth-code&state="+url.QueryEscape(state), nil) + t.Run("missing_oauth_gets_challenge", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/"+jweToken, nil) + req.SetPathValue("token", jweToken) rr := httptest.NewRecorder() - errorApp.handleOAuthCallback(rr, req) - require.Equal(t, http.StatusBadGateway, rr.Code) + handler := app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusUnauthorized, rr.Code) + require.Contains(t, rr.Header().Get("WWW-Authenticate"), "resource_metadata=") + require.Contains(t, rr.Header().Get("WWW-Authenticate"), "error=\"invalid_token\"") }) - t.Run("upstream_returns_empty_tokens", func(t *testing.T) { + t.Run("jwe_with_credentials_skips_oauth", func(t *testing.T) { t.Parallel() - emptyProvider := newTestForwardModeOIDCProvider(t, map[string]interface{}{ - "token_type": "Bearer", - "expires_in": 1800, - }, nil) - emptyApp := newGatingModeTestApp(emptyProvider) - - redirectURI := "http://127.0.0.1:3334/callback" - clientID := registerOAuthBrowserClient(t, emptyApp, redirectURI) - state := startOAuthBrowserLogin(t, emptyApp, clientID, redirectURI, "s", "verifier") - - req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?code=upstream-auth-code&state="+url.QueryEscape(state), nil) + req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/"+jweTokenWithCredentials, nil) + req.SetPathValue("token", jweTokenWithCredentials) rr := httptest.NewRecorder() - emptyApp.handleOAuthCallback(rr, req) - require.Equal(t, http.StatusBadGateway, rr.Code) + called := false + handler := app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + require.Equal(t, jweTokenWithCredentials, r.Context().Value(altinitymcp.JWETokenKey)) + require.Nil(t, r.Context().Value(altinitymcp.OAuthTokenKey)) + w.WriteHeader(http.StatusOK) + })) + handler.ServeHTTP(rr, req) + require.True(t, called) + require.Equal(t, http.StatusOK, rr.Code) }) } -func TestOAuthTokenExchangeNegative(t *testing.T) { +func TestOAuthMCPAuthInjectorForwardModePassesOpaqueBearerToken(t *testing.T) { t.Parallel() - provider := newTestForwardModeOIDCProvider(t, map[string]interface{}{ - "access_token": "upstream-access-token", - "token_type": "Bearer", - "expires_in": 1800, - "scope": "openid email", - }, nil) - provider.tokenResponse["id_token"] = provider.issueIDToken(t, map[string]interface{}{ - "sub": "user-1", - "iss": provider.server.URL, - "aud": "upstream-client-id", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - "email": "user@example.com", - "email_verified": true, - }) - app := newGatingModeTestApp(provider) - redirectURI := "http://127.0.0.1:3334/callback" - clientID := registerOAuthBrowserClient(t, app, redirectURI) - - postToken := func(form url.Values) *httptest.ResponseRecorder { - req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/token", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - rr := httptest.NewRecorder() - app.handleOAuthToken(rr, req) - return rr + token := "opaque-access-token" + app := &application{ + config: config.Config{ + Server: config.ServerConfig{ + OAuth: config.OAuthConfig{ + Enabled: true, + Mode: "forward", + }, + }, + }, + mcpServer: altinitymcp.NewClickHouseMCPServer(config.Config{ + Server: config.ServerConfig{ + OAuth: config.OAuthConfig{ + Enabled: true, + Mode: "forward", + }, + }, + }, "test"), } - t.Run("unknown_auth_code", func(t *testing.T) { - t.Parallel() - form := url.Values{} - form.Set("grant_type", "authorization_code") - form.Set("client_id", clientID) - form.Set("code", "random-unknown-code") - form.Set("redirect_uri", redirectURI) - form.Set("code_verifier", "test-verifier") - rr := postToken(form) - require.Equal(t, http.StatusBadRequest, rr.Code) - require.Contains(t, rr.Body.String(), "invalid_grant") - }) - - t.Run("redirect_uri_mismatch", func(t *testing.T) { - t.Parallel() - codeVerifier := "test-code-verifier-neg" - state := startOAuthBrowserLogin(t, app, clientID, redirectURI, "s", codeVerifier) + req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + called := false - callbackReq := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?code=upstream-auth-code&state="+url.QueryEscape(state), nil) - callbackRR := httptest.NewRecorder() - app.handleOAuthCallback(callbackRR, callbackReq) - require.Equal(t, http.StatusFound, callbackRR.Code) - loc, err := url.Parse(callbackRR.Header().Get("Location")) - require.NoError(t, err) + handler := app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + require.Equal(t, token, r.Context().Value(altinitymcp.OAuthTokenKey)) + require.Nil(t, r.Context().Value(altinitymcp.OAuthClaimsKey)) + w.WriteHeader(http.StatusOK) + })) - form := url.Values{} - form.Set("grant_type", "authorization_code") - form.Set("client_id", clientID) - form.Set("code", loc.Query().Get("code")) - form.Set("redirect_uri", "https://wrong.example.com/cb") - form.Set("code_verifier", codeVerifier) - rr := postToken(form) - require.Equal(t, http.StatusBadRequest, rr.Code) - require.Contains(t, rr.Body.String(), "invalid_grant") - }) + handler.ServeHTTP(rr, req) + require.True(t, called) + require.Equal(t, http.StatusOK, rr.Code) +} - t.Run("wrong_pkce_verifier", func(t *testing.T) { - t.Parallel() - codeVerifier := "correct-code-verifier" - state := startOAuthBrowserLogin(t, app, clientID, redirectURI, "s", codeVerifier) +// TestOAuthMCPAuthInjectorForwardModeValidatesJWT is the integration check +// for the C-1 fix: forward mode used to skip ValidateOAuthToken entirely, +// so any string in `Authorization: Bearer …` reached the inner handler +// and was forwarded to ClickHouse. After C-1 the auth layer validates JWT +// bearers when Issuer/JWKSURL is configured and rejects bad ones at 401. +func exchangeOAuthBrowserCode(t *testing.T, app *application, clientID, code, redirectURI, codeVerifier string) *httptest.ResponseRecorder { + t.Helper() - callbackReq := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/oauth/callback?code=upstream-auth-code&state="+url.QueryEscape(state), nil) - callbackRR := httptest.NewRecorder() - app.handleOAuthCallback(callbackRR, callbackReq) - require.Equal(t, http.StatusFound, callbackRR.Code) - loc, err := url.Parse(callbackRR.Header().Get("Location")) - require.NoError(t, err) + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", clientID) + form.Set("code", code) + form.Set("redirect_uri", redirectURI) + form.Set("code_verifier", codeVerifier) - form := url.Values{} - form.Set("grant_type", "authorization_code") - form.Set("client_id", clientID) - form.Set("code", loc.Query().Get("code")) - form.Set("redirect_uri", redirectURI) - form.Set("code_verifier", "wrong-verifier") - rr := postToken(form) - require.Equal(t, http.StatusBadRequest, rr.Code) - require.Contains(t, rr.Body.String(), "invalid_grant") - }) + req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + app.handleOAuthToken(rr, req) + return rr +} - t.Run("unsupported_grant_type", func(t *testing.T) { - t.Parallel() - form := url.Values{} - form.Set("grant_type", "client_credentials") - form.Set("client_id", clientID) - rr := postToken(form) - require.Equal(t, http.StatusBadRequest, rr.Code) - require.Contains(t, rr.Body.String(), "unsupported_grant_type") - }) +// TestOAuthForwardModeTokenResourceMismatch pins the RFC 8707 §2.2 enforcement +// in forward mode: a /token (auth-code grant) request whose `resource` differs +// from the one already pinned at /authorize must be rejected with +// invalid_target, regardless of which mode we're running in. +func generateOAuthTokenForApp(claims map[string]interface{}) (string, error) { + payload, err := json.Marshal(claims) + if err != nil { + return "", err + } + hashedSecret := jwe_auth.HashSHA256([]byte("test-gating-secret-32-byte-key!!")) + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: hashedSecret}, (&jose.SignerOptions{}).WithType("JWT")) + if err != nil { + return "", err + } + object, err := signer.Sign(payload) + if err != nil { + return "", err + } + return object.CompactSerialize() } -func TestOAuthMetadataAdvertisesRefreshToken(t *testing.T) { +func TestCanonicalResourceURL(t *testing.T) { t.Parallel() - provider := newTestForwardModeOIDCProvider(t, nil, nil) - app := newGatingModeTestApp(provider) - - for _, path := range []string{ - "/.well-known/oauth-authorization-server", - "/.well-known/openid-configuration", - } { - t.Run(path, func(t *testing.T) { - t.Parallel() - req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com"+path, nil) - rr := httptest.NewRecorder() - if strings.Contains(path, "openid") { - app.handleOAuthOpenIDConfiguration(rr, req) - } else { - app.handleOAuthAuthorizationServerMetadata(rr, req) - } - require.Equal(t, http.StatusOK, rr.Code) + cases := []struct{ in, want string }{ + {"", ""}, + {" ", ""}, + {"https://mcp.example.com", "https://mcp.example.com/"}, + {"https://mcp.example.com/", "https://mcp.example.com/"}, + {"https://mcp.example.com//", "https://mcp.example.com/"}, + {" https://mcp.example.com ", "https://mcp.example.com/"}, + {"https://mcp.example.com/path", "https://mcp.example.com/path/"}, + {"https://mcp.example.com/path/", "https://mcp.example.com/path/"}, + } + for _, c := range cases { + require.Equal(t, c.want, canonicalResourceURL(c.in), "input=%q", c.in) + } +} - var meta map[string]interface{} - require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &meta)) - grants, ok := meta["grant_types_supported"].([]interface{}) - require.True(t, ok) - var grantStrings []string - for _, g := range grants { - grantStrings = append(grantStrings, g.(string)) - } - require.Contains(t, grantStrings, "refresh_token") - }) +// newJWEStateTestApp builds a minimal application wired with a SigningSecret +// for exercising the stateless JWE encode/decode helpers in isolation. +func newJWEStateTestApp(secret string) *application { + cfg := config.Config{ + Server: config.ServerConfig{ + OAuth: config.OAuthConfig{ + Enabled: true, + SigningSecret: secret, + }, + }, } + return &application{config: cfg} +} + +// newGatingModeTestApp creates an application configured for gating mode OAuth. +// doGatingAuthCodeFlow runs the full authorize→callback→token exchange and +// returns the parsed token response. +func exchangeRefreshToken(t *testing.T, app *application, clientID, refreshToken string) *httptest.ResponseRecorder { + t.Helper() + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("client_id", clientID) + form.Set("refresh_token", refreshToken) + + req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + app.handleOAuthToken(rr, req) + return rr } +// newForwardModeRefreshTestApp configures a forward-mode app with +// UpstreamOfflineAccess enabled, so the auth-code response carries a JWE +// refresh_token wrapping the upstream IdP's refresh token. func TestNormalizedPath(t *testing.T) { t.Parallel() tests := []struct { @@ -2125,104 +414,6 @@ func TestDecodeStringSlice(t *testing.T) { }) } -func TestAuthenticateClientSecret(t *testing.T) { - t.Parallel() - - t.Run("public_client_legacy_no_secret_required", func(t *testing.T) { - t.Parallel() - // Backward compat: client_id JWEs issued before this change have no - // client_secret claim; they continue to work with PKCE only. - client := &statelessRegisteredClient{} - req := httptest.NewRequest(http.MethodPost, "/oauth/token", nil) - require.NoError(t, req.ParseForm()) - require.NoError(t, authenticateClientSecret(client, req)) - }) - - t.Run("confidential_client_secret_via_form", func(t *testing.T) { - t.Parallel() - client := &statelessRegisteredClient{ClientSecret: "abc123"} - req := httptest.NewRequest(http.MethodPost, "/oauth/token", - strings.NewReader("client_secret=abc123")) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - require.NoError(t, req.ParseForm()) - require.NoError(t, authenticateClientSecret(client, req)) - }) - - t.Run("confidential_client_secret_via_basic_auth", func(t *testing.T) { - t.Parallel() - client := &statelessRegisteredClient{ClientSecret: "abc123"} - req := httptest.NewRequest(http.MethodPost, "/oauth/token", nil) - req.SetBasicAuth("client-id-doesnt-matter", "abc123") - require.NoError(t, req.ParseForm()) - require.NoError(t, authenticateClientSecret(client, req)) - }) - - t.Run("confidential_client_secret_missing", func(t *testing.T) { - t.Parallel() - client := &statelessRegisteredClient{ClientSecret: "abc123"} - req := httptest.NewRequest(http.MethodPost, "/oauth/token", nil) - require.NoError(t, req.ParseForm()) - require.Error(t, authenticateClientSecret(client, req)) - }) - - t.Run("confidential_client_secret_mismatch", func(t *testing.T) { - t.Parallel() - client := &statelessRegisteredClient{ClientSecret: "abc123"} - req := httptest.NewRequest(http.MethodPost, "/oauth/token", - strings.NewReader("client_secret=wrong")) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - require.NoError(t, req.ParseForm()) - require.Error(t, authenticateClientSecret(client, req)) - }) -} - -func TestParseStatelessRegisteredClient(t *testing.T) { - t.Parallel() - t.Run("all_fields", func(t *testing.T) { - t.Parallel() - claims := map[string]interface{}{ - "redirect_uris": []interface{}{"https://example.com/callback"}, - "token_endpoint_auth_method": "client_secret_post", - "grant_type": "authorization_code", - "exp": float64(time.Now().Add(time.Hour).Unix()), - } - client, err := parseStatelessRegisteredClient(claims) - require.NoError(t, err) - require.Equal(t, []string{"https://example.com/callback"}, client.RedirectURIs) - require.Equal(t, "client_secret_post", client.TokenEndpointAuthMethod) - require.Equal(t, "authorization_code", client.GrantType) - }) - - t.Run("defaults_applied", func(t *testing.T) { - t.Parallel() - claims := map[string]interface{}{ - "redirect_uris": []interface{}{"https://example.com/callback"}, - } - client, err := parseStatelessRegisteredClient(claims) - require.NoError(t, err) - require.Equal(t, "none", client.TokenEndpointAuthMethod) - require.Equal(t, "authorization_code", client.GrantType) - }) - - t.Run("missing_redirect_uris", func(t *testing.T) { - t.Parallel() - claims := map[string]interface{}{} - _, err := parseStatelessRegisteredClient(claims) - require.Error(t, err) - require.Contains(t, err.Error(), "missing redirect URIs") - }) - - t.Run("empty_redirect_uris", func(t *testing.T) { - t.Parallel() - claims := map[string]interface{}{ - "redirect_uris": []interface{}{}, - } - _, err := parseStatelessRegisteredClient(claims) - require.Error(t, err) - require.Contains(t, err.Error(), "missing redirect URIs") - }) -} - func TestOAuthClaimsFromUserInfo(t *testing.T) { t.Parallel() t.Run("all_standard_fields", func(t *testing.T) { diff --git a/go.mod b/go.mod index 3dba1ac..2414d2d 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/rs/zerolog v1.35.1 github.com/stretchr/testify v1.11.1 github.com/urfave/cli/v3 v3.8.0 - golang.org/x/crypto v0.48.0 + golang.org/x/crypto v0.51.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -37,6 +37,8 @@ require ( go.opentelemetry.io/otel v1.43.0 // indirect go.opentelemetry.io/otel/trace v1.43.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/net v0.54.0 // indirect golang.org/x/oauth2 v0.36.0 // indirect - golang.org/x/sys v0.43.0 // indirect + golang.org/x/sys v0.44.0 // indirect + golang.org/x/text v0.37.0 // indirect ) diff --git a/go.sum b/go.sum index 91825b8..fe57bb8 100644 --- a/go.sum +++ b/go.sum @@ -70,10 +70,18 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= +golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= +golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/pkg/jwe_auth/jwe_auth.go b/pkg/jwe_auth/jwe_auth.go index e9eb7f7..f11b033 100644 --- a/pkg/jwe_auth/jwe_auth.go +++ b/pkg/jwe_auth/jwe_auth.go @@ -266,6 +266,7 @@ func validateClaimsWhitelist(claims map[string]interface{}) error { "upstream_refresh_token": true, "upstream_token_type": true, "upstream_pkce_verifier": true, + "upstream_auth_code": true, "resource": true, "access_token_exp": true, "email": true, From 03b19f6323655821bf6c136f76ee31f3dde5a17d Mon Sep 17 00:00:00 2001 From: Boris Tyshkevich Date: Fri, 15 May 2026 13:22:28 +0200 Subject: [PATCH 2/4] =?UTF-8?q?oauth:=20refactor=20CIMD=20inbound=20?= =?UTF-8?q?=E2=80=94=20adopt=20oauthex=20types,=20drop=20duplication?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Self-review follow-up to #116. No behaviour change; CIMD wire shape, JWE constants, HA replay semantics all unchanged. E2E reverified on otel-google-mcp with btyshkevich@gmail.com. - Adopt go-sdk's oauthex.ClientRegistrationResponse in parseCIMDMetadata instead of an inline anonymous struct. CIMD docs are the same JSON shape as RFC 7591 client registration; the SDK already types them with field- level documentation. Extra SDK fields (logo_uri, tos_uri, etc.) are ignored, same as before. Saves ~10 LOC; nothing on the wire changes. - Move resolveCIMDClient from package-level var-with-test-stub-hatch to a method on *application backed by a cimdResolver field. Drops the sync.Once singleton and the test-time var swap. Tests that need a fake resolver construct one via the existing testResolver helper and inject it through the application struct (oauth_ha_replay_test.go). - Trim three dead fields from statelessRegisteredClient (GrantType, ExpiresAt, ClientSecret) — all DCR-era; nothing reads them post-DCR. Now a two-field struct. - Simplify validateCIMDPath: collapse three overlapping dot-segment checks (raw segment / decoded segment / %2E expansion) into a single path.Clean(decoded) != decoded comparison, the exact formulation the issue calls for. Encoded slash / backslash check kept separate (path.Clean can't see them). Same reject-set, fewer code paths. - DRY the two .well-known handlers behind one oauthASMetadata helper; handleOAuthOpenIDConfiguration only tacks on its OIDC-specific extras. Eliminates field-by-field drift between RFC 8414 and OIDC discovery documents. - Drop the redundant isCIMDClientID prefix check at /authorize and /token. validateCIMDClientIDURL already enforces scheme==https inside the resolver, so the standalone helper was dead weight. - Replace the parallel ip.IsPrivate() + manual range arithmetic in isBlockedIP with a single audit-friendly []net.IPNet (ssrfBlockedCIDRs) parsed via net.ParseCIDR. One list to read, one list to maintain; same blocklist semantics. Adds an explicit comment naming the RFCs. - Small: apply truncateForLog(clientID, 80) to the three client_id log fields under /authorize, /callback, /token; drop _=resource in handleOAuthTokenAuthCode (replaced by a focused RFC 8707 mismatch check); thread the cache's logical clock through cimdCache.put for test consistency. Verification: go test ./... and go vet ./... green. Local image ghcr.io/altinity/altinity-mcp:cimd-115-refactor-dac3961-arm64 deployed to otel-google-mcp (forward mode); claude.ai connector cimd-rfx-703658 → whoami=btyshkevich@gmail.com, execute_query SELECT currentUser(), 1+1 → u=btyshkevich@gmail.com, two=2. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/altinity-mcp/cimd.go | 154 +++++++++++------------ cmd/altinity-mcp/main.go | 4 + cmd/altinity-mcp/oauth_ha_replay_test.go | 34 ++--- cmd/altinity-mcp/oauth_server.go | 106 +++++++--------- 4 files changed, 139 insertions(+), 159 deletions(-) diff --git a/cmd/altinity-mcp/cimd.go b/cmd/altinity-mcp/cimd.go index 873b475..52a5dd5 100644 --- a/cmd/altinity-mcp/cimd.go +++ b/cmd/altinity-mcp/cimd.go @@ -10,11 +10,13 @@ import ( "net" "net/http" "net/url" + "path" "strconv" "strings" "sync" "time" + "github.com/modelcontextprotocol/go-sdk/oauthex" "golang.org/x/net/idna" ) @@ -47,12 +49,6 @@ var ( errCIMDInvalidMetadata = errors.New("cimd: invalid metadata document") ) -// isCIMDClientID reports whether s should be resolved as a CIMD client_id URL. -// Anything else is rejected as an unknown client. -func isCIMDClientID(s string) bool { - return strings.HasPrefix(s, "https://") -} - // validateCIMDClientIDURL parses and validates a CIMD client_id URL against the // strict rules from issue #115 § "CIMD client identifier URL validation". // Returns the parsed URL on success; on failure the error wraps errCIMDInvalidURL. @@ -100,56 +96,71 @@ func validateCIMDClientIDURL(raw string) (*url.URL, error) { } // validateCIMDPath rejects dot-segments (raw or percent-encoded), encoded -// slashes, and encoded backslashes. Operates on the raw escaped path so we -// inspect what was actually requested rather than a normalized form. +// slashes, and encoded backslashes. The dot-segment test uses the issue #115 +// formulation: "Reject paths where applying standard dot-segment removal +// would change the path." `path.Clean` does exactly that. Encoded slashes and +// backslashes are checked separately because path.Clean can't see them — they +// would change segment boundaries after decoding, which is the attack we're +// blocking. func validateCIMDPath(rawPath string) error { if !strings.HasPrefix(rawPath, "/") { return fmt.Errorf("%w: path must start with /", errCIMDInvalidURL) } - for _, raw := range strings.Split(rawPath[1:], "/") { - decoded, err := url.PathUnescape(raw) + upper := strings.ToUpper(rawPath) + if strings.Contains(upper, "%2F") || strings.Contains(upper, "%5C") { + return fmt.Errorf("%w: encoded slash or backslash in path", errCIMDInvalidURL) + } + decoded, err := url.PathUnescape(rawPath) + if err != nil { + return fmt.Errorf("%w: invalid percent-encoding in path", errCIMDInvalidURL) + } + if path.Clean(decoded) != decoded { + return fmt.Errorf("%w: dot-segment in path", errCIMDInvalidURL) + } + return nil +} + +// ssrfBlockedCIDRs is the single audit-friendly list of address ranges we +// refuse to dial during CIMD metadata fetch. Covers IANA special-use IPv4 +// (RFC 6890) + IPv6 (RFC 6890 / RFC 4291): loopback, RFC 1918 private, +// link-local, CGNAT, "this network", reserved 192.0.0.0/24, multicast, IPv6 +// loopback / link-local / unique-local / multicast. +var ssrfBlockedCIDRs = mustParseCIDRs( + "127.0.0.0/8", + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "169.254.0.0/16", + "100.64.0.0/10", + "0.0.0.0/8", + "192.0.0.0/24", + "224.0.0.0/4", + "::1/128", + "fe80::/10", + "fc00::/7", + "ff00::/8", +) + +func mustParseCIDRs(cidrs ...string) []*net.IPNet { + out := make([]*net.IPNet, 0, len(cidrs)) + for _, c := range cidrs { + _, n, err := net.ParseCIDR(c) if err != nil { - return fmt.Errorf("%w: invalid percent-encoding in path segment", errCIMDInvalidURL) - } - if raw == "." || raw == ".." || decoded == "." || decoded == ".." { - return fmt.Errorf("%w: dot-segment in path", errCIMDInvalidURL) - } - upper := strings.ToUpper(raw) - if strings.Contains(upper, "%2F") || strings.Contains(upper, "%5C") { - return fmt.Errorf("%w: encoded slash or backslash in path", errCIMDInvalidURL) - } - // Catch %2E variants explicitly so url.PathUnescape's decoded form is - // not the only signal. - if strings.Contains(upper, "%2E") { - noEnc := strings.ReplaceAll(strings.ReplaceAll(upper, "%2E", "."), "%2e", ".") - if noEnc == "." || noEnc == ".." { - return fmt.Errorf("%w: encoded dot-segment in path", errCIMDInvalidURL) - } + panic(fmt.Sprintf("cimd: bad SSRF CIDR %q: %v", c, err)) } + out = append(out, n) } - return nil + return out } // isBlockedIP reports whether ip falls in a special-use range we must refuse -// to dial during CIMD metadata fetch. The blocklist covers RFC1918, loopback, -// link-local, multicast, unspecified, IPv6 ULA/loopback/link-local, CGNAT, and -// 0.0.0.0/8. +// to dial during CIMD metadata fetch. func isBlockedIP(ip net.IP) bool { - if ip == nil { - return true - } - if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || - ip.IsMulticast() || ip.IsUnspecified() || ip.IsPrivate() { + if ip == nil || ip.IsUnspecified() { return true } - if ip4 := ip.To4(); ip4 != nil { - if ip4[0] == 0 { // 0.0.0.0/8 - return true - } - if ip4[0] == 100 && ip4[1]&0xc0 == 64 { // CGNAT 100.64/10 - return true - } - if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 0 { // 192.0.0.0/24 + for _, n := range ssrfBlockedCIDRs { + if n.Contains(ip) { return true } } @@ -167,20 +178,6 @@ type cimdResolver struct { now func() time.Time } -var ( - defaultCIMDResolverOnce sync.Once - defaultCIMDResolver *cimdResolver -) - -// cimdDefaultResolver returns the package-level singleton resolver used by -// production handlers. Initialised lazily on first use. -func cimdDefaultResolver() *cimdResolver { - defaultCIMDResolverOnce.Do(func() { - defaultCIMDResolver = newCIMDResolver(nil) - }) - return defaultCIMDResolver -} - // newCIMDResolver constructs a resolver with an SSRF-safe http.Client. If // resolveIP is nil it uses net.DefaultResolver. func newCIMDResolver(resolveIP func(ctx context.Context, host string) ([]net.IP, error)) *cimdResolver { @@ -249,11 +246,11 @@ func (r *cimdResolver) ssrfSafeDial(ctx context.Context, network, addr string) ( return conn, nil } -// resolveCIMDClient is the package-level entry point used by oauth_server.go. -// Defined as a var so tests can swap in an in-process resolver pointed at an -// httptest.Server without doing real DNS. -var resolveCIMDClient = func(ctx context.Context, clientIDURL string) (*statelessRegisteredClient, error) { - return cimdDefaultResolver().resolve(ctx, clientIDURL) +// resolveCIMDClient is the entry point used by handlers. It delegates to the +// resolver owned by the application; tests construct the application with a +// resolver pointed at an in-process httptest.Server (see cimd_test.go). +func (a *application) resolveCIMDClient(ctx context.Context, clientIDURL string) (*statelessRegisteredClient, error) { + return a.cimdResolver.resolve(ctx, clientIDURL) } func (r *cimdResolver) resolve(ctx context.Context, clientIDURL string) (*statelessRegisteredClient, error) { @@ -269,11 +266,11 @@ func (r *cimdResolver) resolve(ctx context.Context, clientIDURL string) (*statel client, ttl, err := r.fetchAndValidate(ctx, clientIDURL) now := r.now() if err != nil { - r.cache.put(clientIDURL, &cimdCacheEntry{err: err, expiresAt: now.Add(cimdNegativeCacheTTL)}) + r.cache.put(clientIDURL, &cimdCacheEntry{err: err, expiresAt: now.Add(cimdNegativeCacheTTL)}, now) return nil, err } if ttl > 0 { - r.cache.put(clientIDURL, &cimdCacheEntry{client: client, expiresAt: now.Add(ttl)}) + r.cache.put(clientIDURL, &cimdCacheEntry{client: client, expiresAt: now.Add(ttl)}, now) } return client, nil } @@ -350,17 +347,14 @@ func extractMaxAge(cc string) time.Duration { // parseCIMDMetadata decodes the document and applies the schema rules from // issue #115 §"Metadata schema validation". Treats the body as untrusted. +// +// The wire shape of a CIMD document matches RFC 7591 §3.2.1 client registration +// response — same field names, types, and JSON tags — so we reuse the SDK's +// `oauthex.ClientRegistrationResponse` rather than maintaining a parallel +// struct. Extra fields the SDK knows about (logo_uri, tos_uri, jwks, etc.) are +// safely ignored because we don't read them. func parseCIMDMetadata(clientIDURL string, body []byte) (*statelessRegisteredClient, error) { - var doc struct { - ClientID string `json:"client_id"` - ClientName string `json:"client_name"` - RedirectURIs []string `json:"redirect_uris"` - TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` - GrantTypes []string `json:"grant_types"` - ResponseTypes []string `json:"response_types"` - ClientSecret string `json:"client_secret"` - ClientSecretExpiresAt *int64 `json:"client_secret_expires_at,omitempty"` - } + var doc oauthex.ClientRegistrationResponse dec := json.NewDecoder(strings.NewReader(string(body))) dec.UseNumber() if err := dec.Decode(&doc); err != nil { @@ -375,7 +369,7 @@ func parseCIMDMetadata(clientIDURL string, body []byte) (*statelessRegisteredCli if doc.ClientName == "" || len(doc.ClientName) > cimdMaxClientNameLength { return nil, fmt.Errorf("%w: client_name length out of range", errCIMDInvalidMetadata) } - if doc.ClientSecret != "" || doc.ClientSecretExpiresAt != nil { + if doc.ClientSecret != "" || !doc.ClientSecretExpiresAt.IsZero() { return nil, fmt.Errorf("%w: client_secret not allowed for CIMD public client", errCIMDInvalidMetadata) } if doc.TokenEndpointAuthMethod != "none" { @@ -429,7 +423,6 @@ func parseCIMDMetadata(clientIDURL string, body []byte) (*statelessRegisteredCli return &statelessRegisteredClient{ RedirectURIs: doc.RedirectURIs, TokenEndpointAuthMethod: "none", - GrantType: "authorization_code", }, nil } @@ -488,12 +481,15 @@ func (c *cimdCache) get(key string, now time.Time) (*cimdCacheEntry, bool) { } // put inserts/updates a cache entry. Negative entries do NOT override an -// existing unexpired positive entry (per issue #115 cache requirements). -func (c *cimdCache) put(key string, e *cimdCacheEntry) { +// existing unexpired positive entry (per issue #115 cache requirements). The +// now argument is the cache's logical clock — the caller passes the same +// value it uses for cache.get expiry, so put/get stay coherent under tests +// that fix time. +func (c *cimdCache) put(key string, e *cimdCacheEntry, now time.Time) { c.mu.Lock() defer c.mu.Unlock() if e.err != nil { - if existing, ok := c.entries[key]; ok && existing.err == nil && existing.expiresAt.After(time.Now()) { + if existing, ok := c.entries[key]; ok && existing.err == nil && existing.expiresAt.After(now) { return } } diff --git a/cmd/altinity-mcp/main.go b/cmd/altinity-mcp/main.go index 577ef32..3fcdd27 100644 --- a/cmd/altinity-mcp/main.go +++ b/cmd/altinity-mcp/main.go @@ -979,6 +979,9 @@ type application struct { configFile string configMutex sync.RWMutex stopConfigReload chan struct{} + // cimdResolver fetches and caches inbound CIMD client metadata documents. + // Constructed in newApplication; tests inject an alternative resolver. + cimdResolver *cimdResolver } // setHTTPServer sets the HTTP server with proper synchronization @@ -1055,6 +1058,7 @@ func newApplication(ctx context.Context, cfg config.Config, cmd CommandInterface mcpServer: mcpServer, configFile: cmd.String("config"), stopConfigReload: make(chan struct{}), + cimdResolver: newCIMDResolver(nil), } // Start config reload goroutine if enabled diff --git a/cmd/altinity-mcp/oauth_ha_replay_test.go b/cmd/altinity-mcp/oauth_ha_replay_test.go index 7236268..1cacbc5 100644 --- a/cmd/altinity-mcp/oauth_ha_replay_test.go +++ b/cmd/altinity-mcp/oauth_ha_replay_test.go @@ -1,8 +1,8 @@ package main import ( - "context" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" @@ -69,20 +69,19 @@ func TestHAReplay_UpstreamInvalidGrantOnReplay(t *testing.T) { })) defer upstream.Close() - // Stub CIMD resolver: skip real network and return a client allowing the - // downstream redirect URI. - origResolver := resolveCIMDClient - t.Cleanup(func() { resolveCIMDClient = origResolver }) - resolveCIMDClient = func(_ context.Context, raw string) (*statelessRegisteredClient, error) { - if raw != downstreamClient { - return nil, errCIMDInvalidURL - } - return &statelessRegisteredClient{ - RedirectURIs: []string{downstreamRedir}, - TokenEndpointAuthMethod: "none", - GrantType: "authorization_code", - }, nil - } + // Spin up a TLS httptest server serving the CIMD metadata document for the + // downstream client_id URL. We point the resolver's transport at it via + // testResolver (which keeps the rest of the parse/cache/SSRF logic alive). + cimdServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{ + "client_id": %q, + "client_name": "Demo", + "redirect_uris": [%q], + "token_endpoint_auth_method": "none" + }`, downstreamClient, downstreamRedir) + })) + defer cimdServer.Close() cfg := config.Config{ Server: config.ServerConfig{ @@ -104,8 +103,9 @@ func TestHAReplay_UpstreamInvalidGrantOnReplay(t *testing.T) { }, } app := &application{ - config: cfg, - mcpServer: altinitymcp.NewClickHouseMCPServer(cfg, "test"), + config: cfg, + mcpServer: altinitymcp.NewClickHouseMCPServer(cfg, "test"), + cimdResolver: testResolver(t, cimdServer), } // Build a valid downstream auth code JWE by exercising encodeAuthCode. diff --git a/cmd/altinity-mcp/oauth_server.go b/cmd/altinity-mcp/oauth_server.go index 3d1a162..d91e700 100644 --- a/cmd/altinity-mcp/oauth_server.go +++ b/cmd/altinity-mcp/oauth_server.go @@ -47,13 +47,6 @@ const ( type statelessRegisteredClient struct { RedirectURIs []string `json:"redirect_uris"` TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` - GrantType string `json:"grant_type"` - ExpiresAt int64 `json:"exp"` - // ClientSecret is the per-registration secret issued during DCR for - // confidential clients (token_endpoint_auth_method: client_secret_post | - // client_secret_basic). When empty, the client is public (PKCE-only) — - // retained for backward compat with previously-issued client_ids. - ClientSecret string `json:"client_secret,omitempty"` } type oauthPendingAuth struct { @@ -971,30 +964,37 @@ func (a *application) handleOAuthProtectedResource(w http.ResponseWriter, r *htt _ = json.NewEncoder(w).Encode(resp) } +// oauthASMetadata returns the field set shared by RFC 8414 (oauth-authorization-server) +// and OIDC Discovery (openid-configuration). Both endpoints serve the same +// AS-side advertisement; OIDC adds two extra fields under gating mode (see +// handleOAuthOpenIDConfiguration). +// +// `issuer` is published without a trailing slash to match RFC 8414 §2 +// (issuer == authorization_servers[i] in the resource document). The /token +// response mints `iss` in the same form and validateOAuthClaims normalises +// slashes defensively. +func (a *application) oauthASMetadata(r *http.Request) map[string]interface{} { + baseURL := a.oauthAuthorizationServerBaseURL(r) + return map[string]interface{}{ + "issuer": strings.TrimRight(baseURL, "/"), + "authorization_endpoint": joinURLPath(baseURL, a.oauthAuthorizationPath()), + "token_endpoint": joinURLPath(baseURL, a.oauthTokenPath()), + "scopes_supported": oidcScopesForAdvertisement(a.GetCurrentConfig().Server.OAuth), + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code"}, + "token_endpoint_auth_methods_supported": []string{"none"}, + "code_challenge_methods_supported": []string{"S256"}, + "client_id_metadata_document_supported": true, + } +} + func (a *application) handleOAuthAuthorizationServerMetadata(w http.ResponseWriter, r *http.Request) { if !a.oauthEnabled() { http.NotFound(w, r) return } - baseURL := a.oauthAuthorizationServerBaseURL(r) - // `issuer` is published without a trailing slash to match the RFC 8414 §2 - // convention (issuer == authorization_servers[i] in the resource document). - // mintGatingTokenResponse mints `iss` in the same form, and - // validateOAuthClaims still normalises slashes defensively. - issuer := strings.TrimRight(baseURL, "/") - resp := map[string]interface{}{ - "issuer": issuer, - "authorization_endpoint": joinURLPath(baseURL, a.oauthAuthorizationPath()), - "token_endpoint": joinURLPath(baseURL, a.oauthTokenPath()), - "scopes_supported": oidcScopesForAdvertisement(a.GetCurrentConfig().Server.OAuth), - "response_types_supported": []string{"code"}, - "grant_types_supported": []string{"authorization_code"}, - "token_endpoint_auth_methods_supported": []string{"none"}, - "code_challenge_methods_supported": []string{"S256"}, - "client_id_metadata_document_supported": true, - } w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(a.oauthASMetadata(r)) } func (a *application) handleOAuthOpenIDConfiguration(w http.ResponseWriter, r *http.Request) { @@ -1002,19 +1002,7 @@ func (a *application) handleOAuthOpenIDConfiguration(w http.ResponseWriter, r *h http.NotFound(w, r) return } - baseURL := a.oauthAuthorizationServerBaseURL(r) - issuer := strings.TrimRight(baseURL, "/") - resp := map[string]interface{}{ - "issuer": issuer, - "authorization_endpoint": joinURLPath(baseURL, a.oauthAuthorizationPath()), - "token_endpoint": joinURLPath(baseURL, a.oauthTokenPath()), - "scopes_supported": oidcScopesForAdvertisement(a.GetCurrentConfig().Server.OAuth), - "response_types_supported": []string{"code"}, - "grant_types_supported": []string{"authorization_code"}, - "token_endpoint_auth_methods_supported": []string{"none"}, - "code_challenge_methods_supported": []string{"S256"}, - "client_id_metadata_document_supported": true, - } + resp := a.oauthASMetadata(r) if !a.oauthForwardMode() { resp["subject_types_supported"] = []string{"public"} resp["id_token_signing_alg_values_supported"] = []string{"HS256"} @@ -1040,17 +1028,13 @@ func (a *application) handleOAuthAuthorize(w http.ResponseWriter, r *http.Reques return } // CIMD inbound (#115): client_id is the HTTPS URL of the MCP client's - // metadata document. resolveCIMDClient validates the URL, fetches the - // document under SSRF-safe constraints, and synthesises the registered - // client. DCR was removed in the same change; non-https client_ids are - // rejected as unknown. - if !isCIMDClientID(clientID) { - http.Error(w, "Unknown OAuth client", http.StatusBadRequest) - return - } - client, err := resolveCIMDClient(r.Context(), clientID) + // metadata document. The resolver validates the URL, fetches the document + // under SSRF-safe constraints, and synthesises the registered client. DCR + // was removed in the same change; non-https client_ids are rejected as + // invalid URLs by validateCIMDClientIDURL inside the resolver. + client, err := a.resolveCIMDClient(r.Context(), clientID) if err != nil { - log.Debug().Err(err).Str("client_id", clientID).Msg("OAuth /authorize rejected: CIMD resolution failed") + log.Debug().Err(err).Str("client_id", truncateForLog(clientID, 80)).Msg("OAuth /authorize rejected: CIMD resolution failed") http.Error(w, "Unknown OAuth client", http.StatusBadRequest) return } @@ -1174,7 +1158,7 @@ func (a *application) handleOAuthCallback(w http.ResponseWriter, r *http.Request } log.Info(). - Str("client_id", pending.ClientID). + Str("client_id", truncateForLog(pending.ClientID, 80)). Bool("forward_mode", a.oauthForwardMode()). Msg("OAuth /callback wrapped upstream auth code in downstream JWE; awaiting /token redemption") @@ -1225,19 +1209,15 @@ func (a *application) handleOAuthToken(w http.ResponseWriter, r *http.Request) { func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Request) { clientID := r.Form.Get("client_id") - if !isCIMDClientID(clientID) { - writeOAuthTokenError(w, http.StatusUnauthorized, "invalid_client", "unknown OAuth client") - return - } // Public CIMD clients reject any client_secret / client_assertion on /token // per RFC 7591 token_endpoint_auth_method=none + CIMD spec. if r.Form.Get("client_secret") != "" || r.Form.Get("client_assertion") != "" { writeOAuthTokenError(w, http.StatusUnauthorized, "invalid_client", "client authentication not supported for public CIMD clients") return } - client, err := resolveCIMDClient(r.Context(), clientID) + client, err := a.resolveCIMDClient(r.Context(), clientID) if err != nil { - log.Debug().Err(err).Str("client_id", clientID).Msg("OAuth /token rejected: CIMD resolution failed") + log.Debug().Err(err).Str("client_id", truncateForLog(clientID, 80)).Msg("OAuth /token rejected: CIMD resolution failed") writeOAuthTokenError(w, http.StatusUnauthorized, "invalid_client", "unknown OAuth client") return } @@ -1272,16 +1252,16 @@ func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Re // RFC 8707 §2.2: when `resource` was pinned at /authorize, /token must // match. When /authorize omitted it but /token includes one, accept and // use the latter for downstream advisory only. - resource := issued.Resource - if formResource := r.Form.Get("resource"); formResource != "" { - if resource == "" { - resource = formResource - } else if strings.TrimRight(formResource, "/") != strings.TrimRight(resource, "/") { + // RFC 8707 §2.2 cross-check between the resource pinned at /authorize and + // the one (optionally) re-sent at /token. Mismatch → invalid_target. v1 + // doesn't otherwise act on the resource value (audience binding is a + // separate issue). + if formResource := r.Form.Get("resource"); formResource != "" && issued.Resource != "" { + if strings.TrimRight(formResource, "/") != strings.TrimRight(issued.Resource, "/") { writeOAuthTokenError(w, http.StatusBadRequest, "invalid_target", "resource indicator does not match the one used at /authorize") return } } - _ = resource // HA replay model: redeem the upstream auth code with the upstream IdP // *now*, not at /callback. The upstream IdP's `invalid_grant` on a second @@ -1330,7 +1310,7 @@ func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Re Int("status", upstreamResp.StatusCode). Str("upstream_error", errCode). Int("body_len", bodyLen). - Str("client_id", clientID). + Str("client_id", truncateForLog(clientID, 80)). Msg("OAuth /token: upstream code exchange rejected — likely replay") // Map upstream invalid_grant (replay-detected, expired, already used) // to a downstream invalid_grant per RFC 6749 §5.2. @@ -1360,7 +1340,7 @@ func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Re Bool("forward_mode", a.oauthForwardMode()). Str("scope", tokenResp.Scope). Int64("expires_in", tokenResp.ExpiresIn). - Str("client_id", clientID). + Str("client_id", truncateForLog(clientID, 80)). Msg("OAuth /token: upstream code exchange succeeded") var identityClaims *altinitymcp.OAuthClaims From 6f8bbed2c2f82cbc3ead3c4399e40a380dcdf38f Mon Sep 17 00:00:00 2001 From: Boris Tyshkevich Date: Fri, 15 May 2026 15:00:19 +0200 Subject: [PATCH 3/4] oauth/cimd: address self-review nits (max-age=0 bug + 8 follow-ups) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Self-review notes returned nine concrete findings; tightening each. 1. **max-age=0 bug (real)**. extractMaxAge returned 0 for max-age=0, which was then treated as "directive absent" and fell through to the 5-minute default TTL — opposite of RFC 7234 semantics. Replaced with a proper directive-aware parser cacheTTLFromHeader: max-age=0 (and negative) correctly returns ttl=0, treated identically to no-store / no-cache. Regression test TestCIMDResolve_MaxAgeZeroSkipsCache plus a TestCacheTTLFromHeader matrix. 2. **Cache-Control parsing too loose**. strings.Contains on the lowercased header would have matched x-custom-no-storage as no-store. Same new parser does directive-level matching: it splits on ',', trims, and compares each directive exactly. Test rows in TestCacheTTLFromHeader cover x-custom-no-storage and "no-storage" alone — both fall through to the default TTL as expected. 3. **_ = identityClaims looked like dead code**. Replaced with a comment explaining that the validation has a 502-side-effect that's the whole point, and that audience binding is deferred to a follow-up — so the line must not be pruned without re-introducing claim binding. 4. **FIFO/LRU misnomer**. Cache evicts oldest-inserted and never reorders on get — that's FIFO. Comment + the cap field renamed accordingly; added a sentence on why FIFO is fine here (cap >> unique CIMD URLs). 5. **_ = host vestige** in testResolver removed. 6. **Post-dial recheck comment** clarified — explicit-IP dial means the recheck is defense against future refactors, not active rebinding protection in the current code path. 7. **cap → capacity rename** on cimdCache to stop shadowing the builtin. 8. **refresh_token tolerate-but-ignore** now has a comment in parseCIMDMetadata warning future refresh-token implementers that the CIMD grant_types array is NOT authoritative for what we issue — the .well-known AS metadata is. 9. **SSRF blocklist extended** to the IANA Special-Purpose Address registries (RFC 6890): added 192.0.2.0/24 (TEST-NET-1), 198.18.0.0/15 (benchmarking), 198.51.100.0/24 (TEST-NET-2), 203.0.113.0/24 (TEST-NET-3), 240.0.0.0/4 (reserved, includes 255.255.255.255), 2001:db8::/32 (IPv6 docs), 64:ff9b::/96 (IPv4/IPv6 translation), 100::/64 (IPv6 discard). Each entry annotated with its RFC. Tests in TestIsBlockedIP extended to cover all new ranges. Also: TestCIMDResolve_DefaultTTLWhenNoDirectives pins the 5-minute default-TTL path with a Cache-Control: private header (neither no-store nor max-age). All tests + go vet green. No behaviour change beyond the max-age=0 bug fix; no API or wire change. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/altinity-mcp/cimd.go | 148 ++++++++++++++++++++----------- cmd/altinity-mcp/cimd_test.go | 98 +++++++++++++++++++- cmd/altinity-mcp/oauth_server.go | 9 +- 3 files changed, 198 insertions(+), 57 deletions(-) diff --git a/cmd/altinity-mcp/cimd.go b/cmd/altinity-mcp/cimd.go index 52a5dd5..a77e3ee 100644 --- a/cmd/altinity-mcp/cimd.go +++ b/cmd/altinity-mcp/cimd.go @@ -121,24 +121,34 @@ func validateCIMDPath(rawPath string) error { } // ssrfBlockedCIDRs is the single audit-friendly list of address ranges we -// refuse to dial during CIMD metadata fetch. Covers IANA special-use IPv4 -// (RFC 6890) + IPv6 (RFC 6890 / RFC 4291): loopback, RFC 1918 private, -// link-local, CGNAT, "this network", reserved 192.0.0.0/24, multicast, IPv6 -// loopback / link-local / unique-local / multicast. +// refuse to dial during CIMD metadata fetch. Tracks the IANA special-purpose +// address registry (RFC 6890 + IPv6 registry RFC 8190). Comments give the +// RFC and human name for each entry so future audits can read it +// linearly. var ssrfBlockedCIDRs = mustParseCIDRs( - "127.0.0.0/8", - "10.0.0.0/8", - "172.16.0.0/12", - "192.168.0.0/16", - "169.254.0.0/16", - "100.64.0.0/10", - "0.0.0.0/8", - "192.0.0.0/24", - "224.0.0.0/4", - "::1/128", - "fe80::/10", - "fc00::/7", - "ff00::/8", + // IPv4 — IANA IPv4 Special-Purpose Address Registry + "0.0.0.0/8", // RFC 1122 — "this network" + "10.0.0.0/8", // RFC 1918 — private + "100.64.0.0/10", // RFC 6598 — Carrier-Grade NAT + "127.0.0.0/8", // RFC 1122 — loopback + "169.254.0.0/16", // RFC 3927 — link-local + "172.16.0.0/12", // RFC 1918 — private + "192.0.0.0/24", // RFC 6890 — IETF protocol assignments + "192.0.2.0/24", // RFC 5737 — TEST-NET-1 (documentation) + "192.168.0.0/16", // RFC 1918 — private + "198.18.0.0/15", // RFC 2544 — benchmarking + "198.51.100.0/24", // RFC 5737 — TEST-NET-2 (documentation) + "203.0.113.0/24", // RFC 5737 — TEST-NET-3 (documentation) + "224.0.0.0/4", // RFC 5771 — multicast + "240.0.0.0/4", // RFC 1112 — reserved (includes 255.255.255.255 broadcast) + // IPv6 — IANA IPv6 Special-Purpose Address Registry + "::1/128", // RFC 4291 — loopback + "64:ff9b::/96", // RFC 6052 — IPv4/IPv6 translation + "100::/64", // RFC 6666 — discard prefix + "2001:db8::/32", // RFC 3849 — documentation + "fc00::/7", // RFC 4193 — unique local + "fe80::/10", // RFC 4291 — link-local + "ff00::/8", // RFC 4291 — multicast ) func mustParseCIDRs(cidrs ...string) []*net.IPNet { @@ -212,6 +222,12 @@ func newCIMDResolver(resolveIP func(ctx context.Context, host string) ([]net.IP, // ssrfSafeDial resolves the host explicitly, pins the dial to a validated IP, // and re-checks the connected remote address before returning. +// +// Why the post-dial check is essentially belt-and-suspenders here: we dial +// JoinHostPort(pinned.String(), port) — an explicit IP literal — so the +// resolver cannot rebind to a different address. The re-check survives only +// as defense against future refactors that swap the dial target back to the +// hostname (e.g. for SNI symmetry). Cheap; keep it. func (r *cimdResolver) ssrfSafeDial(ctx context.Context, network, addr string) (net.Conn, error) { host, port, err := net.SplitHostPort(addr) if err != nil { @@ -309,40 +325,53 @@ func (r *cimdResolver) fetchAndValidate(ctx context.Context, clientIDURL string) if err != nil { return nil, 0, err } - ttl := cimdDefaultCacheTTL - if cc := resp.Header.Get("Cache-Control"); cc != "" { - lc := strings.ToLower(cc) - switch { - case strings.Contains(lc, "no-store"): - ttl = 0 - case strings.Contains(lc, "no-cache"): - ttl = 0 - default: - if ma := extractMaxAge(lc); ma > 0 { - if ma > cimdMaxCacheTTL { - ma = cimdMaxCacheTTL - } - ttl = ma - } - } - } - return client, ttl, nil + return client, cacheTTLFromHeader(resp.Header.Get("Cache-Control")), nil } -func extractMaxAge(cc string) time.Duration { - for _, p := range strings.Split(cc, ",") { - p = strings.TrimSpace(p) - if !strings.HasPrefix(p, "max-age=") { +// cacheTTLFromHeader maps the response's Cache-Control header to a positive +// cache TTL or zero (do-not-cache). Returns cimdDefaultCacheTTL when the +// header is absent or carries no relevant directives, capped at +// cimdMaxCacheTTL. +// +// Semantics per RFC 7234 §5.2: +// - no-store / no-cache → ttl = 0 (do not reuse from cache) +// - max-age=0 or negative → ttl = 0 (RFC 7234 treats negative as 0) +// - max-age=N → ttl = min(N, cap) +// - none of the above → ttl = default +// +// Directive matching is exact: a stray substring like "x-no-storage" does +// NOT trigger no-store. +func cacheTTLFromHeader(cc string) time.Duration { + if cc == "" { + return cimdDefaultCacheTTL + } + maxAge := time.Duration(-1) // sentinel: directive absent + for _, raw := range strings.Split(cc, ",") { + directive := strings.TrimSpace(strings.ToLower(raw)) + if directive == "" { continue } - v := strings.TrimPrefix(p, "max-age=") - n, err := strconv.Atoi(v) - if err != nil || n <= 0 { + switch { + case directive == "no-store" || directive == "no-cache": return 0 + case strings.HasPrefix(directive, "max-age="): + n, err := strconv.Atoi(strings.TrimPrefix(directive, "max-age=")) + if err != nil { + continue // malformed value; ignore directive + } + if n <= 0 { + return 0 // RFC 7234: max-age=0 (or negative) means uncached. + } + maxAge = time.Duration(n) * time.Second } - return time.Duration(n) * time.Second } - return 0 + if maxAge < 0 { + return cimdDefaultCacheTTL + } + if maxAge > cimdMaxCacheTTL { + return cimdMaxCacheTTL + } + return maxAge } // parseCIMDMetadata decodes the document and applies the schema rules from @@ -398,7 +427,14 @@ func parseCIMDMetadata(clientIDURL string, body []byte) (*statelessRegisteredCli case "authorization_code": hasAuthCode = true case "refresh_token": - // Tolerated in metadata; not honored — v1 issues no refresh tokens. + // Tolerated in metadata, deliberately NOT honored in v1: a client + // publishing ["authorization_code","refresh_token"] (which + // claude.ai does today) silently gets no refresh capability — + // .well-known/oauth-authorization-server only advertises + // authorization_code and /token returns unsupported_grant_type + // for refresh. If/when refresh ships, do NOT treat the CIMD + // grant_types array as authoritative for what we issue — the AS + // metadata is the source of truth. default: return nil, fmt.Errorf("%w: unsupported grant_type %q", errCIMDInvalidMetadata, gt) } @@ -452,18 +488,22 @@ type cimdCacheEntry struct { expiresAt time.Time } +// cimdCache is a bounded FIFO with TTL. Eviction order is insertion order: +// on overflow we drop the oldest-inserted entry. `get` does NOT promote, so +// this is FIFO, not LRU. The distinction doesn't matter at our scale (cap +// ≫ unique CIMD URLs in practice) and FIFO has a simpler invariant. type cimdCache struct { - mu sync.Mutex - entries map[string]*cimdCacheEntry - order []string - cap int + mu sync.Mutex + entries map[string]*cimdCacheEntry + order []string + capacity int } -func newCIMDCache(cap int) *cimdCache { - if cap <= 0 { - cap = 1 +func newCIMDCache(capacity int) *cimdCache { + if capacity <= 0 { + capacity = 1 } - return &cimdCache{entries: make(map[string]*cimdCacheEntry, cap), cap: cap} + return &cimdCache{entries: make(map[string]*cimdCacheEntry, capacity), capacity: capacity} } func (c *cimdCache) get(key string, now time.Time) (*cimdCacheEntry, bool) { @@ -494,7 +534,7 @@ func (c *cimdCache) put(key string, e *cimdCacheEntry, now time.Time) { } } if _, exists := c.entries[key]; !exists { - if len(c.entries) >= c.cap { + if len(c.entries) >= c.capacity { oldest := c.order[0] c.order = c.order[1:] delete(c.entries, oldest) diff --git a/cmd/altinity-mcp/cimd_test.go b/cmd/altinity-mcp/cimd_test.go index d378e04..8ad31d9 100644 --- a/cmd/altinity-mcp/cimd_test.go +++ b/cmd/altinity-mcp/cimd_test.go @@ -74,9 +74,20 @@ func TestValidateCIMDClientIDURL_OversizeRejected(t *testing.T) { func TestIsBlockedIP(t *testing.T) { blocked := []string{ + // Pre-existing coverage "127.0.0.1", "10.0.0.1", "192.168.1.1", "172.16.0.1", "169.254.169.254", "100.64.0.1", "0.0.0.0", "224.0.0.1", "::1", "fe80::1", "fc00::1", "192.0.0.1", + // Extended IANA special-purpose ranges added 2026-05-15. + "192.0.2.1", // TEST-NET-1 + "198.18.0.1", // benchmarking + "198.51.100.1", // TEST-NET-2 + "203.0.113.1", // TEST-NET-3 + "240.0.0.1", // reserved + "255.255.255.255", // broadcast (inside 240/4) + "2001:db8::1", // IPv6 documentation + "64:ff9b::1", // IPv4/IPv6 translation + "100::1", // IPv6 discard prefix } ok := []string{ "8.8.8.8", "1.1.1.1", "93.184.216.34", "2606:4700:4700::1111", @@ -93,6 +104,36 @@ func TestIsBlockedIP(t *testing.T) { } } +func TestCacheTTLFromHeader(t *testing.T) { + cases := []struct { + name string + header string + want time.Duration + }{ + {"empty header → default", "", cimdDefaultCacheTTL}, + {"no-store → 0", "no-store", 0}, + {"no-cache → 0", "no-cache", 0}, + {"public, no-store mixed → 0", "public, no-store", 0}, + {"max-age=0 → 0 (RFC 7234)", "max-age=0", 0}, + {"max-age=-5 → 0 (RFC 7234: negatives treated as 0)", "max-age=-5", 0}, + {"max-age=300 → 5m", "max-age=300", 5 * time.Minute}, + {"max-age=999999999 → cap", "max-age=999999999", cimdMaxCacheTTL}, + {"max-age with public", "public, max-age=120", 2 * time.Minute}, + {"unknown directive → default", "private", cimdDefaultCacheTTL}, + {"x-custom-no-storage must NOT match no-store", "x-custom-no-storage", cimdDefaultCacheTTL}, + {"no-storage (substring of no-store) must NOT match", "no-storage", cimdDefaultCacheTTL}, + {"malformed max-age value ignored, falls to default", "max-age=banana", cimdDefaultCacheTTL}, + {"no-store wins over max-age", "max-age=300, no-store", 0}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := cacheTTLFromHeader(tc.header); got != tc.want { + t.Errorf("cacheTTLFromHeader(%q) = %v, want %v", tc.header, got, tc.want) + } + }) + } +} + // --- schema validation -------------------------------------------------- func TestParseCIMDMetadata_OK(t *testing.T) { @@ -164,11 +205,10 @@ func testResolver(t *testing.T, server *httptest.Server) *cimdResolver { if err != nil { t.Fatalf("server URL parse: %v", err) } - host, port, err := net.SplitHostPort(su.Host) + _, port, err := net.SplitHostPort(su.Host) if err != nil { t.Fatalf("split host port: %v", err) } - _ = host r := newCIMDResolver(nil) // Replace the Transport with one that always dials the httptest server // instead of doing real DNS. This keeps the rest of the fetch / parse / @@ -231,6 +271,60 @@ func TestCIMDResolve_HappyPath_Cached(t *testing.T) { } } +func TestCIMDResolve_MaxAgeZeroSkipsCache(t *testing.T) { + hits := int32(0) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hits, 1) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "max-age=0") + fmt.Fprintf(w, `{"client_id":%q,"client_name":"D","redirect_uris":["https://d.example.com/cb"],"token_endpoint_auth_method":"none"}`, cimdTestURL("d.example.com", "/zero.json")) + })) + defer server.Close() + r := testResolver(t, server) + u := cimdTestURL("d.example.com", "/zero.json") + for i := 0; i < 3; i++ { + if _, err := r.resolve(context.Background(), u); err != nil { + t.Fatalf("resolve %d: %v", i, err) + } + } + if atomic.LoadInt32(&hits) != 3 { + t.Errorf("expected 3 fetches (max-age=0), got %d", hits) + } +} + +func TestCIMDResolve_DefaultTTLWhenNoDirectives(t *testing.T) { + hits := int32(0) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hits, 1) + w.Header().Set("Content-Type", "application/json") + // Cache-Control present but with no directives we care about. + w.Header().Set("Cache-Control", "private") + fmt.Fprintf(w, `{"client_id":%q,"client_name":"D","redirect_uris":["https://d.example.com/cb"],"token_endpoint_auth_method":"none"}`, cimdTestURL("d.example.com", "/private.json")) + })) + defer server.Close() + r := testResolver(t, server) + now := time.Now() + r.now = func() time.Time { return now } + u := cimdTestURL("d.example.com", "/private.json") + if _, err := r.resolve(context.Background(), u); err != nil { + t.Fatalf("first resolve: %v", err) + } + if _, err := r.resolve(context.Background(), u); err != nil { + t.Fatalf("second resolve: %v", err) + } + if atomic.LoadInt32(&hits) != 1 { + t.Errorf("expected 1 fetch (default TTL serves the second from cache), got %d", hits) + } + e, ok := r.cache.get(u, now) + if !ok { + t.Fatalf("expected cache entry") + } + wantExp := now.Add(cimdDefaultCacheTTL) + if e.expiresAt.Before(wantExp.Add(-time.Second)) || e.expiresAt.After(wantExp.Add(time.Second)) { + t.Errorf("expected default TTL ~%v, got expiresAt=%v (now=%v)", cimdDefaultCacheTTL, e.expiresAt, now) + } +} + func TestCIMDResolve_NoStoreSkipsCache(t *testing.T) { hits := int32(0) server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/cmd/altinity-mcp/oauth_server.go b/cmd/altinity-mcp/oauth_server.go index d91e700..0ed105c 100644 --- a/cmd/altinity-mcp/oauth_server.go +++ b/cmd/altinity-mcp/oauth_server.go @@ -1343,6 +1343,13 @@ func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Re Str("client_id", truncateForLog(clientID, 80)). Msg("OAuth /token: upstream code exchange succeeded") + // Validate the upstream identity before handing the bearer to the MCP + // client. We do NOT bind these claims into the downstream token in v1 + // (audience binding is deferred — see #115 § Non-goals); validation here + // exists purely to fail-fast on a malformed upstream response with a + // proper 502. Do not delete the underscored assignment without first + // re-introducing claim binding, or this side-effecting validation will + // look like dead code and get pruned. var identityClaims *altinitymcp.OAuthClaims if tokenResp.IDToken != "" { identityClaims, err = a.mcpServer.ValidateUpstreamIdentityToken(tokenResp.IDToken, cfg.Server.OAuth.ClientID) @@ -1359,7 +1366,7 @@ func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Re return } } - _ = identityClaims + _ = identityClaims // validation-only; intentionally unused in v1. if tokenResp.Scope == "" { tokenResp.Scope = issued.Scope From 85f468f07f431f55876419d76080c2373cffc6fd Mon Sep 17 00:00:00 2001 From: Boris Tyshkevich Date: Fri, 15 May 2026 16:27:54 +0200 Subject: [PATCH 4/4] oauth/cimd: 10 minor-finding follow-ups (bugs + dead-code + docs) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Self-review #3 turned up four real bugs, three documentation drifts, and three dead-code / consistency issues. No wire-shape change beyond the bug fixes; all behaviour-affecting fixes have regression tests. Bugs: - **identityClaims comment lied** (#5-dup). The "intentionally unused" comment + bare `_ = identityClaims` next to a block that DOES read identityClaims.ExpiresAt (used to derive `expires_in`) was an invitation for a future "clean up unused vars" PR to silently drop the upstream identity validation. Comment rewritten to name the three jobs the validation actually does; underscore drop assigned. - **cacheTTLFromHeader int64 overflow** (#7). `time.Duration(n) * time.Second` overflows when n > ~9.22e9, wrapping to a negative Duration that hit the "directive absent" sentinel branch and silently returned cimdDefaultCacheTTL (5m) instead of the intended cimdMaxCacheTTL (1h). Clamp n against int(cimdMaxCacheTTL/time.Second) before the multiply. Test rows added in TestCacheTTLFromHeader for max-age=9999999999 and max-age=int64.max. - **upstream 200 OK with RFC 6749 error body mapped to 502 server_error** (#8). Non-RFC-compliant IdPs and test stubs that return HTTP 200 + {"error":"invalid_grant"} hit our "no usable token" branch → 502 server_error. Replay-oracle contract assumes downstream sees invalid_grant. tokenResp struct gains Error / ErrorDescription fields; non-empty Error → downstream invalid_grant regardless of status. New TestOAuthTokenUpstream200WithErrorBody covers this. - **Content-Type prefix-match too permissive** (#9). strings.HasPrefix accepted application/json-ld, application/jsonpatch+json, etc. Replaced with isApplicationJSON helper that splits on ";", trims, case-folds, and exact-matches "application/json". TestIsApplicationJSON covers the new behaviour. UX / contract: - **/oauth/register 404 → 410 Gone + JSON** (#11). Legacy DCR clients hitting the route now get an RFC 7591 §3.2.2-shaped error body pointing them at CIMD, not Go's bare text/plain 404. New handleOAuthRegisterRemoved handler. TestOAuthRegisterGone. - **upstream_offline_access description rewritten** (#6). Previous desc tag claimed the flag "issues JWE-wrapped refresh tokens" — post-#115 it doesn't, and can't. Rewritten to explain the flag is upstream-scope-only and v1 issues no downstream refresh tokens regardless of its value. Dead code: - **#12 removals**: defaultRefreshTokenTTLSeconds (oauth_server.go), oauthRegistrationPath() method, OAuthConfig.RegistrationPath + YAML fixture + 2 test assertions, decodeStringSlice + TestDecodeStringSlice, statelessRegisteredClient.TokenEndpointAuthMethod (write-only field). statelessRegisteredClient is now one field. - **#14 dec.UseNumber no-op + string copy**: oauthex.ClientRegistrationResponse has a custom UnmarshalJSON that bypasses outer-decoder settings, so UseNumber() was a no-op and strings.NewReader(string(body)) was a wasted copy. json.Unmarshal(body, &doc) directly. Style: - **Named upstream HTTP timeout** (#1). oauthUpstreamHTTPTimeout = 10 * time.Second constant; both call sites use it. Docs / comments: - **oauthKidV1 comment** (#13). Said "client_id / refresh-token JWE artifacts" and "30-day window" — neither exists post-#115. Rewritten to name pending-auth (10 min) as the longest live legacy artifact and to indicate the SHA256(secret) fallback can be deleted in a follow-up after one rolling restart. - **mustJWESecret error string**: said "OAuth client registration and forward-mode token wrapping" → now "JWE-wrapped pending-auth state and downstream auth-code minting". - **oauth_server_test.go stale tombstone** removed. - **oidcScopesForAdvertisement** doc no longer mentions DCR responses. - **docs/oauth_authorization.md** banner explains the #115 cutover (DCR removed → CIMD only, no downstream refresh tokens, HA replay). Tests added or strengthened: TestCacheTTLFromHeader (+2 overflow rows), TestIsApplicationJSON (8 cases), TestValidateCIMDClientIDURL_Reject (+data:/javascript:/file: scheme rejects), TestOAuthRegisterGone, TestOAuthTokenRefreshGrantUnsupported, TestOAuthASMetadataShape, TestOAuthTokenUpstream200WithErrorBody. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/altinity-mcp/cimd.go | 66 +- cmd/altinity-mcp/cimd_test.go | 102 ++- cmd/altinity-mcp/oauth_regression_test.go | 750 ++++++++++++++++++++++ cmd/altinity-mcp/oauth_server.go | 126 ++-- cmd/altinity-mcp/oauth_server_test.go | 43 -- docs/oauth_authorization.md | 25 + pkg/config/config.go | 16 +- pkg/config/config_test.go | 4 - 8 files changed, 1008 insertions(+), 124 deletions(-) create mode 100644 cmd/altinity-mcp/oauth_regression_test.go diff --git a/cmd/altinity-mcp/cimd.go b/cmd/altinity-mcp/cimd.go index a77e3ee..397aef9 100644 --- a/cmd/altinity-mcp/cimd.go +++ b/cmd/altinity-mcp/cimd.go @@ -114,7 +114,11 @@ func validateCIMDPath(rawPath string) error { if err != nil { return fmt.Errorf("%w: invalid percent-encoding in path", errCIMDInvalidURL) } - if path.Clean(decoded) != decoded { + // path.Clean strips trailing slashes ("/a/" → "/a"), but RFC 3986 + // treats a trailing slash as a significant, legal path. Accept the path + // if it differs from its Clean form only by a single trailing slash. + cleaned := path.Clean(decoded) + if cleaned != decoded && cleaned+"/" != decoded { return fmt.Errorf("%w: dot-segment in path", errCIMDInvalidURL) } return nil @@ -271,6 +275,7 @@ func (a *application) resolveCIMDClient(ctx context.Context, clientIDURL string) func (r *cimdResolver) resolve(ctx context.Context, clientIDURL string) (*statelessRegisteredClient, error) { if _, err := validateCIMDClientIDURL(clientIDURL); err != nil { + r.cache.put(clientIDURL, &cimdCacheEntry{err: err, expiresAt: r.now().Add(cimdNegativeCacheTTL)}, r.now()) return nil, err } if e, ok := r.cache.get(clientIDURL, r.now()); ok { @@ -282,7 +287,17 @@ func (r *cimdResolver) resolve(ctx context.Context, clientIDURL string) (*statel client, ttl, err := r.fetchAndValidate(ctx, clientIDURL) now := r.now() if err != nil { - r.cache.put(clientIDURL, &cimdCacheEntry{err: err, expiresAt: now.Add(cimdNegativeCacheTTL)}, now) + // Negative-cache only stably-wrong outcomes (abuse control per #115 + // § Caching). Transient fetch failures — upstream 5xx, timeouts, + // client disconnects that propagate as context.Canceled — must NOT + // poison the cache: a single bad fetch from one user would lock all + // users of that client_id URL out for cimdNegativeCacheTTL. + switch { + case errors.Is(err, errCIMDInvalidMetadata), + errors.Is(err, errCIMDInvalidURL), + errors.Is(err, errCIMDSSRFBlocked): + r.cache.put(clientIDURL, &cimdCacheEntry{err: err, expiresAt: now.Add(cimdNegativeCacheTTL)}, now) + } return nil, err } if ttl > 0 { @@ -292,7 +307,11 @@ func (r *cimdResolver) resolve(ctx context.Context, clientIDURL string) (*statel } func (r *cimdResolver) fetchAndValidate(ctx context.Context, clientIDURL string) (*statelessRegisteredClient, time.Duration, error) { - ctx, cancel := context.WithTimeout(ctx, cimdFetchTimeout) + // Detach the fetch from the inbound request's cancellation. The fetch is + // shared across goroutines via the cache, so an inbound disconnect must + // not abort it (and produce a context.Canceled error that other waiters + // observe). The dedicated cimdFetchTimeout still bounds the call. + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cimdFetchTimeout) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, clientIDURL, nil) if err != nil { @@ -310,9 +329,8 @@ func (r *cimdResolver) fetchAndValidate(ctx context.Context, clientIDURL string) if resp.StatusCode != http.StatusOK { return nil, 0, fmt.Errorf("%w: HTTP %d", errCIMDFetch, resp.StatusCode) } - ct := resp.Header.Get("Content-Type") - if !strings.HasPrefix(ct, "application/json") { - return nil, 0, fmt.Errorf("%w: content-type %q not application/json", errCIMDFetch, ct) + if !isApplicationJSON(resp.Header.Get("Content-Type")) { + return nil, 0, fmt.Errorf("%w: content-type %q not application/json", errCIMDFetch, resp.Header.Get("Content-Type")) } body, err := io.ReadAll(io.LimitReader(resp.Body, int64(cimdMaxBodyBytes+1))) if err != nil { @@ -341,6 +359,16 @@ func (r *cimdResolver) fetchAndValidate(ctx context.Context, clientIDURL string) // // Directive matching is exact: a stray substring like "x-no-storage" does // NOT trigger no-store. +// isApplicationJSON matches RFC 7231 §3.1.1.5 media-type syntax: the bare +// type is "application/json", optionally followed by ";" parameters +// (charset, boundary, etc.). A bare prefix match would falsely accept +// "application/json-ld", "application/jsonpatch+json", and similar +// distinct media types whose bodies don't shape-match CIMD documents. +func isApplicationJSON(ct string) bool { + mt, _, _ := strings.Cut(ct, ";") + return strings.EqualFold(strings.TrimSpace(mt), "application/json") +} + func cacheTTLFromHeader(cc string) time.Duration { if cc == "" { return cimdDefaultCacheTTL @@ -362,6 +390,15 @@ func cacheTTLFromHeader(cc string) time.Duration { if n <= 0 { return 0 // RFC 7234: max-age=0 (or negative) means uncached. } + // Clamp n before the *time.Second multiply so we don't overflow + // int64 nanoseconds for absurd max-age values (n*1e9 overflows + // when n > ~9.22e9). Without this, a CIMD doc with + // "Cache-Control: max-age=9999999999" would wrap to negative + // and silently fall back to cimdDefaultCacheTTL. + const maxSeconds = int(cimdMaxCacheTTL / time.Second) + if n > maxSeconds { + return cimdMaxCacheTTL + } maxAge = time.Duration(n) * time.Second } } @@ -383,15 +420,15 @@ func cacheTTLFromHeader(cc string) time.Duration { // struct. Extra fields the SDK knows about (logo_uri, tos_uri, jwks, etc.) are // safely ignored because we don't read them. func parseCIMDMetadata(clientIDURL string, body []byte) (*statelessRegisteredClient, error) { + // json.Unmarshal here rather than json.Decoder: oauthex.ClientRegistrationResponse + // has a custom UnmarshalJSON that bypasses outer-decoder settings (UseNumber + // would be a no-op), and we don't need trailing-token detection — a + // well-formed CIMD doc is a single JSON object. The body was already + // bounded by io.LimitReader at fetch time. var doc oauthex.ClientRegistrationResponse - dec := json.NewDecoder(strings.NewReader(string(body))) - dec.UseNumber() - if err := dec.Decode(&doc); err != nil { + if err := json.Unmarshal(body, &doc); err != nil { return nil, fmt.Errorf("%w: decode: %v", errCIMDInvalidMetadata, err) } - if dec.More() { - return nil, fmt.Errorf("%w: trailing tokens after object", errCIMDInvalidMetadata) - } if doc.ClientID != clientIDURL { return nil, fmt.Errorf("%w: client_id mismatch", errCIMDInvalidMetadata) } @@ -456,10 +493,7 @@ func parseCIMDMetadata(clientIDURL string, body []byte) (*statelessRegisteredCli return nil, fmt.Errorf("%w: response_types must include code", errCIMDInvalidMetadata) } } - return &statelessRegisteredClient{ - RedirectURIs: doc.RedirectURIs, - TokenEndpointAuthMethod: "none", - }, nil + return &statelessRegisteredClient{RedirectURIs: doc.RedirectURIs}, nil } // validateCIMDRedirectURI: v1 requires https for all redirect URIs. Loopback diff --git a/cmd/altinity-mcp/cimd_test.go b/cmd/altinity-mcp/cimd_test.go index 8ad31d9..17ee22a 100644 --- a/cmd/altinity-mcp/cimd_test.go +++ b/cmd/altinity-mcp/cimd_test.go @@ -22,6 +22,10 @@ func TestValidateCIMDClientIDURL_OK(t *testing.T) { "https://chatgpt.com/.well-known/oauth-client-id", "https://example.com:443/x.json", "https://example.com/a/b/c.json", + // RFC 3986 trailing slashes are legal and significant; regression + // guard for the path.Clean rejection bug fixed in this PR. + "https://example.com/oauth/cimd/", + "https://example.com/a/b/c.json/", } for _, c := range cases { if _, err := validateCIMDClientIDURL(c); err != nil { @@ -51,6 +55,13 @@ func TestValidateCIMDClientIDURL_Reject(t *testing.T) { "encoded_slash": "https://example.com/a%2fb/x.json", "encoded_backslash": "https://example.com/a%5cb/x.json", "uppercase_host": "https://Example.com/x.json", + // IDN normalization: Cyrillic 'а' (U+0430) is a confusable for ASCII + // 'a'. idna.Lookup.ToASCII converts the IDN to its xn-- form, which + // won't equal the raw input, so we reject. + "cyrillic_a_idn": "https://exаmple.com/x.json", + "data_scheme": "data:application/json,{}", + "javascript_scheme": "javascript:alert(1)", + "file_scheme": "file:///etc/passwd", } for name, raw := range cases { t.Run(name, func(t *testing.T) { @@ -124,6 +135,12 @@ func TestCacheTTLFromHeader(t *testing.T) { {"no-storage (substring of no-store) must NOT match", "no-storage", cimdDefaultCacheTTL}, {"malformed max-age value ignored, falls to default", "max-age=banana", cimdDefaultCacheTTL}, {"no-store wins over max-age", "max-age=300, no-store", 0}, + // Regression for int64 overflow on n * time.Second when n is huge. + // Pre-fix this returned cimdDefaultCacheTTL (5m) because the multiply + // wrapped to a negative time.Duration, hitting the "sentinel = absent" + // branch. Now we clamp before multiplying. + {"max-age=9999999999 (n*1e9 overflows int64) → cap", "max-age=9999999999", cimdMaxCacheTTL}, + {"max-age=int64.max → cap", "max-age=9223372036854775807", cimdMaxCacheTTL}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { @@ -151,7 +168,7 @@ func TestParseCIMDMetadata_OK(t *testing.T) { if err != nil { t.Fatalf("expected ok, got %v", err) } - if c.TokenEndpointAuthMethod != "none" || len(c.RedirectURIs) != 1 { + if len(c.RedirectURIs) != 1 || c.RedirectURIs[0] != "https://claude.ai/api/mcp/auth_callback" { t.Errorf("unexpected client: %#v", c) } } @@ -384,6 +401,25 @@ func TestCIMDResolve_OversizeBodyRejected(t *testing.T) { } } +func TestIsApplicationJSON(t *testing.T) { + cases := map[string]bool{ + "application/json": true, + "application/json; charset=utf-8": true, + "APPLICATION/JSON": true, // RFC 7231: media types are case-insensitive + "application/json-ld": false, + "application/jsonpatch+json": false, + "application/json5": false, + "application/jose": false, + "text/json": false, + "": false, + } + for ct, want := range cases { + if got := isApplicationJSON(ct); got != want { + t.Errorf("isApplicationJSON(%q) = %v, want %v", ct, got, want) + } + } +} + func TestCIMDResolve_NonJSONRejected(t *testing.T) { server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") @@ -409,19 +445,77 @@ func TestCIMDResolve_RedirectRejected(t *testing.T) { } } -func TestCIMDResolve_NegativeCache(t *testing.T) { +func TestCIMDResolve_TransientFetchErrorDoesNotPoisonCache(t *testing.T) { + // Issue #115 review-followup: a transient upstream 5xx (or client + // disconnect propagating as context.Canceled) must NOT write a negative + // cache entry — otherwise one bad fetch locks all users of that + // client_id URL out for cimdNegativeCacheTTL. errCIMDFetch is the + // "transient" bucket and is explicitly skipped by resolve(). server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() r := testResolver(t, server) - u := cimdTestURL("d.example.com", "/x.json") + u := cimdTestURL("d.example.com", "/transient.json") + if _, err := r.resolve(context.Background(), u); err == nil { + t.Fatal("expected error") + } + if _, ok := r.cache.get(u, time.Now()); ok { + t.Errorf("transient fetch error must NOT produce a cache entry") + } +} + +func TestCIMDResolve_InvalidMetadataNegativeCache(t *testing.T) { + // Conversely, a stably-malformed metadata response IS cached + // (abuse-control per issue #115 § Caching) — the document is wrong + // every time, so hammering the upstream is pointless. + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + // client_id in doc deliberately mismatches the requested URL. + w.Write([]byte(`{"client_id":"https://other/x","client_name":"X","redirect_uris":["https://x/cb"],"token_endpoint_auth_method":"none"}`)) + })) + defer server.Close() + r := testResolver(t, server) + u := cimdTestURL("d.example.com", "/invalid.json") if _, err := r.resolve(context.Background(), u); err == nil { t.Fatal("expected error") } e, ok := r.cache.get(u, time.Now()) if !ok || e.err == nil { - t.Errorf("expected negative cache entry") + t.Errorf("invalid-metadata error must produce a negative cache entry") + } +} + +func TestCIMDResolve_ClientCancellationDoesNotPoisonCache(t *testing.T) { + // Mid-fetch client disconnect: simulate by cancelling the caller's + // context before the response can be sent. fetchAndValidate uses + // context.WithoutCancel internally so the fetch survives, but if a + // future regression re-couples the contexts, this test catches the + // "first-user-cancel locks everyone out" failure mode. + release := make(chan struct{}) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-release // hold until the test cancels the caller's ctx + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"client_id":%q,"client_name":"D","redirect_uris":["https://d.example.com/cb"],"token_endpoint_auth_method":"none"}`, cimdTestURL("d.example.com", "/cancel.json")) + })) + defer server.Close() + defer close(release) + r := testResolver(t, server) + u := cimdTestURL("d.example.com", "/cancel.json") + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + defer close(done) + _, _ = r.resolve(ctx, u) + }() + cancel() + release <- struct{}{} + <-done + // The fetch detached its context from the caller, so it succeeded and + // the resolver wrote a positive cache entry. Either way, no negative + // poisoning: subsequent callers either hit a valid cache or refetch. + if e, ok := r.cache.get(u, time.Now()); ok && e.err != nil { + t.Errorf("client cancellation must NOT produce a negative cache entry; got err=%v", e.err) } } diff --git a/cmd/altinity-mcp/oauth_regression_test.go b/cmd/altinity-mcp/oauth_regression_test.go new file mode 100644 index 0000000..d823d79 --- /dev/null +++ b/cmd/altinity-mcp/oauth_regression_test.go @@ -0,0 +1,750 @@ +package main + +// Regression coverage for behaviour that survived the DCR cleanup but lost +// its tests in #116. See PR review (commit 6f8bbed → 03b19f6 → dac3961). +// +// Test groups: +// - HKDF info-label isolation between pending-auth and auth-code JWEs +// - Field-by-field round-trip of encodePendingAuth / encodeAuthCode +// - Forward-mode JWT validation in the MCP auth injector +// - /token RFC 8707 invalid_target check +// - /.well-known/* alias-path registration +// - End-to-end /authorize → /callback → /token through the CIMD resolver + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/altinity/altinity-mcp/pkg/config" + "github.com/altinity/altinity-mcp/pkg/jwe_auth" + altinitymcp "github.com/altinity/altinity-mcp/pkg/server" + "github.com/go-jose/go-jose/v4" + "github.com/stretchr/testify/require" +) + +// --- HKDF + JWE round-trip ---------------------------------------------- + +func TestOAuthJWEHKDFRoundtripAndIsolation(t *testing.T) { + t.Parallel() + secret := []byte("regression-hkdf-secret-32-bytes!") + + t.Run("pending-auth roundtrip", func(t *testing.T) { + t.Parallel() + in := map[string]interface{}{ + "client_id": "https://x.example/y.json", + "redirect_uri": "https://x.example/cb", + "exp": time.Now().Add(time.Hour).Unix(), + } + token, err := encodeOAuthJWE(secret, hkdfInfoOAuthPendingAuth, in) + require.NoError(t, err) + out, err := decodeOAuthJWE(secret, hkdfInfoOAuthPendingAuth, token) + require.NoError(t, err) + require.Equal(t, in["client_id"], out["client_id"]) + require.Equal(t, in["redirect_uri"], out["redirect_uri"]) + }) + + t.Run("auth-code roundtrip", func(t *testing.T) { + t.Parallel() + in := map[string]interface{}{ + "client_id": "https://x.example/y.json", + "upstream_auth_code": "abc", + "exp": time.Now().Add(60 * time.Second).Unix(), + } + token, err := encodeOAuthJWE(secret, hkdfInfoOAuthAuthCode, in) + require.NoError(t, err) + out, err := decodeOAuthJWE(secret, hkdfInfoOAuthAuthCode, token) + require.NoError(t, err) + require.Equal(t, in["upstream_auth_code"], out["upstream_auth_code"]) + }) + + t.Run("HKDF info-label isolation: pending JWE will not decode as auth-code", func(t *testing.T) { + t.Parallel() + token, err := encodeOAuthJWE(secret, hkdfInfoOAuthPendingAuth, map[string]interface{}{ + "client_id": "https://x.example/y.json", + "exp": time.Now().Add(time.Hour).Unix(), + }) + require.NoError(t, err) + _, err = decodeOAuthJWE(secret, hkdfInfoOAuthAuthCode, token) + require.Error(t, err, "JWE minted under pending-auth/v1 must NOT decrypt under auth-code/v2 — that's the whole point of HKDF info-label separation") + }) + + t.Run("legacy kid=\"\" fallback path is reachable via jwe_auth.ParseAndDecryptJWE", func(t *testing.T) { + t.Parallel() + // Mint a legacy JWE the same way pre-HKDF artifacts were minted + // (SHA256(secret) → A256GCM, no kid). + legacy, err := jwe_auth.GenerateJWEToken(map[string]interface{}{ + "sub": "legacy-subject", + "exp": time.Now().Add(time.Hour).Unix(), + }, secret, nil) + require.NoError(t, err) + // decodeOAuthJWE inspects the kid header: kid="" routes to the + // legacy SHA256 path. Any info label works because the legacy path + // doesn't HKDF-derive — passing the auth-code label here mirrors + // what a production token-handler call would do. + claims, err := decodeOAuthJWE(secret, hkdfInfoOAuthAuthCode, legacy) + require.NoError(t, err) + require.Equal(t, "legacy-subject", claims["sub"]) + }) +} + +// --- pending-auth / auth-code field-by-field round-trip ----------------- + +func TestOAuthPendingAuthAndAuthCodeRoundTrip(t *testing.T) { + t.Parallel() + app := &application{ + config: config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{ + SigningSecret: "regression-roundtrip-32-bytes!!", + }}}, + } + + t.Run("oauthPendingAuth", func(t *testing.T) { + t.Parallel() + in := oauthPendingAuth{ + ClientID: "https://claude.ai/oauth/x", + RedirectURI: "https://claude.ai/cb", + Scope: "openid email", + ClientState: "csrf-state", + CodeChallenge: "ZH-pVPpAjHk", + CodeChallengeMethod: "S256", + Resource: "https://mcp.example.com/", + UpstreamPKCEVerifier: "upstream-verifier", + ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second), + } + token, err := app.encodePendingAuth(in) + require.NoError(t, err) + out, ok := app.decodePendingAuth(token) + require.True(t, ok) + require.Equal(t, in.ClientID, out.ClientID) + require.Equal(t, in.RedirectURI, out.RedirectURI) + require.Equal(t, in.Scope, out.Scope) + require.Equal(t, in.ClientState, out.ClientState) + require.Equal(t, in.CodeChallenge, out.CodeChallenge) + require.Equal(t, in.CodeChallengeMethod, out.CodeChallengeMethod) + require.Equal(t, in.Resource, out.Resource) + require.Equal(t, in.UpstreamPKCEVerifier, out.UpstreamPKCEVerifier) + }) + + t.Run("oauthIssuedCode", func(t *testing.T) { + t.Parallel() + in := oauthIssuedCode{ + ClientID: "https://claude.ai/oauth/x", + RedirectURI: "https://claude.ai/cb", + Scope: "openid email", + CodeChallenge: "ZH-pVPpAjHk", + CodeChallengeMethod: "S256", + Resource: "https://mcp.example.com/", + UpstreamAuthCode: "upstream-code-abc", + UpstreamPKCEVerifier: "upstream-verifier", + ExpiresAt: time.Now().Add(60 * time.Second).Truncate(time.Second), + } + token, err := app.encodeAuthCode(in) + require.NoError(t, err) + out, ok := app.decodeAuthCode(token) + require.True(t, ok) + require.Equal(t, in.ClientID, out.ClientID) + require.Equal(t, in.RedirectURI, out.RedirectURI) + require.Equal(t, in.Scope, out.Scope) + require.Equal(t, in.CodeChallenge, out.CodeChallenge) + require.Equal(t, in.CodeChallengeMethod, out.CodeChallengeMethod) + require.Equal(t, in.Resource, out.Resource) + require.Equal(t, in.UpstreamAuthCode, out.UpstreamAuthCode) + require.Equal(t, in.UpstreamPKCEVerifier, out.UpstreamPKCEVerifier) + }) +} + +// --- forward-mode JWT validation in the MCP auth injector ---------------- + +func TestOAuthMCPAuthInjectorForwardModeValidatesJWT(t *testing.T) { + t.Parallel() + provider := newRegressionOIDCProvider(t, nil, nil) + cfg := config.Config{ + Server: config.ServerConfig{ + OAuth: config.OAuthConfig{ + Enabled: true, + Mode: "forward", + Issuer: provider.server.URL, + JWKSURL: provider.server.URL + "/jwks", + Audience: "clickhouse-api", + }, + }, + } + app := &application{ + config: cfg, + mcpServer: altinitymcp.NewClickHouseMCPServer(cfg, "test"), + } + + mkReq := func(token string) (*httptest.ResponseRecorder, *http.Request) { + req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/", nil) + req.Header.Set("Authorization", "Bearer "+token) + return httptest.NewRecorder(), req + } + + t.Run("valid JWT reaches handler with claims", func(t *testing.T) { + t.Parallel() + tok := provider.issueIDToken(t, map[string]interface{}{ + "sub": "u-good", + "iss": provider.server.URL, + "aud": "clickhouse-api", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }) + rr, req := mkReq(tok) + called := false + app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + require.Equal(t, tok, r.Context().Value(altinitymcp.OAuthTokenKey)) + claims, ok := r.Context().Value(altinitymcp.OAuthClaimsKey).(*altinitymcp.OAuthClaims) + require.True(t, ok) + require.Equal(t, "u-good", claims.Subject) + w.WriteHeader(http.StatusOK) + })).ServeHTTP(rr, req) + require.True(t, called) + require.Equal(t, http.StatusOK, rr.Code) + }) + + t.Run("wrong-audience JWT rejected with 401", func(t *testing.T) { + t.Parallel() + tok := provider.issueIDToken(t, map[string]interface{}{ + "sub": "u-bad-aud", + "iss": provider.server.URL, + "aud": "some-other-api", + "exp": time.Now().Add(time.Hour).Unix(), + }) + rr, req := mkReq(tok) + called := false + app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })).ServeHTTP(rr, req) + require.False(t, called, "wrong-aud token must not reach inner handler") + require.Equal(t, http.StatusUnauthorized, rr.Code) + require.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) + }) + + t.Run("expired JWT rejected with 401", func(t *testing.T) { + t.Parallel() + tok := provider.issueIDToken(t, map[string]interface{}{ + "sub": "u-expired", + "iss": provider.server.URL, + "aud": "clickhouse-api", + "exp": time.Now().Add(-2 * time.Hour).Unix(), + "iat": time.Now().Add(-3 * time.Hour).Unix(), + }) + rr, req := mkReq(tok) + called := false + app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })).ServeHTTP(rr, req) + require.False(t, called, "expired token must not reach inner handler") + require.Equal(t, http.StatusUnauthorized, rr.Code) + }) +} + +// --- /token RFC 8707 invalid_target mismatch --------------------------- + +func TestOAuthForwardModeTokenResourceMismatch(t *testing.T) { + t.Parallel() + const ( + cimdURL = "https://demo.example.com/cimd.json" + redirectURI = "https://demo.example.com/cb" + boundResource = "https://mcp.example.com/" + clashResource = "https://other.example.com/" + signingSecret = "regression-resource-32-bytes!!!!" + ) + // CIMD doc server so the resolver can satisfy /token. + cimdServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"client_id":%q,"client_name":"D","redirect_uris":[%q],"token_endpoint_auth_method":"none"}`, cimdURL, redirectURI) + })) + defer cimdServer.Close() + + cfg := config.Config{ + Server: config.ServerConfig{ + OAuth: config.OAuthConfig{ + Enabled: true, + Mode: "forward", + Issuer: "https://idp.example.com", + PublicAuthServerURL: "https://mcp.example.com", + SigningSecret: signingSecret, + }, + }, + } + app := &application{ + config: cfg, + mcpServer: altinitymcp.NewClickHouseMCPServer(cfg, "test"), + cimdResolver: testResolver(t, cimdServer), + } + + verifier, err := newPKCEVerifier() + require.NoError(t, err) + issued := oauthIssuedCode{ + ClientID: cimdURL, + RedirectURI: redirectURI, + Scope: "openid email", + CodeChallenge: pkceChallenge(verifier), + CodeChallengeMethod: "S256", + Resource: boundResource, + UpstreamAuthCode: "unused-this-test-rejects-pre-upstream", + UpstreamPKCEVerifier: "uv", + ExpiresAt: time.Now().Add(60 * time.Second), + } + code, err := app.encodeAuthCode(issued) + require.NoError(t, err) + + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", cimdURL) + form.Set("redirect_uri", redirectURI) + form.Set("code", code) + form.Set("code_verifier", verifier) + form.Set("resource", clashResource) // <- mismatch + req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + require.NoError(t, req.ParseForm()) + rr := httptest.NewRecorder() + app.handleOAuthTokenAuthCode(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code, "body=%s", rr.Body.String()) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &resp)) + require.Equal(t, "invalid_target", resp["error"]) +} + +// --- .well-known alias paths --------------------------------------------- + +func TestRegisterOAuthHTTPRoutesAliases(t *testing.T) { + t.Parallel() + app := &application{ + config: config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{ + Enabled: true, + Mode: "forward", + Issuer: "https://idp.example.com", + PublicAuthServerURL: "https://mcp.example.com", + SigningSecret: "regression-aliases-32-bytes!!!!!", + }}}, + } + mux := http.NewServeMux() + app.registerOAuthHTTPRoutes(mux) + + // Each alias path must return the same JSON document (modulo OIDC's + // id_token_signing_alg_values_supported extra in gating-mode openid + // configuration — we run forward so neither path adds it). + for _, path := range []string{ + "/.well-known/oauth-authorization-server", + "/.well-known/oauth-authorization-server/oauth", + "/oauth/.well-known/oauth-authorization-server", + "/.well-known/openid-configuration", + "/.well-known/openid-configuration/oauth", + "/oauth/.well-known/openid-configuration", + } { + t.Run(path, func(t *testing.T) { + t.Parallel() + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "https://mcp.example.com"+path, nil)) + require.Equal(t, http.StatusOK, rr.Code, "alias %s should serve metadata", path) + var doc map[string]interface{} + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &doc)) + require.Equal(t, true, doc["client_id_metadata_document_supported"], "alias %s must advertise CIMD support", path) + require.NotContains(t, doc, "registration_endpoint", "alias %s must not advertise DCR endpoint", path) + }) + } +} + +// --- CIMD end-to-end happy path ------------------------------------------ + +// TestCIMDFullAuthCodeFlow walks /authorize → /callback → /token through a +// resolver that fetches a fake CIMD doc, a fake upstream IdP that mints +// access_tokens, and an in-process userinfo endpoint. Closes the gap left +// by oauth_ha_replay_test.go, which short-circuits to /token from a hand- +// built JWE auth-code. +func TestCIMDFullAuthCodeFlow(t *testing.T) { + t.Parallel() + const ( + downstreamClient = "https://demo.example.com/cimd.json" + downstreamRedir = "https://demo.example.com/cb" + upstreamClient = "broker-client" + upstreamSecret = "broker-secret" + signingSecret = "regression-fullflow-32-bytes!!!!" + ) + + // Fake upstream IdP. + tokenRedemptions := int32(0) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/authorize": + // Bounce straight to the broker's /oauth/callback with the + // pending state preserved. In a real run the user logs in here. + cb := r.URL.Query().Get("redirect_uri") + state := r.URL.Query().Get("state") + http.Redirect(w, r, cb+"?code=upstream-code&state="+state, http.StatusFound) + case "/token": + atomic.AddInt32(&tokenRedemptions, 1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "upstream-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "openid email", + }) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "sub": "u-1", "email": "u1@example.com", "email_verified": true, + }) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + // Fake CIMD doc server. + cimdServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"client_id":%q,"client_name":"Demo","redirect_uris":[%q],"token_endpoint_auth_method":"none"}`, downstreamClient, downstreamRedir) + })) + defer cimdServer.Close() + + cfg := config.Config{ + Server: config.ServerConfig{ + OAuth: config.OAuthConfig{ + Enabled: true, + Mode: "forward", + Issuer: upstream.URL, + JWKSURL: upstream.URL + "/jwks", + AuthURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + ClientID: upstreamClient, + ClientSecret: upstreamSecret, + Audience: upstreamClient, + PublicAuthServerURL: "https://mcp.example.com", + SigningSecret: signingSecret, + Scopes: []string{"openid", "email"}, + }, + }, + } + app := &application{ + config: cfg, + mcpServer: altinitymcp.NewClickHouseMCPServer(cfg, "test"), + cimdResolver: testResolver(t, cimdServer), + } + + verifier, err := newPKCEVerifier() + require.NoError(t, err) + challenge := pkceChallenge(verifier) + + // 1. /oauth/authorize — should produce a 302 to upstream /authorize with + // a JWE state parameter. + authReq := httptest.NewRequest(http.MethodGet, + "https://mcp.example.com/oauth/authorize?"+url.Values{ + "client_id": {downstreamClient}, + "redirect_uri": {downstreamRedir}, + "response_type": {"code"}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + "state": {"client-state"}, + }.Encode(), nil) + authRR := httptest.NewRecorder() + app.handleOAuthAuthorize(authRR, authReq) + require.Equal(t, http.StatusFound, authRR.Code, "body=%s", authRR.Body.String()) + upstreamRedirect, err := url.Parse(authRR.Header().Get("Location")) + require.NoError(t, err) + require.True(t, strings.HasPrefix(upstreamRedirect.String(), upstream.URL+"/authorize")) + state := upstreamRedirect.Query().Get("state") + require.NotEmpty(t, state) + + // 2. /oauth/callback — simulating the upstream IdP's redirect back to + // us. Our handler wraps the upstream auth code into a downstream + // JWE and 302s the user to the downstream redirect_uri. + cbReq := httptest.NewRequest(http.MethodGet, + "https://mcp.example.com/oauth/callback?code=upstream-code&state="+url.QueryEscape(state), nil) + cbRR := httptest.NewRecorder() + app.handleOAuthCallback(cbRR, cbReq) + require.Equal(t, http.StatusFound, cbRR.Code, "body=%s", cbRR.Body.String()) + downstreamRedirect, err := url.Parse(cbRR.Header().Get("Location")) + require.NoError(t, err) + require.Equal(t, "demo.example.com", downstreamRedirect.Host) + downstreamCode := downstreamRedirect.Query().Get("code") + require.NotEmpty(t, downstreamCode) + require.Equal(t, "client-state", downstreamRedirect.Query().Get("state")) + require.Equal(t, int32(0), atomic.LoadInt32(&tokenRedemptions), "/callback must NOT redeem upstream (HA replay model)") + + // 3. /oauth/token — the broker now redeems upstream and hands the + // bearer to the MCP client. + tokenForm := url.Values{} + tokenForm.Set("grant_type", "authorization_code") + tokenForm.Set("client_id", downstreamClient) + tokenForm.Set("redirect_uri", downstreamRedir) + tokenForm.Set("code", downstreamCode) + tokenForm.Set("code_verifier", verifier) + tokenReq := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/token", strings.NewReader(tokenForm.Encode())) + tokenReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + tokenRR := httptest.NewRecorder() + app.handleOAuthToken(tokenRR, tokenReq) + require.Equal(t, http.StatusOK, tokenRR.Code, "body=%s", tokenRR.Body.String()) + var tokenResp map[string]interface{} + require.NoError(t, json.Unmarshal(tokenRR.Body.Bytes(), &tokenResp)) + require.Equal(t, "upstream-access-token", tokenResp["access_token"]) + require.NotContains(t, tokenResp, "refresh_token", "v1 issues no refresh tokens to CIMD clients") + require.Equal(t, int32(1), atomic.LoadInt32(&tokenRedemptions), "exactly one upstream /token call per /oauth/token attempt") +} + +// --- /oauth/register 410 Gone -------------------------------------------- + +// TestOAuthRegisterGone confirms the DCR-tombstone handler returns a +// diagnosable RFC 7591 §3.2.2-shaped JSON error rather than the bare mux +// 404 a DCR client would otherwise see. +func TestOAuthRegisterGone(t *testing.T) { + t.Parallel() + app := &application{ + config: config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{ + Enabled: true, + Mode: "forward", + Issuer: "https://idp.example.com", + PublicAuthServerURL: "https://mcp.example.com", + SigningSecret: "regression-410-32bytes!!!!!!!!!!!", + }}}, + } + mux := http.NewServeMux() + app.registerOAuthHTTPRoutes(mux) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/register", strings.NewReader(`{"redirect_uris":["https://x/cb"],"token_endpoint_auth_method":"none"}`))) + require.Equal(t, http.StatusGone, rr.Code) + require.Contains(t, rr.Header().Get("Content-Type"), "application/json") + var body map[string]string + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) + require.Equal(t, "registration_not_supported", body["error"]) + require.Contains(t, body["error_description"], "CIMD") +} + +// --- /oauth/token refresh grant unsupported ----------------------------- + +func TestOAuthTokenRefreshGrantUnsupported(t *testing.T) { + t.Parallel() + app := &application{ + config: config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{ + Enabled: true, + Mode: "forward", + Issuer: "https://idp.example.com", + PublicAuthServerURL: "https://mcp.example.com", + SigningSecret: "regression-refresh-32bytes!!!!!!!", + }}}, + } + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("client_id", "https://x.example.com/cimd.json") + form.Set("refresh_token", "anything") + req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + app.handleOAuthToken(rr, req) + require.Equal(t, http.StatusBadRequest, rr.Code, "body=%s", rr.Body.String()) + var body map[string]string + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) + require.Equal(t, "unsupported_grant_type", body["error"]) +} + +// --- AS metadata shape --------------------------------------------------- + +func TestOAuthASMetadataShape(t *testing.T) { + t.Parallel() + app := &application{ + config: config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{ + Enabled: true, + Mode: "forward", + Issuer: "https://idp.example.com", + PublicAuthServerURL: "https://mcp.example.com", + SigningSecret: "regression-shape-32bytes!!!!!!!!!", + }}}, + } + rr := httptest.NewRecorder() + app.handleOAuthAuthorizationServerMetadata(rr, httptest.NewRequest(http.MethodGet, "https://mcp.example.com/.well-known/oauth-authorization-server", nil)) + require.Equal(t, http.StatusOK, rr.Code) + var doc map[string]interface{} + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &doc)) + + require.Equal(t, true, doc["client_id_metadata_document_supported"]) + require.NotContains(t, doc, "registration_endpoint") + require.Equal(t, []interface{}{"none"}, doc["token_endpoint_auth_methods_supported"]) + require.Equal(t, []interface{}{"authorization_code"}, doc["grant_types_supported"]) + require.Equal(t, []interface{}{"code"}, doc["response_types_supported"]) + require.Equal(t, []interface{}{"S256"}, doc["code_challenge_methods_supported"]) + require.NotContains(t, doc["grant_types_supported"], "refresh_token") +} + +// --- upstream 200 OK with RFC 6749 §5.2 error body ---------------------- + +// TestOAuthTokenUpstream200WithErrorBody covers the non-RFC-compliant IdP +// case where /token returns HTTP 200 OK + {"error":"invalid_grant"} (no +// access_token / id_token). Status-only checks miss this; we must surface +// it as downstream invalid_grant so the HA replay contract holds. +func TestOAuthTokenUpstream200WithErrorBody(t *testing.T) { + t.Parallel() + const ( + cimdURL = "https://demo.example.com/cimd.json" + redirectURI = "https://demo.example.com/cb" + ) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/token" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{"error":"invalid_grant","error_description":"already used"}`) + })) + defer upstream.Close() + cimdServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"client_id":%q,"client_name":"D","redirect_uris":[%q],"token_endpoint_auth_method":"none"}`, cimdURL, redirectURI) + })) + defer cimdServer.Close() + + cfg := config.Config{Server: config.ServerConfig{OAuth: config.OAuthConfig{ + Enabled: true, + Mode: "forward", + Issuer: upstream.URL, + JWKSURL: upstream.URL + "/jwks", + AuthURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + ClientID: "broker", + ClientSecret: "s", + PublicAuthServerURL: "https://mcp.example.com", + SigningSecret: "regression-200err-32bytes!!!!!!!", + }}} + app := &application{ + config: cfg, + mcpServer: altinitymcp.NewClickHouseMCPServer(cfg, "test"), + cimdResolver: testResolver(t, cimdServer), + } + verifier, _ := newPKCEVerifier() + issued := oauthIssuedCode{ + ClientID: cimdURL, RedirectURI: redirectURI, Scope: "openid email", + CodeChallenge: pkceChallenge(verifier), CodeChallengeMethod: "S256", + UpstreamAuthCode: "abc", UpstreamPKCEVerifier: "uv", + ExpiresAt: time.Now().Add(60 * time.Second), + } + code, err := app.encodeAuthCode(issued) + require.NoError(t, err) + + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", cimdURL) + form.Set("redirect_uri", redirectURI) + form.Set("code", code) + form.Set("code_verifier", verifier) + req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/oauth/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + require.NoError(t, req.ParseForm()) + rr := httptest.NewRecorder() + app.handleOAuthTokenAuthCode(rr, req) + require.Equal(t, http.StatusBadRequest, rr.Code, "200+error body must surface as 400 invalid_grant, got body=%s", rr.Body.String()) + var body map[string]string + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) + require.Equal(t, "invalid_grant", body["error"]) +} + +// --- helpers ------------------------------------------------------------- + +// regressionOIDCProvider is a small fake OIDC AS used by the forward-mode +// JWT validation tests. It signs id_tokens with RS256 and exposes JWKS at +// /jwks so altinitymcp.ValidateUpstreamIdentityToken can verify them. +type regressionOIDCProvider struct { + server *httptest.Server + privateKey *rsa.PrivateKey + keyID string + + tokenResp map[string]interface{} + userInfoClaims map[string]interface{} + + mu sync.Mutex + userInfoCalls int + lastUserInfoAuth string +} + +func newRegressionOIDCProvider(t *testing.T, tokenResp, userInfoClaims map[string]interface{}) *regressionOIDCProvider { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + p := ®ressionOIDCProvider{ + privateKey: priv, + keyID: "regression-key", + tokenResp: tokenResp, + userInfoClaims: userInfoClaims, + } + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + p.server = server + + mux.HandleFunc("/authorize", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + }) + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(p.tokenResp)) + }) + mux.HandleFunc("/jwks", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{{ + Key: &priv.PublicKey, + KeyID: p.keyID, + Use: "sig", + Algorithm: string(jose.RS256), + }}, + })) + }) + mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { + p.mu.Lock() + p.userInfoCalls++ + p.lastUserInfoAuth = r.Header.Get("Authorization") + p.mu.Unlock() + if p.userInfoClaims == nil { + http.Error(w, "userinfo not configured", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(p.userInfoClaims)) + }) + return p +} + +func (p *regressionOIDCProvider) issueIDToken(t *testing.T, claims map[string]interface{}) string { + t.Helper() + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.RS256, + Key: jose.JSONWebKey{ + Key: p.privateKey, + KeyID: p.keyID, + Use: "sig", + Algorithm: string(jose.RS256), + }, + }, (&jose.SignerOptions{}).WithType("JWT")) + require.NoError(t, err) + payload, err := json.Marshal(claims) + require.NoError(t, err) + obj, err := signer.Sign(payload) + require.NoError(t, err) + tok, err := obj.CompactSerialize() + require.NoError(t, err) + return tok +} + +// Avoid unused-import errors if some sub-tests get gated off. +var _ = context.Background +var _ = io.Discard diff --git a/cmd/altinity-mcp/oauth_server.go b/cmd/altinity-mcp/oauth_server.go index 0ed105c..06f56dd 100644 --- a/cmd/altinity-mcp/oauth_server.go +++ b/cmd/altinity-mcp/oauth_server.go @@ -22,7 +22,17 @@ import ( "github.com/rs/zerolog/log" ) -const maxOAuthResponseBytes = 1 << 20 // 1 MB +const ( + maxOAuthResponseBytes = 1 << 20 // 1 MB cap on upstream IdP response bodies. + + // oauthUpstreamHTTPTimeout bounds the broker's outbound HTTP calls to the + // upstream IdP (`/token` exchange + `/userinfo` fetch). Mirrors + // pkg/server.oauthHTTPTimeout used for JWKS / OIDC discovery — both call + // the same set of upstream hosts and should fail-fast together. Not + // shared as one cross-package constant to keep cmd/altinity-mcp free of + // pkg/server import-loop risk. + oauthUpstreamHTTPTimeout = 10 * time.Second +) const ( defaultProtectedResourceMetadataPath = "/.well-known/oauth-protected-resource" @@ -39,14 +49,16 @@ const ( // defaultAuthCodeTTLSeconds bounds /callback → /token (the legitimate // client redeems immediately). 60 seconds per OAuth 2.1 §4.1.2 — auth // codes "should be redeemed within seconds, never minutes." - defaultAuthCodeTTLSeconds = 60 - defaultAccessTokenTTLSeconds = 60 * 60 - defaultRefreshTokenTTLSeconds = 30 * 24 * 60 * 60 + defaultAuthCodeTTLSeconds = 60 + defaultAccessTokenTTLSeconds = 60 * 60 ) +// statelessRegisteredClient is the in-memory shape parseCIMDMetadata +// returns. After DCR removal the only field anything reads is RedirectURIs +// — parseCIMDMetadata's `token_endpoint_auth_method` check happens at +// parse-time and never reaches the struct. type statelessRegisteredClient struct { - RedirectURIs []string `json:"redirect_uris"` - TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + RedirectURIs []string `json:"redirect_uris"` } type oauthPendingAuth struct { @@ -145,17 +157,18 @@ func (a *application) oauthJWESecret() []byte { func (a *application) mustJWESecret() ([]byte, error) { secret := a.oauthJWESecret() if len(secret) == 0 { - return nil, fmt.Errorf("oauth signing_secret is required for OAuth client registration and forward-mode token wrapping") + return nil, fmt.Errorf("oauth signing_secret is required for JWE-wrapped pending-auth state and downstream auth-code minting") } return secret, nil } // oauthKidV1 is the kid header set on cmd-minted OAuth JWE artifacts -// (client_id, refresh-token). Its presence selects the HKDF-derived key on -// decryption; absence (kid="") selects the legacy SHA256(secret) key for -// backwards compat with artifacts minted before the rotation cutover. After -// the longest legacy artifact lifetime expires (refresh tokens, default 30 -// days), the legacy fallback below can be removed. +// (pending-auth and auth-code). Its presence selects the HKDF-derived key +// on decryption; absence (kid="") selects the legacy SHA256(secret) key +// for backwards compat with artifacts minted before the HKDF rotation. +// Post-#115 the longest legacy artifact in flight is the 10-minute +// pending-auth JWE — the legacy fallback can be deleted in a follow-up +// after a >10-minute rolling restart window has passed. const oauthKidV1 = "v1" // HKDF info labels for cmd-internal OAuth key derivation. Each label produces @@ -516,10 +529,6 @@ func (a *application) oauthAuthorizationServerBaseURL(r *http.Request) string { return a.schemeAndHost(r) + a.oauthPrefix(r) } -func (a *application) oauthRegistrationPath() string { - return normalizedPath(a.GetCurrentConfig().Server.OAuth.RegistrationPath, defaultRegistrationPath) -} - func (a *application) oauthAuthorizationPath() string { return normalizedPath(a.GetCurrentConfig().Server.OAuth.AuthorizationPath, defaultAuthorizationPath) } @@ -702,23 +711,6 @@ func newPKCEVerifier() (string, error) { return base64.RawURLEncoding.EncodeToString(buf), nil } -func decodeStringSlice(value interface{}) []string { - switch typed := value.(type) { - case []string: - return append([]string{}, typed...) - case []interface{}: - out := make([]string, 0, len(typed)) - for _, item := range typed { - if str, ok := item.(string); ok { - out = append(out, str) - } - } - return out - default: - return nil - } -} - func sanitizeScope(scope string) string { return strings.Join(strings.Fields(scope), " ") } @@ -769,7 +761,7 @@ func normalizeUpstreamScopeForClient(scope string) string { // oidcScopesForAdvertisement returns the subset of cfg.Scopes that altinity-mcp // will surface to MCP clients via discovery metadata (protected-resource doc, -// authorization-server metadata, openid-configuration), DCR responses, and the +// authorization-server metadata, openid-configuration) and the // WWW-Authenticate challenge. Only an explicit OIDC-identity allowlist plus // Auth0's offline_access refresh-token gate is passed through; anything else // (URI-form upstream scopes like Google's https://www.googleapis.com/auth/…, @@ -863,7 +855,7 @@ func (a *application) fetchUserInfo(accessToken string) (*altinitymcp.OAuthClaim } req.Header.Set("Authorization", "Bearer "+accessToken) - resp, err := (&http.Client{Timeout: 10 * time.Second}).Do(req) + resp, err := (&http.Client{Timeout: oauthUpstreamHTTPTimeout}).Do(req) if err != nil { return nil, err } @@ -964,6 +956,20 @@ func (a *application) handleOAuthProtectedResource(w http.ResponseWriter, r *htt _ = json.NewEncoder(w).Encode(resp) } +// handleOAuthRegisterRemoved is the tombstone handler at /oauth/register. +// DCR was removed under #115 in favour of CIMD; this responds with an +// RFC 7591 §3.2.2-shaped JSON error so DCR clients in the wild see a +// diagnosable response rather than the bare mux 404. Always 410 Gone — +// the route is permanently retired, not "endpoint unavailable". +func handleOAuthRegisterRemoved(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusGone) + _ = json.NewEncoder(w).Encode(map[string]string{ + "error": "registration_not_supported", + "error_description": "Dynamic Client Registration is no longer supported; clients must use OAuth Client ID Metadata Documents (CIMD). See client_id_metadata_document_supported on /.well-known/oauth-authorization-server.", + }) +} + // oauthASMetadata returns the field set shared by RFC 8414 (oauth-authorization-server) // and OIDC Discovery (openid-configuration). Both endpoints serve the same // AS-side advertisement; OIDC adds two extra fields under gating mode (see @@ -1288,7 +1294,7 @@ func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Re form.Set("redirect_uri", callbackURL) form.Set("code_verifier", issued.UpstreamPKCEVerifier) - upstreamResp, err := (&http.Client{Timeout: 10 * time.Second}).PostForm(tokenURL, form) + upstreamResp, err := (&http.Client{Timeout: oauthUpstreamHTTPTimeout}).PostForm(tokenURL, form) if err != nil { log.Error().Err(err).Str("token_url", tokenURL).Msg("OAuth /token: upstream code exchange transport error") writeOAuthTokenError(w, http.StatusBadGateway, "server_error", "upstream code exchange failed") @@ -1324,10 +1330,29 @@ func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Re TokenType string `json:"token_type"` ExpiresIn int64 `json:"expires_in"` Scope string `json:"scope"` + // Error is present on non-RFC-compliant IdPs that signal failure + // via 200 OK + RFC 6749 §5.2 error JSON. Status-only checks miss + // this; treating a non-empty Error as upstream rejection keeps the + // HA replay contract intact (downstream sees invalid_grant). + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + log.Error().Err(err).Msg("OAuth /token: upstream response not JSON") + writeOAuthTokenError(w, http.StatusBadGateway, "server_error", "upstream returned non-JSON response") + return + } + if tokenResp.Error != "" { + log.Warn(). + Int("status", upstreamResp.StatusCode). + Str("upstream_error", tokenResp.Error). + Str("client_id", truncateForLog(clientID, 80)). + Msg("OAuth /token: upstream 2xx with RFC 6749 error body — treat as invalid_grant") + writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "upstream rejected the authorization code") + return } - if err := json.Unmarshal(body, &tokenResp); err != nil || (tokenResp.AccessToken == "" && tokenResp.IDToken == "") { + if tokenResp.AccessToken == "" && tokenResp.IDToken == "" { log.Error(). - Err(err). Bool("has_access_token", tokenResp.AccessToken != ""). Bool("has_id_token", tokenResp.IDToken != ""). Msg("OAuth /token: upstream response missing usable token") @@ -1344,12 +1369,13 @@ func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Re Msg("OAuth /token: upstream code exchange succeeded") // Validate the upstream identity before handing the bearer to the MCP - // client. We do NOT bind these claims into the downstream token in v1 - // (audience binding is deferred — see #115 § Non-goals); validation here - // exists purely to fail-fast on a malformed upstream response with a - // proper 502. Do not delete the underscored assignment without first - // re-introducing claim binding, or this side-effecting validation will - // look like dead code and get pruned. + // client. Claims are NOT bound into the downstream token (audience + // binding deferred per #115 § Non-goals); the validation has three + // jobs: fail-fast on a malformed upstream response with a proper 502, + // confirm the upstream id_token signature/audience for forward mode, + // and surface the id_token `exp` so we report an accurate `expires_in` + // to the MCP client below (used at the "expiresIn = identityClaims..." + // line further down). var identityClaims *altinitymcp.OAuthClaims if tokenResp.IDToken != "" { identityClaims, err = a.mcpServer.ValidateUpstreamIdentityToken(tokenResp.IDToken, cfg.Server.OAuth.ClientID) @@ -1366,8 +1392,6 @@ func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Re return } } - _ = identityClaims // validation-only; intentionally unused in v1. - if tokenResp.Scope == "" { tokenResp.Scope = issued.Scope } @@ -1445,10 +1469,12 @@ func (a *application) registerOAuthHTTPRoutes(mux *http.ServeMux) { mux.HandleFunc(path, a.handleOAuthOpenIDConfiguration) } - // /oauth/register is intentionally NOT mounted: DCR was removed in - // favour of CIMD per #115. Old clients calling /oauth/register get the - // mux's default 404. The .well-known metadata no longer advertises - // registration_endpoint either. + // DCR was removed in favour of CIMD per #115. Mount /oauth/register + // with a stub that returns HTTP 410 Gone + an RFC 7591 §3.2.2-shaped + // JSON error so an in-the-wild DCR client sees a diagnosable response + // rather than the bare mux 404. + mux.HandleFunc(defaultRegistrationPath, handleOAuthRegisterRemoved) + for _, path := range uniquePaths(a.oauthAuthorizationPath(), defaultAuthorizationPath) { mux.HandleFunc(path, a.handleOAuthAuthorize) } diff --git a/cmd/altinity-mcp/oauth_server_test.go b/cmd/altinity-mcp/oauth_server_test.go index 93d9114..01fe94e 100644 --- a/cmd/altinity-mcp/oauth_server_test.go +++ b/cmd/altinity-mcp/oauth_server_test.go @@ -25,17 +25,6 @@ func decodeJWTSegment(seg string) ([]byte, error) { return base64.URLEncoding.DecodeString(seg) } -// TestOAuthJWEHKDFRoundtripAndLegacyFallback covers the v1 (HKDF) ↔ legacy -// (SHA256) compatibility surface introduced in Step 2 of the OAuth review. -// Three invariants: -// -// 1. Newly-issued artifacts emit kid="v1" in the JWE/JWS header. -// 2. v1 artifacts decrypt/verify with the matching HKDF-derived key — and -// ONLY with that key (a leak in one info-namespace doesn't compromise -// another). -// 3. Legacy artifacts (no kid, single SHA256(secret) key) still decrypt and -// verify, so existing refresh tokens / client_ids minted before the -// cutover keep working through the rotation window. func TestOAuthMCPAuthInjector(t *testing.T) { t.Parallel() @@ -382,38 +371,6 @@ func TestTruncateForLog(t *testing.T) { } } -func TestDecodeStringSlice(t *testing.T) { - t.Parallel() - t.Run("string_slice", func(t *testing.T) { - t.Parallel() - result := decodeStringSlice([]string{"a", "b"}) - require.Equal(t, []string{"a", "b"}, result) - }) - t.Run("interface_slice", func(t *testing.T) { - t.Parallel() - result := decodeStringSlice([]interface{}{"a", "b"}) - require.Equal(t, []string{"a", "b"}, result) - }) - t.Run("interface_slice_non_strings_skipped", func(t *testing.T) { - t.Parallel() - result := decodeStringSlice([]interface{}{"a", 123, "b"}) - require.Equal(t, []string{"a", "b"}, result) - }) - t.Run("nil_returns_nil", func(t *testing.T) { - t.Parallel() - require.Nil(t, decodeStringSlice(nil)) - }) - t.Run("unsupported_type_returns_nil", func(t *testing.T) { - t.Parallel() - require.Nil(t, decodeStringSlice("not-a-slice")) - }) - t.Run("empty_interface_slice", func(t *testing.T) { - t.Parallel() - result := decodeStringSlice([]interface{}{}) - require.Empty(t, result) - }) -} - func TestOAuthClaimsFromUserInfo(t *testing.T) { t.Parallel() t.Run("all_standard_fields", func(t *testing.T) { diff --git a/docs/oauth_authorization.md b/docs/oauth_authorization.md index b47a6fd..24bcc9f 100644 --- a/docs/oauth_authorization.md +++ b/docs/oauth_authorization.md @@ -1,5 +1,30 @@ # OAuth 2.0 Authorization for Altinity MCP Server +> **Updated 2026-05-15 (#115 landing):** Dynamic Client Registration (DCR) has +> been removed. Inbound MCP OAuth clients must use the spec-track replacement, +> OAuth Client ID Metadata Documents ([draft-ietf-oauth-client-id-metadata-document](https://datatracker.ietf.org/doc/draft-ietf-oauth-client-id-metadata-document/)). +> claude.ai and ChatGPT both publish CIMD documents today. The +> `/.well-known/oauth-authorization-server` document advertises +> `client_id_metadata_document_supported: true`, drops `registration_endpoint`, +> and lists `token_endpoint_auth_methods_supported: ["none"]` plus +> `grant_types_supported: ["authorization_code"]`. `/oauth/register` returns +> HTTP 410 Gone with an RFC 7591 §3.2.2-shaped JSON error. +> +> v1 issues **no downstream refresh tokens**. CIMD clients re-authorize via +> `/oauth/authorize` when the access token expires. The +> `upstream_offline_access` flag only controls whether `offline_access` is +> appended to the upstream scope (to influence the IdP's consent screen); any +> upstream refresh token returned is discarded. +> +> The HA replay model (#115 § HA replay) defers upstream authorization-code +> redemption from `/oauth/callback` to `/oauth/token` so the upstream IdP +> becomes the cross-replica replay oracle via `invalid_grant`. +> +> The rest of this document still describes the gating / forward / broker +> dichotomy and the trust model. Mentions of DCR below predate #115 and +> apply only to the upstream IdP side (Auth0 / Hydra / Keycloak), never to +> altinity-mcp itself. + This document explains how to configure OAuth 2.0 / OpenID Connect (OIDC) authentication with the Altinity MCP Server. ## Overview diff --git a/pkg/config/config.go b/pkg/config/config.go index c055237..886a25f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -150,10 +150,15 @@ type OAuthConfig struct { // Scopes is the list of OAuth scopes to request Scopes []string `json:"scopes" yaml:"scopes" flag:"oauth-scopes" env:"MCP_OAUTH_SCOPES" desc:"OAuth scopes to request"` - // UpstreamOfflineAccess opts forward mode into requesting offline_access from the upstream IdP - // and wrapping the returned refresh token in a stateless JWE handed back to the MCP client. - // Default false: forward mode behaves exactly as before (no refresh token issued, refresh grant rejected). - UpstreamOfflineAccess bool `json:"upstream_offline_access" yaml:"upstream_offline_access" flag:"oauth-upstream-offline-access" env:"MCP_OAUTH_UPSTREAM_OFFLINE_ACCESS" desc:"Forward mode: request offline_access upstream and issue JWE-wrapped refresh tokens"` + // UpstreamOfflineAccess opts forward/broker mode into appending + // `offline_access` to the scope sent upstream. Used mainly so the IdP's + // consent screen offers long-lived sessions; the upstream refresh token + // MCP receives is currently discarded. v1 issues NO downstream refresh + // tokens to CIMD clients — they re-authorize via /oauth/authorize when + // the access token expires. See #115 § Refresh-token policy. + // Default false. Effect is upstream-only; this flag does not turn on + // any downstream refresh-token issuance. + UpstreamOfflineAccess bool `json:"upstream_offline_access" yaml:"upstream_offline_access" flag:"oauth-upstream-offline-access" env:"MCP_OAUTH_UPSTREAM_OFFLINE_ACCESS" desc:"Append offline_access to the upstream scope so the IdP's consent screen offers long-lived sessions. v1 does NOT issue downstream refresh tokens regardless of this flag — clients re-authorize via /oauth/authorize."` // BrokerUpstream opts gating mode into the DCR-via-MCP broker pattern that // forward mode uses by default. When true under gating mode, altinity-mcp: @@ -195,9 +200,6 @@ type OAuthConfig struct { // Set true only when the IdP omits email_verified or the operator trusts upstream verification. AllowUnverifiedEmail bool `json:"allow_unverified_email" yaml:"allow_unverified_email" flag:"oauth-allow-unverified-email" env:"MCP_OAUTH_ALLOW_UNVERIFIED_EMAIL" desc:"Accept OAuth identities with email_verified=false (default: reject)"` - // RegistrationPath configures the relative path for dynamic client registration. - RegistrationPath string `json:"registration_path" yaml:"registration_path" flag:"oauth-registration-path" env:"MCP_OAUTH_REGISTRATION_PATH" desc:"Relative path for OAuth client registration endpoint"` - // AuthorizationPath configures the relative path for the authorization endpoint. AuthorizationPath string `json:"authorization_path" yaml:"authorization_path" flag:"oauth-authorization-path" env:"MCP_OAUTH_AUTHORIZATION_PATH" desc:"Relative path for OAuth authorization endpoint"` diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 502e5a1..21ea1bf 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -546,7 +546,6 @@ func TestConfigStructs(t *testing.T) { RequiredScopes: []string{"read"}, ClickHouseHeaderName: "X-Custom-Token", ClaimsToHeaders: map[string]string{"sub": "X-User", "email": "X-Email"}, - RegistrationPath: "/register", AuthorizationPath: "/authorize", CallbackPath: "/callback", TokenPath: "/token", @@ -570,7 +569,6 @@ func TestConfigStructs(t *testing.T) { require.Equal(t, "X-Custom-Token", cfg.ClickHouseHeaderName) require.Equal(t, "X-User", cfg.ClaimsToHeaders["sub"]) require.Equal(t, "X-Email", cfg.ClaimsToHeaders["email"]) - require.Equal(t, "/register", cfg.RegistrationPath) require.Equal(t, "/authorize", cfg.AuthorizationPath) require.Equal(t, "/callback", cfg.CallbackPath) require.Equal(t, "/token", cfg.TokenPath) @@ -611,7 +609,6 @@ server: client_secret: "secret-456" token_url: "https://auth.example.com/oauth/token" auth_url: "https://auth.example.com/oauth/authorize" - registration_path: "/register" authorization_path: "/authorize" callback_path: "/callback" token_path: "/token" @@ -660,7 +657,6 @@ logging: require.Equal(t, "X-Custom-Token", cfg.Server.OAuth.ClickHouseHeaderName) require.Equal(t, "X-ClickHouse-User", cfg.Server.OAuth.ClaimsToHeaders["sub"]) require.Equal(t, "X-ClickHouse-Email", cfg.Server.OAuth.ClaimsToHeaders["email"]) - require.Equal(t, "/register", cfg.Server.OAuth.RegistrationPath) require.Equal(t, "/authorize", cfg.Server.OAuth.AuthorizationPath) require.Equal(t, "/callback", cfg.Server.OAuth.CallbackPath) require.Equal(t, "/token", cfg.Server.OAuth.TokenPath)