From 2069dbc45d7371c0754a42826b19a546d578112b Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Sun, 22 Mar 2026 18:31:00 +0530 Subject: [PATCH 1/9] ref: Implementation for dynamic webhook phase 2 Signed-off-by: Sanskarzz --- docs/server/docs.go | 130 +++++++++ docs/server/swagger.json | 130 +++++++++ docs/server/swagger.yaml | 114 ++++++++ pkg/runner/config.go | 4 + pkg/runner/middleware.go | 30 +++ pkg/webhook/validating/config.go | 33 +++ pkg/webhook/validating/middleware.go | 193 ++++++++++++++ pkg/webhook/validating/middleware_test.go | 311 ++++++++++++++++++++++ 8 files changed, 945 insertions(+) create mode 100644 pkg/webhook/validating/config.go create mode 100644 pkg/webhook/validating/middleware.go create mode 100644 pkg/webhook/validating/middleware_test.go diff --git a/docs/server/docs.go b/docs/server/docs.go index a5c12c4f4c..19d407cb4e 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -1103,6 +1103,14 @@ const docTemplate = `{ "upstream_swap_config": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_auth_upstreamswap.Config" }, + "validating_webhooks": { + "description": "ValidatingWebhooks contains the configuration for validating webhook middleware.", + "items": { + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.Config" + }, + "type": "array", + "uniqueItems": false + }, "volumes": { "description": "Volumes are the directory mounts to pass to the container\nFormat: \"host-path:container-path[:ro]\"", "items": { @@ -1382,6 +1390,66 @@ const docTemplate = `{ }, "type": "object" }, + "github_com_stacklok_toolhive_pkg_webhook.Config": { + "properties": { + "failure_policy": { + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.FailurePolicy" + }, + "hmac_secret_ref": { + "description": "HMACSecretRef is an optional reference to an HMAC secret for payload signing.", + "type": "string" + }, + "name": { + "description": "Name is a unique identifier for this webhook.", + "type": "string" + }, + "timeout": { + "$ref": "#/components/schemas/time.Duration" + }, + "tls_config": { + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.TLSConfig" + }, + "url": { + "description": "URL is the HTTPS endpoint to call.", + "type": "string" + } + }, + "type": "object" + }, + "github_com_stacklok_toolhive_pkg_webhook.FailurePolicy": { + "description": "FailurePolicy determines behavior when the webhook call fails.", + "enum": [ + "fail", + "ignore" + ], + "type": "string", + "x-enum-varnames": [ + "FailurePolicyFail", + "FailurePolicyIgnore" + ] + }, + "github_com_stacklok_toolhive_pkg_webhook.TLSConfig": { + "description": "TLSConfig holds optional TLS configuration (CA bundles, client certs).", + "properties": { + "ca_bundle_path": { + "description": "CABundlePath is the path to a CA certificate bundle for server verification.", + "type": "string" + }, + "client_cert_path": { + "description": "ClientCertPath is the path to a client certificate for mTLS.", + "type": "string" + }, + "client_key_path": { + "description": "ClientKeyPath is the path to a client key for mTLS.", + "type": "string" + }, + "insecure_skip_verify": { + "description": "InsecureSkipVerify disables server certificate verification.\nWARNING: This should only be used for development/testing.", + "type": "boolean" + } + }, + "type": "object" + }, "ignore.Config": { "description": "IgnoreConfig contains configuration for ignore processing", "properties": { @@ -2984,6 +3052,68 @@ const docTemplate = `{ }, "type": "object" }, + "time.Duration": { + "description": "Timeout is the maximum time to wait for a webhook response.", + "enum": [ + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000 + ], + "type": "integer", + "x-enum-varnames": [ + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour" + ] + }, "types.MiddlewareConfig": { "properties": { "parameters": { diff --git a/docs/server/swagger.json b/docs/server/swagger.json index f20887b4c5..b4f2a866c1 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -1096,6 +1096,14 @@ "upstream_swap_config": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_auth_upstreamswap.Config" }, + "validating_webhooks": { + "description": "ValidatingWebhooks contains the configuration for validating webhook middleware.", + "items": { + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.Config" + }, + "type": "array", + "uniqueItems": false + }, "volumes": { "description": "Volumes are the directory mounts to pass to the container\nFormat: \"host-path:container-path[:ro]\"", "items": { @@ -1375,6 +1383,66 @@ }, "type": "object" }, + "github_com_stacklok_toolhive_pkg_webhook.Config": { + "properties": { + "failure_policy": { + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.FailurePolicy" + }, + "hmac_secret_ref": { + "description": "HMACSecretRef is an optional reference to an HMAC secret for payload signing.", + "type": "string" + }, + "name": { + "description": "Name is a unique identifier for this webhook.", + "type": "string" + }, + "timeout": { + "$ref": "#/components/schemas/time.Duration" + }, + "tls_config": { + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.TLSConfig" + }, + "url": { + "description": "URL is the HTTPS endpoint to call.", + "type": "string" + } + }, + "type": "object" + }, + "github_com_stacklok_toolhive_pkg_webhook.FailurePolicy": { + "description": "FailurePolicy determines behavior when the webhook call fails.", + "enum": [ + "fail", + "ignore" + ], + "type": "string", + "x-enum-varnames": [ + "FailurePolicyFail", + "FailurePolicyIgnore" + ] + }, + "github_com_stacklok_toolhive_pkg_webhook.TLSConfig": { + "description": "TLSConfig holds optional TLS configuration (CA bundles, client certs).", + "properties": { + "ca_bundle_path": { + "description": "CABundlePath is the path to a CA certificate bundle for server verification.", + "type": "string" + }, + "client_cert_path": { + "description": "ClientCertPath is the path to a client certificate for mTLS.", + "type": "string" + }, + "client_key_path": { + "description": "ClientKeyPath is the path to a client key for mTLS.", + "type": "string" + }, + "insecure_skip_verify": { + "description": "InsecureSkipVerify disables server certificate verification.\nWARNING: This should only be used for development/testing.", + "type": "boolean" + } + }, + "type": "object" + }, "ignore.Config": { "description": "IgnoreConfig contains configuration for ignore processing", "properties": { @@ -2977,6 +3045,68 @@ }, "type": "object" }, + "time.Duration": { + "description": "Timeout is the maximum time to wait for a webhook response.", + "enum": [ + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000 + ], + "type": "integer", + "x-enum-varnames": [ + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour" + ] + }, "types.MiddlewareConfig": { "properties": { "parameters": { diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 6e8bbfef49..b015716c5f 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -1044,6 +1044,13 @@ components: type: boolean upstream_swap_config: $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_auth_upstreamswap.Config' + validating_webhooks: + description: ValidatingWebhooks contains the configuration for validating + webhook middleware. + items: + $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.Config' + type: array + uniqueItems: false volumes: description: |- Volumes are the directory mounts to pass to the container @@ -1305,6 +1312,54 @@ components: +optional type: boolean type: object + github_com_stacklok_toolhive_pkg_webhook.Config: + properties: + failure_policy: + $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.FailurePolicy' + hmac_secret_ref: + description: HMACSecretRef is an optional reference to an HMAC secret for + payload signing. + type: string + name: + description: Name is a unique identifier for this webhook. + type: string + timeout: + $ref: '#/components/schemas/time.Duration' + tls_config: + $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.TLSConfig' + url: + description: URL is the HTTPS endpoint to call. + type: string + type: object + github_com_stacklok_toolhive_pkg_webhook.FailurePolicy: + description: FailurePolicy determines behavior when the webhook call fails. + enum: + - fail + - ignore + type: string + x-enum-varnames: + - FailurePolicyFail + - FailurePolicyIgnore + github_com_stacklok_toolhive_pkg_webhook.TLSConfig: + description: TLSConfig holds optional TLS configuration (CA bundles, client + certs). + properties: + ca_bundle_path: + description: CABundlePath is the path to a CA certificate bundle for server + verification. + type: string + client_cert_path: + description: ClientCertPath is the path to a client certificate for mTLS. + type: string + client_key_path: + description: ClientKeyPath is the path to a client key for mTLS. + type: string + insecure_skip_verify: + description: |- + InsecureSkipVerify disables server certificate verification. + WARNING: This should only be used for development/testing. + type: boolean + type: object ignore.Config: description: IgnoreConfig contains configuration for ignore processing properties: @@ -2597,6 +2652,65 @@ components: type: array uniqueItems: false type: object + time.Duration: + description: Timeout is the maximum time to wait for a webhook response. + enum: + - -9223372036854775808 + - 9223372036854775807 + - 1 + - 1000 + - 1000000 + - 1000000000 + - 60000000000 + - 3600000000000 + - -9223372036854775808 + - 9223372036854775807 + - 1 + - 1000 + - 1000000 + - 1000000000 + - 60000000000 + - 3600000000000 + - 1 + - 1000 + - 1000000 + - 1000000000 + - 60000000000 + - 1 + - 1000 + - 1000000 + - 1000000000 + - 60000000000 + - 3600000000000 + type: integer + x-enum-varnames: + - minDuration + - maxDuration + - Nanosecond + - Microsecond + - Millisecond + - Second + - Minute + - Hour + - minDuration + - maxDuration + - Nanosecond + - Microsecond + - Millisecond + - Second + - Minute + - Hour + - Nanosecond + - Microsecond + - Millisecond + - Second + - Minute + - Nanosecond + - Microsecond + - Millisecond + - Second + - Minute + - Hour types.MiddlewareConfig: properties: parameters: diff --git a/pkg/runner/config.go b/pkg/runner/config.go index f0f14a30e0..ce029012f9 100644 --- a/pkg/runner/config.go +++ b/pkg/runner/config.go @@ -32,6 +32,7 @@ import ( "github.com/stacklok/toolhive/pkg/state" "github.com/stacklok/toolhive/pkg/telemetry" "github.com/stacklok/toolhive/pkg/transport/types" + "github.com/stacklok/toolhive/pkg/webhook" workloadtypes "github.com/stacklok/toolhive/pkg/workloads/types" ) @@ -191,6 +192,9 @@ type RunConfig struct { // and the configuration for each middleware. MiddlewareConfigs []types.MiddlewareConfig `json:"middleware_configs,omitempty" yaml:"middleware_configs,omitempty"` + // ValidatingWebhooks contains the configuration for validating webhook middleware. + ValidatingWebhooks []webhook.Config `json:"validating_webhooks,omitempty" yaml:"validating_webhooks,omitempty"` + // existingPort is the port from an existing workload being updated (not serialized) // Used during port validation to allow reusing the same port existingPort int diff --git a/pkg/runner/middleware.go b/pkg/runner/middleware.go index 299d6c61a7..03b9cc1207 100644 --- a/pkg/runner/middleware.go +++ b/pkg/runner/middleware.go @@ -19,6 +19,7 @@ import ( headerfwd "github.com/stacklok/toolhive/pkg/transport/middleware" "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/usagemetrics" + "github.com/stacklok/toolhive/pkg/webhook/validating" ) // GetSupportedMiddlewareFactories returns a map of supported middleware types to their factory functions @@ -37,6 +38,7 @@ func GetSupportedMiddlewareFactories() map[string]types.MiddlewareFactory { audit.MiddlewareType: audit.CreateMiddleware, recovery.MiddlewareType: recovery.CreateMiddleware, headerfwd.HeaderForwardMiddlewareName: headerfwd.CreateMiddleware, + validating.MiddlewareType: validating.CreateMiddleware, } } @@ -113,6 +115,12 @@ func PopulateMiddlewareConfigs(config *RunConfig) error { } middlewareConfigs = append(middlewareConfigs, *mcpParserConfig) + // Validating Webhooks middleware (if configured) + middlewareConfigs, err = addValidatingWebhookMiddleware(middlewareConfigs, config) + if err != nil { + return err + } + // Load application config for global settings configProvider := cfg.NewDefaultProvider() appConfig := configProvider.GetConfig() @@ -197,6 +205,28 @@ func PopulateMiddlewareConfigs(config *RunConfig) error { return nil } +// addValidatingWebhookMiddleware configures the validating webhook middleware if any webhooks are defined +func addValidatingWebhookMiddleware(configs []types.MiddlewareConfig, runConfig *RunConfig) ([]types.MiddlewareConfig, error) { + if len(runConfig.ValidatingWebhooks) == 0 { + return configs, nil + } + + params := validating.FactoryMiddlewareParams{ + MiddlewareParams: validating.MiddlewareParams{ + Webhooks: runConfig.ValidatingWebhooks, + }, + ServerName: runConfig.Name, + Transport: runConfig.Transport.String(), + } + + config, err := types.NewMiddlewareConfig(validating.MiddlewareType, params) + if err != nil { + return nil, fmt.Errorf("failed to create validating webhook middleware config: %w", err) + } + + return append(configs, *config), nil +} + // addTokenExchangeMiddleware adds token exchange middleware if configured func addTokenExchangeMiddleware( middlewares []types.MiddlewareConfig, diff --git a/pkg/webhook/validating/config.go b/pkg/webhook/validating/config.go new file mode 100644 index 0000000000..b5ee504b49 --- /dev/null +++ b/pkg/webhook/validating/config.go @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package validating implements a validating webhook middleware for ToolHive. +// It calls external HTTP services to approve or deny MCP requests. +package validating + +import ( + "fmt" + + "github.com/stacklok/toolhive/pkg/webhook" +) + +// MiddlewareParams holds the configuration parameters for the validating webhook middleware. +type MiddlewareParams struct { + // Webhooks is the list of validating webhook configurations to call. + // Webhooks are called in configuration order; if any webhook denies the request, + // the request is rejected. All webhooks must allow the request for it to proceed. + Webhooks []webhook.Config `json:"webhooks"` +} + +// Validate checks that the MiddlewareParams are valid. +func (p *MiddlewareParams) Validate() error { + if len(p.Webhooks) == 0 { + return fmt.Errorf("validating webhook middleware requires at least one webhook") + } + for i, wh := range p.Webhooks { + if err := wh.Validate(); err != nil { + return fmt.Errorf("webhook[%d] (%q): %w", i, wh.Name, err) + } + } + return nil +} diff --git a/pkg/webhook/validating/middleware.go b/pkg/webhook/validating/middleware.go new file mode 100644 index 0000000000..b5ab39cf6a --- /dev/null +++ b/pkg/webhook/validating/middleware.go @@ -0,0 +1,193 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package validating + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "time" + + "github.com/google/uuid" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/mcp" + "github.com/stacklok/toolhive/pkg/transport/types" + "github.com/stacklok/toolhive/pkg/webhook" +) + +// MiddlewareType is the type constant for the validating webhook middleware. +const MiddlewareType = "validating-webhook" + +// FactoryMiddlewareParams extends MiddlewareParams with context for the factory. +type FactoryMiddlewareParams struct { + MiddlewareParams + // ServerName is the name of the ToolHive instance. + ServerName string `json:"server_name"` + // Transport is the transport type (e.g., sse, stdio). + Transport string `json:"transport"` +} + +// Middleware wraps validating webhook functionality for the factory pattern. +type Middleware struct { + handler types.MiddlewareFunction +} + +// Handler returns the middleware function used by the proxy. +func (m *Middleware) Handler() types.MiddlewareFunction { + return m.handler +} + +// Close cleans up any resources used by the middleware. +func (*Middleware) Close() error { + return nil +} + +type clientExecutor struct { + client *webhook.Client + config webhook.Config +} + +// CreateMiddleware is the factory function for validating webhook middleware. +func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRunner) error { + var params FactoryMiddlewareParams + if err := json.Unmarshal(config.Parameters, ¶ms); err != nil { + return fmt.Errorf("failed to unmarshal validating webhook middleware parameters: %w", err) + } + + if err := params.Validate(); err != nil { + return fmt.Errorf("invalid validating webhook configuration: %w", err) + } + + // Create clients for each webhook + var executors []clientExecutor + for i, whCfg := range params.Webhooks { + client, err := webhook.NewClient(whCfg, webhook.TypeValidating, nil) // HMAC secret not yet plumbed + if err != nil { + return fmt.Errorf("failed to create client for webhook[%d] (%q): %w", i, whCfg.Name, err) + } + executors = append(executors, clientExecutor{client: client, config: whCfg}) + } + + mw := &Middleware{ + handler: createValidatingHandler(executors, params.ServerName, params.Transport), + } + runner.AddMiddleware(MiddlewareType, mw) + return nil +} + +func createValidatingHandler(executors []clientExecutor, serverName, transport string) types.MiddlewareFunction { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip if it's not a parsed MCP request (middleware runs after mcp parser) + parsedMCP := mcp.GetParsedMCPRequest(r.Context()) + if parsedMCP == nil { + next.ServeHTTP(w, r) + return + } + + // Read the request body to get the raw MCP request + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + sendErrorResponse(w, http.StatusInternalServerError, "Internal Server Error", "Failed to read request body") + return + } + // Restore the request body for downstream handlers + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + // Build the webhook request payload + reqUID := uuid.New().String() + whReq := &webhook.Request{ + Version: webhook.APIVersion, + UID: reqUID, + Timestamp: time.Now().UTC(), + MCPRequest: json.RawMessage(bodyBytes), + Context: &webhook.RequestContext{ + ServerName: serverName, + SourceIP: readSourceIP(r), + Transport: transport, + }, + } + + // Add Principal if authenticated + if identity, ok := auth.IdentityFromContext(r.Context()); ok { + whReq.Principal = &webhook.Principal{ + Sub: identity.Subject, + Email: identity.Email, + Name: identity.Name, + Groups: identity.Groups, + Claims: identity.Claims, + } + } + + // Call each webhook in order + for _, exec := range executors { + whName := exec.config.Name + + resp, err := exec.client.Call(r.Context(), whReq) + if err != nil { + // Handle error based on failure policy + if exec.config.FailurePolicy == webhook.FailurePolicyIgnore { + slog.Warn("Validating webhook error ignored due to fail-open policy", + "webhook", whName, "error", err) + continue + } + + slog.Error("Validating webhook error caused request denial", + "webhook", whName, "error", err) + sendErrorResponse(w, http.StatusForbidden, "Forbidden", fmt.Sprintf("Webhook %q error: %v", whName, err)) + return + } + + if !resp.Allowed { + slog.Info("Validating webhook denied request", "webhook", whName, "reason", resp.Reason, "message", resp.Message) + + msg := resp.Message + if msg == "" { + msg = fmt.Sprintf("Webhook %q denied the request", whName) + } + + code := resp.Code + if code < 400 || code > 599 { + code = http.StatusForbidden + } + + sendErrorResponse(w, code, "Forbidden", msg) + return + } + } + + // All webhooks allowed or ignored errors + next.ServeHTTP(w, r) + }) + } +} + +func readSourceIP(r *http.Request) string { + // Let runner handle X-Forwarded-For if TrustProxyHeaders is set. + // For now, simple RemoteAddr. + return r.RemoteAddr +} + +func sendErrorResponse(w http.ResponseWriter, statusCode int, _, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + + // Since we are intercepting an MCP request, we should really be returning a JSON-RPC error. + // However, if the error happens before actual execution, a standard HTTP error or a basic JSON + // with error details is typical. Here we'll follow standard HTTP error structure or JSON-RPC format. + // We'll return a JSON format that could be interpreted as a JSON-RPC error. + errResp := map[string]any{ + "jsonrpc": "2.0", + "id": nil, + "error": map[string]any{ + "code": statusCode, + "message": message, + }, + } + _ = json.NewEncoder(w).Encode(errResp) +} diff --git a/pkg/webhook/validating/middleware_test.go b/pkg/webhook/validating/middleware_test.go new file mode 100644 index 0000000000..d2c3db0697 --- /dev/null +++ b/pkg/webhook/validating/middleware_test.go @@ -0,0 +1,311 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package validating + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/mcp" + "github.com/stacklok/toolhive/pkg/transport/types" + "github.com/stacklok/toolhive/pkg/webhook" +) + +//nolint:paralleltest // Shares a mock HTTP server and lastRequest state +func TestValidatingMiddleware(t *testing.T) { + // Setup a mock webhook server + var lastRequest webhook.Request + mockResponse := webhook.Response{ + Version: webhook.APIVersion, + UID: "resp-uid", + Allowed: true, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + + err := json.NewDecoder(r.Body).Decode(&lastRequest) + require.NoError(t, err) + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(mockResponse) + require.NoError(t, err) + })) + defer server.Close() + + // Create middleware handler + config := []webhook.Config{ + { + Name: "test-webhook", + URL: server.URL, + Timeout: webhook.DefaultTimeout, + FailurePolicy: webhook.FailurePolicyFail, + TLSConfig: &webhook.TLSConfig{ + InsecureSkipVerify: true, // Need this for httptest server + }, + }, + } + + var executors []clientExecutor + for _, cfg := range config { + client, err := webhook.NewClient(cfg, webhook.TypeValidating, nil) + require.NoError(t, err) + executors = append(executors, clientExecutor{client: client, config: cfg}) + } + + mw := createValidatingHandler(executors, "test-server", "stdio") + + t.Run("Allowed Request", func(t *testing.T) { + mockResponse.Allowed = true // Server will return allowed + + reqBody := []byte(`{"jsonrpc":"2.0","method":"tools/call","id":1}`) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) + + // Add parsed MCP request and auth identity to context + parsedMCP := &mcp.ParsedMCPRequest{ + Method: "tools/call", + ID: 1, + } + ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, parsedMCP) + + identity := &auth.Identity{ + Subject: "user-1", + Email: "user@example.com", + Groups: []string{"admin"}, + } + ctx = auth.WithIdentity(ctx, identity) + + req = req.WithContext(ctx) + req.RemoteAddr = "192.168.1.1:1234" + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, req) + + assert.True(t, nextCalled, "Next handler should be called for allowed request") + assert.Equal(t, http.StatusOK, rr.Code) + + // Verify the payload sent to webhook + assert.Equal(t, webhook.APIVersion, lastRequest.Version) + assert.NotEmpty(t, lastRequest.UID) + assert.NotZero(t, lastRequest.Timestamp) + assert.JSONEq(t, string(reqBody), string(lastRequest.MCPRequest)) + + require.NotNil(t, lastRequest.Context) + assert.Equal(t, "test-server", lastRequest.Context.ServerName) + assert.Equal(t, "stdio", lastRequest.Context.Transport) + assert.Equal(t, "192.168.1.1:1234", lastRequest.Context.SourceIP) + + require.NotNil(t, lastRequest.Principal) + assert.Equal(t, "user-1", lastRequest.Principal.Sub) + assert.Equal(t, "user@example.com", lastRequest.Principal.Email) + assert.Equal(t, []string{"admin"}, lastRequest.Principal.Groups) + }) + + t.Run("Denied Request", func(t *testing.T) { + mockResponse.Allowed = false + mockResponse.Message = "Custom deny message" + mockResponse.Code = http.StatusForbidden + + reqBody := []byte(`{"jsonrpc":"2.0","method":"tools/call","id":1}`) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) + + ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, &mcp.ParsedMCPRequest{}) + req = req.WithContext(ctx) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, req) + + assert.False(t, nextCalled, "Next handler should not be called for denied request") + assert.Equal(t, http.StatusForbidden, rr.Code) + + // The error response is a JSON-RPC format + var errResp map[string]interface{} + err := json.Unmarshal(rr.Body.Bytes(), &errResp) + require.NoError(t, err) + assert.Equal(t, "2.0", errResp["jsonrpc"]) + assert.Nil(t, errResp["id"]) + + errObj, ok := errResp["error"].(map[string]interface{}) + require.True(t, ok) + assert.Equal(t, float64(http.StatusForbidden), errObj["code"]) + assert.Equal(t, "Custom deny message", errObj["message"]) + }) + + t.Run("Webhook Error - Fail Policy", func(t *testing.T) { + // Create a client pointing to a closed port to generate connection error + cfg := config[0] + cfg.URL = "http://127.0.0.1:0" + cfg.FailurePolicy = webhook.FailurePolicyFail + + failClient, err := webhook.NewClient(cfg, webhook.TypeValidating, nil) + require.NoError(t, err) + + failMw := createValidatingHandler([]clientExecutor{{client: failClient, config: cfg}}, "test", "stdio") + + reqBody := []byte(`{"jsonrpc":"2.0","method":"tools/call","id":1}`) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) + ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, &mcp.ParsedMCPRequest{}) + req = req.WithContext(ctx) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + }) + + rr := httptest.NewRecorder() + failMw(nextHandler).ServeHTTP(rr, req) + + assert.False(t, nextCalled, "Next handler should not be called on evaluation error with fail policy") + assert.Equal(t, http.StatusForbidden, rr.Code) + }) + + t.Run("Webhook Error - Ignore Policy", func(t *testing.T) { + // Create a client pointing to a closed port to generate connection error + cfg := config[0] + cfg.URL = "http://127.0.0.1:0" + cfg.FailurePolicy = webhook.FailurePolicyIgnore + + ignoreClient, err := webhook.NewClient(cfg, webhook.TypeValidating, nil) + require.NoError(t, err) + + ignoreMw := createValidatingHandler([]clientExecutor{{client: ignoreClient, config: cfg}}, "test", "stdio") + + reqBody := []byte(`{"jsonrpc":"2.0","method":"tools/call","id":1}`) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) + ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, &mcp.ParsedMCPRequest{}) + req = req.WithContext(ctx) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + }) + + rr := httptest.NewRecorder() + ignoreMw(nextHandler).ServeHTTP(rr, req) + + assert.True(t, nextCalled, "Next handler should be called on evaluation error with ignore policy") + assert.Equal(t, http.StatusOK, rr.Code) + }) + + t.Run("Skip Non-MCP Requests", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/health", nil) + // Missing parsed MCP request in context + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, req) + + assert.True(t, nextCalled, "Next handler should be called for non-MCP requests") + assert.Equal(t, http.StatusOK, rr.Code) + }) +} + +func TestMiddlewareParams_Validate(t *testing.T) { + t.Parallel() + tests := []struct { + name string + params MiddlewareParams + wantErr bool + }{ + { + name: "valid", + params: MiddlewareParams{Webhooks: []webhook.Config{{Name: "a", URL: "https://a", Timeout: webhook.DefaultTimeout, FailurePolicy: webhook.FailurePolicyFail}}}, + wantErr: false, + }, + { + name: "empty webhooks", + params: MiddlewareParams{}, + wantErr: true, + }, + { + name: "invalid webhook config", + params: MiddlewareParams{Webhooks: []webhook.Config{{Name: ""}}}, // Missing name + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := tt.params.Validate() + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +type mockRunner struct { + types.MiddlewareRunner + middlewares map[string]types.Middleware +} + +func (m *mockRunner) AddMiddleware(name string, mw types.Middleware) { + if m.middlewares == nil { + m.middlewares = make(map[string]types.Middleware) + } + m.middlewares[name] = mw +} + +func TestCreateMiddleware(t *testing.T) { + t.Parallel() + runner := &mockRunner{} + + // Create valid config JSON + params := FactoryMiddlewareParams{ + MiddlewareParams: MiddlewareParams{ + Webhooks: []webhook.Config{ + { + Name: "test", + URL: "https://test.com/hook", + Timeout: webhook.DefaultTimeout, + FailurePolicy: webhook.FailurePolicyIgnore, + }, + }, + }, + ServerName: "test-server", + Transport: "stdio", + } + paramsJSON, err := json.Marshal(params) + require.NoError(t, err) + + mwConfig := &types.MiddlewareConfig{ + Type: MiddlewareType, + Parameters: paramsJSON, + } + + err = CreateMiddleware(mwConfig, runner) + require.NoError(t, err) + + require.Contains(t, runner.middlewares, MiddlewareType) + mw := runner.middlewares[MiddlewareType] + + // Test Handler/Close methods to get 100% coverage + require.NotNil(t, mw.Handler()) + require.NoError(t, mw.Close()) +} From 78e5190d8be4e4f80943887ad40a435d61c4603f Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Mon, 23 Mar 2026 14:56:09 +0530 Subject: [PATCH 2/9] fix: docs Signed-off-by: Sanskarzz --- docs/server/docs.go | 30 ++++++++++++++++++++++++++++++ docs/server/swagger.json | 30 ++++++++++++++++++++++++++++++ docs/server/swagger.yaml | 30 ++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+) diff --git a/docs/server/docs.go b/docs/server/docs.go index 19d407cb4e..25001b6c23 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -3071,11 +3071,26 @@ const docTemplate = `{ 1000000000, 60000000000, 3600000000000, + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, 1, 1000, 1000000, 1000000000, 60000000000, + 3600000000000, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, 1, 1000, 1000000, @@ -3101,11 +3116,26 @@ const docTemplate = `{ "Second", "Minute", "Hour", + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", "Nanosecond", "Microsecond", "Millisecond", "Second", "Minute", + "Hour", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", "Nanosecond", "Microsecond", "Millisecond", diff --git a/docs/server/swagger.json b/docs/server/swagger.json index b4f2a866c1..4f9f389e68 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -3064,11 +3064,26 @@ 1000000000, 60000000000, 3600000000000, + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, 1, 1000, 1000000, 1000000000, 60000000000, + 3600000000000, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, 1, 1000, 1000000, @@ -3094,11 +3109,26 @@ "Second", "Minute", "Hour", + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", "Nanosecond", "Microsecond", "Millisecond", "Second", "Minute", + "Hour", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", "Nanosecond", "Microsecond", "Millisecond", diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index b015716c5f..fb7098446d 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -2671,11 +2671,26 @@ components: - 1000000000 - 60000000000 - 3600000000000 + - -9223372036854775808 + - 9223372036854775807 + - 1 + - 1000 + - 1000000 + - 1000000000 + - 60000000000 + - 3600000000000 - 1 - 1000 - 1000000 - 1000000000 - 60000000000 + - 3600000000000 + - 1 + - 1000 + - 1000000 + - 1000000000 + - 60000000000 + - 3600000000000 - 1 - 1000 - 1000000 @@ -2700,11 +2715,26 @@ components: - Second - Minute - Hour + - minDuration + - maxDuration + - Nanosecond + - Microsecond + - Millisecond + - Second + - Minute + - Hour - Nanosecond - Microsecond - Millisecond - Second - Minute + - Hour + - Nanosecond + - Microsecond + - Millisecond + - Second + - Minute + - Hour - Nanosecond - Microsecond - Millisecond From e4ba4de57afe2378595c9747ea011f8c7e380454 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Tue, 24 Mar 2026 01:01:47 +0530 Subject: [PATCH 3/9] fix: replaced webhook.Principle with auth.PrincipalInfo Signed-off-by: Sanskarzz --- docs/server/docs.go | 52 ----------------------- docs/server/swagger.json | 52 ----------------------- docs/server/swagger.yaml | 52 ----------------------- pkg/webhook/validating/middleware.go | 8 +--- pkg/webhook/validating/middleware_test.go | 10 +++-- 5 files changed, 7 insertions(+), 167 deletions(-) diff --git a/docs/server/docs.go b/docs/server/docs.go index 25001b6c23..1e97637f54 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -3070,32 +3070,6 @@ const docTemplate = `{ 1000000, 1000000000, 60000000000, - 3600000000000, - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, 3600000000000 ], "type": "integer", @@ -3115,32 +3089,6 @@ const docTemplate = `{ "Millisecond", "Second", "Minute", - "Hour", - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", "Hour" ] }, diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 4f9f389e68..7efaee2ff3 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -3063,32 +3063,6 @@ 1000000, 1000000000, 60000000000, - 3600000000000, - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, 3600000000000 ], "type": "integer", @@ -3108,32 +3082,6 @@ "Millisecond", "Second", "Minute", - "Hour", - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", "Hour" ] }, diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index fb7098446d..25e595f026 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -2671,32 +2671,6 @@ components: - 1000000000 - 60000000000 - 3600000000000 - - -9223372036854775808 - - 9223372036854775807 - - 1 - - 1000 - - 1000000 - - 1000000000 - - 60000000000 - - 3600000000000 - - 1 - - 1000 - - 1000000 - - 1000000000 - - 60000000000 - - 3600000000000 - - 1 - - 1000 - - 1000000 - - 1000000000 - - 60000000000 - - 3600000000000 - - 1 - - 1000 - - 1000000 - - 1000000000 - - 60000000000 - - 3600000000000 type: integer x-enum-varnames: - minDuration @@ -2715,32 +2689,6 @@ components: - Second - Minute - Hour - - minDuration - - maxDuration - - Nanosecond - - Microsecond - - Millisecond - - Second - - Minute - - Hour - - Nanosecond - - Microsecond - - Millisecond - - Second - - Minute - - Hour - - Nanosecond - - Microsecond - - Millisecond - - Second - - Minute - - Hour - - Nanosecond - - Microsecond - - Millisecond - - Second - - Minute - - Hour types.MiddlewareConfig: properties: parameters: diff --git a/pkg/webhook/validating/middleware.go b/pkg/webhook/validating/middleware.go index b5ab39cf6a..5fdfbc3f73 100644 --- a/pkg/webhook/validating/middleware.go +++ b/pkg/webhook/validating/middleware.go @@ -115,13 +115,7 @@ func createValidatingHandler(executors []clientExecutor, serverName, transport s // Add Principal if authenticated if identity, ok := auth.IdentityFromContext(r.Context()); ok { - whReq.Principal = &webhook.Principal{ - Sub: identity.Subject, - Email: identity.Email, - Name: identity.Name, - Groups: identity.Groups, - Claims: identity.Claims, - } + whReq.Principal = identity.GetPrincipalInfo() } // Call each webhook in order diff --git a/pkg/webhook/validating/middleware_test.go b/pkg/webhook/validating/middleware_test.go index d2c3db0697..b3c2d29f14 100644 --- a/pkg/webhook/validating/middleware_test.go +++ b/pkg/webhook/validating/middleware_test.go @@ -79,9 +79,11 @@ func TestValidatingMiddleware(t *testing.T) { ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, parsedMCP) identity := &auth.Identity{ - Subject: "user-1", - Email: "user@example.com", - Groups: []string{"admin"}, + PrincipalInfo: auth.PrincipalInfo{ + Subject: "user-1", + Email: "user@example.com", + Groups: []string{"admin"}, + }, } ctx = auth.WithIdentity(ctx, identity) @@ -111,7 +113,7 @@ func TestValidatingMiddleware(t *testing.T) { assert.Equal(t, "192.168.1.1:1234", lastRequest.Context.SourceIP) require.NotNil(t, lastRequest.Principal) - assert.Equal(t, "user-1", lastRequest.Principal.Sub) + assert.Equal(t, "user-1", lastRequest.Principal.Subject) assert.Equal(t, "user@example.com", lastRequest.Principal.Email) assert.Equal(t, []string{"admin"}, lastRequest.Principal.Groups) }) From 061b76802f24773894c5162b22165babdf208864 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Tue, 24 Mar 2026 01:13:31 +0530 Subject: [PATCH 4/9] fix: docs CI error Signed-off-by: Sanskarzz --- docs/server/docs.go | 32 ++++++++++++++++++++++++++++++++ docs/server/swagger.json | 32 ++++++++++++++++++++++++++++++++ docs/server/swagger.yaml | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+) diff --git a/docs/server/docs.go b/docs/server/docs.go index 1e97637f54..c50a526579 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -3055,6 +3055,22 @@ const docTemplate = `{ "time.Duration": { "description": "Timeout is the maximum time to wait for a webhook response.", "enum": [ + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, -9223372036854775808, 9223372036854775807, 1, @@ -3074,6 +3090,22 @@ const docTemplate = `{ ], "type": "integer", "x-enum-varnames": [ + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", "minDuration", "maxDuration", "Nanosecond", diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 7efaee2ff3..49a94acb57 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -3048,6 +3048,22 @@ "time.Duration": { "description": "Timeout is the maximum time to wait for a webhook response.", "enum": [ + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, -9223372036854775808, 9223372036854775807, 1, @@ -3067,6 +3083,22 @@ ], "type": "integer", "x-enum-varnames": [ + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", "minDuration", "maxDuration", "Nanosecond", diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 25e595f026..e789629179 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -2671,6 +2671,22 @@ components: - 1000000000 - 60000000000 - 3600000000000 + - -9223372036854775808 + - 9223372036854775807 + - 1 + - 1000 + - 1000000 + - 1000000000 + - 60000000000 + - 3600000000000 + - -9223372036854775808 + - 9223372036854775807 + - 1 + - 1000 + - 1000000 + - 1000000000 + - 60000000000 + - 3600000000000 type: integer x-enum-varnames: - minDuration @@ -2689,6 +2705,22 @@ components: - Second - Minute - Hour + - minDuration + - maxDuration + - Nanosecond + - Microsecond + - Millisecond + - Second + - Minute + - Hour + - minDuration + - maxDuration + - Nanosecond + - Microsecond + - Millisecond + - Second + - Minute + - Hour types.MiddlewareConfig: properties: parameters: From f69978bb549d52686d59137d10cef0285229c05a Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Tue, 24 Mar 2026 02:31:45 +0530 Subject: [PATCH 5/9] fix: swag docs Signed-off-by: Sanskarzz --- docs/server/docs.go | 78 ++------------------------------------- docs/server/swagger.json | 78 ++------------------------------------- docs/server/swagger.yaml | 75 ++----------------------------------- pkg/auth/remote/config.go | 2 +- pkg/webhook/types.go | 2 +- 5 files changed, 11 insertions(+), 224 deletions(-) diff --git a/docs/server/docs.go b/docs/server/docs.go index c50a526579..6fd9a44cf1 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -279,8 +279,7 @@ const docTemplate = `{ "type": "boolean" }, "timeout": { - "example": "5m", - "type": "string" + "type": "integer" }, "token_url": { "type": "string" @@ -1404,7 +1403,8 @@ const docTemplate = `{ "type": "string" }, "timeout": { - "$ref": "#/components/schemas/time.Duration" + "description": "Timeout is the maximum time to wait for a webhook response.", + "type": "integer" }, "tls_config": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.TLSConfig" @@ -3052,78 +3052,6 @@ const docTemplate = `{ }, "type": "object" }, - "time.Duration": { - "description": "Timeout is the maximum time to wait for a webhook response.", - "enum": [ - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000 - ], - "type": "integer", - "x-enum-varnames": [ - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour" - ] - }, "types.MiddlewareConfig": { "properties": { "parameters": { diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 49a94acb57..1a680e251b 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -272,8 +272,7 @@ "type": "boolean" }, "timeout": { - "example": "5m", - "type": "string" + "type": "integer" }, "token_url": { "type": "string" @@ -1397,7 +1396,8 @@ "type": "string" }, "timeout": { - "$ref": "#/components/schemas/time.Duration" + "description": "Timeout is the maximum time to wait for a webhook response.", + "type": "integer" }, "tls_config": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.TLSConfig" @@ -3045,78 +3045,6 @@ }, "type": "object" }, - "time.Duration": { - "description": "Timeout is the maximum time to wait for a webhook response.", - "enum": [ - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000 - ], - "type": "integer", - "x-enum-varnames": [ - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour" - ] - }, "types.MiddlewareConfig": { "properties": { "parameters": { diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index e789629179..b30fb7bda1 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -267,8 +267,7 @@ components: skip_browser: type: boolean timeout: - example: 5m - type: string + type: integer token_url: type: string use_pkce: @@ -1324,7 +1323,8 @@ components: description: Name is a unique identifier for this webhook. type: string timeout: - $ref: '#/components/schemas/time.Duration' + description: Timeout is the maximum time to wait for a webhook response. + type: integer tls_config: $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.TLSConfig' url: @@ -2652,75 +2652,6 @@ components: type: array uniqueItems: false type: object - time.Duration: - description: Timeout is the maximum time to wait for a webhook response. - enum: - - -9223372036854775808 - - 9223372036854775807 - - 1 - - 1000 - - 1000000 - - 1000000000 - - 60000000000 - - 3600000000000 - - -9223372036854775808 - - 9223372036854775807 - - 1 - - 1000 - - 1000000 - - 1000000000 - - 60000000000 - - 3600000000000 - - -9223372036854775808 - - 9223372036854775807 - - 1 - - 1000 - - 1000000 - - 1000000000 - - 60000000000 - - 3600000000000 - - -9223372036854775808 - - 9223372036854775807 - - 1 - - 1000 - - 1000000 - - 1000000000 - - 60000000000 - - 3600000000000 - type: integer - x-enum-varnames: - - minDuration - - maxDuration - - Nanosecond - - Microsecond - - Millisecond - - Second - - Minute - - Hour - - minDuration - - maxDuration - - Nanosecond - - Microsecond - - Millisecond - - Second - - Minute - - Hour - - minDuration - - maxDuration - - Nanosecond - - Microsecond - - Millisecond - - Second - - Minute - - Hour - - minDuration - - maxDuration - - Nanosecond - - Microsecond - - Millisecond - - Second - - Minute - - Hour types.MiddlewareConfig: properties: parameters: diff --git a/pkg/auth/remote/config.go b/pkg/auth/remote/config.go index 049e5ab6b2..64e8a1d304 100644 --- a/pkg/auth/remote/config.go +++ b/pkg/auth/remote/config.go @@ -23,7 +23,7 @@ type Config struct { ClientSecretFile string `json:"client_secret_file,omitempty" yaml:"client_secret_file,omitempty"` Scopes []string `json:"scopes,omitempty" yaml:"scopes,omitempty"` SkipBrowser bool `json:"skip_browser,omitempty" yaml:"skip_browser,omitempty"` - Timeout time.Duration `json:"timeout,omitempty" yaml:"timeout,omitempty" swaggertype:"string" example:"5m"` + Timeout time.Duration `json:"timeout,omitempty" yaml:"timeout,omitempty" swaggertype:"primitive,integer"` CallbackPort int `json:"callback_port,omitempty" yaml:"callback_port,omitempty"` UsePKCE bool `json:"use_pkce" yaml:"use_pkce"` diff --git a/pkg/webhook/types.go b/pkg/webhook/types.go index c21b469360..8f7c611a06 100644 --- a/pkg/webhook/types.go +++ b/pkg/webhook/types.go @@ -66,7 +66,7 @@ type Config struct { // URL is the HTTPS endpoint to call. URL string `json:"url"` // Timeout is the maximum time to wait for a webhook response. - Timeout time.Duration `json:"timeout"` + Timeout time.Duration `json:"timeout" yaml:"timeout" swaggertype:"primitive,integer"` // FailurePolicy determines behavior when the webhook call fails. FailurePolicy FailurePolicy `json:"failure_policy"` // TLSConfig holds optional TLS configuration (CA bundles, client certs). From d44e2cb55a9ac784b82c1e0b7dbb1d900dd6a2a2 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Wed, 25 Mar 2026 01:57:15 +0530 Subject: [PATCH 6/9] fix: revert to review comments Signed-off-by: Sanskarzz --- docs/server/docs.go | 57 +++++- docs/server/swagger.json | 57 +++++- docs/server/swagger.yaml | 54 +++++- pkg/auth/remote/config.go | 4 +- pkg/webhook/types.go | 10 +- pkg/webhook/validating/middleware.go | 36 ++-- pkg/webhook/validating/middleware_test.go | 209 +++++++++++++++++++++- 7 files changed, 388 insertions(+), 39 deletions(-) diff --git a/docs/server/docs.go b/docs/server/docs.go index 6fd9a44cf1..acb847fc8b 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -279,7 +279,7 @@ const docTemplate = `{ "type": "boolean" }, "timeout": { - "type": "integer" + "$ref": "#/components/schemas/time.Duration" }, "token_url": { "type": "string" @@ -3052,6 +3052,61 @@ const docTemplate = `{ }, "type": "object" }, + "time.Duration": { + "enum": [ + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000 + ], + "type": "integer", + "x-enum-varnames": [ + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour" + ] + }, "types.MiddlewareConfig": { "properties": { "parameters": { diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 1a680e251b..4ba8a24072 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -272,7 +272,7 @@ "type": "boolean" }, "timeout": { - "type": "integer" + "$ref": "#/components/schemas/time.Duration" }, "token_url": { "type": "string" @@ -3045,6 +3045,61 @@ }, "type": "object" }, + "time.Duration": { + "enum": [ + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000, + -9223372036854775808, + 9223372036854775807, + 1, + 1000, + 1000000, + 1000000000, + 60000000000, + 3600000000000 + ], + "type": "integer", + "x-enum-varnames": [ + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour", + "minDuration", + "maxDuration", + "Nanosecond", + "Microsecond", + "Millisecond", + "Second", + "Minute", + "Hour" + ] + }, "types.MiddlewareConfig": { "properties": { "parameters": { diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index b30fb7bda1..c538b9cde9 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -267,7 +267,7 @@ components: skip_browser: type: boolean timeout: - type: integer + $ref: '#/components/schemas/time.Duration' token_url: type: string use_pkce: @@ -2652,6 +2652,58 @@ components: type: array uniqueItems: false type: object + time.Duration: + enum: + - -9223372036854775808 + - 9223372036854775807 + - 1 + - 1000 + - 1000000 + - 1000000000 + - 60000000000 + - 3600000000000 + - -9223372036854775808 + - 9223372036854775807 + - 1 + - 1000 + - 1000000 + - 1000000000 + - 60000000000 + - 3600000000000 + - -9223372036854775808 + - 9223372036854775807 + - 1 + - 1000 + - 1000000 + - 1000000000 + - 60000000000 + - 3600000000000 + type: integer + x-enum-varnames: + - minDuration + - maxDuration + - Nanosecond + - Microsecond + - Millisecond + - Second + - Minute + - Hour + - minDuration + - maxDuration + - Nanosecond + - Microsecond + - Millisecond + - Second + - Minute + - Hour + - minDuration + - maxDuration + - Nanosecond + - Microsecond + - Millisecond + - Second + - Minute + - Hour types.MiddlewareConfig: properties: parameters: diff --git a/pkg/auth/remote/config.go b/pkg/auth/remote/config.go index 64e8a1d304..c124277c18 100644 --- a/pkg/auth/remote/config.go +++ b/pkg/auth/remote/config.go @@ -11,7 +11,7 @@ import ( "strings" "time" - "github.com/stacklok/toolhive-core/registry/types" + registry "github.com/stacklok/toolhive-core/registry/types" httpval "github.com/stacklok/toolhive-core/validation/http" ) @@ -23,7 +23,7 @@ type Config struct { ClientSecretFile string `json:"client_secret_file,omitempty" yaml:"client_secret_file,omitempty"` Scopes []string `json:"scopes,omitempty" yaml:"scopes,omitempty"` SkipBrowser bool `json:"skip_browser,omitempty" yaml:"skip_browser,omitempty"` - Timeout time.Duration `json:"timeout,omitempty" yaml:"timeout,omitempty" swaggertype:"primitive,integer"` + Timeout time.Duration `json:"timeout,omitempty" yaml:"timeout,omitempty"` CallbackPort int `json:"callback_port,omitempty" yaml:"callback_port,omitempty"` UsePKCE bool `json:"use_pkce" yaml:"use_pkce"` diff --git a/pkg/webhook/types.go b/pkg/webhook/types.go index 8f7c611a06..2011f4c491 100644 --- a/pkg/webhook/types.go +++ b/pkg/webhook/types.go @@ -62,17 +62,17 @@ type TLSConfig struct { // Config holds the configuration for a single webhook. type Config struct { // Name is a unique identifier for this webhook. - Name string `json:"name"` + Name string `json:"name" yaml:"name"` // URL is the HTTPS endpoint to call. - URL string `json:"url"` + URL string `json:"url" yaml:"url"` // Timeout is the maximum time to wait for a webhook response. Timeout time.Duration `json:"timeout" yaml:"timeout" swaggertype:"primitive,integer"` // FailurePolicy determines behavior when the webhook call fails. - FailurePolicy FailurePolicy `json:"failure_policy"` + FailurePolicy FailurePolicy `json:"failure_policy" yaml:"failure_policy"` // TLSConfig holds optional TLS configuration (CA bundles, client certs). - TLSConfig *TLSConfig `json:"tls_config,omitempty"` + TLSConfig *TLSConfig `json:"tls_config,omitempty" yaml:"tls_config,omitempty"` // HMACSecretRef is an optional reference to an HMAC secret for payload signing. - HMACSecretRef string `json:"hmac_secret_ref,omitempty"` + HMACSecretRef string `json:"hmac_secret_ref,omitempty" yaml:"hmac_secret_ref,omitempty"` } // Validate checks that the WebhookConfig has valid required fields. diff --git a/pkg/webhook/validating/middleware.go b/pkg/webhook/validating/middleware.go index 5fdfbc3f73..3e73104c8d 100644 --- a/pkg/webhook/validating/middleware.go +++ b/pkg/webhook/validating/middleware.go @@ -13,6 +13,7 @@ import ( "time" "github.com/google/uuid" + "golang.org/x/exp/jsonrpc2" "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/mcp" @@ -93,7 +94,7 @@ func createValidatingHandler(executors []clientExecutor, serverName, transport s // Read the request body to get the raw MCP request bodyBytes, err := io.ReadAll(r.Body) if err != nil { - sendErrorResponse(w, http.StatusInternalServerError, "Internal Server Error", "Failed to read request body") + sendErrorResponse(w, http.StatusInternalServerError, "Failed to read request body") return } // Restore the request body for downstream handlers @@ -133,24 +134,19 @@ func createValidatingHandler(executors []clientExecutor, serverName, transport s slog.Error("Validating webhook error caused request denial", "webhook", whName, "error", err) - sendErrorResponse(w, http.StatusForbidden, "Forbidden", fmt.Sprintf("Webhook %q error: %v", whName, err)) + sendErrorResponse(w, http.StatusForbidden, "Request denied by policy") return } if !resp.Allowed { slog.Info("Validating webhook denied request", "webhook", whName, "reason", resp.Reason, "message", resp.Message) - msg := resp.Message - if msg == "" { - msg = fmt.Sprintf("Webhook %q denied the request", whName) - } + // Prevent information leaks by ignoring the webhook's message + msg := "Request denied by policy" - code := resp.Code - if code < 400 || code > 599 { - code = http.StatusForbidden - } + code := http.StatusForbidden - sendErrorResponse(w, code, "Forbidden", msg) + sendErrorResponse(w, code, msg) return } } @@ -167,21 +163,15 @@ func readSourceIP(r *http.Request) string { return r.RemoteAddr } -func sendErrorResponse(w http.ResponseWriter, statusCode int, _, message string) { +func sendErrorResponse(w http.ResponseWriter, statusCode int, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) - // Since we are intercepting an MCP request, we should really be returning a JSON-RPC error. - // However, if the error happens before actual execution, a standard HTTP error or a basic JSON - // with error details is typical. Here we'll follow standard HTTP error structure or JSON-RPC format. - // We'll return a JSON format that could be interpreted as a JSON-RPC error. - errResp := map[string]any{ - "jsonrpc": "2.0", - "id": nil, - "error": map[string]any{ - "code": statusCode, - "message": message, - }, + // Return a JSON-RPC 2.0 error so MCP clients can parse the denial. + // The HTTP status code signals the error at the transport level; the JSON-RPC body carries the detail. + errResp := &jsonrpc2.Response{ + ID: jsonrpc2.ID{}, + Error: jsonrpc2.NewError(int64(statusCode), message), } _ = json.NewEncoder(w).Encode(errResp) } diff --git a/pkg/webhook/validating/middleware_test.go b/pkg/webhook/validating/middleware_test.go index b3c2d29f14..554dcc1b9d 100644 --- a/pkg/webhook/validating/middleware_test.go +++ b/pkg/webhook/validating/middleware_test.go @@ -20,6 +20,9 @@ import ( "github.com/stacklok/toolhive/pkg/webhook" ) +// closedServerURL is a URL that will always fail to connect (port 0 is reserved/closed). +const closedServerURL = "http://127.0.0.1:0" + //nolint:paralleltest // Shares a mock HTTP server and lastRequest state func TestValidatingMiddleware(t *testing.T) { // Setup a mock webhook server @@ -118,6 +121,27 @@ func TestValidatingMiddleware(t *testing.T) { assert.Equal(t, []string{"admin"}, lastRequest.Principal.Groups) }) + t.Run("Allowed Request - No Identity", func(t *testing.T) { + mockResponse.Allowed = true + + reqBody := []byte(`{"jsonrpc":"2.0","method":"tools/call","id":1}`) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) + ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, &mcp.ParsedMCPRequest{}) + req = req.WithContext(ctx) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, req) + + assert.True(t, nextCalled) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Nil(t, lastRequest.Principal, "Principal should be nil") + }) + t.Run("Denied Request", func(t *testing.T) { mockResponse.Allowed = false mockResponse.Message = "Custom deny message" @@ -144,19 +168,39 @@ func TestValidatingMiddleware(t *testing.T) { var errResp map[string]interface{} err := json.Unmarshal(rr.Body.Bytes(), &errResp) require.NoError(t, err) - assert.Equal(t, "2.0", errResp["jsonrpc"]) - assert.Nil(t, errResp["id"]) - errObj, ok := errResp["error"].(map[string]interface{}) + errObj, ok := errResp["Error"].(map[string]interface{}) require.True(t, ok) assert.Equal(t, float64(http.StatusForbidden), errObj["code"]) - assert.Equal(t, "Custom deny message", errObj["message"]) + assert.Equal(t, "Request denied by policy", errObj["message"]) + }) + + t.Run("Denied Request - Out-of-Range Code Defaults to 403", func(t *testing.T) { + mockResponse.Allowed = false + mockResponse.Message = "blocked" + mockResponse.Code = 200 // out-of-range (not 4xx-5xx) should default to 403 + + reqBody := []byte(`{"jsonrpc":"2.0","method":"tools/call","id":1}`) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) + ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, &mcp.ParsedMCPRequest{}) + req = req.WithContext(ctx) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, req) + + assert.False(t, nextCalled) + assert.Equal(t, http.StatusForbidden, rr.Code, "Out-of-range webhook code should be normalized to 403") }) t.Run("Webhook Error - Fail Policy", func(t *testing.T) { // Create a client pointing to a closed port to generate connection error cfg := config[0] - cfg.URL = "http://127.0.0.1:0" + cfg.URL = closedServerURL cfg.FailurePolicy = webhook.FailurePolicyFail failClient, err := webhook.NewClient(cfg, webhook.TypeValidating, nil) @@ -184,7 +228,7 @@ func TestValidatingMiddleware(t *testing.T) { t.Run("Webhook Error - Ignore Policy", func(t *testing.T) { // Create a client pointing to a closed port to generate connection error cfg := config[0] - cfg.URL = "http://127.0.0.1:0" + cfg.URL = closedServerURL cfg.FailurePolicy = webhook.FailurePolicyIgnore ignoreClient, err := webhook.NewClient(cfg, webhook.TypeValidating, nil) @@ -311,3 +355,156 @@ func TestCreateMiddleware(t *testing.T) { require.NotNil(t, mw.Handler()) require.NoError(t, mw.Close()) } + +//nolint:paralleltest // Shares a mock HTTP server and lastRequest state +func TestMultiWebhookChain(t *testing.T) { + // Setup mock webhook servers + var lastRequest1, lastRequest2 webhook.Request + mockResponse1 := webhook.Response{Version: webhook.APIVersion, UID: "resp-uid-1", Allowed: true} + mockResponse2 := webhook.Response{Version: webhook.APIVersion, UID: "resp-uid-2", Allowed: true} + + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewDecoder(r.Body).Decode(&lastRequest1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(mockResponse1) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewDecoder(r.Body).Decode(&lastRequest2) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(mockResponse2) + })) + defer server2.Close() + + // Create middleware handler with two webhooks + config := []webhook.Config{ + { + Name: "hook-1", + URL: server1.URL, + Timeout: webhook.DefaultTimeout, + FailurePolicy: webhook.FailurePolicyFail, + TLSConfig: &webhook.TLSConfig{InsecureSkipVerify: true}, + }, + { + Name: "hook-2", + URL: server2.URL, + Timeout: webhook.DefaultTimeout, + FailurePolicy: webhook.FailurePolicyFail, + TLSConfig: &webhook.TLSConfig{InsecureSkipVerify: true}, + }, + } + + var executors []clientExecutor + for _, cfg := range config { + client, err := webhook.NewClient(cfg, webhook.TypeValidating, nil) + require.NoError(t, err) + executors = append(executors, clientExecutor{client: client, config: cfg}) + } + mw := createValidatingHandler(executors, "test-server", "stdio") + + createReq := func() *http.Request { + reqBody := []byte(`{"jsonrpc":"2.0","method":"tools/call","id":1}`) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) + ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, &mcp.ParsedMCPRequest{}) + return req.WithContext(ctx) + } + + t.Run("Both Allow", func(t *testing.T) { + mockResponse1.Allowed = true + mockResponse2.Allowed = true + lastRequest1 = webhook.Request{} + lastRequest2 = webhook.Request{} + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { nextCalled = true }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, createReq()) + + assert.True(t, nextCalled, "Next handler should be called when both webhooks allow") + assert.Equal(t, http.StatusOK, rr.Code) + assert.NotEmpty(t, lastRequest1.UID, "First webhook should be called") + assert.NotEmpty(t, lastRequest2.UID, "Second webhook should be called") + }) + + t.Run("First Denies, Second Skipped", func(t *testing.T) { + mockResponse1.Allowed = false + mockResponse1.Message = "Denied by hook-1" + mockResponse2.Allowed = true // shouldn't matter + lastRequest1 = webhook.Request{} + lastRequest2 = webhook.Request{} // reset + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { nextCalled = true }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, createReq()) + + assert.False(t, nextCalled, "Next handler should not be called") + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.NotEmpty(t, lastRequest1.UID, "First webhook should be called") + assert.Empty(t, lastRequest2.UID, "Second webhook should NOT be called") + + // Verify error response + var errResp map[string]interface{} + _ = json.Unmarshal(rr.Body.Bytes(), &errResp) + errObj := errResp["Error"].(map[string]interface{}) + assert.Equal(t, "Request denied by policy", errObj["message"]) + }) + + t.Run("First Allows, Second Denies", func(t *testing.T) { + mockResponse1.Allowed = true + mockResponse2.Allowed = false + mockResponse2.Message = "Denied by hook-2" + lastRequest1 = webhook.Request{} + lastRequest2 = webhook.Request{} + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { nextCalled = true }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, createReq()) + + assert.False(t, nextCalled, "Next handler should not be called") + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.NotEmpty(t, lastRequest1.UID, "First webhook should be called") + assert.NotEmpty(t, lastRequest2.UID, "Second webhook should be called") + + // Verify error response + var errResp map[string]interface{} + _ = json.Unmarshal(rr.Body.Bytes(), &errResp) + errObj := errResp["Error"].(map[string]interface{}) + assert.Equal(t, "Request denied by policy", errObj["message"]) + }) + + t.Run("Mixed Failure Policies: Err Ignore -> Allow", func(t *testing.T) { + // Clone configs, set hook-1 to fail-open (ignore) and use bad URL + cfg1 := config[0] + cfg1.FailurePolicy = webhook.FailurePolicyIgnore + cfg1.URL = closedServerURL // Force connection error + client1, _ := webhook.NewClient(cfg1, webhook.TypeValidating, nil) + + cfg2 := config[1] + client2, _ := webhook.NewClient(cfg2, webhook.TypeValidating, nil) + + mixedExecutors := []clientExecutor{ + {client: client1, config: cfg1}, + {client: client2, config: cfg2}, + } + mixedMw := createValidatingHandler(mixedExecutors, "test-server", "stdio") + + mockResponse2.Allowed = true + lastRequest2 = webhook.Request{} + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { nextCalled = true }) + + rr := httptest.NewRecorder() + mixedMw(nextHandler).ServeHTTP(rr, createReq()) + + assert.True(t, nextCalled, "Next handler should be called because error on first is ignored, and second allows") + assert.Equal(t, http.StatusOK, rr.Code) + assert.NotEmpty(t, lastRequest2.UID, "Second webhook should be called") + }) +} From cae9883c45ea773cf439965e195961dfce3a6a24 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Wed, 25 Mar 2026 02:06:58 +0530 Subject: [PATCH 7/9] fix: add docs Signed-off-by: Sanskarzz --- docs/server/docs.go | 16 ---------------- docs/server/swagger.json | 16 ---------------- docs/server/swagger.yaml | 16 ---------------- 3 files changed, 48 deletions(-) diff --git a/docs/server/docs.go b/docs/server/docs.go index acb847fc8b..c8448a4819 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -3069,14 +3069,6 @@ const docTemplate = `{ 1000000, 1000000000, 60000000000, - 3600000000000, - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, 3600000000000 ], "type": "integer", @@ -3096,14 +3088,6 @@ const docTemplate = `{ "Millisecond", "Second", "Minute", - "Hour", - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", "Hour" ] }, diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 4ba8a24072..986983dae2 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -3062,14 +3062,6 @@ 1000000, 1000000000, 60000000000, - 3600000000000, - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, 3600000000000 ], "type": "integer", @@ -3089,14 +3081,6 @@ "Millisecond", "Second", "Minute", - "Hour", - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", "Hour" ] }, diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index c538b9cde9..3af79773af 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -2670,14 +2670,6 @@ components: - 1000000000 - 60000000000 - 3600000000000 - - -9223372036854775808 - - 9223372036854775807 - - 1 - - 1000 - - 1000000 - - 1000000000 - - 60000000000 - - 3600000000000 type: integer x-enum-varnames: - minDuration @@ -2696,14 +2688,6 @@ components: - Second - Minute - Hour - - minDuration - - maxDuration - - Nanosecond - - Microsecond - - Millisecond - - Second - - Minute - - Hour types.MiddlewareConfig: properties: parameters: From 744b66e833bb3bc78137c515bcd90ac07829fa88 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Wed, 25 Mar 2026 03:27:35 +0530 Subject: [PATCH 8/9] fix: remove changes bz fixed in upstream Signed-off-by: Sanskarzz --- pkg/auth/remote/config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/auth/remote/config.go b/pkg/auth/remote/config.go index c124277c18..049e5ab6b2 100644 --- a/pkg/auth/remote/config.go +++ b/pkg/auth/remote/config.go @@ -11,7 +11,7 @@ import ( "strings" "time" - registry "github.com/stacklok/toolhive-core/registry/types" + "github.com/stacklok/toolhive-core/registry/types" httpval "github.com/stacklok/toolhive-core/validation/http" ) @@ -23,7 +23,7 @@ type Config struct { ClientSecretFile string `json:"client_secret_file,omitempty" yaml:"client_secret_file,omitempty"` Scopes []string `json:"scopes,omitempty" yaml:"scopes,omitempty"` SkipBrowser bool `json:"skip_browser,omitempty" yaml:"skip_browser,omitempty"` - Timeout time.Duration `json:"timeout,omitempty" yaml:"timeout,omitempty"` + Timeout time.Duration `json:"timeout,omitempty" yaml:"timeout,omitempty" swaggertype:"string" example:"5m"` CallbackPort int `json:"callback_port,omitempty" yaml:"callback_port,omitempty"` UsePKCE bool `json:"use_pkce" yaml:"use_pkce"` From f43688a08c9157227acc234ea452f55cae31ab6a Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Wed, 25 Mar 2026 18:51:54 +0530 Subject: [PATCH 9/9] fix: used covertToJSONRPC2ID func Signed-off-by: Sanskarzz --- docs/server/docs.go | 42 ++-------------------------- docs/server/swagger.json | 42 ++-------------------------- docs/server/swagger.yaml | 39 ++------------------------ pkg/webhook/validating/middleware.go | 35 +++++++++++++++++++---- 4 files changed, 36 insertions(+), 122 deletions(-) diff --git a/docs/server/docs.go b/docs/server/docs.go index c8448a4819..279699bcc8 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -279,7 +279,8 @@ const docTemplate = `{ "type": "boolean" }, "timeout": { - "$ref": "#/components/schemas/time.Duration" + "example": "5m", + "type": "string" }, "token_url": { "type": "string" @@ -3052,45 +3053,6 @@ const docTemplate = `{ }, "type": "object" }, - "time.Duration": { - "enum": [ - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000 - ], - "type": "integer", - "x-enum-varnames": [ - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour" - ] - }, "types.MiddlewareConfig": { "properties": { "parameters": { diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 986983dae2..0c8828fa27 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -272,7 +272,8 @@ "type": "boolean" }, "timeout": { - "$ref": "#/components/schemas/time.Duration" + "example": "5m", + "type": "string" }, "token_url": { "type": "string" @@ -3045,45 +3046,6 @@ }, "type": "object" }, - "time.Duration": { - "enum": [ - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000, - -9223372036854775808, - 9223372036854775807, - 1, - 1000, - 1000000, - 1000000000, - 60000000000, - 3600000000000 - ], - "type": "integer", - "x-enum-varnames": [ - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour", - "minDuration", - "maxDuration", - "Nanosecond", - "Microsecond", - "Millisecond", - "Second", - "Minute", - "Hour" - ] - }, "types.MiddlewareConfig": { "properties": { "parameters": { diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 3af79773af..f9cbfb3e7c 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -267,7 +267,8 @@ components: skip_browser: type: boolean timeout: - $ref: '#/components/schemas/time.Duration' + example: 5m + type: string token_url: type: string use_pkce: @@ -2652,42 +2653,6 @@ components: type: array uniqueItems: false type: object - time.Duration: - enum: - - -9223372036854775808 - - 9223372036854775807 - - 1 - - 1000 - - 1000000 - - 1000000000 - - 60000000000 - - 3600000000000 - - -9223372036854775808 - - 9223372036854775807 - - 1 - - 1000 - - 1000000 - - 1000000000 - - 60000000000 - - 3600000000000 - type: integer - x-enum-varnames: - - minDuration - - maxDuration - - Nanosecond - - Microsecond - - Millisecond - - Second - - Minute - - Hour - - minDuration - - maxDuration - - Nanosecond - - Microsecond - - Millisecond - - Second - - Minute - - Hour types.MiddlewareConfig: properties: parameters: diff --git a/pkg/webhook/validating/middleware.go b/pkg/webhook/validating/middleware.go index 3e73104c8d..74999e11f7 100644 --- a/pkg/webhook/validating/middleware.go +++ b/pkg/webhook/validating/middleware.go @@ -94,7 +94,7 @@ func createValidatingHandler(executors []clientExecutor, serverName, transport s // Read the request body to get the raw MCP request bodyBytes, err := io.ReadAll(r.Body) if err != nil { - sendErrorResponse(w, http.StatusInternalServerError, "Failed to read request body") + sendErrorResponse(w, http.StatusInternalServerError, "Failed to read request body", parsedMCP.ID) return } // Restore the request body for downstream handlers @@ -134,7 +134,7 @@ func createValidatingHandler(executors []clientExecutor, serverName, transport s slog.Error("Validating webhook error caused request denial", "webhook", whName, "error", err) - sendErrorResponse(w, http.StatusForbidden, "Request denied by policy") + sendErrorResponse(w, http.StatusForbidden, "Request denied by policy", parsedMCP.ID) return } @@ -146,7 +146,7 @@ func createValidatingHandler(executors []clientExecutor, serverName, transport s code := http.StatusForbidden - sendErrorResponse(w, code, msg) + sendErrorResponse(w, code, msg, parsedMCP.ID) return } } @@ -163,14 +163,39 @@ func readSourceIP(r *http.Request) string { return r.RemoteAddr } -func sendErrorResponse(w http.ResponseWriter, statusCode int, message string) { +func convertToJSONRPC2ID(id interface{}) (jsonrpc2.ID, error) { + if id == nil { + return jsonrpc2.ID{}, nil + } + + switch v := id.(type) { + case string: + return jsonrpc2.StringID(v), nil + case int: + return jsonrpc2.Int64ID(int64(v)), nil + case int64: + return jsonrpc2.Int64ID(v), nil + case float64: + // JSON numbers are often unmarshaled as float64 + return jsonrpc2.Int64ID(int64(v)), nil + default: + return jsonrpc2.ID{}, fmt.Errorf("unsupported ID type: %T", id) + } +} + +func sendErrorResponse(w http.ResponseWriter, statusCode int, message string, msgID interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) + id, err := convertToJSONRPC2ID(msgID) + if err != nil { + id = jsonrpc2.ID{} // Use empty ID if conversion fails + } + // Return a JSON-RPC 2.0 error so MCP clients can parse the denial. // The HTTP status code signals the error at the transport level; the JSON-RPC body carries the detail. errResp := &jsonrpc2.Response{ - ID: jsonrpc2.ID{}, + ID: id, Error: jsonrpc2.NewError(int64(statusCode), message), } _ = json.NewEncoder(w).Encode(errResp)