From 799e80e1233dfc4ae07b46de8a4fffa235cfdc4e Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 3 Nov 2025 22:39:32 +0000 Subject: [PATCH 01/11] Reorganize incoming auth factory into subfolder to prevent an import cycle Move incoming authentication factory from pkg/vmcp/auth/ to pkg/vmcp/auth/factory/ subfolder to improve code organization. This separates factory code from core authentication types and middleware. --- pkg/vmcp/auth/{incoming_factory.go => factory/incoming.go} | 0 .../auth/{incoming_factory_test.go => factory/incoming_test.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename pkg/vmcp/auth/{incoming_factory.go => factory/incoming.go} (100%) rename pkg/vmcp/auth/{incoming_factory_test.go => factory/incoming_test.go} (100%) diff --git a/pkg/vmcp/auth/incoming_factory.go b/pkg/vmcp/auth/factory/incoming.go similarity index 100% rename from pkg/vmcp/auth/incoming_factory.go rename to pkg/vmcp/auth/factory/incoming.go diff --git a/pkg/vmcp/auth/incoming_factory_test.go b/pkg/vmcp/auth/factory/incoming_test.go similarity index 100% rename from pkg/vmcp/auth/incoming_factory_test.go rename to pkg/vmcp/auth/factory/incoming_test.go From 8bb674d4c87263146d7f65fc99604a1b94749d66 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 4 Nov 2025 13:31:15 +0000 Subject: [PATCH 02/11] Refactor outgoing auth to separate registry from strategy Rename OutgoingAuthenticator to OutgoingAuthRegistry to better reflect its responsibility as a strategy registry rather than an authenticator. The interface now focuses solely on strategy management (registration and retrieval), while authentication is performed directly by Strategy implementations. This separation improves performance by eliminating indirection in the hot path (per-request authentication) and clarifies the single responsibility of each component: the registry manages strategies, strategies perform authentication. --- pkg/vmcp/auth/auth.go | 45 +- pkg/vmcp/auth/outgoing_authenticator.go | 130 ------ pkg/vmcp/auth/outgoing_authenticator_test.go | 455 ------------------- pkg/vmcp/auth/outgoing_registry.go | 103 +++++ pkg/vmcp/auth/outgoing_registry_test.go | 263 +++++++++++ pkg/vmcp/doc.go | 14 +- pkg/vmcp/types.go | 2 +- 7 files changed, 404 insertions(+), 608 deletions(-) delete mode 100644 pkg/vmcp/auth/outgoing_authenticator.go delete mode 100644 pkg/vmcp/auth/outgoing_authenticator_test.go create mode 100644 pkg/vmcp/auth/outgoing_registry.go create mode 100644 pkg/vmcp/auth/outgoing_registry_test.go diff --git a/pkg/vmcp/auth/auth.go b/pkg/vmcp/auth/auth.go index 76f9626eb..455b6e71c 100644 --- a/pkg/vmcp/auth/auth.go +++ b/pkg/vmcp/auth/auth.go @@ -1,7 +1,7 @@ // Package auth provides authentication for Virtual MCP Server. // // This package defines: -// - OutgoingAuthenticator: Authenticates vMCP to backend servers +// - OutgoingAuthRegistry: Registry for managing backend authentication strategies // - Strategy: Pluggable authentication strategies for backends // // Incoming authentication uses pkg/auth middleware (OIDC, local, anonymous) @@ -17,24 +17,39 @@ import ( "github.com/stacklok/toolhive/pkg/auth" ) -// OutgoingAuthenticator handles authentication to backend MCP servers. -// This is responsible for obtaining and injecting appropriate credentials -// for each backend based on its authentication strategy. +// OutgoingAuthRegistry manages authentication strategies for outgoing requests to backend MCP servers. +// This is a registry that stores and retrieves Strategy implementations. // -// The specific authentication strategies and their behavior will be defined -// during implementation based on the design decisions documented in the -// Virtual MCP Server proposal. -type OutgoingAuthenticator interface { - // AuthenticateRequest adds authentication to an outgoing backend request. - // The strategy and metadata are provided in the BackendTarget. - AuthenticateRequest(ctx context.Context, req *http.Request, strategy string, metadata map[string]any) error - - // GetStrategy returns the authentication strategy handler for a given strategy name. - // This enables extensibility - new strategies can be registered. +// The registry supports dynamic strategy registration, allowing custom authentication +// strategies to be added at runtime. Once registered, strategies can be retrieved +// by name and used to authenticate requests to backends. +// +// Responsibilities: +// - Maintain registry of available strategies +// - Retrieve strategies by name +// - Register new strategies dynamically +// +// This registry does NOT perform authentication itself. Authentication is performed +// by Strategy implementations retrieved from this registry. +// +// Usage Pattern: +// 1. Register strategies during application initialization +// 2. Resolve strategy once at client creation time (cold path) +// 3. Call strategy.Authenticate() directly per-request (hot path) +// +// Thread-safety: Implementations must be safe for concurrent access. +type OutgoingAuthRegistry interface { + // GetStrategy retrieves an authentication strategy by name. + // Returns an error if the strategy is not found. GetStrategy(name string) (Strategy, error) // RegisterStrategy registers a new authentication strategy. - // This allows custom auth strategies to be added at runtime. + // The strategy name must match the name returned by strategy.Name(). + // Returns an error if: + // - name is empty + // - strategy is nil + // - a strategy with the same name is already registered + // - strategy.Name() does not match the registration name RegisterStrategy(name string, strategy Strategy) error } diff --git a/pkg/vmcp/auth/outgoing_authenticator.go b/pkg/vmcp/auth/outgoing_authenticator.go deleted file mode 100644 index 6498f68dd..000000000 --- a/pkg/vmcp/auth/outgoing_authenticator.go +++ /dev/null @@ -1,130 +0,0 @@ -package auth - -import ( - "context" - "errors" - "fmt" - "net/http" - "sync" -) - -// DefaultOutgoingAuthenticator is a thread-safe implementation of OutgoingAuthenticator -// that maintains a registry of authentication strategies. -// -// Thread-safety: Safe for concurrent calls to RegisterStrategy and AuthenticateRequest. -// Strategy implementations must be thread-safe as they are called concurrently. -// It uses sync.RWMutex for thread-safety as HTTP servers are inherently concurrent. -// -// This authenticator supports dynamic registration of strategies and dispatches -// authentication requests to the appropriate strategy based on the strategy name. -// -// Example usage: -// -// auth := NewDefaultOutgoingAuthenticator() -// auth.RegisterStrategy("bearer", NewBearerStrategy()) -// err := auth.AuthenticateRequest(ctx, req, "bearer", metadata) -type DefaultOutgoingAuthenticator struct { - strategies map[string]Strategy - mu sync.RWMutex -} - -// NewDefaultOutgoingAuthenticator creates a new DefaultOutgoingAuthenticator -// with an empty strategy registry. -// -// Strategies must be registered using RegisterStrategy before they can be used -// for authentication. -func NewDefaultOutgoingAuthenticator() *DefaultOutgoingAuthenticator { - return &DefaultOutgoingAuthenticator{ - strategies: make(map[string]Strategy), - } -} - -// RegisterStrategy registers a new authentication strategy. -// -// This method is thread-safe and validates that: -// - name is not empty -// - strategy is not nil -// - no strategy is already registered with the same name -// -// Parameters: -// - name: The unique identifier for this strategy -// - strategy: The Strategy implementation to register -// -// Returns an error if validation fails or a strategy with the same name -// already exists. -func (a *DefaultOutgoingAuthenticator) RegisterStrategy(name string, strategy Strategy) error { - if name == "" { - return errors.New("strategy name cannot be empty") - } - if strategy == nil { - return errors.New("strategy cannot be nil") - } - - a.mu.Lock() - defer a.mu.Unlock() - - if _, exists := a.strategies[name]; exists { - return fmt.Errorf("strategy %q is already registered", name) - } - - a.strategies[name] = strategy - return nil -} - -// GetStrategy retrieves an authentication strategy by name. -// -// This method is thread-safe for concurrent reads. It returns the strategy -// if found, or an error if no strategy is registered with the given name. -// -// Parameters: -// - name: The identifier of the strategy to retrieve -// -// Returns: -// - Strategy: The registered strategy -// - error: An error if the strategy is not found -func (a *DefaultOutgoingAuthenticator) GetStrategy(name string) (Strategy, error) { - a.mu.RLock() - defer a.mu.RUnlock() - - strategy, exists := a.strategies[name] - if !exists { - return nil, fmt.Errorf("strategy %q not found", name) - } - - return strategy, nil -} - -// AuthenticateRequest adds authentication to an outgoing backend request. -// -// This method retrieves the specified strategy and delegates authentication -// to it. The strategy modifies the request by adding appropriate headers, -// tokens, or other authentication artifacts. -// -// Parameters: -// - ctx: Request context (may contain identity for pass-through auth) -// - req: The HTTP request to authenticate -// - strategyName: The name of the strategy to use -// - metadata: Strategy-specific configuration -// -// Returns an error if: -// - The strategy is not found -// - The metadata validation fails -// - The strategy's Authenticate method fails -func (a *DefaultOutgoingAuthenticator) AuthenticateRequest( - ctx context.Context, - req *http.Request, - strategyName string, - metadata map[string]any, -) error { - strategy, err := a.GetStrategy(strategyName) - if err != nil { - return err - } - - // Validate metadata before using it - if err := strategy.Validate(metadata); err != nil { - return fmt.Errorf("invalid metadata for strategy %q: %w", strategyName, err) - } - - return strategy.Authenticate(ctx, req, metadata) -} diff --git a/pkg/vmcp/auth/outgoing_authenticator_test.go b/pkg/vmcp/auth/outgoing_authenticator_test.go deleted file mode 100644 index 43073bc7d..000000000 --- a/pkg/vmcp/auth/outgoing_authenticator_test.go +++ /dev/null @@ -1,455 +0,0 @@ -package auth - -import ( - "context" - "errors" - "net/http" - "net/http/httptest" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks" -) - -type testContextKey struct{} - -var testKey = testContextKey{} - -func TestDefaultOutgoingAuthenticator_RegisterStrategy(t *testing.T) { - t.Parallel() - t.Run("register valid strategy succeeds", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - - err := auth.RegisterStrategy("bearer", strategy) - - require.NoError(t, err) - // Verify strategy was registered - retrieved, err := auth.GetStrategy("bearer") - require.NoError(t, err) - assert.Equal(t, strategy, retrieved) - }) - - t.Run("register empty name fails", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - - err := auth.RegisterStrategy("", strategy) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "strategy name cannot be empty") - }) - - t.Run("register nil strategy fails", func(t *testing.T) { - t.Parallel() - auth := NewDefaultOutgoingAuthenticator() - - err := auth.RegisterStrategy("bearer", nil) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "strategy cannot be nil") - }) - - t.Run("register duplicate name fails", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy1 := mocks.NewMockStrategy(ctrl) - strategy1.EXPECT().Name().Return("bearer").AnyTimes() - strategy2 := mocks.NewMockStrategy(ctrl) - - err := auth.RegisterStrategy("bearer", strategy1) - require.NoError(t, err) - - err = auth.RegisterStrategy("bearer", strategy2) - assert.Error(t, err) - assert.Contains(t, err.Error(), "already registered") - assert.Contains(t, err.Error(), "bearer") - }) - - t.Run("register multiple different strategies succeeds", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - bearer := mocks.NewMockStrategy(ctrl) - bearer.EXPECT().Name().Return("bearer").AnyTimes() - basic := mocks.NewMockStrategy(ctrl) - basic.EXPECT().Name().Return("basic").AnyTimes() - apiKey := mocks.NewMockStrategy(ctrl) - apiKey.EXPECT().Name().Return("api-key").AnyTimes() - - require.NoError(t, auth.RegisterStrategy("bearer", bearer)) - require.NoError(t, auth.RegisterStrategy("basic", basic)) - require.NoError(t, auth.RegisterStrategy("api-key", apiKey)) - - // Verify all strategies are registered - s1, err := auth.GetStrategy("bearer") - require.NoError(t, err) - assert.Equal(t, bearer, s1) - - s2, err := auth.GetStrategy("basic") - require.NoError(t, err) - assert.Equal(t, basic, s2) - - s3, err := auth.GetStrategy("api-key") - require.NoError(t, err) - assert.Equal(t, apiKey, s3) - }) -} - -func TestDefaultOutgoingAuthenticator_GetStrategy(t *testing.T) { - t.Parallel() - t.Run("get existing strategy succeeds", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - retrieved, err := auth.GetStrategy("bearer") - - require.NoError(t, err) - assert.Equal(t, strategy, retrieved) - }) - - t.Run("get non-existent strategy fails", func(t *testing.T) { - t.Parallel() - auth := NewDefaultOutgoingAuthenticator() - - retrieved, err := auth.GetStrategy("non-existent") - - assert.Error(t, err) - assert.Nil(t, retrieved) - assert.Contains(t, err.Error(), "not found") - assert.Contains(t, err.Error(), "non-existent") - }) - - t.Run("get from empty registry fails", func(t *testing.T) { - t.Parallel() - auth := NewDefaultOutgoingAuthenticator() - - retrieved, err := auth.GetStrategy("bearer") - - assert.Error(t, err) - assert.Nil(t, retrieved) - }) -} - -func TestDefaultOutgoingAuthenticator_AuthenticateRequest(t *testing.T) { - t.Parallel() - t.Run("authenticates with valid strategy", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - strategy.EXPECT().Validate(gomock.Any()).Return(nil) - strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, req *http.Request, _ map[string]any) error { - // Add a header to verify the request was modified - req.Header.Set("Authorization", "Bearer token123") - return nil - }, - ) - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - metadata := map[string]any{"token": "token123"} - err := auth.AuthenticateRequest(context.Background(), req, "bearer", metadata) - - require.NoError(t, err) - assert.Equal(t, "Bearer token123", req.Header.Get("Authorization")) - }) - - t.Run("fails with non-existent strategy", func(t *testing.T) { - t.Parallel() - auth := NewDefaultOutgoingAuthenticator() - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - - err := auth.AuthenticateRequest(context.Background(), req, "non-existent", nil) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "not found") - }) - - t.Run("returns error from strategy", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategyErr := errors.New("authentication failed") - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - strategy.EXPECT().Validate(gomock.Any()).Return(nil) - strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).Return(strategyErr) - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - err := auth.AuthenticateRequest(context.Background(), req, "bearer", nil) - - assert.Error(t, err) - assert.Equal(t, strategyErr, err) - }) - - t.Run("passes context and metadata to strategy", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - - var receivedCtx context.Context - var receivedMetadata map[string]any - - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - strategy.EXPECT().Validate(gomock.Any()).Return(nil) - strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(ctx context.Context, _ *http.Request, metadata map[string]any) error { - receivedCtx = ctx - receivedMetadata = metadata - return nil - }, - ) - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - ctx := context.WithValue(context.Background(), testKey, "test-value") - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - metadata := map[string]any{ - "token": "abc123", - "scopes": []string{"read", "write"}, - } - - err := auth.AuthenticateRequest(ctx, req, "bearer", metadata) - - require.NoError(t, err) - assert.NotNil(t, receivedCtx) - assert.Equal(t, "test-value", receivedCtx.Value(testKey)) - assert.Equal(t, metadata, receivedMetadata) - }) - - t.Run("validates metadata before authentication", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("test-strategy").AnyTimes() - - require.NoError(t, auth.RegisterStrategy("test-strategy", strategy)) - - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - metadata := map[string]any{"invalid": "data"} - - // Expect Validate to be called and return error - strategy.EXPECT(). - Validate(metadata). - Return(errors.New("invalid metadata")) - - // Authenticate should NOT be called if validation fails - // (no EXPECT for Authenticate) - - err := auth.AuthenticateRequest(context.Background(), req, "test-strategy", metadata) - - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid metadata for strategy") - assert.Contains(t, err.Error(), "test-strategy") - }) -} - -func TestDefaultOutgoingAuthenticator_ConcurrentAccess(t *testing.T) { - t.Parallel() - t.Run("concurrent GetStrategy calls are thread-safe", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - - // Register multiple strategies - strategies := []string{"bearer", "basic", "api-key", "oauth2", "jwt"} - for _, name := range strategies { - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return(name).AnyTimes() - require.NoError(t, auth.RegisterStrategy(name, strategy)) - } - - // Test concurrent reads with -race detector - const numGoroutines = 100 - const numOperations = 1000 - - var wg sync.WaitGroup - wg.Add(numGoroutines) - - errs := make(chan error, numGoroutines*numOperations) - - for i := 0; i < numGoroutines; i++ { - go func(_ int) { - defer wg.Done() - for j := 0; j < numOperations; j++ { - // Rotate through strategies - strategyName := strategies[j%len(strategies)] - strategy, err := auth.GetStrategy(strategyName) - if err != nil { - errs <- err - return - } - if strategy.Name() != strategyName { - errs <- errors.New("strategy name mismatch") - return - } - } - }(i) - } - - wg.Wait() - close(errs) - - // Check for errors - var collectedErrors []error - for err := range errs { - collectedErrors = append(collectedErrors, err) - } - - if len(collectedErrors) > 0 { - t.Fatalf("concurrent access produced errors: %v", collectedErrors) - } - }) - - t.Run("concurrent AuthenticateRequest calls are thread-safe", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - - // Counter to verify all authentications happen - var authCount int64 - var authMu sync.Mutex - - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - strategy.EXPECT().Validate(gomock.Any()).Return(nil).AnyTimes() - strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, req *http.Request, _ map[string]any) error { - authMu.Lock() - authCount++ - authMu.Unlock() - req.Header.Set("Authorization", "Bearer test") - return nil - }, - ).AnyTimes() - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - const numGoroutines = 100 - - var wg sync.WaitGroup - wg.Add(numGoroutines) - - errs := make(chan error, numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func() { - defer wg.Done() - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - err := auth.AuthenticateRequest(context.Background(), req, "bearer", nil) - if err != nil { - errs <- err - } - }() - } - - wg.Wait() - close(errs) - - // Check for errors - var collectedErrors []error - for err := range errs { - collectedErrors = append(collectedErrors, err) - } - - if len(collectedErrors) > 0 { - t.Fatalf("concurrent AuthenticateRequest produced errors: %v", collectedErrors) - } - - // Verify all authentications completed - assert.Equal(t, int64(numGoroutines), authCount) - }) - - t.Run("concurrent RegisterStrategy and GetStrategy are thread-safe", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - - const numRegister = 50 - const numGet = 50 - - var wg sync.WaitGroup - wg.Add(numRegister + numGet) - - errs := make(chan error, numRegister+numGet) - - // Goroutines registering strategies - for i := 0; i < numRegister; i++ { - go func(id int) { - defer wg.Done() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("strategy").AnyTimes() - strategyName := "strategy-" + string(rune('A'+id%26)) + string(rune('0'+id/26)) - err := auth.RegisterStrategy(strategyName, strategy) - if err != nil { - errs <- err - } - }(i) - } - - // Goroutines reading strategies (will mostly fail, but shouldn't race) - for i := 0; i < numGet; i++ { - go func(id int) { - defer wg.Done() - strategyName := "strategy-" + string(rune('A'+id%26)) + string(rune('0'+id/26)) - // GetStrategy may return error if not registered yet, that's OK - _, _ = auth.GetStrategy(strategyName) - }(i) - } - - wg.Wait() - close(errs) - - // Check for unexpected errors (registration errors are not expected) - var collectedErrors []error - for err := range errs { - collectedErrors = append(collectedErrors, err) - } - - if len(collectedErrors) > 0 { - t.Fatalf("concurrent RegisterStrategy/GetStrategy produced errors: %v", collectedErrors) - } - }) -} diff --git a/pkg/vmcp/auth/outgoing_registry.go b/pkg/vmcp/auth/outgoing_registry.go new file mode 100644 index 000000000..04f2513a3 --- /dev/null +++ b/pkg/vmcp/auth/outgoing_registry.go @@ -0,0 +1,103 @@ +package auth + +import ( + "errors" + "fmt" + "sync" +) + +// DefaultOutgoingAuthRegistry is a thread-safe implementation of OutgoingAuthRegistry +// that maintains a registry of authentication strategies. +// +// Thread-safety: Safe for concurrent calls to RegisterStrategy and GetStrategy. +// Strategy implementations must be thread-safe as they are called concurrently. +// It uses sync.RWMutex for thread-safety as HTTP servers are inherently concurrent. +// +// This registry supports dynamic registration of strategies and retrieval by name. +// It does not perform authentication itself - that is done by the Strategy implementations. +// +// Example usage: +// +// registry := NewDefaultOutgoingAuthRegistry() +// registry.RegisterStrategy("header_injection", NewHeaderInjectionStrategy()) +// strategy, err := registry.GetStrategy("header_injection") +// if err == nil { +// err = strategy.Authenticate(ctx, req, metadata) +// } +type DefaultOutgoingAuthRegistry struct { + strategies map[string]Strategy + mu sync.RWMutex +} + +// NewDefaultOutgoingAuthRegistry creates a new DefaultOutgoingAuthRegistry +// with an empty strategy registry. +// +// Strategies must be registered using RegisterStrategy before they can be used +// for authentication. +func NewDefaultOutgoingAuthRegistry() *DefaultOutgoingAuthRegistry { + return &DefaultOutgoingAuthRegistry{ + strategies: make(map[string]Strategy), + } +} + +// RegisterStrategy registers a new authentication strategy. +// +// This method is thread-safe and validates that: +// - name is not empty +// - strategy is not nil +// - strategy.Name() matches the registration name +// - no strategy is already registered with the same name +// +// Parameters: +// - name: The unique identifier for this strategy +// - strategy: The Strategy implementation to register +// +// Returns an error if validation fails or a strategy with the same name +// already exists. +func (r *DefaultOutgoingAuthRegistry) RegisterStrategy(name string, strategy Strategy) error { + if name == "" { + return errors.New("strategy name cannot be empty") + } + if strategy == nil { + return errors.New("strategy cannot be nil") + } + + // Validate that strategy name matches registration name + if name != strategy.Name() { + return fmt.Errorf("strategy name mismatch: registered as %q but strategy.Name() returns %q", + name, strategy.Name()) + } + + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.strategies[name]; exists { + return fmt.Errorf("strategy %q is already registered", name) + } + + r.strategies[name] = strategy + return nil +} + +// GetStrategy retrieves an authentication strategy by name. +// +// This method is thread-safe for concurrent reads. It returns the strategy +// if found, or an error if no strategy is registered with the given name. +// +// Parameters: +// - name: The identifier of the strategy to retrieve +// +// Returns: +// - Strategy: The registered strategy +// - error: An error if the strategy is not found +func (r *DefaultOutgoingAuthRegistry) GetStrategy(name string) (Strategy, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + strategy, exists := r.strategies[name] + if !exists { + return nil, fmt.Errorf("strategy %q not found", name) + } + + return strategy, nil +} diff --git a/pkg/vmcp/auth/outgoing_registry_test.go b/pkg/vmcp/auth/outgoing_registry_test.go new file mode 100644 index 000000000..3d2e8a495 --- /dev/null +++ b/pkg/vmcp/auth/outgoing_registry_test.go @@ -0,0 +1,263 @@ +package auth + +import ( + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks" +) + +func TestDefaultOutgoingAuthRegistry_RegisterStrategy(t *testing.T) { + t.Parallel() + t.Run("register valid strategy succeeds", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return("bearer").AnyTimes() + + err := registry.RegisterStrategy("bearer", strategy) + + require.NoError(t, err) + // Verify strategy was registered + retrieved, err := registry.GetStrategy("bearer") + require.NoError(t, err) + assert.Equal(t, strategy, retrieved) + }) + + t.Run("register empty name fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy := mocks.NewMockStrategy(ctrl) + + err := registry.RegisterStrategy("", strategy) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "strategy name cannot be empty") + }) + + t.Run("register nil strategy fails", func(t *testing.T) { + t.Parallel() + registry := NewDefaultOutgoingAuthRegistry() + + err := registry.RegisterStrategy("bearer", nil) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "strategy cannot be nil") + }) + + t.Run("register duplicate name fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy1 := mocks.NewMockStrategy(ctrl) + strategy1.EXPECT().Name().Return("bearer").AnyTimes() + strategy2 := mocks.NewMockStrategy(ctrl) + strategy2.EXPECT().Name().Return("bearer").AnyTimes() + + err := registry.RegisterStrategy("bearer", strategy1) + require.NoError(t, err) + + err = registry.RegisterStrategy("bearer", strategy2) + assert.Error(t, err) + assert.Contains(t, err.Error(), "already registered") + assert.Contains(t, err.Error(), "bearer") + }) + + t.Run("register multiple different strategies succeeds", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + bearer := mocks.NewMockStrategy(ctrl) + bearer.EXPECT().Name().Return("bearer").AnyTimes() + basic := mocks.NewMockStrategy(ctrl) + basic.EXPECT().Name().Return("basic").AnyTimes() + apiKey := mocks.NewMockStrategy(ctrl) + apiKey.EXPECT().Name().Return("api-key").AnyTimes() + + require.NoError(t, registry.RegisterStrategy("bearer", bearer)) + require.NoError(t, registry.RegisterStrategy("basic", basic)) + require.NoError(t, registry.RegisterStrategy("api-key", apiKey)) + + // Verify all strategies are registered + s1, err := registry.GetStrategy("bearer") + require.NoError(t, err) + assert.Equal(t, bearer, s1) + + s2, err := registry.GetStrategy("basic") + require.NoError(t, err) + assert.Equal(t, basic, s2) + + s3, err := registry.GetStrategy("api-key") + require.NoError(t, err) + assert.Equal(t, apiKey, s3) + }) +} + +func TestDefaultOutgoingAuthRegistry_GetStrategy(t *testing.T) { + t.Parallel() + t.Run("get existing strategy succeeds", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return("bearer").AnyTimes() + require.NoError(t, registry.RegisterStrategy("bearer", strategy)) + + retrieved, err := registry.GetStrategy("bearer") + + require.NoError(t, err) + assert.Equal(t, strategy, retrieved) + }) + + t.Run("get non-existent strategy fails", func(t *testing.T) { + t.Parallel() + registry := NewDefaultOutgoingAuthRegistry() + + retrieved, err := registry.GetStrategy("non-existent") + + assert.Error(t, err) + assert.Nil(t, retrieved) + assert.Contains(t, err.Error(), "not found") + assert.Contains(t, err.Error(), "non-existent") + }) + + t.Run("get from empty registry fails", func(t *testing.T) { + t.Parallel() + registry := NewDefaultOutgoingAuthRegistry() + + retrieved, err := registry.GetStrategy("bearer") + + assert.Error(t, err) + assert.Nil(t, retrieved) + }) +} + +func TestDefaultOutgoingAuthRegistry_ConcurrentAccess(t *testing.T) { + t.Parallel() + t.Run("concurrent GetStrategy calls are thread-safe", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + + // Register multiple strategies + strategies := []string{"bearer", "basic", "api-key", "oauth2", "jwt"} + for _, name := range strategies { + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return(name).AnyTimes() + require.NoError(t, registry.RegisterStrategy(name, strategy)) + } + + // Test concurrent reads with -race detector + const numGoroutines = 100 + const numOperations = 1000 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + errs := make(chan error, numGoroutines*numOperations) + + for i := 0; i < numGoroutines; i++ { + go func(_ int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + // Rotate through strategies + strategyName := strategies[j%len(strategies)] + strategy, err := registry.GetStrategy(strategyName) + if err != nil { + errs <- err + return + } + if strategy.Name() != strategyName { + errs <- errors.New("strategy name mismatch") + return + } + } + }(i) + } + + wg.Wait() + close(errs) + + // Check for errors + var collectedErrors []error + for err := range errs { + collectedErrors = append(collectedErrors, err) + } + + if len(collectedErrors) > 0 { + t.Fatalf("concurrent access produced errors: %v", collectedErrors) + } + }) + + t.Run("concurrent RegisterStrategy and GetStrategy are thread-safe", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + + const numRegister = 50 + const numGet = 50 + + var wg sync.WaitGroup + wg.Add(numRegister + numGet) + + errs := make(chan error, numRegister+numGet) + + // Goroutines registering strategies + for i := 0; i < numRegister; i++ { + go func(id int) { + defer wg.Done() + strategyName := "strategy-" + string(rune('A'+id%26)) + string(rune('0'+id/26)) + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return(strategyName).AnyTimes() + err := registry.RegisterStrategy(strategyName, strategy) + if err != nil { + errs <- err + } + }(i) + } + + // Goroutines reading strategies (will mostly fail, but shouldn't race) + for i := 0; i < numGet; i++ { + go func(id int) { + defer wg.Done() + strategyName := "strategy-" + string(rune('A'+id%26)) + string(rune('0'+id/26)) + // GetStrategy may return error if not registered yet, that's OK + _, _ = registry.GetStrategy(strategyName) + }(i) + } + + wg.Wait() + close(errs) + + // Check for unexpected errors (registration errors are not expected) + var collectedErrors []error + for err := range errs { + collectedErrors = append(collectedErrors, err) + } + + if len(collectedErrors) > 0 { + t.Fatalf("concurrent RegisterStrategy/GetStrategy produced errors: %v", collectedErrors) + } + }) +} diff --git a/pkg/vmcp/doc.go b/pkg/vmcp/doc.go index 246b03d2c..f81f8561b 100644 --- a/pkg/vmcp/doc.go +++ b/pkg/vmcp/doc.go @@ -83,12 +83,11 @@ // Middleware() func(http.Handler) http.Handler // } // -// OutgoingAuthenticator (pkg/vmcp/auth): +// OutgoingAuthRegistry (pkg/vmcp/auth): // -// type OutgoingAuthenticator interface { -// AuthenticateRequest(ctx context.Context, req *http.Request, strategy string, metadata map[string]any) error -// GetStrategy(name string) (AuthStrategy, error) -// RegisterStrategy(name string, strategy AuthStrategy) error +// type OutgoingAuthRegistry interface { +// GetStrategy(name string) (Strategy, error) +// RegisterStrategy(name string, strategy Strategy) error // } // // # Design Principles @@ -137,9 +136,10 @@ // // Route to backend // target, err := rtr.RouteTool(ctx, toolName) // -// // Authenticate to backend +// // Authenticate to backend (resolve strategy and call it) // backendReq := createBackendRequest(...) -// err = outAuth.AuthenticateRequest(ctx, backendReq, target.AuthStrategy, target.AuthMetadata) +// strategy, err := outAuth.GetStrategy(target.AuthStrategy) +// err = strategy.Authenticate(ctx, backendReq, target.AuthMetadata) // // // Forward request and return response // // ... diff --git a/pkg/vmcp/types.go b/pkg/vmcp/types.go index 118e2082a..f46383fdf 100644 --- a/pkg/vmcp/types.go +++ b/pkg/vmcp/types.go @@ -24,7 +24,7 @@ type BackendTarget struct { TransportType string // AuthStrategy identifies the authentication strategy for this backend. - // The actual authentication is handled by OutgoingAuthenticator interface. + // The actual authentication is handled by OutgoingAuthRegistry interface. // Examples: "pass_through", "token_exchange", "client_credentials", "oauth_proxy" AuthStrategy string From 571bdfca18594deaefef7bddeff26a9d2137f2b1 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 3 Nov 2025 21:15:16 +0000 Subject: [PATCH 03/11] Add factory package to resolve auth import cycle Introduces pkg/vmcp/auth/factory to break the circular dependency between pkg/vmcp/auth and pkg/vmcp/auth/strategies. The import cycle occurred because: - auth package needed to import strategies to instantiate them - strategies package imported auth for Identity and context helpers The factory package sits at the composition layer and can import both auth (for interfaces) and strategies (for implementations) without creating cycles. --- pkg/vmcp/auth/factory/outgoing.go | 166 ++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 pkg/vmcp/auth/factory/outgoing.go diff --git a/pkg/vmcp/auth/factory/outgoing.go b/pkg/vmcp/auth/factory/outgoing.go new file mode 100644 index 000000000..1c7cf7254 --- /dev/null +++ b/pkg/vmcp/auth/factory/outgoing.go @@ -0,0 +1,166 @@ +// Copyright 2025 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package factory provides factory functions for creating vMCP authentication components. +package factory + +import ( + "context" + "fmt" + "strings" + + "github.com/stacklok/toolhive/pkg/vmcp/auth" + "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" + "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +// NewOutgoingAuthRegistry creates an OutgoingAuthRegistry from configuration. +// It registers all strategies found in the configuration (both default and backend-specific). +// +// The factory ALWAYS registers the "unauthenticated" strategy as a default fallback, +// ensuring that backends without explicit authentication configuration can function. +// This makes empty/nil configuration safe: the registry will have at least one +// usable strategy. +// +// Strategy Registration: +// - "unauthenticated" is always registered (default fallback) +// - Additional strategies are registered based on configuration +// - Each strategy is instantiated once and shared across backends +// - Strategies are stateless (except token_exchange which has internal caching) +// +// Parameters: +// - ctx: Context for any initialization that requires it +// - cfg: The outgoing authentication configuration (may be nil) +// +// Returns: +// - auth.OutgoingAuthRegistry: Configured registry with registered strategies +// - error: Any error during strategy initialization or registration +func NewOutgoingAuthRegistry(_ context.Context, cfg *config.OutgoingAuthConfig) (auth.OutgoingAuthRegistry, error) { + registry := auth.NewDefaultOutgoingAuthRegistry() + + // ALWAYS register the unauthenticated strategy as the default fallback. + if err := registerUnauthenticatedStrategy(registry); err != nil { + return nil, err + } + + // Handle nil config gracefully - return registry with unauthenticated strategy + if cfg == nil { + return registry, nil + } + + // Validate configuration structure + if err := validateConfig(cfg); err != nil { + return nil, err + } + + // Collect and register all unique strategy types from configuration + strategyTypes := collectStrategyTypes(cfg) + if err := registerStrategies(registry, strategyTypes); err != nil { + return nil, err + } + + return registry, nil +} + +// registerUnauthenticatedStrategy registers the default unauthenticated strategy. +func registerUnauthenticatedStrategy(registry auth.OutgoingAuthRegistry) error { + unauthStrategy := strategies.NewUnauthenticatedStrategy() + if err := registry.RegisterStrategy("unauthenticated", unauthStrategy); err != nil { + return fmt.Errorf("failed to register default unauthenticated strategy: %w", err) + } + return nil +} + +// validateConfig validates the configuration structure. +func validateConfig(cfg *config.OutgoingAuthConfig) error { + if cfg.Default != nil && strings.TrimSpace(cfg.Default.Type) == "" { + return fmt.Errorf("default auth strategy type cannot be empty") + } + + for backendID, backendCfg := range cfg.Backends { + if backendCfg != nil && strings.TrimSpace(backendCfg.Type) == "" { + return fmt.Errorf("backend %q has empty auth strategy type", backendID) + } + } + + return nil +} + +// collectStrategyTypes collects all unique strategy types from configuration. +func collectStrategyTypes(cfg *config.OutgoingAuthConfig) map[string]struct{} { + strategyTypes := make(map[string]struct{}) + + // Add default strategy type if present + if cfg.Default != nil && cfg.Default.Type != "" { + strategyTypes[cfg.Default.Type] = struct{}{} + } + + // Add all backend strategy types + for _, backendCfg := range cfg.Backends { + if backendCfg != nil && backendCfg.Type != "" { + strategyTypes[backendCfg.Type] = struct{}{} + } + } + + return strategyTypes +} + +// registerStrategies instantiates and registers each unique strategy type. +func registerStrategies(registry auth.OutgoingAuthRegistry, strategyTypes map[string]struct{}) error { + for strategyType := range strategyTypes { + // Skip "unauthenticated" - already registered + if strategyType == "unauthenticated" { + continue + } + + strategy, err := createStrategy(strategyType) + if err != nil { + return fmt.Errorf("failed to create strategy %q: %w", strategyType, err) + } + + if err := registry.RegisterStrategy(strategyType, strategy); err != nil { + return fmt.Errorf("failed to register strategy %q: %w", strategyType, err) + } + } + + return nil +} + +// createStrategy instantiates a strategy based on its type. +// +// Each strategy instance is stateless (except token_exchange which has internal caching). +// This function validates that the strategy type is not empty and returns an appropriate +// error for unknown strategy types. +// +// Parameters: +// - strategyType: The type identifier of the strategy to create +// +// Returns: +// - auth.Strategy: The instantiated strategy +// - error: Any error during strategy creation or validation +func createStrategy(strategyType string) (auth.Strategy, error) { + // Validate strategy type is not empty + if strings.TrimSpace(strategyType) == "" { + return nil, fmt.Errorf("strategy type cannot be empty") + } + + switch strategyType { + case "header_injection": + return strategies.NewHeaderInjectionStrategy(), nil + case "unauthenticated": + return strategies.NewUnauthenticatedStrategy(), nil + default: + return nil, fmt.Errorf("unknown strategy type: %s", strategyType) + } +} From f6b287aa69975ffe5e71411b4b75e8175d3b1461 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 4 Nov 2025 13:31:46 +0000 Subject: [PATCH 04/11] Integrate authentication registry into HTTP backend client Refactors HTTPBackendClient to accept an OutgoingAuthRegistry and apply authentication strategies to all backend requests via a new authRoundTripper middleware. Authentication is now resolved and validated once at client creation time rather than per-request, improving performance and enabling early error detection for misconfigurations. The authRoundTripper clones requests to preserve immutability before applying authentication, ensuring thread-safety and preventing unintended side effects. --- pkg/vmcp/client/client.go | 135 +++++- pkg/vmcp/client/client_test.go | 426 +++++++++++++++++- .../client/mocks/mock_outgoing_registry.go | 70 +++ 3 files changed, 607 insertions(+), 24 deletions(-) create mode 100644 pkg/vmcp/client/mocks/mock_outgoing_registry.go diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index aaaf9cc59..aadc1dae4 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -17,6 +17,7 @@ import ( "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/auth" ) const ( @@ -44,14 +45,30 @@ type httpBackendClient struct { // clientFactory creates MCP clients for backends. // Abstracted as a function to enable testing with mock clients. clientFactory func(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) + + // registry manages authentication strategies for outgoing requests to backend MCP servers. + // Must not be nil - use UnauthenticatedStrategy for no authentication. + registry auth.OutgoingAuthRegistry } // NewHTTPBackendClient creates a new HTTP-based backend client. // This client supports streamable-HTTP and SSE transports. -func NewHTTPBackendClient() vmcp.BackendClient { - return &httpBackendClient{ - clientFactory: defaultClientFactory, +// +// The registry parameter manages authentication strategies for outgoing requests to backend MCP servers. +// It must not be nil. To disable authentication, use a registry configured with the +// "unauthenticated" strategy. +// +// Returns an error if registry is nil. +func NewHTTPBackendClient(registry auth.OutgoingAuthRegistry) (vmcp.BackendClient, error) { + if registry == nil { + return nil, fmt.Errorf("registry cannot be nil; use UnauthenticatedStrategy for no authentication") + } + + c := &httpBackendClient{ + registry: registry, } + c.clientFactory = c.defaultClientFactory + return c, nil } // roundTripperFunc is a function adapter for http.RoundTripper. @@ -62,29 +79,103 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } +// authRoundTripper is an http.RoundTripper that adds authentication to backend requests. +// The authentication strategy and metadata are pre-resolved and validated at client creation time, +// eliminating per-request lookups and validation overhead. +type authRoundTripper struct { + base http.RoundTripper + authStrategy auth.Strategy + authMetadata map[string]any + target *vmcp.BackendTarget +} + +// RoundTrip implements http.RoundTripper by adding authentication headers to requests. +// The authentication strategy was pre-resolved and validated at client creation time, +// so this method simply applies the authentication without any lookups or validation. +func (a *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone request to avoid modifying the original + reqClone := req.Clone(req.Context()) + + // Apply pre-resolved authentication strategy + if err := a.authStrategy.Authenticate(reqClone.Context(), reqClone, a.authMetadata); err != nil { + return nil, fmt.Errorf("authentication failed for backend %s: %w", a.target.WorkloadID, err) + } + + logger.Debugf("Applied authentication strategy %q to backend %s", a.authStrategy.Name(), a.target.WorkloadID) + + return a.base.RoundTrip(reqClone) +} + +// resolveAuthStrategy resolves the authentication strategy for a backend target. +// It handles defaulting to "unauthenticated" when no strategy is specified. +// This method should be called once at client creation time to enable fail-fast +// behavior for invalid authentication configurations. +func (h *httpBackendClient) resolveAuthStrategy(target *vmcp.BackendTarget) (auth.Strategy, error) { + strategyName := target.AuthStrategy + + // Default to unauthenticated if not specified + if strategyName == "" { + strategyName = "unauthenticated" + } + + // Resolve strategy from registry + strategy, err := h.registry.GetStrategy(strategyName) + if err != nil { + return nil, fmt.Errorf("authentication strategy %q not found: %w", strategyName, err) + } + + return strategy, nil +} + // defaultClientFactory creates mark3labs MCP clients for different transport types. -func defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) { - // Create HTTP client with response size limits for DoS protection +func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) { + // Build transport chain: size limit → authentication → HTTP + var baseTransport http.RoundTripper = http.DefaultTransport + + // Resolve authentication strategy ONCE at client creation time + authStrategy, err := h.resolveAuthStrategy(target) + if err != nil { + return nil, fmt.Errorf("failed to resolve authentication for backend %s: %w", + target.WorkloadID, err) + } + + // Validate metadata ONCE at client creation time + if err := authStrategy.Validate(target.AuthMetadata); err != nil { + return nil, fmt.Errorf("invalid authentication configuration for backend %s: %w", + target.WorkloadID, err) + } + + // Add authentication layer with pre-resolved strategy + baseTransport = &authRoundTripper{ + base: baseTransport, + authStrategy: authStrategy, + authMetadata: target.AuthMetadata, + target: target, + } + + // Add size limit layer for DoS protection + sizeLimitedTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + resp, err := baseTransport.RoundTrip(req) + if err != nil { + return nil, err + } + // Wrap response body with size limit + resp.Body = struct { + io.Reader + io.Closer + }{ + Reader: io.LimitReader(resp.Body, maxResponseSize), + Closer: resp.Body, + } + return resp, nil + }) + + // Create HTTP client with configured transport chain httpClient := &http.Client{ - Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - resp, err := http.DefaultTransport.RoundTrip(req) - if err != nil { - return nil, err - } - // Wrap response body with size limit - resp.Body = struct { - io.Reader - io.Closer - }{ - Reader: io.LimitReader(resp.Body, maxResponseSize), - Closer: resp.Body, - } - return resp, nil - }), + Transport: sizeLimitedTransport, } var c *client.Client - var err error switch target.TransportType { case "streamable-http", "streamable": @@ -93,8 +184,6 @@ func defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*cli transport.WithHTTPTimeout(0), transport.WithContinuousListening(), transport.WithHTTPBasicClient(httpClient), - // TODO: Add authentication header injection via WithHTTPHeaderFunc - // This will be implemented when we add OutgoingAuthenticator support ) if err != nil { return nil, fmt.Errorf("failed to create streamable-http client: %w", err) diff --git a/pkg/vmcp/client/client_test.go b/pkg/vmcp/client/client_test.go index 2a7619cb0..4e1c38837 100644 --- a/pkg/vmcp/client/client_test.go +++ b/pkg/vmcp/client/client_test.go @@ -1,15 +1,23 @@ package client +//go:generate mockgen -destination=mocks/mock_outgoing_registry.go -package=mocks github.com/stacklok/toolhive/pkg/vmcp/auth OutgoingAuthRegistry + import ( "context" "errors" + "net/http" + "net/http/httptest" "testing" "github.com/mark3labs/mcp-go/client" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/auth" + authmocks "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks" + "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" ) func TestHTTPBackendClient_ListCapabilities_WithMockFactory(t *testing.T) { @@ -76,7 +84,16 @@ func TestDefaultClientFactory_UnsupportedTransport(t *testing.T) { TransportType: tc.transportType, } - _, err := defaultClientFactory(context.Background(), target) + // Create authenticator with unauthenticated strategy for testing + mockRegistry := auth.NewDefaultOutgoingAuthRegistry() + err := mockRegistry.RegisterStrategy("unauthenticated", &strategies.UnauthenticatedStrategy{}) + require.NoError(t, err) + + backendClient, err := NewHTTPBackendClient(mockRegistry) + require.NoError(t, err) + httpClient := backendClient.(*httpBackendClient) + + _, err = httpClient.defaultClientFactory(context.Background(), target) require.Error(t, err) assert.ErrorIs(t, err, vmcp.ErrUnsupportedTransport) @@ -189,3 +206,410 @@ func TestInitializeClient_ErrorHandling(t *testing.T) { assert.NotNil(t, initializeClient) }) } + +// mockRoundTripper is a test implementation of http.RoundTripper that captures requests +type mockRoundTripper struct { + capturedReq *http.Request + response *http.Response + err error +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + m.capturedReq = req + if m.err != nil { + return nil, m.err + } + return m.response, nil +} + +func TestAuthRoundTripper_RoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + target *vmcp.BackendTarget + setupStrategy func(*gomock.Controller) auth.Strategy + baseTransportResp *http.Response + baseTransportErr error + expectError bool + errorContains string + checkRequest func(t *testing.T, originalReq, capturedReq *http.Request) + checkBaseTransport func(t *testing.T, baseTransport *mockRoundTripper) + }{ + { + name: "successful authentication adds headers and forwards request", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "pass_through", + AuthMetadata: map[string]any{"key": "value"}, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("pass_through"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + map[string]any{"key": "value"}, + ). + DoAndReturn(func(_ context.Context, req *http.Request, _ map[string]any) error { + // Simulate adding auth header + req.Header.Set("Authorization", "Bearer test-token") + return nil + }) + return mockStrategy + }, + baseTransportResp: &http.Response{StatusCode: http.StatusOK}, + expectError: false, + checkRequest: func(t *testing.T, originalReq, capturedReq *http.Request) { + t.Helper() + // Original request should not be modified + assert.Empty(t, originalReq.Header.Get("Authorization")) + // Captured request should have auth header + assert.Equal(t, "Bearer test-token", capturedReq.Header.Get("Authorization")) + }, + checkBaseTransport: func(t *testing.T, baseTransport *mockRoundTripper) { + t.Helper() + // Base transport should have been called + assert.NotNil(t, baseTransport.capturedReq) + }, + }, + { + name: "unauthenticated strategy skips authentication", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "unauthenticated", + AuthMetadata: nil, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("unauthenticated"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + gomock.Nil(), + ). + DoAndReturn(func(_ context.Context, _ *http.Request, _ map[string]any) error { + // UnauthenticatedStrategy does nothing + return nil + }) + return mockStrategy + }, + baseTransportResp: &http.Response{StatusCode: http.StatusOK}, + expectError: false, + checkRequest: func(t *testing.T, originalReq, capturedReq *http.Request) { + t.Helper() + // Neither request should have auth headers + assert.Empty(t, originalReq.Header.Get("Authorization")) + assert.Empty(t, capturedReq.Header.Get("Authorization")) + }, + checkBaseTransport: func(t *testing.T, baseTransport *mockRoundTripper) { + t.Helper() + // Base transport should have been called + assert.NotNil(t, baseTransport.capturedReq) + }, + }, + { + name: "authentication failure returns error without calling base transport", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "pass_through", + AuthMetadata: map[string]any{"key": "value"}, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("pass_through"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + map[string]any{"key": "value"}, + ). + Return(errors.New("auth failed")) + return mockStrategy + }, + baseTransportResp: &http.Response{StatusCode: http.StatusOK}, + expectError: true, + errorContains: "authentication failed for backend backend-1", + checkBaseTransport: func(t *testing.T, baseTransport *mockRoundTripper) { + t.Helper() + // Base transport should NOT have been called + assert.Nil(t, baseTransport.capturedReq) + }, + }, + { + name: "base transport error propagates after successful auth", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "pass_through", + AuthMetadata: map[string]any{"key": "value"}, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("pass_through"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + map[string]any{"key": "value"}, + ). + Return(nil) + return mockStrategy + }, + baseTransportErr: errors.New("connection refused"), + expectError: true, + errorContains: "connection refused", + checkBaseTransport: func(t *testing.T, baseTransport *mockRoundTripper) { + t.Helper() + // Base transport should have been called + assert.NotNil(t, baseTransport.capturedReq) + }, + }, + { + name: "request immutability - original request unchanged", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "pass_through", + AuthMetadata: map[string]any{"key": "value"}, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("pass_through"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + map[string]any{"key": "value"}, + ). + DoAndReturn(func(_ context.Context, req *http.Request, _ map[string]any) error { + // Modify the cloned request + req.Header.Set("Authorization", "Bearer modified-token") + req.Header.Set("X-Custom-Header", "custom-value") + return nil + }) + return mockStrategy + }, + baseTransportResp: &http.Response{StatusCode: http.StatusOK}, + expectError: false, + checkRequest: func(t *testing.T, originalReq, capturedReq *http.Request) { + t.Helper() + // Original request should be completely unmodified + assert.Empty(t, originalReq.Header.Get("Authorization")) + assert.Empty(t, originalReq.Header.Get("X-Custom-Header")) + + // Captured (cloned) request should have modifications + assert.Equal(t, "Bearer modified-token", capturedReq.Header.Get("Authorization")) + assert.Equal(t, "custom-value", capturedReq.Header.Get("X-Custom-Header")) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + // Setup mock strategy + var mockStrategy auth.Strategy + if tt.setupStrategy != nil { + mockStrategy = tt.setupStrategy(ctrl) + } + + // Setup mock base transport + baseTransport := &mockRoundTripper{ + response: tt.baseTransportResp, + err: tt.baseTransportErr, + } + + // Create authRoundTripper with pre-resolved strategy + authRT := &authRoundTripper{ + base: baseTransport, + authStrategy: mockStrategy, + authMetadata: tt.target.AuthMetadata, + target: tt.target, + } + + // Create test request + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + ctx := context.Background() + req = req.WithContext(ctx) + + // Execute RoundTrip + resp, err := authRT.RoundTrip(req) + + // Check error expectations + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + assert.NotNil(t, resp) + } + + // Check request modifications if specified + if tt.checkRequest != nil { + tt.checkRequest(t, req, baseTransport.capturedReq) + } + + // Check base transport calls if specified + if tt.checkBaseTransport != nil { + tt.checkBaseTransport(t, baseTransport) + } + }) + } +} + +func TestNewHTTPBackendClient_NilRegistry(t *testing.T) { + t.Parallel() + + t.Run("returns error when registry is nil", func(t *testing.T) { + t.Parallel() + + client, err := NewHTTPBackendClient(nil) + + require.Error(t, err) + assert.Nil(t, client) + assert.Contains(t, err.Error(), "registry cannot be nil") + assert.Contains(t, err.Error(), "UnauthenticatedStrategy") + }) + + t.Run("succeeds with valid registry", func(t *testing.T) { + t.Parallel() + + mockRegistry := auth.NewDefaultOutgoingAuthRegistry() + client, err := NewHTTPBackendClient(mockRegistry) + + require.NoError(t, err) + assert.NotNil(t, client) + }) +} + +func TestResolveAuthStrategy(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + target *vmcp.BackendTarget + setupRegistry func() auth.OutgoingAuthRegistry + expectError bool + errorContains string + checkStrategy func(t *testing.T, strategy auth.Strategy) + }{ + { + name: "defaults to unauthenticated when strategy is empty", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "", + AuthMetadata: nil, + }, + setupRegistry: func() auth.OutgoingAuthRegistry { + registry := auth.NewDefaultOutgoingAuthRegistry() + err := registry.RegisterStrategy("unauthenticated", &strategies.UnauthenticatedStrategy{}) + require.NoError(t, err) + return registry + }, + expectError: false, + checkStrategy: func(t *testing.T, strategy auth.Strategy) { + t.Helper() + assert.Equal(t, "unauthenticated", strategy.Name()) + }, + }, + { + name: "resolves explicitly configured strategy", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "header_injection", + AuthMetadata: map[string]any{"header_name": "X-API-Key", "api_key": "test-key"}, + }, + setupRegistry: func() auth.OutgoingAuthRegistry { + registry := auth.NewDefaultOutgoingAuthRegistry() + err := registry.RegisterStrategy("header_injection", strategies.NewHeaderInjectionStrategy()) + require.NoError(t, err) + return registry + }, + expectError: false, + checkStrategy: func(t *testing.T, strategy auth.Strategy) { + t.Helper() + assert.Equal(t, "header_injection", strategy.Name()) + }, + }, + { + name: "returns error for unknown strategy", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "unknown_strategy", + AuthMetadata: nil, + }, + setupRegistry: func() auth.OutgoingAuthRegistry { + registry := auth.NewDefaultOutgoingAuthRegistry() + err := registry.RegisterStrategy("unauthenticated", &strategies.UnauthenticatedStrategy{}) + require.NoError(t, err) + return registry + }, + expectError: true, + errorContains: "authentication strategy \"unknown_strategy\" not found", + }, + { + name: "returns error when unauthenticated strategy not registered", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "", // Empty strategy defaults to unauthenticated + AuthMetadata: nil, + }, + setupRegistry: func() auth.OutgoingAuthRegistry { + // Don't register unauthenticated strategy + return auth.NewDefaultOutgoingAuthRegistry() + }, + expectError: true, + errorContains: "authentication strategy \"unauthenticated\" not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + registry := tt.setupRegistry() + backendClient, err := NewHTTPBackendClient(registry) + require.NoError(t, err) + + httpClient := backendClient.(*httpBackendClient) + + // Call resolveAuthStrategy + strategy, err := httpClient.resolveAuthStrategy(tt.target) + + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, strategy) + } else { + require.NoError(t, err) + assert.NotNil(t, strategy) + if tt.checkStrategy != nil { + tt.checkStrategy(t, strategy) + } + } + }) + } +} diff --git a/pkg/vmcp/client/mocks/mock_outgoing_registry.go b/pkg/vmcp/client/mocks/mock_outgoing_registry.go new file mode 100644 index 000000000..e18e65e05 --- /dev/null +++ b/pkg/vmcp/client/mocks/mock_outgoing_registry.go @@ -0,0 +1,70 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/stacklok/toolhive/pkg/vmcp/auth (interfaces: OutgoingAuthRegistry) +// +// Generated by this command: +// +// mockgen -destination=mocks/mock_outgoing_registry.go -package=mocks github.com/stacklok/toolhive/pkg/vmcp/auth OutgoingAuthRegistry +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + auth "github.com/stacklok/toolhive/pkg/vmcp/auth" + gomock "go.uber.org/mock/gomock" +) + +// MockOutgoingAuthRegistry is a mock of OutgoingAuthRegistry interface. +type MockOutgoingAuthRegistry struct { + ctrl *gomock.Controller + recorder *MockOutgoingAuthRegistryMockRecorder + isgomock struct{} +} + +// MockOutgoingAuthRegistryMockRecorder is the mock recorder for MockOutgoingAuthRegistry. +type MockOutgoingAuthRegistryMockRecorder struct { + mock *MockOutgoingAuthRegistry +} + +// NewMockOutgoingAuthRegistry creates a new mock instance. +func NewMockOutgoingAuthRegistry(ctrl *gomock.Controller) *MockOutgoingAuthRegistry { + mock := &MockOutgoingAuthRegistry{ctrl: ctrl} + mock.recorder = &MockOutgoingAuthRegistryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOutgoingAuthRegistry) EXPECT() *MockOutgoingAuthRegistryMockRecorder { + return m.recorder +} + +// GetStrategy mocks base method. +func (m *MockOutgoingAuthRegistry) GetStrategy(name string) (auth.Strategy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStrategy", name) + ret0, _ := ret[0].(auth.Strategy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetStrategy indicates an expected call of GetStrategy. +func (mr *MockOutgoingAuthRegistryMockRecorder) GetStrategy(name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStrategy", reflect.TypeOf((*MockOutgoingAuthRegistry)(nil).GetStrategy), name) +} + +// RegisterStrategy mocks base method. +func (m *MockOutgoingAuthRegistry) RegisterStrategy(name string, strategy auth.Strategy) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterStrategy", name, strategy) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterStrategy indicates an expected call of RegisterStrategy. +func (mr *MockOutgoingAuthRegistryMockRecorder) RegisterStrategy(name, strategy any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterStrategy", reflect.TypeOf((*MockOutgoingAuthRegistry)(nil).RegisterStrategy), name, strategy) +} From dec109a9f0d70d3296b056ba23ee532313ff0635 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 3 Nov 2025 21:16:26 +0000 Subject: [PATCH 05/11] Apply auth configuration in backend discoverer The CLI backend discoverer now accepts authentication configuration and applies it to discovered backends during the discovery process. This change enables per-backend authentication by: - Adding authConfig parameter to NewCLIBackendDiscoverer constructor - Implementing resolveAuthConfig() to select backend-specific or default authentication settings with proper precedence - Populating Backend.AuthStrategy and Backend.AuthMetadata fields during backend creation Authentication configuration follows this precedence: 1. Backend-specific configuration (cfg.Backends[backendID]) 2. Default configuration (cfg.Default) 3. No authentication (if neither is configured) The populated authentication fields are later consumed when converting Backend instances to BackendTarget for use by the HTTP client's authRoundTripper. --- pkg/vmcp/aggregator/cli_discoverer.go | 45 +++++++++++++++++++++- pkg/vmcp/aggregator/cli_discoverer_test.go | 18 ++++----- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/pkg/vmcp/aggregator/cli_discoverer.go b/pkg/vmcp/aggregator/cli_discoverer.go index b96350b53..c1dec6b41 100644 --- a/pkg/vmcp/aggregator/cli_discoverer.go +++ b/pkg/vmcp/aggregator/cli_discoverer.go @@ -8,6 +8,7 @@ import ( "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" "github.com/stacklok/toolhive/pkg/workloads" ) @@ -16,14 +17,23 @@ import ( type cliBackendDiscoverer struct { workloadsManager workloads.Manager groupsManager groups.Manager + authConfig *config.OutgoingAuthConfig } // NewCLIBackendDiscoverer creates a new CLI-based backend discoverer. // It discovers workloads from Docker/Podman containers managed by ToolHive. -func NewCLIBackendDiscoverer(workloadsManager workloads.Manager, groupsManager groups.Manager) BackendDiscoverer { +// +// The authConfig parameter configures authentication for discovered backends. +// If nil, backends will have no authentication configured. +func NewCLIBackendDiscoverer( + workloadsManager workloads.Manager, + groupsManager groups.Manager, + authConfig *config.OutgoingAuthConfig, +) BackendDiscoverer { return &cliBackendDiscoverer{ workloadsManager: workloadsManager, groupsManager: groupsManager, + authConfig: authConfig, } } @@ -92,6 +102,16 @@ func (d *cliBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([ Metadata: make(map[string]string), } + // Apply authentication configuration if provided + if d.authConfig != nil { + authStrategy, authMetadata := d.resolveAuthConfig(name) + backend.AuthStrategy = authStrategy + backend.AuthMetadata = authMetadata + if authStrategy != "" { + logger.Debugf("Backend %s configured with auth strategy: %s", name, authStrategy) + } + } + // Copy user labels to metadata first for k, v := range workload.Labels { backend.Metadata[k] = v @@ -116,6 +136,29 @@ func (d *cliBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([ return backends, nil } +// resolveAuthConfig determines the authentication strategy and metadata for a backend. +// It checks for backend-specific configuration first, then falls back to default. +func (d *cliBackendDiscoverer) resolveAuthConfig(backendID string) (string, map[string]any) { + if d.authConfig == nil { + return "", nil + } + + // Check for backend-specific configuration + if strategy, exists := d.authConfig.Backends[backendID]; exists && strategy != nil { + logger.Debugf("Using backend-specific auth strategy for %s: %s", backendID, strategy.Type) + return strategy.Type, strategy.Metadata + } + + // Fall back to default configuration + if d.authConfig.Default != nil { + logger.Debugf("Using default auth strategy for %s: %s", backendID, d.authConfig.Default.Type) + return d.authConfig.Default.Type, d.authConfig.Default.Metadata + } + + // No authentication configured + return "", nil +} + // mapWorkloadStatusToHealth converts a workload status to a backend health status. func mapWorkloadStatusToHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus { switch status { diff --git a/pkg/vmcp/aggregator/cli_discoverer_test.go b/pkg/vmcp/aggregator/cli_discoverer_test.go index 19e1de944..9c3402fad 100644 --- a/pkg/vmcp/aggregator/cli_discoverer_test.go +++ b/pkg/vmcp/aggregator/cli_discoverer_test.go @@ -45,7 +45,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -79,7 +79,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(stoppedWorkload, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -108,7 +108,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workloadWithURL, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workloadWithoutURL, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -133,7 +133,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -150,7 +150,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), "nonexistent-group") require.Error(t, err) @@ -168,7 +168,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.Error(t, err) @@ -187,7 +187,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), "empty-group") require.NoError(t, err) @@ -214,7 +214,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped1").Return(stoppedWorkload, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "error1").Return(errorWorkload, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -240,7 +240,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failing-workload"). Return(core.Workload{}, errors.New("workload query failed")) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) From 24dc69d666f2b6f61c51f943cc151b8477bfb6c1 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 3 Nov 2025 21:17:14 +0000 Subject: [PATCH 06/11] Complete outgoing authentication integration in serve command Finalizes the end-to-end authentication flow by connecting the authentication factory, backend discoverer, and HTTP client in the serve command. This enables vMCP proxy to authenticate requests to downstream MCP servers using configured authentication strategies. The serve command now: - Creates outgoing authenticator from configuration using the factory - Provides authentication config to backend discoverer for setup - Supplies authenticator to HTTP client for request signing - Uses factory for incoming authentication middleware (consistency) This completes the authentication architecture where configuration flows through the factory to create strategies that are applied by the client's round tripper to outgoing requests. Also simplifies redundant type annotation in client variable declaration for consistency with Go style conventions. --- cmd/vmcp/app/commands.go | 18 ++++++++++++++---- pkg/vmcp/client/client.go | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index 96007a152..dc209c81f 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -12,7 +12,7 @@ import ( "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" - vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" + "github.com/stacklok/toolhive/pkg/vmcp/auth/factory" vmcpclient "github.com/stacklok/toolhive/pkg/vmcp/client" "github.com/stacklok/toolhive/pkg/vmcp/config" vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router" @@ -213,8 +213,15 @@ func runServe(cmd *cobra.Command, _ []string) error { return fmt.Errorf("failed to create groups manager: %w", err) } + // Create outgoing authentication registry from configuration + logger.Info("Initializing outgoing authentication") + outgoingRegistry, err := factory.NewOutgoingAuthRegistry(ctx, cfg.OutgoingAuth) + if err != nil { + return fmt.Errorf("failed to create outgoing authentication registry: %w", err) + } + // Create backend discoverer - discoverer := aggregator.NewCLIBackendDiscoverer(workloadsManager, groupsManager) + discoverer := aggregator.NewCLIBackendDiscoverer(workloadsManager, groupsManager, cfg.OutgoingAuth) // Discover backends from the configured group logger.Infof("Discovering backends in group: %s", cfg.GroupRef) @@ -230,7 +237,10 @@ func runServe(cmd *cobra.Command, _ []string) error { logger.Infof("Discovered %d backends", len(backends)) // Create backend client - backendClient := vmcpclient.NewHTTPBackendClient() + backendClient, err := vmcpclient.NewHTTPBackendClient(outgoingRegistry) + if err != nil { + return fmt.Errorf("failed to create backend client: %w", err) + } // Create conflict resolver based on configuration // Use the factory method that handles all strategies @@ -264,7 +274,7 @@ func runServe(cmd *cobra.Command, _ []string) error { // Setup authentication middleware logger.Infof("Setting up incoming authentication (type: %s)", cfg.IncomingAuth.Type) - authMiddleware, authInfoHandler, err := vmcpauth.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth) + authMiddleware, authInfoHandler, err := factory.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth) if err != nil { return fmt.Errorf("failed to create authentication middleware: %w", err) } diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index aadc1dae4..cd83cd061 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -130,7 +130,7 @@ func (h *httpBackendClient) resolveAuthStrategy(target *vmcp.BackendTarget) (aut // defaultClientFactory creates mark3labs MCP clients for different transport types. func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) { // Build transport chain: size limit → authentication → HTTP - var baseTransport http.RoundTripper = http.DefaultTransport + var baseTransport = http.DefaultTransport // Resolve authentication strategy ONCE at client creation time authStrategy, err := h.resolveAuthStrategy(target) From 1538dd05bbc401ab9984941b3b455647b9bdf7d7 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 3 Nov 2025 23:21:20 +0000 Subject: [PATCH 07/11] Add explicit unauthenticated strategy for vMCP Replace the pattern of passing nil authenticators with an explicit UnauthenticatedStrategy that implements the Strategy interface as a no-op. This makes the intent clear in configuration and improves type safety by eliminating nil checks. The strategy is appropriate for backends on trusted networks or where authentication is handled at the network layer. Configuration now explicitly declares "strategy: unauthenticated" instead of relying on implicit nil behavior. --- pkg/vmcp/auth/strategies/unauthenticated.go | 72 +++++++ .../auth/strategies/unauthenticated_test.go | 196 ++++++++++++++++++ 2 files changed, 268 insertions(+) create mode 100644 pkg/vmcp/auth/strategies/unauthenticated.go create mode 100644 pkg/vmcp/auth/strategies/unauthenticated_test.go diff --git a/pkg/vmcp/auth/strategies/unauthenticated.go b/pkg/vmcp/auth/strategies/unauthenticated.go new file mode 100644 index 000000000..454495c52 --- /dev/null +++ b/pkg/vmcp/auth/strategies/unauthenticated.go @@ -0,0 +1,72 @@ +package strategies + +import ( + "context" + "net/http" +) + +// UnauthenticatedStrategy is a no-op authentication strategy that performs no authentication. +// This strategy is used when a backend MCP server requires no authentication. +// +// Unlike passing a nil authenticator (which is now an error), this strategy makes +// the intent explicit: "this backend intentionally has no authentication". +// +// The strategy performs no modifications to requests and validates all metadata. +// +// This is appropriate when: +// - The backend MCP server is on a trusted network (e.g., localhost) +// - The backend has no authentication requirements +// - Authentication is handled by network-level security (e.g., VPC, firewall) +// +// Security Warning: Only use this strategy when you are certain the backend +// requires no authentication. For production deployments, prefer explicit +// authentication strategies (pass_through, header_injection, token_exchange). +// +// Configuration: No metadata required, but any metadata is accepted and ignored. +// +// Example configuration: +// +// backends: +// local-backend: +// strategy: "unauthenticated" +type UnauthenticatedStrategy struct{} + +// NewUnauthenticatedStrategy creates a new UnauthenticatedStrategy instance. +func NewUnauthenticatedStrategy() *UnauthenticatedStrategy { + return &UnauthenticatedStrategy{} +} + +// Name returns the strategy identifier. +func (*UnauthenticatedStrategy) Name() string { + return "unauthenticated" +} + +// Authenticate performs no authentication and returns immediately. +// +// This method: +// 1. Does not modify the request in any way +// 2. Always returns nil (success) +// +// Parameters: +// - ctx: Request context (unused) +// - req: The HTTP request (not modified) +// - metadata: Strategy-specific configuration (ignored) +// +// Returns nil (always succeeds). +func (*UnauthenticatedStrategy) Authenticate(_ context.Context, _ *http.Request, _ map[string]any) error { + // No-op: intentionally does nothing + return nil +} + +// Validate checks if the strategy configuration is valid. +// +// UnauthenticatedStrategy accepts any metadata (including nil or empty), +// so this always returns nil. +// +// This permissive validation allows the strategy to be used without +// configuration or with arbitrary configuration that may be present +// for documentation purposes. +func (*UnauthenticatedStrategy) Validate(_ map[string]any) error { + // No-op: accepts any metadata + return nil +} diff --git a/pkg/vmcp/auth/strategies/unauthenticated_test.go b/pkg/vmcp/auth/strategies/unauthenticated_test.go new file mode 100644 index 000000000..43ee62bb2 --- /dev/null +++ b/pkg/vmcp/auth/strategies/unauthenticated_test.go @@ -0,0 +1,196 @@ +package strategies + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnauthenticatedStrategy_Name(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + assert.Equal(t, "unauthenticated", strategy.Name()) +} + +func TestUnauthenticatedStrategy_Authenticate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata map[string]any + setupRequest func() *http.Request + checkRequest func(t *testing.T, req *http.Request) + }{ + { + name: "does not modify request with no metadata", + metadata: nil, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + req.Header.Set("X-Custom-Header", "original-value") + return req + }, + checkRequest: func(t *testing.T, req *http.Request) { + t.Helper() + // Original headers should be unchanged + assert.Equal(t, "original-value", req.Header.Get("X-Custom-Header")) + // No auth headers should be added + assert.Empty(t, req.Header.Get("Authorization")) + }, + }, + { + name: "does not modify request with metadata present", + metadata: map[string]any{ + "some_key": "some_value", + "count": 42, + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + req.Header.Set("X-Existing", "existing-value") + return req + }, + checkRequest: func(t *testing.T, req *http.Request) { + t.Helper() + // Original headers should be unchanged + assert.Equal(t, "existing-value", req.Header.Get("X-Existing")) + // No auth headers should be added + assert.Empty(t, req.Header.Get("Authorization")) + }, + }, + { + name: "preserves existing Authorization header", + metadata: nil, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + req.Header.Set("Authorization", "Bearer existing-token") + return req + }, + checkRequest: func(t *testing.T, req *http.Request) { + t.Helper() + // Should not modify existing Authorization header + assert.Equal(t, "Bearer existing-token", req.Header.Get("Authorization")) + }, + }, + { + name: "works with empty request", + metadata: nil, + setupRequest: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + }, + checkRequest: func(t *testing.T, req *http.Request) { + t.Helper() + // Request should have no auth headers + assert.Empty(t, req.Header.Get("Authorization")) + // Headers should be empty or minimal + assert.LessOrEqual(t, len(req.Header), 1) // May have Host header + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + req := tt.setupRequest() + ctx := context.Background() + + err := strategy.Authenticate(ctx, req, tt.metadata) + + require.NoError(t, err) + tt.checkRequest(t, req) + }) + } +} + +func TestUnauthenticatedStrategy_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata map[string]any + }{ + { + name: "accepts nil metadata", + metadata: nil, + }, + { + name: "accepts empty metadata", + metadata: map[string]any{}, + }, + { + name: "accepts arbitrary metadata", + metadata: map[string]any{ + "key1": "value1", + "key2": 42, + "key3": []string{"a", "b", "c"}, + "nested": map[string]any{"inner": "value"}, + }, + }, + { + name: "accepts metadata with typical auth fields", + metadata: map[string]any{ + "token_url": "https://example.com/token", + "client_id": "client-123", + "header_name": "X-API-Key", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + err := strategy.Validate(tt.metadata) + + require.NoError(t, err) + }) + } +} + +func TestUnauthenticatedStrategy_IntegrationBehavior(t *testing.T) { + t.Parallel() + + t.Run("strategy can be called multiple times safely", func(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + ctx := context.Background() + + // Call multiple times with different requests + for i := 0; i < 5; i++ { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + err := strategy.Authenticate(ctx, req, nil) + require.NoError(t, err) + assert.Empty(t, req.Header.Get("Authorization")) + } + }) + + t.Run("strategy is safe for concurrent use", func(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + ctx := context.Background() + + // Run authentication concurrently + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + err := strategy.Authenticate(ctx, req, nil) + assert.NoError(t, err) + done <- true + }() + } + + // Wait for all goroutines + for i := 0; i < 10; i++ { + <-done + } + }) +} From 3131412adacd41b934b38e929b1424ab12e479f0 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 3 Nov 2025 22:54:36 +0000 Subject: [PATCH 08/11] Implement HeaderInjection authentication strategy Add HeaderInjectionStrategy for injecting static header values into backend requests. This general-purpose strategy supports any HTTP header with any static value, enabling flexible authentication schemes like API keys, bearer tokens, and custom auth headers. The strategy extracts header_name and api_key from metadata configuration and validates them to prevent CRLF injection attacks using pkg/validation functions. Validation occurs at configuration time for fail-fast behavior. Changes: - Add HeaderInjectionStrategy implementation with Authenticate/Validate - Include comprehensive test coverage (408 test lines) - Use ValidateHTTPHeaderName/Value for security checks - Prepared for future secret reference resolution --- pkg/vmcp/auth/strategies/header_injection.go | 113 +++++ .../auth/strategies/header_injection_test.go | 408 ++++++++++++++++++ 2 files changed, 521 insertions(+) create mode 100644 pkg/vmcp/auth/strategies/header_injection.go create mode 100644 pkg/vmcp/auth/strategies/header_injection_test.go diff --git a/pkg/vmcp/auth/strategies/header_injection.go b/pkg/vmcp/auth/strategies/header_injection.go new file mode 100644 index 000000000..07fccc084 --- /dev/null +++ b/pkg/vmcp/auth/strategies/header_injection.go @@ -0,0 +1,113 @@ +// Package strategies provides authentication strategy implementations for Virtual MCP Server. +package strategies + +import ( + "context" + "fmt" + "net/http" + + "github.com/stacklok/toolhive/pkg/validation" +) + +// HeaderInjectionStrategy injects a static header value into request headers. +// This is a general-purpose strategy that can inject any header with any value, +// commonly used for API keys, bearer tokens, or custom authentication headers. +// +// The strategy extracts the header name and value from the metadata +// configuration and injects them into the backend request headers. +// +// Required metadata fields: +// - header_name: The HTTP header name to use (e.g., "X-API-Key", "Authorization") +// - api_key: The header value to inject (can be an API key, token, or any value) +// +// This strategy is appropriate when: +// - The backend requires a static header value for authentication +// - The header value is stored securely in the vMCP configuration +// - No dynamic token exchange or user-specific authentication is required +// +// Future enhancements may include: +// - Secret reference resolution (e.g., ${SECRET_REF:...}) +// - Support for multiple header formats (e.g., "Bearer ") +// - Value rotation and refresh mechanisms +type HeaderInjectionStrategy struct{} + +// NewHeaderInjectionStrategy creates a new HeaderInjectionStrategy instance. +func NewHeaderInjectionStrategy() *HeaderInjectionStrategy { + return &HeaderInjectionStrategy{} +} + +// Name returns the strategy identifier. +func (*HeaderInjectionStrategy) Name() string { + return "header_injection" +} + +// Authenticate injects the header value from metadata into the request header. +// +// This method: +// 1. Validates that header_name and api_key are present in metadata +// 2. Sets the specified header with the provided value +// +// Parameters: +// - ctx: Request context (currently unused, reserved for future secret resolution) +// - req: The HTTP request to authenticate +// - metadata: Strategy-specific configuration containing header_name and api_key +// +// Returns an error if: +// - header_name is missing or empty +// - api_key is missing or empty +func (*HeaderInjectionStrategy) Authenticate(_ context.Context, req *http.Request, metadata map[string]any) error { + headerName, ok := metadata["header_name"].(string) + if !ok || headerName == "" { + return fmt.Errorf("header_name required in metadata") + } + + apiKey, ok := metadata["api_key"].(string) + if !ok || apiKey == "" { + return fmt.Errorf("api_key required in metadata") + } + + // TODO: Future enhancement - resolve secret references + // if strings.HasPrefix(apiKey, "${SECRET_REF:") { + // apiKey, err = s.secretResolver.Resolve(ctx, apiKey) + // if err != nil { + // return fmt.Errorf("failed to resolve secret reference: %w", err) + // } + // } + + req.Header.Set(headerName, apiKey) + return nil +} + +// Validate checks if the required metadata fields are present and valid. +// +// This method verifies that: +// - header_name is present and non-empty +// - api_key is present and non-empty +// - header_name is a valid HTTP header name (prevents CRLF injection) +// - api_key is a valid HTTP header value (prevents CRLF injection) +// +// This validation is typically called during configuration parsing to fail fast +// if the strategy is misconfigured. +func (*HeaderInjectionStrategy) Validate(metadata map[string]any) error { + headerName, ok := metadata["header_name"].(string) + if !ok || headerName == "" { + return fmt.Errorf("header_name required in metadata") + } + + apiKey, ok := metadata["api_key"].(string) + if !ok || apiKey == "" { + return fmt.Errorf("api_key required in metadata") + } + + // Validate header name to prevent injection attacks + if err := validation.ValidateHTTPHeaderName(headerName); err != nil { + return fmt.Errorf("invalid header_name: %w", err) + } + + // Validate API key value to prevent injection attacks + if err := validation.ValidateHTTPHeaderValue(apiKey); err != nil { + return fmt.Errorf("invalid api_key: %w", err) + } + + return nil +} diff --git a/pkg/vmcp/auth/strategies/header_injection_test.go b/pkg/vmcp/auth/strategies/header_injection_test.go new file mode 100644 index 000000000..537fd3d86 --- /dev/null +++ b/pkg/vmcp/auth/strategies/header_injection_test.go @@ -0,0 +1,408 @@ +package strategies + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHeaderInjectionStrategy_Name(t *testing.T) { + t.Parallel() + + strategy := NewHeaderInjectionStrategy() + assert.Equal(t, "header_injection", strategy.Name()) +} + +func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata map[string]any + expectError bool + errorContains string + checkHeader func(t *testing.T, req *http.Request) + }{ + { + name: "sets X-API-Key header correctly", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "secret-key-123", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "secret-key-123", req.Header.Get("X-API-Key")) + }, + }, + { + name: "sets Authorization header with API key", + metadata: map[string]any{ + "header_name": "Authorization", + "api_key": "ApiKey my-secret-key", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "ApiKey my-secret-key", req.Header.Get("Authorization")) + }, + }, + { + name: "sets custom header name", + metadata: map[string]any{ + "header_name": "X-Custom-Auth-Token", + "api_key": "custom-token-value", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "custom-token-value", req.Header.Get("X-Custom-Auth-Token")) + }, + }, + { + name: "handles complex API key values", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.test", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.test", + req.Header.Get("X-API-Key")) + }, + }, + { + name: "handles API key with special characters", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "key-with-!@#$%^&*()-_=+[]{}|;:,.<>?", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "key-with-!@#$%^&*()-_=+[]{}|;:,.<>?", req.Header.Get("X-API-Key")) + }, + }, + { + name: "ignores additional metadata fields", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "my-key", + "extra_field": "ignored", + "another": 123, + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "my-key", req.Header.Get("X-API-Key")) + }, + }, + { + name: "returns error when header_name is missing", + metadata: map[string]any{ + "api_key": "my-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is empty string", + metadata: map[string]any{ + "header_name": "", + "api_key": "my-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is not a string", + metadata: map[string]any{ + "header_name": 123, + "api_key": "my-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when api_key is missing", + metadata: map[string]any{ + "header_name": "X-API-Key", + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is empty string", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "", + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is not a string", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": 123, + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when metadata is nil", + metadata: nil, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when metadata is empty", + metadata: map[string]any{}, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when both fields are missing", + metadata: map[string]any{ + "unrelated": "field", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "overwrites existing header value", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "new-key", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + // Verify the new key was set (old-key was already set before Authenticate) + assert.Equal(t, "new-key", req.Header.Get("X-API-Key")) + }, + }, + { + name: "handles very long API keys", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": string(make([]byte, 10000)) + "very-long-key", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + expected := string(make([]byte, 10000)) + "very-long-key" + assert.Equal(t, expected, req.Header.Get("X-API-Key")) + }, + }, + { + name: "handles case-sensitive header names", + metadata: map[string]any{ + "header_name": "x-api-key", // lowercase + "api_key": "my-key", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + // HTTP headers are case-insensitive, but Go normalizes them + assert.Equal(t, "my-key", req.Header.Get("x-api-key")) + assert.Equal(t, "my-key", req.Header.Get("X-Api-Key")) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategy := NewHeaderInjectionStrategy() + ctx := context.Background() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + + // Special setup for the "overwrites existing header value" test + if tt.name == "overwrites existing header value" { + req.Header.Set("X-API-Key", "old-key") + } + + err := strategy.Authenticate(ctx, req, tt.metadata) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + return + } + + require.NoError(t, err) + if tt.checkHeader != nil { + tt.checkHeader(t, req) + } + }) + } +} + +func TestHeaderInjectionStrategy_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata map[string]any + expectError bool + errorContains string + }{ + { + name: "valid metadata with all required fields", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "secret-key", + }, + expectError: false, + }, + { + name: "valid with extra metadata fields", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "secret-key", + "extra": "ignored", + "count": 123, + }, + expectError: false, + }, + { + name: "valid with different header name", + metadata: map[string]any{ + "header_name": "Authorization", + "api_key": "Bearer token", + }, + expectError: false, + }, + { + name: "returns error when header_name is missing", + metadata: map[string]any{ + "api_key": "secret-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is empty", + metadata: map[string]any{ + "header_name": "", + "api_key": "secret-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is not a string", + metadata: map[string]any{ + "header_name": 123, + "api_key": "secret-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is a boolean", + metadata: map[string]any{ + "header_name": true, + "api_key": "secret-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when api_key is missing", + metadata: map[string]any{ + "header_name": "X-API-Key", + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is empty", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "", + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is not a string", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": 123, + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is a map", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": map[string]any{"nested": "value"}, + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when metadata is nil", + metadata: nil, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when metadata is empty", + metadata: map[string]any{}, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when both fields are wrong type", + metadata: map[string]any{ + "header_name": 123, + "api_key": false, + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error for whitespace in header_name", + metadata: map[string]any{ + "header_name": "X-Custom Header", + "api_key": "key", + }, + expectError: true, + errorContains: "invalid header_name", + }, + { + name: "accepts unicode in api_key", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "key-with-unicode-日本語", + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategy := NewHeaderInjectionStrategy() + err := strategy.Validate(tt.metadata) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + } else { + require.NoError(t, err) + } + }) + } +} From c117b5d5a8f7a99326d64bcd2bb6a0c5df4da02f Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 4 Nov 2025 14:22:22 +0000 Subject: [PATCH 09/11] Fix package name in factory directory Change package declaration from 'auth' to 'factory' in incoming.go and incoming_test.go to match outgoing.go and prevent typecheck error. All files in pkg/vmcp/auth/factory/ must use package factory. --- pkg/vmcp/auth/factory/incoming.go | 2 +- pkg/vmcp/auth/factory/incoming_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/vmcp/auth/factory/incoming.go b/pkg/vmcp/auth/factory/incoming.go index 479876d2d..edb09a6cd 100644 --- a/pkg/vmcp/auth/factory/incoming.go +++ b/pkg/vmcp/auth/factory/incoming.go @@ -1,4 +1,4 @@ -package auth +package factory import ( "context" diff --git a/pkg/vmcp/auth/factory/incoming_test.go b/pkg/vmcp/auth/factory/incoming_test.go index e3f7a22cb..10bc65344 100644 --- a/pkg/vmcp/auth/factory/incoming_test.go +++ b/pkg/vmcp/auth/factory/incoming_test.go @@ -1,4 +1,4 @@ -package auth +package factory import ( "context" From d9623f257d8a39bf238399df53569335bbde9c32 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 4 Nov 2025 14:43:28 +0000 Subject: [PATCH 10/11] Initial plan From 3e9c50092d87df33622c06d277a430d3df69371d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 4 Nov 2025 14:47:46 +0000 Subject: [PATCH 11/11] Add test for strategy name mismatch validation Co-authored-by: jhrozek <715522+jhrozek@users.noreply.github.com> --- pkg/vmcp/auth/outgoing_registry_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pkg/vmcp/auth/outgoing_registry_test.go b/pkg/vmcp/auth/outgoing_registry_test.go index 3d2e8a495..0c87b48d8 100644 --- a/pkg/vmcp/auth/outgoing_registry_test.go +++ b/pkg/vmcp/auth/outgoing_registry_test.go @@ -56,6 +56,23 @@ func TestDefaultOutgoingAuthRegistry_RegisterStrategy(t *testing.T) { assert.Contains(t, err.Error(), "strategy cannot be nil") }) + t.Run("register strategy name mismatch fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return("actual_name").AnyTimes() + + err := registry.RegisterStrategy("different_name", strategy) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "strategy name mismatch") + assert.Contains(t, err.Error(), "different_name") + assert.Contains(t, err.Error(), "actual_name") + }) + t.Run("register duplicate name fails", func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t)