Skip to content
Merged
30 changes: 20 additions & 10 deletions pkg/context/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,37 @@ import (
"github.com/github/github-mcp-server/pkg/utils"
)

// tokenCtxKey is a context key for authentication token information
type tokenCtx string

var tokenCtxKey tokenCtx = "tokenctx"
type tokenCtxKey struct{}

type TokenInfo struct {
Token string
TokenType utils.TokenType
ScopesFetched bool
Scopes []string
Token string
TokenType utils.TokenType
}

// WithTokenInfo adds TokenInfo to the context
func WithTokenInfo(ctx context.Context, tokenInfo *TokenInfo) context.Context {
return context.WithValue(ctx, tokenCtxKey, tokenInfo)
return context.WithValue(ctx, tokenCtxKey{}, tokenInfo)
}

// GetTokenInfo retrieves the authentication token from the context
func GetTokenInfo(ctx context.Context) (*TokenInfo, bool) {
if tokenInfo, ok := ctx.Value(tokenCtxKey).(*TokenInfo); ok {
if tokenInfo, ok := ctx.Value(tokenCtxKey{}).(*TokenInfo); ok {
return tokenInfo, true
}
return nil, false
}

type tokenScopesKey struct{}

// WithTokenScopes adds token scopes to the context
func WithTokenScopes(ctx context.Context, scopes []string) context.Context {
return context.WithValue(ctx, tokenScopesKey{}, scopes)
}

// GetTokenScopes retrieves token scopes from the context
func GetTokenScopes(ctx context.Context) ([]string, bool) {
if scopes, ok := ctx.Value(tokenScopesKey{}).([]string); ok {
return scopes, true
}
return nil, false
}
15 changes: 13 additions & 2 deletions pkg/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http

import (
"context"
"errors"
"log/slog"
"net/http"

Expand Down Expand Up @@ -178,6 +179,14 @@ func withInsiders(next http.Handler) http.Handler {
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
inv, err := h.inventoryFactoryFunc(r)
if err != nil {
if errors.Is(err, inventory.ErrUnknownTools) {
w.WriteHeader(http.StatusBadRequest)
if _, writeErr := w.Write([]byte(err.Error())); writeErr != nil {
h.logger.Error("failed to write response", "error", writeErr)
}
return
}

w.WriteHeader(http.StatusInternalServerError)
return
}
Expand Down Expand Up @@ -278,8 +287,10 @@ func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.Fetche
// Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header.
// Fine-grained PATs and other token types don't support this, so we skip filtering.
if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken {
if tokenInfo.ScopesFetched {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, was there a reason we needed this flag at all?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to handle the scenario where we had a token info from copilot API, but it was a token with no scopes. But having it combined with the TokenInfo meant the ordering was a little awkward.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also implied now, if ok is false from scopes, ok := ghcontext.GetTokenScopes(ctx)

return b.WithFilter(github.CreateToolScopeFilter(tokenInfo.Scopes))
// Check if scopes are already in context (should be set by WithPATScopes). If not, fetch them.
existingScopes, ok := ghcontext.GetTokenScopes(ctx)
if ok {
return b.WithFilter(github.CreateToolScopeFilter(existingScopes))
}

scopesList, err := fetcher.FetchTokenScopes(ctx, tokenInfo.Token)
Expand Down
12 changes: 8 additions & 4 deletions pkg/http/middleware/pat_scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,22 @@ func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) fu
// Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header.
// Fine-grained PATs and other token types don't support this, so we skip filtering.
if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken {
existingScopes, ok := ghcontext.GetTokenScopes(ctx)
if ok {
logger.Debug("using existing scopes from context", "scopes", existingScopes)
next.ServeHTTP(w, r)
return
}

scopesList, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token)
if err != nil {
logger.Warn("failed to fetch PAT scopes", "error", err)
next.ServeHTTP(w, r)
return
}

tokenInfo.Scopes = scopesList
tokenInfo.ScopesFetched = true

// Store fetched scopes in context for downstream use
ctx := ghcontext.WithTokenInfo(ctx, tokenInfo)
ctx = ghcontext.WithTokenScopes(ctx, scopesList)

next.ServeHTTP(w, r.WithContext(ctx))
return
Expand Down
19 changes: 11 additions & 8 deletions pkg/http/middleware/pat_scope_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,13 @@ func TestWithPATScopes(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedTokenInfo *ghcontext.TokenInfo
var capturedScopes []string
var scopesFound bool
var nextHandlerCalled bool

nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextHandlerCalled = true
capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context())
capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context())
w.WriteHeader(http.StatusOK)
})

Expand All @@ -141,10 +142,9 @@ func TestWithPATScopes(t *testing.T) {

assert.Equal(t, tt.expectNextHandlerCalled, nextHandlerCalled, "next handler called mismatch")

if tt.expectNextHandlerCalled && tt.tokenInfo != nil {
require.NotNil(t, capturedTokenInfo, "expected token info in context")
assert.Equal(t, tt.expectScopesFetched, capturedTokenInfo.ScopesFetched)
assert.Equal(t, tt.expectedScopes, capturedTokenInfo.Scopes)
if tt.expectNextHandlerCalled {
assert.Equal(t, tt.expectScopesFetched, scopesFound, "scopes found mismatch")
assert.Equal(t, tt.expectedScopes, capturedScopes)
}
})
}
Expand All @@ -154,9 +154,12 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) {
logger := slog.Default()

var capturedTokenInfo *ghcontext.TokenInfo
var capturedScopes []string
var scopesFound bool

nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context())
capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context())
w.WriteHeader(http.StatusOK)
})

Expand All @@ -182,6 +185,6 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) {
require.NotNil(t, capturedTokenInfo)
assert.Equal(t, originalTokenInfo.Token, capturedTokenInfo.Token)
assert.Equal(t, originalTokenInfo.TokenType, capturedTokenInfo.TokenType)
assert.True(t, capturedTokenInfo.ScopesFetched)
assert.Equal(t, []string{"repo", "user"}, capturedTokenInfo.Scopes)
assert.True(t, scopesFound)
assert.Equal(t, []string{"repo", "user"}, capturedScopes)
}
18 changes: 10 additions & 8 deletions pkg/http/middleware/scope_challenge.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,19 @@ func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInter
return
}

// Get OAuth scopes from GitHub API
activeScopes, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token)
if err != nil {
next.ServeHTTP(w, r)
return
// Get OAuth scopes for Token. First check if scopes are already in context, then fetch from GitHub if not present.
// This allows Remote Server to pass scope info to avoid redundant GitHub API calls.
activeScopes, ok := ghcontext.GetTokenScopes(ctx)
if !ok || (len(activeScopes) == 0 && tokenInfo.Token != "") {
activeScopes, err = scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token)
if err != nil {
next.ServeHTTP(w, r)
return
}
}

// Store active scopes in context for downstream use
tokenInfo.Scopes = activeScopes
tokenInfo.ScopesFetched = true
ctx = ghcontext.WithTokenInfo(ctx, tokenInfo)
ctx = ghcontext.WithTokenScopes(ctx, activeScopes)
r = r.WithContext(ctx)
Comment on lines 97 to 110
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New behavior reads scopes from context to avoid redundant GitHub API calls, but there’s no unit test coverage for (1) scopes already present in context skipping FetchTokenScopes, and (2) scopes being stored back into context for downstream use. Adding a table-driven test with a mock fetcher would prevent regressions here.

Copilot generated this review using guidance from repository custom instructions.

// Check if user has the required scopes
Expand Down
11 changes: 10 additions & 1 deletion pkg/http/middleware/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ import (
func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// Check if token info already exists in context, if it does, skip extraction.
// In remote setup, we may have already extracted token info earlier.
if _, ok := ghcontext.GetTokenInfo(ctx); ok {
// Token info already exists in context, skip extraction
next.ServeHTTP(w, r)
return
}
Comment on lines +18 to +24
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This middleware now short-circuits when TokenInfo is already present in context, but token_test.go doesn’t cover the new branch. Add a test that pre-populates TokenInfo in the request context and asserts the middleware does not overwrite it and does not require an Authorization header.

Copilot generated this review using guidance from repository custom instructions.

tokenType, token, err := utils.ParseAuthorizationHeader(r)
if err != nil {
// For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec
Expand All @@ -25,7 +35,6 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl
return
}

ctx := r.Context()
ctx = ghcontext.WithTokenInfo(ctx, &ghcontext.TokenInfo{
Token: token,
TokenType: tokenType,
Expand Down
8 changes: 7 additions & 1 deletion pkg/inventory/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@ package inventory

import (
"context"
"errors"
"fmt"
"maps"
"slices"
"strings"
)

var (
// ErrUnknownTools is returned when tools specified via WithTools() are not recognized.
ErrUnknownTools = errors.New("unknown tools specified in WithTools")
)

// ToolFilter is a function that determines if a tool should be included.
// Returns true if the tool should be included, false to exclude it.
type ToolFilter func(ctx context.Context, tool *ServerTool) (bool, error)
Expand Down Expand Up @@ -219,7 +225,7 @@ func (b *Builder) Build() (*Inventory, error) {

// Error out if there are unrecognized tools
if len(unrecognizedTools) > 0 {
return nil, fmt.Errorf("unrecognized tools: %s", strings.Join(unrecognizedTools, ", "))
return nil, fmt.Errorf("%w: %s", ErrUnknownTools, strings.Join(unrecognizedTools, ", "))
}
}

Expand Down