diff --git a/pkg/context/token.go b/pkg/context/token.go index beddb02b2..97091a922 100644 --- a/pkg/context/token.go +++ b/pkg/context/token.go @@ -6,27 +6,37 @@ import ( "github.com/github/github-mcp-server/pkg/utils" ) -// tokenCtxKey is a context key for authentication token information -type tokenCtx string - -var tokenCtxKey tokenCtx = "tokenctx" +type tokenCtxKey struct{} type TokenInfo struct { - Token string - TokenType utils.TokenType - ScopesFetched bool - Scopes []string + Token string + TokenType utils.TokenType } // WithTokenInfo adds TokenInfo to the context func WithTokenInfo(ctx context.Context, tokenInfo *TokenInfo) context.Context { - return context.WithValue(ctx, tokenCtxKey, tokenInfo) + return context.WithValue(ctx, tokenCtxKey{}, tokenInfo) } // GetTokenInfo retrieves the authentication token from the context func GetTokenInfo(ctx context.Context) (*TokenInfo, bool) { - if tokenInfo, ok := ctx.Value(tokenCtxKey).(*TokenInfo); ok { + if tokenInfo, ok := ctx.Value(tokenCtxKey{}).(*TokenInfo); ok { return tokenInfo, true } return nil, false } + +type tokenScopesKey struct{} + +// WithTokenScopes adds token scopes to the context +func WithTokenScopes(ctx context.Context, scopes []string) context.Context { + return context.WithValue(ctx, tokenScopesKey{}, scopes) +} + +// GetTokenScopes retrieves token scopes from the context +func GetTokenScopes(ctx context.Context) ([]string, bool) { + if scopes, ok := ctx.Value(tokenScopesKey{}).([]string); ok { + return scopes, true + } + return nil, false +} diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 875d54bbb..3c6c5302e 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -2,6 +2,7 @@ package http import ( "context" + "errors" "log/slog" "net/http" @@ -178,6 +179,14 @@ func withInsiders(next http.Handler) http.Handler { func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { inv, err := h.inventoryFactoryFunc(r) if err != nil { + if errors.Is(err, inventory.ErrUnknownTools) { + w.WriteHeader(http.StatusBadRequest) + if _, writeErr := w.Write([]byte(err.Error())); writeErr != nil { + h.logger.Error("failed to write response", "error", writeErr) + } + return + } + w.WriteHeader(http.StatusInternalServerError) return } @@ -278,8 +287,10 @@ func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.Fetche // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. // Fine-grained PATs and other token types don't support this, so we skip filtering. if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken { - if tokenInfo.ScopesFetched { - return b.WithFilter(github.CreateToolScopeFilter(tokenInfo.Scopes)) + // Check if scopes are already in context (should be set by WithPATScopes). If not, fetch them. + existingScopes, ok := ghcontext.GetTokenScopes(ctx) + if ok { + return b.WithFilter(github.CreateToolScopeFilter(existingScopes)) } scopesList, err := fetcher.FetchTokenScopes(ctx, tokenInfo.Token) diff --git a/pkg/http/middleware/pat_scope.go b/pkg/http/middleware/pat_scope.go index 8b77b3d32..bb1efdc01 100644 --- a/pkg/http/middleware/pat_scope.go +++ b/pkg/http/middleware/pat_scope.go @@ -26,6 +26,13 @@ func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) fu // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. // Fine-grained PATs and other token types don't support this, so we skip filtering. if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken { + existingScopes, ok := ghcontext.GetTokenScopes(ctx) + if ok { + logger.Debug("using existing scopes from context", "scopes", existingScopes) + next.ServeHTTP(w, r) + return + } + scopesList, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token) if err != nil { logger.Warn("failed to fetch PAT scopes", "error", err) @@ -33,11 +40,8 @@ func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) fu return } - tokenInfo.Scopes = scopesList - tokenInfo.ScopesFetched = true - // Store fetched scopes in context for downstream use - ctx := ghcontext.WithTokenInfo(ctx, tokenInfo) + ctx = ghcontext.WithTokenScopes(ctx, scopesList) next.ServeHTTP(w, r.WithContext(ctx)) return diff --git a/pkg/http/middleware/pat_scope_test.go b/pkg/http/middleware/pat_scope_test.go index eb472bcf1..0607b8cf2 100644 --- a/pkg/http/middleware/pat_scope_test.go +++ b/pkg/http/middleware/pat_scope_test.go @@ -111,12 +111,13 @@ func TestWithPATScopes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var capturedTokenInfo *ghcontext.TokenInfo + var capturedScopes []string + var scopesFound bool var nextHandlerCalled bool nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { nextHandlerCalled = true - capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context()) + capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context()) w.WriteHeader(http.StatusOK) }) @@ -141,10 +142,9 @@ func TestWithPATScopes(t *testing.T) { assert.Equal(t, tt.expectNextHandlerCalled, nextHandlerCalled, "next handler called mismatch") - if tt.expectNextHandlerCalled && tt.tokenInfo != nil { - require.NotNil(t, capturedTokenInfo, "expected token info in context") - assert.Equal(t, tt.expectScopesFetched, capturedTokenInfo.ScopesFetched) - assert.Equal(t, tt.expectedScopes, capturedTokenInfo.Scopes) + if tt.expectNextHandlerCalled { + assert.Equal(t, tt.expectScopesFetched, scopesFound, "scopes found mismatch") + assert.Equal(t, tt.expectedScopes, capturedScopes) } }) } @@ -154,9 +154,12 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) { logger := slog.Default() var capturedTokenInfo *ghcontext.TokenInfo + var capturedScopes []string + var scopesFound bool nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context()) + capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context()) w.WriteHeader(http.StatusOK) }) @@ -182,6 +185,6 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) { require.NotNil(t, capturedTokenInfo) assert.Equal(t, originalTokenInfo.Token, capturedTokenInfo.Token) assert.Equal(t, originalTokenInfo.TokenType, capturedTokenInfo.TokenType) - assert.True(t, capturedTokenInfo.ScopesFetched) - assert.Equal(t, []string{"repo", "user"}, capturedTokenInfo.Scopes) + assert.True(t, scopesFound) + assert.Equal(t, []string{"repo", "user"}, capturedScopes) } diff --git a/pkg/http/middleware/scope_challenge.go b/pkg/http/middleware/scope_challenge.go index 526797241..1a86bf93c 100644 --- a/pkg/http/middleware/scope_challenge.go +++ b/pkg/http/middleware/scope_challenge.go @@ -94,17 +94,19 @@ func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInter return } - // Get OAuth scopes from GitHub API - activeScopes, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token) - if err != nil { - next.ServeHTTP(w, r) - return + // Get OAuth scopes for Token. First check if scopes are already in context, then fetch from GitHub if not present. + // This allows Remote Server to pass scope info to avoid redundant GitHub API calls. + activeScopes, ok := ghcontext.GetTokenScopes(ctx) + if !ok || (len(activeScopes) == 0 && tokenInfo.Token != "") { + activeScopes, err = scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token) + if err != nil { + next.ServeHTTP(w, r) + return + } } // Store active scopes in context for downstream use - tokenInfo.Scopes = activeScopes - tokenInfo.ScopesFetched = true - ctx = ghcontext.WithTokenInfo(ctx, tokenInfo) + ctx = ghcontext.WithTokenScopes(ctx, activeScopes) r = r.WithContext(ctx) // Check if user has the required scopes diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go index c362ea201..012bbabef 100644 --- a/pkg/http/middleware/token.go +++ b/pkg/http/middleware/token.go @@ -13,6 +13,16 @@ import ( func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Check if token info already exists in context, if it does, skip extraction. + // In remote setup, we may have already extracted token info earlier. + if _, ok := ghcontext.GetTokenInfo(ctx); ok { + // Token info already exists in context, skip extraction + next.ServeHTTP(w, r) + return + } + tokenType, token, err := utils.ParseAuthorizationHeader(r) if err != nil { // For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec @@ -25,7 +35,6 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl return } - ctx := r.Context() ctx = ghcontext.WithTokenInfo(ctx, &ghcontext.TokenInfo{ Token: token, TokenType: tokenType, diff --git a/pkg/inventory/builder.go b/pkg/inventory/builder.go index 35ccd5932..6d2f080aa 100644 --- a/pkg/inventory/builder.go +++ b/pkg/inventory/builder.go @@ -2,12 +2,18 @@ package inventory import ( "context" + "errors" "fmt" "maps" "slices" "strings" ) +var ( + // ErrUnknownTools is returned when tools specified via WithTools() are not recognized. + ErrUnknownTools = errors.New("unknown tools specified in WithTools") +) + // ToolFilter is a function that determines if a tool should be included. // Returns true if the tool should be included, false to exclude it. type ToolFilter func(ctx context.Context, tool *ServerTool) (bool, error) @@ -219,7 +225,7 @@ func (b *Builder) Build() (*Inventory, error) { // Error out if there are unrecognized tools if len(unrecognizedTools) > 0 { - return nil, fmt.Errorf("unrecognized tools: %s", strings.Join(unrecognizedTools, ", ")) + return nil, fmt.Errorf("%w: %s", ErrUnknownTools, strings.Join(unrecognizedTools, ", ")) } }